colla/websocket.c
2024-11-29 16:10:48 +01:00

142 lines
3.6 KiB
C

#include "websocket.h"
#include "arena.h"
#include "str.h"
#include "server.h"
#include "socket.h"
#include "base64.h"
#include "strstream.h"
#include "sha1.h"
#if !COLLA_MSVC
extern uint16_t ntohs(uint16_t hostshort);
extern uint16_t htons(uint16_t hostshort);
extern uint32_t htonl(uint32_t hostlong);
uint64_t ntohll(uint64_t input) {
uint64_t rval = 0;
uint8_t *data = (uint8_t *)&rval;
data[0] = input >> 56;
data[1] = input >> 48;
data[2] = input >> 40;
data[3] = input >> 32;
data[4] = input >> 24;
data[5] = input >> 16;
data[6] = input >> 8;
data[7] = input >> 0;
return rval;
}
uint64_t htonll(uint64_t input) {
return ntohll(input);
}
#endif
bool wsInitialiseSocket(arena_t scratch, socket_t websocket, strview_t key) {
str_t full_key = strFmt(&scratch, "%v" WEBSOCKET_MAGIC, key, WEBSOCKET_MAGIC);
sha1_t sha1_ctx = sha1_init();
sha1_result_t sha1_data = sha1(&sha1_ctx, full_key.buf, full_key.len);
// convert to big endian for network communication
for (int i = 0; i < 5; ++i) {
sha1_data.digest[i] = htonl(sha1_data.digest[i]);
}
buffer_t encoded_key = base64Encode(&scratch, (buffer_t){ (uint8 *)sha1_data.digest, sizeof(sha1_data.digest) });
str_t response = strFmt(
&scratch,
"HTTP/1.1 101 Switching Protocols\r\n"
"Connection: Upgrade\r\n"
"Upgrade: websocket\r\n"
"Sec-WebSocket-Accept: %v\r\n"
"\r\n",
encoded_key
);
bool success = skSend(websocket, response.buf, response.len);
return success;
}
buffer_t wsEncodeMessage(arena_t *arena, strview_t message) {
int extra = 6;
if (message.len > UINT16_MAX) extra += sizeof(uint64);
else if (message.len > UINT8_MAX) extra += sizeof(uint16);
uint8 *bytes = alloc(arena, uint8, message.len + extra);
bytes[0] = 0b10000001;
bytes[1] = 0b10000000;
int offset = 2;
if (message.len > UINT16_MAX) {
bytes[1] |= 0b01111111;
uint64 len = htonll(message.len);
memcpy(bytes + 2, &len, sizeof(len));
offset += sizeof(uint64);
}
else if (message.len > UINT8_MAX) {
bytes[1] |= 0b01111110;
uint16 len = htons((uint16)message.len);
memcpy(bytes + 2, &len, sizeof(len));
offset += sizeof(uint16);
}
else {
bytes[1] |= (uint8)message.len;
}
uint32 mask = 0;
memcpy(bytes + offset, &mask, sizeof(mask));
offset += sizeof(mask);
memcpy(bytes + offset, message.buf, message.len);
return (buffer_t){ bytes, message.len + extra };
}
str_t wsDecodeMessage(arena_t *arena, buffer_t message) {
str_t out = STR_EMPTY;
uint8 *bytes = message.data;
bool mask = bytes[1] & 0b10000000;
int offset = 2;
uint64 msglen = bytes[1] & 0b01111111;
// 16bit msg len
if (msglen == 126) {
uint64 be_len = 0;
memcpy(&be_len, bytes + 2, sizeof(be_len));
msglen = ntohs(be_len);
offset += sizeof(uint16);
}
// 64bit msg len
else if (msglen == 127) {
uint64 be_len = 0;
memcpy(&be_len, bytes + 2, sizeof(be_len));
msglen = ntohll(be_len);
offset += sizeof(uint64);
}
if (msglen == 0) {
warn("message length = 0");
}
else if (mask) {
uint8 *decoded = alloc(arena, uint8, msglen + 1);
uint8 masks[4] = {0};
memcpy(masks, bytes + offset, sizeof(masks));
offset += 4;
for (uint64 i = 0; i < msglen; ++i) {
decoded[i] = bytes[offset + i] ^ masks[i % 4];
}
out = (str_t){ (char *)decoded, msglen };
}
else {
warn("mask bit not set!");
}
return out;
}