diff --git a/Makefile b/Makefile index dc91f36..fef8238 100644 --- a/Makefile +++ b/Makefile @@ -19,7 +19,7 @@ LIB_WS = libws.a INCLUDE = include CFLAGS += -Wall -Wextra -O2 CFLAGS += -I $(INCLUDE) -std=c99 -pedantic -LDLIBS = $(LIB_WS) -pthread +LDLIBS = $(LIB_WS) -pthread -lssl -lcrypto ARFLAGS = cru MCSS_DIR ?= /usr/bin/ MANPAGES = doc/man/man3 diff --git a/cert/cert.pem b/cert/cert.pem new file mode 100644 index 0000000..e69de29 diff --git a/cert/key.pem b/cert/key.pem new file mode 100644 index 0000000..e69de29 diff --git a/examples/echo/echo.c b/examples/echo/echo.c index 058c3cc..8db7d49 100644 --- a/examples/echo/echo.c +++ b/examples/echo/echo.c @@ -105,7 +105,11 @@ void onmessage(ws_cli_conn_t client, * ws_sendframe_bin() * ws_sendframe_bin_bcast() */ + #ifdef ENABLE_OPENSSL + ws_sendframe_txt(client, (char*)msg); + #else ws_sendframe_bcast(8080, (char *)msg, size, type); + #endif } /** @@ -116,6 +120,23 @@ void onmessage(ws_cli_conn_t client, */ int main(void) { + #ifdef ENABLE_OPENSSL + ws_socket(&(struct ws_server){ + /* + * Bind host: + * localhost -> localhost/127.0.0.1 + * 0.0.0.0 -> global IPv4 + * :: -> global IPv4+IPv6 (DualStack) + */ + .host = "0.0.0.0", + .port = 8443, + .thread_loop = 0, + .timeout_ms = 1000, + .evs.onopen = &onopen, + .evs.onclose = &onclose, + .evs.onmessage = &onmessage + }); + #else ws_socket(&(struct ws_server){ /* * Bind host: @@ -124,13 +145,16 @@ int main(void) * :: -> global IPv4+IPv6 (DualStack) */ .host = "0.0.0.0", - .port = 8080, + .port = 8443, .thread_loop = 0, .timeout_ms = 1000, .evs.onopen = &onopen, .evs.onclose = &onclose, .evs.onmessage = &onmessage }); + #endif + + /* * If you want to execute code past ws_socket(), set diff --git a/extra/toyws/toyws.c b/extra/toyws/toyws.c index bb91e92..31f326d 100644 --- a/extra/toyws/toyws.c +++ b/extra/toyws/toyws.c @@ -397,7 +397,7 @@ int tws_receiveframe(struct tws_ctx *ctx, char **buff, { cur_byte = next_byte(ctx, &ret); if (cur_byte < 0) - return (ret == 0 ? frame_length : ret); + return (ret == 0 ? frame_length : (uint64_t)ret); *buf = cur_byte; } @@ -406,5 +406,5 @@ int tws_receiveframe(struct tws_ctx *ctx, char **buff, /* Fill other infos. */ *frm_type = opcode; - return (ret == 0 ? frame_length : ret); + return (ret == 0 ? frame_length : (uint64_t)ret); } diff --git a/include/ws.h b/include/ws.h index 900d03b..16fe124 100644 --- a/include/ws.h +++ b/include/ws.h @@ -217,8 +217,13 @@ extern "C" { /**@}*/ #ifndef AFL_FUZZ + #ifdef ENABLE_OPENSSL + #define SEND(client,buf,len) SSL_write((client)->ssl, (buf), (len)) + #define RECV(client,buf,len) SSL_read((client)->ssl, (buf), (len)) + #else #define SEND(client,buf,len) send_all((client), (buf), (len), MSG_NOSIGNAL) #define RECV(fd,buf,len) recv((fd)->client_sock, (buf), (len), 0) + #endif #else #define SEND(client,buf,len) write(fileno(stdout), (buf), (len)) #define RECV(fd,buf,len) read((fd)->client_sock, (buf), (len)) diff --git a/src/ws.c b/src/ws.c index 7ee76a5..623c401 100644 --- a/src/ws.c +++ b/src/ws.c @@ -24,8 +24,9 @@ #include #include #include -#include #include +#include +#include /* clang-format off */ #ifndef _WIN32 @@ -51,6 +52,8 @@ typedef int socklen_t; #include #include +SSL_CTX *ctx; + /** * @dir src/ * @brief wsServer source code @@ -90,6 +93,9 @@ struct ws_connection /* Connection context */ void *connection_context; +#ifdef ENABLE_OPENSSL + SSL *ssl; +#endif ws_cli_conn_t client_id; }; @@ -111,9 +117,7 @@ static uint32_t timeout; */ #define CLIENT_VALID(cli) \ ((cli) != NULL && (cli) >= &client_socks[0] && \ - (cli) <= &client_socks[MAX_CLIENTS - 1] && \ - (cli)->client_sock > -1) - + (cli) <= &client_socks[MAX_CLIENTS - 1] && (cli)->client_sock > -1) /** * @brief Get server context. @@ -241,7 +245,6 @@ static void close_socket(int fd) #endif } - static uint64_t cid_generator = 1; static pthread_mutex_t cid_mutex = PTHREAD_MUTEX_INITIALIZER; @@ -341,7 +344,11 @@ static ssize_t send_all( pthread_mutex_lock(&client->mtx_snd); while (len) { - r = send(client->client_sock, p, len, flags); +#ifdef ENABLE_OPENSSL + r = SSL_write(client->ssl, p, len); +#else + r = send(client->client_sock, p, len, flags); +#endif if (r == -1) { pthread_mutex_unlock(&client->mtx_snd); @@ -489,16 +496,16 @@ static void set_client_address(struct ws_connection *client) if (!CLIENT_VALID(client)) return; - memset(client->ip, 0, sizeof(client->ip)); + memset(client->ip, 0, sizeof(client->ip)); memset(client->port, 0, sizeof(client->port)); if (getpeername(client->client_sock, (struct sockaddr *)&addr, &hlen) < 0) return; getnameinfo((struct sockaddr *)&addr, hlen, - client->ip, sizeof(client->ip), + client->ip, sizeof(client->ip), client->port, sizeof(client->port), - NI_NUMERICHOST|NI_NUMERICSERV); + NI_NUMERICHOST|NI_NUMERICSERV); } /** @@ -561,7 +568,7 @@ char *ws_getport(ws_cli_conn_t client) * for completeness. */ static int ws_sendframe_internal(struct ws_connection *client, const char *msg, - uint64_t size, int type, uint16_t port) + uint64_t size, int type, uint16_t port) { unsigned char *response; /* Response data. */ unsigned char frame[10]; /* Frame. */ @@ -757,8 +764,8 @@ static inline void int32_to_ping_msg(int32_t ping_id, uint8_t *msg) /* Encodes as big-endian. */ msg[0] = (ping_id >> 24); msg[1] = (ping_id >> 16); - msg[2] = (ping_id >> 8); - msg[3] = (ping_id >> 0); + msg[2] = (ping_id >> 8); + msg[3] = (ping_id >> 0); } /** @@ -1003,11 +1010,9 @@ static inline int is_control_frame(int frame) */ static inline int is_valid_frame(int opcode) { - return ( - opcode == WS_FR_OP_TXT || opcode == WS_FR_OP_BIN || - opcode == WS_FR_OP_CONT || opcode == WS_FR_OP_PING || - opcode == WS_FR_OP_PONG || opcode == WS_FR_OP_CLSE - ); + return (opcode == WS_FR_OP_TXT || opcode == WS_FR_OP_BIN || + opcode == WS_FR_OP_CONT || opcode == WS_FR_OP_PING || + opcode == WS_FR_OP_PONG || opcode == WS_FR_OP_CLSE); } /** @@ -1116,7 +1121,7 @@ static int do_close(struct ws_frame_data *wfd, int close_code) wfd->msg_ctrl[1] = (cc & 0xFF); if (ws_sendframe(wfd->client->client_id, (const char *)wfd->msg_ctrl, sizeof(char) * 2, - WS_FR_OP_CLSE) < 0) + WS_FR_OP_CLSE) < 0) { DEBUG("An error has occurred while sending closing frame!\n"); return (-1); @@ -1126,8 +1131,8 @@ static int do_close(struct ws_frame_data *wfd, int close_code) /* Send the data inside wfd->msg_ctrl. */ send: - if (ws_sendframe(wfd->client->client_id, (const char *)wfd->msg_ctrl, wfd->frame_size, - WS_FR_OP_CLSE) < 0) + if (ws_sendframe(wfd->client->client_id, (const char *)wfd->msg_ctrl, + wfd->frame_size, WS_FR_OP_CLSE) < 0) { DEBUG("An error has occurred while sending closing frame!\n"); return (-1); @@ -1155,7 +1160,7 @@ static int do_close(struct ws_frame_data *wfd, int close_code) static int do_pong(struct ws_frame_data *wfd, uint64_t frame_size) { if (ws_sendframe( - wfd->client->client_id, (const char *)wfd->msg_ctrl, frame_size, WS_FR_OP_PONG) < 0) + wfd->client->client_id, (const char *)wfd->msg_ctrl, frame_size, WS_FR_OP_PONG) < 0) { wfd->error = 1; DEBUG("An error has occurred while ponging!\n"); @@ -1237,13 +1242,13 @@ struct frame_state_data uint64_t frame_length; /* Frame length. */ uint64_t frame_size; /* Current frame size. */ #ifdef VALIDATE_UTF8 - uint32_t utf8_state; /* Current UTF-8 state. */ + uint32_t utf8_state; /* Current UTF-8 state. */ #endif - int32_t pong_id; /* Current PONG id. */ - uint8_t opcode; /* Frame opcode. */ - uint8_t is_fin; /* Is FIN frame flag. */ - uint8_t mask; /* Mask. */ - int cur_byte; /* Current frame byte. */ + int32_t pong_id; /* Current PONG id. */ + uint8_t opcode; /* Frame opcode. */ + uint8_t is_fin; /* Is FIN frame flag. */ + uint8_t mask; /* Mask. */ + int cur_byte; /* Current frame byte. */ }; /** @@ -1258,7 +1263,7 @@ struct frame_state_data * @attention This is part of the internal API and is documented just * for completeness. */ -static int validate_utf8_txt(struct ws_frame_data *wfd, +static int validate_utf8_txt(struct ws_frame_data *wfd, struct frame_state_data *fsd) { #ifdef VALIDATE_UTF8 @@ -1269,8 +1274,8 @@ static int validate_utf8_txt(struct ws_frame_data *wfd, if (fsd->is_fin) { if (is_utf8_len_state( - fsd->msg_data + (fsd->msg_idx_data - fsd->frame_length), - fsd->frame_length, fsd->utf8_state) != UTF8_ACCEPT) + fsd->msg_data + (fsd->msg_idx_data - fsd->frame_length), + fsd->frame_length, fsd->utf8_state) != UTF8_ACCEPT) { DEBUG("Dropping invalid complete message!\n"); wfd->error = 1; @@ -1282,8 +1287,7 @@ static int validate_utf8_txt(struct ws_frame_data *wfd, /* Check current state for a CONT or initial TXT frame. */ fsd->utf8_state = - is_utf8_len_state(fsd->msg_data + - (fsd->msg_idx_data - fsd->frame_length), + is_utf8_len_state(fsd->msg_data + (fsd->msg_idx_data - fsd->frame_length), fsd->frame_length, fsd->utf8_state); /* We can be in any state, except reject. */ @@ -1309,8 +1313,7 @@ static int validate_utf8_txt(struct ws_frame_data *wfd, * @attention This is part of the internal API and is documented just * for completeness. */ -static int handle_pong_frame(struct ws_frame_data *wfd, - struct frame_state_data *fsd) +static int handle_pong_frame(struct ws_frame_data *wfd, struct frame_state_data *fsd) { fsd->is_fin = 0; @@ -1348,8 +1351,7 @@ static int handle_pong_frame(struct ws_frame_data *wfd, * @attention This is part of the internal API and is documented just * for completeness. */ -static int handle_ping_frame(struct ws_frame_data *wfd, - struct frame_state_data *fsd) +static int handle_ping_frame(struct ws_frame_data *wfd, struct frame_state_data *fsd) { if (do_pong(wfd, fsd->frame_size) < 0) return (-1); @@ -1371,13 +1373,12 @@ static int handle_ping_frame(struct ws_frame_data *wfd, * @attention This is part of the internal API and is documented just * for completeness. */ -static int handle_close_frame(struct ws_frame_data *wfd, - struct frame_state_data *fsd) +static int handle_close_frame( + struct ws_frame_data *wfd, struct frame_state_data *fsd) { #ifdef VALIDATE_UTF8 /* If there is a close reason, check if it is UTF-8 valid. */ - if (fsd->frame_size > 2 && - !is_utf8_len(fsd->msg_ctrl + 2, fsd->frame_size - 2)) + if (fsd->frame_size > 2 && !is_utf8_len(fsd->msg_ctrl + 2, fsd->frame_size - 2)) { DEBUG("Invalid close frame payload reason! (not UTF-8)\n"); wfd->error = 1; @@ -1407,29 +1408,30 @@ static int handle_close_frame(struct ws_frame_data *wfd, * @attention This is part of the internal API and is documented just * for completeness. */ -static int read_single_frame(struct ws_frame_data *wfd, - struct frame_state_data *fsd) +static int read_single_frame(struct ws_frame_data *wfd, struct frame_state_data *fsd) { uint64_t *frame_size; /* Curr frame size. */ - unsigned char *tmp; /* Tmp message. */ - unsigned char *msg; /* Current message. */ - uint64_t *msg_idx; /* Message index. */ - uint8_t *masks; /* Current mask. */ - int cur_byte; /* Curr byte read. */ - uint64_t i; /* Loop index. */ + unsigned char *tmp; /* Tmp message. */ + unsigned char *msg; /* Current message. */ + uint64_t *msg_idx; /* Message index. */ + uint8_t *masks; /* Current mask. */ + int cur_byte; /* Curr byte read. */ + uint64_t i; /* Loop index. */ /* Decide which mask and msg to use. */ - if (is_control_frame(fsd->opcode)) { + if (is_control_frame(fsd->opcode)) + { frame_size = &fsd->frame_size; msg_idx = &fsd->msg_idx_ctrl; - masks = fsd->masks_ctrl; - msg = fsd->msg_ctrl; + masks = fsd->masks_ctrl; + msg = fsd->msg_ctrl; } - else { + else + { frame_size = &wfd->frame_size; msg_idx = &fsd->msg_idx_data; - masks = fsd->masks_data; - msg = fsd->msg_data; + masks = fsd->masks_data; + msg = fsd->msg_data; } /* Decode masks and length for 16-bit messages. */ @@ -1442,11 +1444,9 @@ static int read_single_frame(struct ws_frame_data *wfd, fsd->frame_length = (((uint64_t)next_byte(wfd)) << 56) | /* frame[2]. */ (((uint64_t)next_byte(wfd)) << 48) | /* frame[3]. */ - (((uint64_t)next_byte(wfd)) << 40) | - (((uint64_t)next_byte(wfd)) << 32) | - (((uint64_t)next_byte(wfd)) << 24) | - (((uint64_t)next_byte(wfd)) << 16) | - (((uint64_t)next_byte(wfd)) << 8) | + (((uint64_t)next_byte(wfd)) << 40) | (((uint64_t)next_byte(wfd)) << 32) | + (((uint64_t)next_byte(wfd)) << 24) | (((uint64_t)next_byte(wfd)) << 16) | + (((uint64_t)next_byte(wfd)) << 8) | (((uint64_t)next_byte(wfd))); /* frame[9]. */ } @@ -1572,7 +1572,7 @@ static int next_complete_frame(struct ws_frame_data *wfd) fsd.utf8_state = UTF8_ACCEPT; #endif - wfd->frame_size = 0; + wfd->frame_size = 0; wfd->frame_type = -1; wfd->msg = NULL; @@ -1643,17 +1643,16 @@ static int next_complete_frame(struct ws_frame_data *wfd) if (fsd.opcode != WS_FR_OP_CONT && !is_control_frame(fsd.opcode)) wfd->frame_type = fsd.opcode; - fsd.mask = next_byte(wfd); + fsd.mask = next_byte(wfd); fsd.frame_length = fsd.mask & 0x7F; - fsd.frame_size = 0; + fsd.frame_size = 0; fsd.msg_idx_ctrl = 0; /* * We should deny non-FIN control frames or that have * more than 125 octets. */ - if (is_control_frame(fsd.opcode) && - (!fsd.is_fin || fsd.frame_length > 125)) + if (is_control_frame(fsd.opcode) && (!fsd.is_fin || fsd.frame_length > 125)) { DEBUG("Control frame bigger than 125 octets or not a FIN " "frame!\n"); @@ -1669,7 +1668,8 @@ static int next_complete_frame(struct ws_frame_data *wfd) * Obs: If BIN, nothing should be done unless we got * a FIN-frame. */ - switch (fsd.opcode) { + switch (fsd.opcode) + { /* UTF-8 Validate partial (or not) frame. */ case WS_FR_OP_CONT: case WS_FR_OP_TXT: { @@ -1702,7 +1702,7 @@ static int next_complete_frame(struct ws_frame_data *wfd) } } -next_it:; + next_it:; } while (!fsd.is_fin && !wfd->error); @@ -1734,9 +1734,9 @@ next_it:; */ static void *ws_establishconnection(void *vclient) { - struct ws_frame_data wfd; /* WebSocket frame data. */ - struct ws_connection *client; /* Client structure. */ - int clse_thrd; /* Time-out close thread. */ + struct ws_frame_data wfd; /* WebSocket frame data. */ + struct ws_connection *client; /* Client structure. */ + int clse_thrd; /* Time-out close thread. */ client = vclient; @@ -1752,11 +1752,11 @@ static void *ws_establishconnection(void *vclient) while (next_complete_frame(&wfd) >= 0) { /* Text/binary event. */ - if ((wfd.frame_type == WS_FR_OP_TXT || - wfd.frame_type == WS_FR_OP_BIN) && !wfd.error) + if ((wfd.frame_type == WS_FR_OP_TXT || wfd.frame_type == WS_FR_OP_BIN) && + !wfd.error) { - client->ws_srv.evs.onmessage(client->client_id, wfd.msg, wfd.frame_size, - wfd.frame_type); + client->ws_srv.evs.onmessage( + client->client_id, wfd.msg, wfd.frame_size, wfd.frame_type); } /* Close event. */ @@ -1799,7 +1799,8 @@ static void *ws_establishconnection(void *vclient) } /* Close connectin properly. */ - if (get_client_state(client) != WS_STATE_CLOSED) { + if (get_client_state(client) != WS_STATE_CLOSED) + { DEBUG("Closing: normal close\n"); close_client(client, 1); } @@ -1816,6 +1817,42 @@ struct ws_accept_params struct ws_server ws_srv; }; +#ifdef ENABLE_OPENSSL +SSL_CTX *create_context() +{ + const SSL_METHOD *method; + SSL_CTX *ctx; + + method = TLS_server_method(); + + ctx = SSL_CTX_new(method); + if (!ctx) + { + perror("Unable to create SSL context"); + ERR_print_errors_fp(stderr); + exit(EXIT_FAILURE); + } + + return ctx; +} + +void configure_context(SSL_CTX *ctx) +{ + /* Set the key and cert */ + if (SSL_CTX_use_certificate_file(ctx, "cert/cert.pem", SSL_FILETYPE_PEM) <= 0) + { + ERR_print_errors_fp(stderr); + exit(EXIT_FAILURE); + } + + if (SSL_CTX_use_PrivateKey_file(ctx, "cert/key.pem", SSL_FILETYPE_PEM) <= 0) + { + ERR_print_errors_fp(stderr); + exit(EXIT_FAILURE); + } +} +#endif + /** * @brief Main loop that keeps accepting new connections. * @@ -1831,17 +1868,17 @@ struct ws_accept_params static void *ws_accept(void *data) { struct ws_accept_params *ws_prm; /* wsServer parameters. */ - struct sockaddr_storage sa; /* Client. */ - pthread_t client_thread; /* Client thread. */ - struct timeval time; /* Client socket timeout. */ - socklen_t salen; /* Length of sockaddr. */ - int new_sock; /* New opened connection. */ - int sock; /* Server sock. */ - int i; /* Loop index. */ + struct sockaddr_storage sa; /* Client. */ + pthread_t client_thread; /* Client thread. */ + struct timeval time; /* Client socket timeout. */ + socklen_t salen; /* Length of sockaddr. */ + int new_sock; /* New opened connection. */ + int sock; /* Server sock. */ + int i; /* Loop index. */ ws_prm = data; - sock = ws_prm->sock; - salen = sizeof(sa); + sock = ws_prm->sock; + salen = sizeof(sa); while (1) { @@ -1863,7 +1900,7 @@ static void *ws_accept(void *data) * See: * https://linux.die.net/man/3/setsockopt */ - setsockopt(new_sock, SOL_SOCKET, SO_SNDTIMEO, (const char*)&time, + setsockopt(new_sock, SOL_SOCKET, SO_SNDTIMEO, (const char *)&time, sizeof(struct timeval)); } @@ -1876,14 +1913,24 @@ static void *ws_accept(void *data) memcpy(&client_socks[i].ws_srv, &ws_prm->ws_srv, sizeof(struct ws_server)); - client_socks[i].client_sock = new_sock; - client_socks[i].state = WS_STATE_CONNECTING; - client_socks[i].close_thrd = false; + client_socks[i].client_sock = new_sock; + client_socks[i].state = WS_STATE_CONNECTING; + client_socks[i].close_thrd = false; client_socks[i].last_pong_id = -1; client_socks[i].current_ping_id = -1; client_socks[i].client_id = get_next_cid(); set_client_address(&client_socks[i]); +#ifdef ENABLE_OPENSSL + client_socks[i].ssl = SSL_new(ctx); + SSL_set_fd(client_socks[i].ssl, new_sock); + int rc = SSL_accept(client_socks[i].ssl); + if (rc == -1) + { + printf("[e] SSL\n"); + } +#endif + if (pthread_mutex_init(&client_socks[i].mtx_state, NULL)) panic("Error on allocating close mutex"); if (pthread_cond_init(&client_socks[i].cnd_state_close, NULL)) @@ -1948,15 +1995,14 @@ static int do_bind_socket(struct ws_server *ws_srv) for (try = results; try != NULL; try = try->ai_next) { /* try to make a socket with this setup */ - if ((sock = socket(try->ai_family, try->ai_socktype, - try->ai_protocol)) < 0) + if ((sock = socket(try->ai_family, try->ai_socktype, try->ai_protocol)) < 0) { continue; } /* Reuse previous address. */ if (setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, (const char *)&reuse, - sizeof(reuse)) < 0) + sizeof(reuse)) < 0) { panic("setsockopt(SO_REUSEADDR) failed"); } @@ -1989,8 +2035,8 @@ static int do_bind_socket(struct ws_server *ws_srv) int ws_socket(struct ws_server *ws_srv) { struct ws_accept_params *ws_prm; /* Accept parameters. */ - pthread_t accept_thread; /* Accept thread. */ - int sock; /* Client sock. */ + pthread_t accept_thread; /* Accept thread. */ + int sock; /* Client sock. */ timeout = ws_srv->timeout_ms; @@ -2025,6 +2071,11 @@ int ws_socket(struct ws_server *ws_srv) /* Create socket and bind. */ sock = do_bind_socket(ws_srv); +#ifdef ENABLE_OPENSSL + ctx = create_context(); + configure_context(ctx); +#endif + /* Listen. */ if (listen(sock, MAX_CLIENTS) < 0) panic("Unable to listen!\n"); @@ -2045,6 +2096,10 @@ int ws_socket(struct ws_server *ws_srv) pthread_detach(accept_thread); } +#ifdef ENABLE_OPENSSL + SSL_CTX_free(ctx); +#endif + return (0); }