diff --git a/protocol/test/protocol_test.cpp b/protocol/test/protocol_test.cpp index 36565dd832d..d98f0e8378c 100644 --- a/protocol/test/protocol_test.cpp +++ b/protocol/test/protocol_test.cpp @@ -1916,7 +1916,7 @@ TEST_F(ProtocolTest, QueueTest) { item = 1; EXPECT_EQ(fifo_queue_enqueue_item(fifo_queue, &item), 0); item = 2; - EXPECT_EQ(fifo_queue_enqueue_item(fifo_queue, &item), 0); + EXPECT_EQ(fifo_queue_enqueue_item_timeout(fifo_queue, &item, -1), 0); EXPECT_EQ(fifo_queue_dequeue_item(fifo_queue, &item), 0); EXPECT_EQ(item, 1); item = 3; @@ -1924,16 +1924,16 @@ TEST_F(ProtocolTest, QueueTest) { item = 4; EXPECT_EQ(fifo_queue_enqueue_item(fifo_queue, &item), 0); item = 5; - EXPECT_EQ(fifo_queue_enqueue_item(fifo_queue, &item), 0); + EXPECT_EQ(fifo_queue_enqueue_item_timeout(fifo_queue, &item, 50), 0); item = 6; EXPECT_EQ(fifo_queue_enqueue_item(fifo_queue, &item), 0); item = 7; - EXPECT_EQ(fifo_queue_enqueue_item(fifo_queue, &item), -1); + EXPECT_EQ(fifo_queue_enqueue_item_timeout(fifo_queue, &item, 50), -1); EXPECT_EQ(fifo_queue_dequeue_item(fifo_queue, &item), 0); EXPECT_EQ(item, 2); EXPECT_EQ(fifo_queue_dequeue_item_timeout(fifo_queue, &item, 100), 0); EXPECT_EQ(item, 3); - EXPECT_EQ(fifo_queue_dequeue_item_timeout(fifo_queue, &item, 100), 0); + EXPECT_EQ(fifo_queue_dequeue_item_timeout(fifo_queue, &item, -1), 0); EXPECT_EQ(item, 4); EXPECT_EQ(fifo_queue_dequeue_item(fifo_queue, &item), 0); EXPECT_EQ(item, 5); diff --git a/protocol/whist/network/tcp.c b/protocol/whist/network/tcp.c index 4d3c8d4a583..324ac12580b 100644 --- a/protocol/whist/network/tcp.c +++ b/protocol/whist/network/tcp.c @@ -15,6 +15,7 @@ Includes #include #include #include "whist/core/features.h" +#include "whist/utils/queue.h" #if !OS_IS(OS_WIN32) #include @@ -30,6 +31,11 @@ Defines // Currently set to the "large enough" 1GB #define MAX_TCP_PAYLOAD_SIZE 1000000000 +// How many packets to allow to be queued up on +// a single TCP sending thread before queueing +// up the next packet will block. +#define TCP_SEND_QUEUE_SIZE 16 + typedef enum { TCP_PING, TCP_PONG, @@ -88,8 +94,20 @@ typedef struct { // Only recvp every RECV_INTERVAL_MS, to keep CPU usage low. // This is because a recvp takes ~8ms sometimes WhistTimer last_recvp; + + // TCP send is not atomic, so we have to hold packets in a queue and send on a separate thread + WhistThread send_thread; + QueueContext* send_queue; + WhistSemaphore send_semaphore; + bool run_sender; } TCPContext; +// Struct for holding packets on queue +typedef struct TCPQueueItem { + TCPNetworkPacket* packet; + int packet_size; +} TCPQueueItem; + // Time between consecutive pings #define TCP_PING_INTERVAL_SEC 2.0 // Time before a ping to be considered "lost", and reconnection starts @@ -160,6 +178,19 @@ int create_tcp_client_context(TCPContext* context, char* destination, int port, */ int tcp_send_constructed_packet(TCPContext* context, TCPPacket* packet); +/** + * @brief Multithreaded function to asynchronously + * send all TCP packets for one socket context + * on the same thread. + * This prevents garbled TCP messages from + * being sent since large TCP sends are not atomic. + * + * @param opaque Pointer to associated socket context + * + * @returns 0 on exit + */ +int multithreaded_tcp_send(void* opaque); + /** * @brief Returns the size, in bytes, of the relevant part of * the TCPPacket, that must be sent over the network @@ -492,6 +523,13 @@ static void tcp_destroy_socket_context(void* raw_context) { FATAL_ASSERT(raw_context != NULL); TCPContext* context = raw_context; + // Destroy TCP send queue resources + context->run_sender = false; + + // Any pending TCP packets will be dropped + whist_wait_thread(context->send_thread, NULL); + fifo_queue_destroy(context->send_queue); + closesocket(context->socket); closesocket(context->listen_socket); whist_destroy_mutex(context->mutex); @@ -560,6 +598,9 @@ bool create_tcp_socket_context(SocketContext* network_context, char* destination context->last_pong_id = -1; start_timer(&context->last_ping_timer); context->connection_lost = false; + context->send_queue = NULL; + context->send_semaphore = NULL; + context->send_thread = NULL; start_timer(&context->last_recvp); int ret; @@ -578,6 +619,22 @@ bool create_tcp_socket_context(SocketContext* network_context, char* destination return false; } + // Set up TCP send queue + context->run_sender = true; + if ((context->send_queue = fifo_queue_create(sizeof(TCPQueueItem), TCP_SEND_QUEUE_SIZE)) == + NULL || + (context->send_semaphore = whist_create_semaphore(0)) == NULL || + (context->send_thread = whist_create_thread(multithreaded_tcp_send, + "multithreaded_tcp_send", context)) == NULL) { + // If any of the created resources are NULL, there was a failure and we need to clean up and + // return false + if (context->send_queue) fifo_queue_destroy(context->send_queue); + if (context->send_semaphore) whist_destroy_semaphore(context->send_semaphore); + free(context); + network_context->context = NULL; + return false; + } + // Restore the original timeout set_timeout(context->socket, context->timeout); @@ -763,33 +820,75 @@ int tcp_send_constructed_packet(TCPContext* context, TCPPacket* packet) { memcpy(network_packet->payload, packet, packet_size); } - int tcp_packet_size = get_tcp_network_packet_size(network_packet); + // Add TCPNetworkPacket to the queue to be sent on the TCP send thread + TCPQueueItem queue_item; + queue_item.packet = network_packet; + queue_item.packet_size = packet_size; + if (fifo_queue_enqueue_item_timeout(context->send_queue, &queue_item, -1) < 0) return -1; + whist_post_semaphore(context->send_semaphore); + return 0; +} - // For now, the TCP network throttler is NULL, so this is a no-op. - network_throttler_wait_byte_allocation(context->network_throttler, tcp_packet_size); +int multithreaded_tcp_send(void* opaque) { + TCPQueueItem queue_item; + TCPNetworkPacket* network_packet = NULL; + TCPContext* context = (TCPContext*)opaque; + while (true) { + whist_wait_semaphore(context->send_semaphore); + // Check to see if the sender thread needs to stop running + if (!context->run_sender) break; + // If connection is lost, then wait for up to TCP_PING_MAX_RECONNECTION_TIME_SEC + // before continuing. + if (context->connection_lost) { + // Need to re-increment semaphore because wait_semaphore at the top of the loop + // will have decremented semaphore for a packet we are not sending yet. + whist_post_semaphore(context->send_semaphore); + // If the wait for another packet times out, then we return to the top of the loop + if (!whist_wait_timeout_semaphore(context->send_semaphore, + TCP_PING_MAX_RECONNECTION_TIME_SEC * 1000)) + continue; + } - // This is useful enough to print, even outside of LOG_NETWORKING GUARDS - LOG_INFO("Sending a WhistPacket of size %d (Total %d bytes), over TCP", packet_size, - tcp_packet_size); + // If there is no item to be dequeued, continue + if (fifo_queue_dequeue_item(context->send_queue, &queue_item) < 0) continue; - // Send the packet - bool failed = false; - int ret = send(context->socket, (const char*)network_packet, tcp_packet_size, 0); - if (ret < 0) { - int error = get_last_network_error(); - if (error == WHIST_ECONNRESET) { - LOG_WARNING("TCP Connection reset by peer"); - context->connection_lost = true; - } else { - LOG_WARNING("Unexpected TCP Packet Error: %d", error); + network_packet = queue_item.packet; + + int tcp_packet_size = get_tcp_network_packet_size(network_packet); + + // For now, the TCP network throttler is NULL, so this is a no-op. + network_throttler_wait_byte_allocation(context->network_throttler, tcp_packet_size); + + // This is useful enough to print, even outside of LOG_NETWORKING GUARDS + LOG_INFO("Sending a WhistPacket of size %d (Total %d bytes), over TCP", + queue_item.packet_size, tcp_packet_size); + + // Send the packet. If a partial packet is sent, keep sending until full packet has been + // sent. + int total_sent = 0; + while (total_sent < tcp_packet_size) { + int ret = send(context->socket, (const char*)(network_packet + total_sent), + tcp_packet_size, 0); + if (ret < 0) { + int error = get_last_network_error(); + if (error == WHIST_ECONNRESET) { + LOG_WARNING("TCP Connection reset by peer"); + context->connection_lost = true; + } else { + LOG_WARNING("Unexpected TCP Packet Error: %d", error); + } + // Don't attempt to send the rest of the packet if there was a failure + break; + } else { + total_sent += ret; + } } - failed = true; - } - // Free the encrypted allocation - deallocate_region(network_packet); + // Free the encrypted allocation + deallocate_region(network_packet); + } - return failed ? -1 : 0; + return 0; } int get_tcp_packet_size(TCPPacket* tcp_packet) { diff --git a/protocol/whist/network/tcp.h b/protocol/whist/network/tcp.h index a4aeb48f02b..10821c5e958 100644 --- a/protocol/whist/network/tcp.h +++ b/protocol/whist/network/tcp.h @@ -54,7 +54,7 @@ bool create_tcp_socket_context(SocketContext* context, char* destination, int po char* binary_aes_private_key); /** - * @brief Creates a tcp listen socket, that can be used in SocketContext + * @brief Creates a tcp listen socket, that can be used in SocketContext * * @param sock The socket that will be initialized * @param port The port to listen on diff --git a/protocol/whist/utils/queue.c b/protocol/whist/utils/queue.c index 6d161b6e1c8..34f0338bc20 100644 --- a/protocol/whist/utils/queue.c +++ b/protocol/whist/utils/queue.c @@ -14,8 +14,10 @@ typedef struct QueueContext { int num_items; int max_items; WhistMutex mutex; - WhistCondition cond; + WhistCondition avail_items_cond; + WhistCondition avail_space_cond; void *data; + bool destroying; } QueueContext; static void increment_idx(QueueContext *context, int *idx) { @@ -29,6 +31,15 @@ static void dequeue_item(QueueContext *context, void *item) { void *source_item = (uint8_t *)context->data + (context->item_size * context->read_idx); memcpy(item, source_item, context->item_size); increment_idx(context, &context->read_idx); + whist_broadcast_cond(context->avail_space_cond); +} + +static void enqueue_item(QueueContext *context, const void *item) { + context->num_items++; + void *target_item = (uint8_t *)context->data + (context->item_size * context->write_idx); + memcpy(target_item, item, context->item_size); + increment_idx(context, &context->write_idx); + whist_broadcast_cond(context->avail_items_cond); } QueueContext *fifo_queue_create(size_t item_size, int max_items) { @@ -49,14 +60,21 @@ QueueContext *fifo_queue_create(size_t item_size, int max_items) { return NULL; } - context->cond = whist_create_cond(); - if (context->cond == NULL) { + context->avail_items_cond = whist_create_cond(); + if (context->avail_items_cond == NULL) { + fifo_queue_destroy(context); + return NULL; + } + + context->avail_space_cond = whist_create_cond(); + if (context->avail_space_cond == NULL) { fifo_queue_destroy(context); return NULL; } context->item_size = item_size; context->max_items = max_items; + context->destroying = false; return context; } @@ -69,11 +87,39 @@ int fifo_queue_enqueue_item(QueueContext *context, const void *item) { whist_unlock_mutex(context->mutex); return -1; } - context->num_items++; - void *target_item = (uint8_t *)context->data + (context->item_size * context->write_idx); - memcpy(target_item, item, context->item_size); - increment_idx(context, &context->write_idx); - whist_broadcast_cond(context->cond); + enqueue_item(context, item); + whist_unlock_mutex(context->mutex); + return 0; +} + +int fifo_queue_enqueue_item_timeout(QueueContext *context, const void *item, int timeout_ms) { + if (context == NULL) { + return -1; + } + WhistTimer timer; + start_timer(&timer); + int current_timeout_ms = timeout_ms; + whist_lock_mutex(context->mutex); + while (context->num_items >= context->max_items) { + if (context->destroying) { + whist_unlock_mutex(context->mutex); + return -1; + } + if (timeout_ms >= 0) { + bool res = + whist_timedwait_cond(context->avail_space_cond, context->mutex, current_timeout_ms); + if (res == false) { // In case of a timeout simply exit + whist_unlock_mutex(context->mutex); + return -1; + } + int elapsed_ms = (int)(get_timer(&timer) * MS_IN_SECOND); + current_timeout_ms = max(timeout_ms - elapsed_ms, 0); + } else { + // Negative timeout_ms indicates block until available, not timeout + whist_wait_cond(context->avail_space_cond, context->mutex); + } + } + enqueue_item(context, item); whist_unlock_mutex(context->mutex); return 0; } @@ -101,13 +147,23 @@ int fifo_queue_dequeue_item_timeout(QueueContext *context, void *item, int timeo int current_timeout_ms = timeout_ms; whist_lock_mutex(context->mutex); while (context->num_items <= 0) { - bool res = whist_timedwait_cond(context->cond, context->mutex, current_timeout_ms); - if (res == false) { // In case of a timeout simply exit + if (context->destroying) { whist_unlock_mutex(context->mutex); return -1; } - int elapsed_ms = (int)(get_timer(&timer) * MS_IN_SECOND); - current_timeout_ms = max(timeout_ms - elapsed_ms, 0); + if (timeout_ms >= 0) { + bool res = + whist_timedwait_cond(context->avail_items_cond, context->mutex, current_timeout_ms); + if (res == false) { // In case of a timeout simply exit + whist_unlock_mutex(context->mutex); + return -1; + } + int elapsed_ms = (int)(get_timer(&timer) * MS_IN_SECOND); + current_timeout_ms = max(timeout_ms - elapsed_ms, 0); + } else { + // Negative timeout_ms indicates block until available, not timeout + whist_wait_cond(context->avail_items_cond, context->mutex); + } } dequeue_item(context, item); whist_unlock_mutex(context->mutex); @@ -118,14 +174,23 @@ void fifo_queue_destroy(QueueContext *context) { if (context == NULL) { return; } + + // Make sure that all blocking calls release + context->destroying = true; + whist_broadcast_cond(context->avail_items_cond); + whist_broadcast_cond(context->avail_space_cond); + if (context->data != NULL) { free(context->data); } if (context->mutex != NULL) { whist_destroy_mutex(context->mutex); } - if (context->cond != NULL) { - whist_destroy_cond(context->cond); + if (context->avail_items_cond != NULL) { + whist_destroy_cond(context->avail_items_cond); + } + if (context->avail_space_cond != NULL) { + whist_destroy_cond(context->avail_space_cond); } free(context); } diff --git a/protocol/whist/utils/queue.h b/protocol/whist/utils/queue.h index 55c05c46daf..ce6d1d7494c 100644 --- a/protocol/whist/utils/queue.h +++ b/protocol/whist/utils/queue.h @@ -19,15 +19,29 @@ typedef struct QueueContext QueueContext; QueueContext *fifo_queue_create(size_t item_size, int max_items); /** - * @brief Enqueue an item to the FIFO queue + * @brief Enqueue an item to the FIFO queue (nonblocking) If queue is full, + * then return immediately without any waiting. * * @param queue_context Queue's context pointer * @param item Pointer to the item that needs to be enqueued * - * @returns 0 on success, -1 on failure + * @returns 0 on success, -1 when queue is full and on failure */ int fifo_queue_enqueue_item(QueueContext *queue_context, const void *item); +/** + * @brief Enqueue an item to the FIFO queue, If an item is not available, + * then wait till a timeout. + * + * @param queue_context Queue's context pointer + * @param item Pointer to the item that needs to be enqueued + * @param timeout_ms The number of milliseconds to wait for. -1 for wait without + * timeout. + * + * @returns 0 on success, -1 on failure + */ +int fifo_queue_enqueue_item_timeout(QueueContext *queue_context, const void *item, int timeout_ms); + /** * @brief Dequeue an item from the FIFO queue. If an item is not available, * then return immediately without any waiting. @@ -35,7 +49,7 @@ int fifo_queue_enqueue_item(QueueContext *queue_context, const void *item); * @param queue_context Queue's context pointer * @param item Pointer to the memory where dequeued item will be stored * - * @returns 0 on success, -1 on failure + * @returns 0 on success, -1 when queue is empty and on failure */ int fifo_queue_dequeue_item(QueueContext *queue_context, void *item); @@ -45,7 +59,8 @@ int fifo_queue_dequeue_item(QueueContext *queue_context, void *item); * * @param queue_context Queue's context pointer * @param item Pointer to the memory where dequeued item will be stored - * @param timeout_ms The number of milliseconds to wait for. + * @param timeout_ms The number of milliseconds to wait for. -1 for wait without + * timeout. * * @returns 0 on success, -1 on failure */