aboutsummaryrefslogtreecommitdiffstats
path: root/util/net.c
diff options
context:
space:
mode:
Diffstat (limited to 'util/net.c')
-rw-r--r--util/net.c383
1 files changed, 383 insertions, 0 deletions
diff --git a/util/net.c b/util/net.c
new file mode 100644
index 0000000..98f371a
--- /dev/null
+++ b/util/net.c
@@ -0,0 +1,383 @@
+#include <errno.h>
+#include <fcntl.h>
+#include <poll.h>
+#include <pthread.h>
+#include <signal.h>
+#include <stdlib.h>
+#include <string.h>
+#include <sys/signalfd.h>
+#include <unistd.h>
+
+#include <cJSON.h>
+#include <curl/curl.h>
+
+#include <dbs/api.h>
+#include <dbs/init.h>
+#include <dbs/log.h>
+#include <dbs/subsys.h>
+
+/* functions */
+int http_request(HTTPMethod method, char *url,
+ struct curl_slist *headers, char *writebuf, size_t bufsiz);
+int api_request(HTTPMethod method, char *url,
+ struct curl_slist *headers, char *writebuf, size_t bufsiz);
+static void setup_token_header();
+
+static void ws_send_heartbeat();
+static void ws_handle_event(cJSON *event);
+
+int net_subsystem();
+void net_get_gateway_url();
+
+/* variables */
+static CURL *ws_handle;
+static char *gateway_url;
+static char *token_header;
+
+static long last_sequence = -1;
+static struct timeval heartbeat_time;
+
+int http_request(HTTPMethod method, char *url,
+ struct curl_slist *headers, char *writebuf, size_t bufsiz)
+{
+ int inputpipe[2];
+ int outputpipe[2];
+
+ if(pipe(inputpipe) < 0)
+ return -(errno << 8);
+ if(pipe(outputpipe) < 0)
+ return -(errno << 8);
+
+ if(writebuf && bufsiz > 0)
+ write(inputpipe[1], writebuf, bufsiz);
+ close(inputpipe[1]);
+
+ FILE *input_read = fdopen(inputpipe[0], "r");
+ FILE *output_write = fdopen(outputpipe[1], "w");
+
+ int ret = outputpipe[0];
+
+ CURL *job = curl_easy_init();
+ if(job == NULL)
+ panic("api: curl_easy_init failed");
+
+ curl_easy_setopt(job, CURLOPT_URL, url);
+ curl_easy_setopt(job, CURLOPT_READDATA, input_read);
+ curl_easy_setopt(job, CURLOPT_WRITEDATA, output_write);
+ char *requestmethod = "GET";
+ switch(method) {
+ case HTTP_PATCH:
+ requestmethod = "PATCH";
+ break;
+ case HTTP_DELETE:
+ requestmethod = "DELETE";
+ break;
+ case HTTP_PUT:
+ requestmethod = "PUT";
+ break;
+ case HTTP_POST:
+ requestmethod = "POST";
+ break;
+ case HTTP_GET: /* fallthrough */
+ default:
+ break;
+ }
+ curl_easy_setopt(job, CURLOPT_CUSTOMREQUEST, requestmethod);
+ if(headers)
+ curl_easy_setopt(job, CURLOPT_HTTPHEADER, headers);
+ CURLcode res = curl_easy_perform(job);
+
+ if(res > 0) {
+ close(outputpipe[0]);
+ ret = -res;
+ }
+
+ curl_easy_cleanup(job);
+ fclose(input_read);
+ fclose(output_write);
+ return ret;
+}
+
+static void setup_token_header()
+{
+ if(token_header != NULL)
+ return;
+ char *token = getenv("TOKEN");
+ if(!token)
+ panic("api: cannot find TOKEN in env");
+ token_header = calloc(strlen(token) + strlen("Authorization: Bot ") + 1, sizeof(char));
+ strcpy(token_header, "Authorization: Bot ");
+ strcat(token_header, token);
+}
+l1_initcall(setup_token_header);
+
+int api_request(HTTPMethod method, char *url,
+ struct curl_slist *headers, char *writebuf, size_t bufsiz)
+{
+ char *new_url = calloc((strlen("https://discord.com/api") + strlen(url) + 1),
+ sizeof(char));
+ strcpy(new_url, "https://discord.com/api");
+ strcat(new_url, url);
+ if(token_header == NULL)
+ setup_token_header();
+ struct curl_slist *headers_auth = curl_slist_append(headers, token_header);
+ int ret = http_request(method, new_url, headers_auth, writebuf, bufsiz);
+ free(new_url);
+ curl_slist_free_all(headers_auth);
+ return ret;
+}
+
+static void ws_send_heartbeat()
+{
+ char buf[128] = "{\"op\":1,\"d\":null}";
+ if(last_sequence > 0)
+ snprintf(buf, 128, "{\"op\":1,\"d\":%ld}", last_sequence);
+ size_t sent;
+ curl_ws_send(ws_handle, buf, strnlen(buf, 128), &sent, 0, CURLWS_TEXT);
+
+ /* if we receive a heartbeat request from discord, we need to fix
+ the itimer so we don't send another one before the desired
+ heartbeat interval. if our itimer is off more than 2 seconds
+ then we fix it up and reset it */
+ struct itimerval itimer;
+ getitimer(ITIMER_REAL, &itimer);
+ if(itimer.it_value.tv_sec < heartbeat_time.tv_sec - 2) {
+ itimer.it_value = heartbeat_time;
+ setitimer(ITIMER_REAL, &itimer, NULL);
+ }
+}
+
+static void ws_handle_event(cJSON *event)
+{
+ int op = cJSON_GetObjectItem(event, "op")->valueint;
+ cJSON *data = cJSON_GetObjectItem(event, "d");
+ switch(op) {
+ case 0: /* Event dispatch */
+ break;
+ case 1: /* Heartbeat request */
+ ws_send_heartbeat();
+ break;
+ case 9: /* Invalid Session */
+ if(!cJSON_IsTrue(data)) {
+ /* discord sets data to true if we can reconnect,
+ but in this statement it is false, so we just die */
+ /* note: discord closes the websocket after sending this,
+ so we let our ws code accept and handle the error */
+ break;
+ }
+ /* FALLTHROUGH */
+ case 7: /* Reconnect */
+ /* TODO */
+ panic("ws: cannot reconnect to ws after failure");
+ break;
+ case 10: ; /* Hello */
+ int heartbeat_wait = cJSON_GetObjectItem(data,
+ "heartbeat_interval")->valueint;
+ float jitter = (float)rand() / (RAND_MAX * 1.0f);
+
+ heartbeat_time.tv_sec = heartbeat_wait / 1000;
+ heartbeat_time.tv_usec = (heartbeat_wait % 1000) * 1000;
+ struct timeval jitter_time = {
+ .tv_sec = heartbeat_time.tv_sec * jitter,
+ .tv_usec = heartbeat_time.tv_usec * jitter,
+ };
+ struct itimerval new_itimer = {
+ .it_interval = heartbeat_time,
+ .it_value = jitter_time
+ };
+ setitimer(ITIMER_REAL, &new_itimer, NULL);
+ break;
+ case 11: /* Heartbeat ACK */
+ print(LOG_DEBUG "ws: heartbeat ACK");
+ break;
+ default:
+ print(LOG_ERR "ws: received unknown WS opcode %d", op);
+ break;
+ }
+}
+
+int net_subsystem(void)
+{
+ if(!gateway_url)
+ panic("net: gateway url invalid");
+
+ /* Initialise CURL */
+ ws_handle = curl_easy_init();
+
+ curl_easy_setopt(ws_handle, CURLOPT_URL, gateway_url);
+ curl_easy_setopt(ws_handle, CURLOPT_CONNECT_ONLY, 2L);
+
+ CURLcode ret = curl_easy_perform(ws_handle);
+
+ if(ret > 0) {
+ panic("net: cannot open websocket: %s", curl_easy_strerror(ret));
+ }
+
+ int ws_sockfd;
+ if((ret = curl_easy_getinfo(ws_handle,
+ CURLINFO_ACTIVESOCKET, &ws_sockfd)) != CURLE_OK)
+ panic("net: curl cannot get active socket: "
+ "%s", curl_easy_strerror(ret));
+
+
+ /* Block ALRM */
+ sigset_t *set = malloc(sizeof(sigset_t));
+ sigemptyset(set);
+ sigaddset(set, SIGALRM);
+ sigprocmask(SIG_BLOCK, set, NULL);
+ int alrmfd = signalfd(-1, set, 0);
+ free(set);
+
+ /* Prepare poll */
+ struct pollfd pollarray[2] = {
+ {
+ .fd = ws_sockfd,
+ .events = POLLIN,
+ .revents = POLLIN
+ },
+ {
+ .fd = alrmfd,
+ .events = POLLIN,
+ .revents = 0
+ }
+ };
+
+ struct pollfd *sockpoll = &(pollarray[0]);
+ struct pollfd *alrmpoll = &(pollarray[1]);
+
+ /* Misc. variables */
+ char *inbuf = malloc(1<<16 * sizeof(char));
+ size_t rlen;
+ const struct curl_ws_frame *meta;
+
+ errno = 0;
+ do {
+ if((sockpoll->revents & POLLIN) == POLLIN) {
+ ret = curl_ws_recv(ws_handle, inbuf, 1<<16, &rlen, &meta);
+ /* sometimes only SSL information gets sent through, so no actual
+ data is received. curl uses NONBLOCK internally so it lets us
+ know if there is no more data remaining */
+ if(ret == CURLE_AGAIN)
+ goto sockpoll_continue;
+ if(ret != CURLE_OK) {
+ print(LOG_ERR "net: encountered error while reading socket: "
+ "%s", curl_easy_strerror(ret));
+ break;
+ }
+
+ /* TODO: partial frames */
+ if((meta->offset | meta->bytesleft) > 0) {
+ print(LOG_ERR "net: dropped partial frame");
+ goto sockpoll_continue;
+ }
+
+ switch(meta->flags) {
+ case(CURLWS_PING):
+ curl_ws_send(ws_handle, NULL, 0, NULL, 0, CURLWS_PONG);
+ goto sockpoll_continue;
+ case(CURLWS_CLOSE):
+ default:
+ break;
+ }
+
+ cJSON *event = cJSON_ParseWithLength(inbuf, rlen);
+ if(!event) {
+ print(LOG_ERR "net: dropped malformed frame");
+ goto sockpoll_continue;
+ }
+ ws_handle_event(event);
+ cJSON_Delete(event);
+ } else if((sockpoll->revents &
+ (POLLRDHUP | POLLERR | POLLHUP | POLLNVAL)) > 0) {
+ break;
+ }
+sockpoll_continue:
+
+ if((alrmpoll->revents & POLLIN) == POLLIN) {
+ struct signalfd_siginfo siginfo;
+ read(alrmfd, &siginfo, sizeof(struct signalfd_siginfo));
+ ws_send_heartbeat();
+ }
+ } while(poll(pollarray, 2, -1) >= 0);
+
+ if(errno > 0) {
+ print(LOG_ERR "net: poll: %s", strerror(errno));
+ }
+
+ free(inbuf);
+
+ curl_easy_cleanup(ws_handle);
+
+ panic("net: websocket closed unexpectedly");
+
+ return 0;
+} /* net_subsystem */
+declare_subsystem(net_subsystem);
+
+void net_get_gateway_url()
+{
+ /* determine if websockets are supported */
+ curl_version_info_data *curl_version =
+ curl_version_info(CURLVERSION_NOW);
+ const char * const* curl_protocols = curl_version->protocols;
+ int wss_supported = 0;
+ for(int i = 0; curl_protocols[i]; ++i) {
+ if(strcmp(curl_protocols[i], "wss") == 0) {
+ wss_supported = 1;
+ break;
+ }
+ }
+
+ if(!wss_supported)
+ panic("net: wss not supported by libcurl");
+
+ /* fetch preferred url from discord */
+ int fd = api_get("/gateway/bot", NULL, NULL, 0);
+ if(fd < 0) {
+ print(LOG_ERR "net: cannot get gateway url: %s", curl_easy_strerror(-fd));
+ goto assume;
+ }
+
+ char buf[512];
+ int buf_length = read(fd, buf, 512);
+ close(fd);
+
+ cJSON *gateway_info = cJSON_ParseWithLength(buf, buf_length);
+ cJSON *gateway_url_json =
+ cJSON_GetObjectItemCaseSensitive(gateway_info, "url");
+ if(!cJSON_IsString(gateway_url_json) ||
+ gateway_url_json->valuestring == NULL) {
+
+ cJSON *gateway_message =
+ cJSON_GetObjectItemCaseSensitive(gateway_info, "message");
+
+ if(cJSON_IsString(gateway_message)) {
+ print(LOG_ERR "net: cannot get gateway url from api: "
+ "%s: assuming url", cJSON_GetStringValue(gateway_message));
+ } else {
+ print(LOG_ERR "net: cannot get gateway url from api "
+ "(unknown error): assuming url");
+ }
+ cJSON_Delete(gateway_info);
+ goto assume;
+ }
+
+ /* curl requires websocket secure URLs to begin with WSS instead
+ of wss, so we fix up the received url for curl */
+ gateway_url = calloc(strlen(gateway_url_json->valuestring) + 1,
+ sizeof(char));
+ strcpy(gateway_url, gateway_url_json->valuestring);
+ gateway_url[0] = 'W';
+ gateway_url[1] = 'S';
+ gateway_url[2] = 'S';
+
+ cJSON_Delete(gateway_info);
+ return;
+
+assume:
+ gateway_url = calloc(strlen("WSS://gateway.discord.gg") + 1,
+ sizeof(char));
+ strcpy(gateway_url, "WSS://gateway.discord.gg");
+ return;
+}
+l1_initcall(net_get_gateway_url);