diff --git a/tests/testlib/s2n_ktls_test_utils.c b/tests/testlib/s2n_ktls_test_utils.c index 08c1a060d41..3dc84dbc363 100644 --- a/tests/testlib/s2n_ktls_test_utils.c +++ b/tests/testlib/s2n_ktls_test_utils.c @@ -203,8 +203,8 @@ S2N_CLEANUP_RESULT s2n_ktls_io_stuffer_pair_free(struct s2n_test_ktls_io_stuffer return S2N_RESULT_OK; } -S2N_RESULT s2n_test_validate_data(struct s2n_test_ktls_io_stuffer *ktls_io, uint8_t *expected_data, - uint16_t expected_len) +S2N_RESULT s2n_test_validate_data(struct s2n_test_ktls_io_stuffer *ktls_io, + const uint8_t *expected_data, uint16_t expected_len) { RESULT_ENSURE_REF(ktls_io); RESULT_ENSURE_REF(expected_data); diff --git a/tests/testlib/s2n_ktls_test_utils.h b/tests/testlib/s2n_ktls_test_utils.h index 255d7efc873..9c69b4c85df 100644 --- a/tests/testlib/s2n_ktls_test_utils.h +++ b/tests/testlib/s2n_ktls_test_utils.h @@ -68,8 +68,8 @@ S2N_RESULT s2n_test_init_ktls_io_stuffer(struct s2n_connection *server, struct s2n_connection *client, struct s2n_test_ktls_io_stuffer_pair *io_pair); S2N_CLEANUP_RESULT s2n_ktls_io_stuffer_free(struct s2n_test_ktls_io_stuffer *io); S2N_CLEANUP_RESULT s2n_ktls_io_stuffer_pair_free(struct s2n_test_ktls_io_stuffer_pair *pair); -S2N_RESULT s2n_test_validate_data(struct s2n_test_ktls_io_stuffer *ktls_io, uint8_t *expected_data, - uint16_t expected_len); +S2N_RESULT s2n_test_validate_data(struct s2n_test_ktls_io_stuffer *ktls_io, + const uint8_t *expected_data, uint16_t expected_len); S2N_RESULT s2n_test_validate_ancillary(struct s2n_test_ktls_io_stuffer *ktls_io, uint8_t expected_record_type, uint16_t expected_len); S2N_RESULT s2n_test_records_in_ancillary(struct s2n_test_ktls_io_stuffer *ktls_io, diff --git a/tests/unit/s2n_ktls_io_test.c b/tests/unit/s2n_ktls_io_test.c index 57a29896474..3f437d92e4d 100644 --- a/tests/unit/s2n_ktls_io_test.c +++ b/tests/unit/s2n_ktls_io_test.c @@ -212,23 +212,22 @@ int main(int argc, char **argv) { /* Safety */ { - DEFER_CLEANUP(struct s2n_connection *server = s2n_connection_new(S2N_SERVER), - s2n_connection_ptr_free); + struct s2n_test_ktls_io_stuffer ctx = { 0 }; struct iovec msg_iov_valid = { .iov_base = test_data, .iov_len = S2N_TEST_TO_SEND }; s2n_blocked_status blocked = S2N_NOT_BLOCKED; size_t bytes_written = 0; EXPECT_ERROR_WITH_ERRNO( s2n_ktls_sendmsg(NULL, test_record_type, &msg_iov_valid, 1, &blocked, &bytes_written), - S2N_ERR_NULL); + S2N_ERR_IO); EXPECT_ERROR_WITH_ERRNO( - s2n_ktls_sendmsg(server, test_record_type, NULL, 1, &blocked, &bytes_written), + s2n_ktls_sendmsg(&ctx, test_record_type, NULL, 1, &blocked, &bytes_written), S2N_ERR_NULL); EXPECT_ERROR_WITH_ERRNO( - s2n_ktls_sendmsg(server, test_record_type, &msg_iov_valid, 1, NULL, &bytes_written), + s2n_ktls_sendmsg(&ctx, test_record_type, &msg_iov_valid, 1, NULL, &bytes_written), S2N_ERR_NULL); EXPECT_ERROR_WITH_ERRNO( - s2n_ktls_sendmsg(server, test_record_type, &msg_iov_valid, 1, &blocked, NULL), + s2n_ktls_sendmsg(&ctx, test_record_type, &msg_iov_valid, 1, &blocked, NULL), S2N_ERR_NULL); }; @@ -243,7 +242,8 @@ int main(int argc, char **argv) struct iovec msg_iov = { .iov_base = test_data, .iov_len = S2N_TEST_TO_SEND }; s2n_blocked_status blocked = S2N_NOT_BLOCKED; size_t bytes_written = 0; - EXPECT_OK(s2n_ktls_sendmsg(server, test_record_type, &msg_iov, 1, &blocked, &bytes_written)); + EXPECT_OK(s2n_ktls_sendmsg(server->send_io_context, test_record_type, + &msg_iov, 1, &blocked, &bytes_written)); EXPECT_EQUAL(bytes_written, S2N_TEST_TO_SEND); EXPECT_EQUAL(blocked, S2N_NOT_BLOCKED); @@ -272,7 +272,8 @@ int main(int argc, char **argv) s2n_blocked_status blocked = S2N_NOT_BLOCKED; size_t bytes_written = 0; - EXPECT_OK(s2n_ktls_sendmsg(server, test_record_type, msg_iov, S2N_TEST_MSG_IOVLEN, &blocked, &bytes_written)); + EXPECT_OK(s2n_ktls_sendmsg(server->send_io_context, test_record_type, + msg_iov, S2N_TEST_MSG_IOVLEN, &blocked, &bytes_written)); EXPECT_EQUAL(bytes_written, total_sent); EXPECT_EQUAL(blocked, S2N_NOT_BLOCKED); @@ -302,7 +303,8 @@ int main(int argc, char **argv) size_t bytes_written = 0; for (size_t i = 0; i < blocked_invoked_count; i++) { EXPECT_ERROR_WITH_ERRNO( - s2n_ktls_sendmsg(server, test_record_type, &msg_iov, 1, &blocked, &bytes_written), + s2n_ktls_sendmsg(server->send_io_context, test_record_type, + &msg_iov, 1, &blocked, &bytes_written), S2N_ERR_IO_BLOCKED); EXPECT_EQUAL(blocked, S2N_BLOCKED_ON_WRITE); } @@ -310,7 +312,8 @@ int main(int argc, char **argv) /* enable growable to unblock write */ /* cppcheck-suppress redundantAssignment */ client_in.data_buffer.growable = true; - EXPECT_OK(s2n_ktls_sendmsg(server, test_record_type, &msg_iov, 1, &blocked, &bytes_written)); + EXPECT_OK(s2n_ktls_sendmsg(server->send_io_context, test_record_type, + &msg_iov, 1, &blocked, &bytes_written)); EXPECT_EQUAL(bytes_written, S2N_TEST_TO_SEND); /* confirm sent data */ @@ -333,13 +336,15 @@ int main(int argc, char **argv) io_ctx.errno_code = EWOULDBLOCK; EXPECT_ERROR_WITH_ERRNO( - s2n_ktls_sendmsg(server, test_record_type, &msg_iov, 1, &blocked, &bytes_written), + s2n_ktls_sendmsg(server->send_io_context, test_record_type, + &msg_iov, 1, &blocked, &bytes_written), S2N_ERR_IO_BLOCKED); /* cppcheck-suppress redundantAssignment */ io_ctx.errno_code = EAGAIN; EXPECT_ERROR_WITH_ERRNO( - s2n_ktls_sendmsg(server, test_record_type, &msg_iov, 1, &blocked, &bytes_written), + s2n_ktls_sendmsg(server->send_io_context, test_record_type, + &msg_iov, 1, &blocked, &bytes_written), S2N_ERR_IO_BLOCKED); EXPECT_EQUAL(io_ctx.invoked_count, 2); @@ -358,7 +363,8 @@ int main(int argc, char **argv) s2n_blocked_status blocked = S2N_NOT_BLOCKED; size_t bytes_written = 0; EXPECT_ERROR_WITH_ERRNO( - s2n_ktls_sendmsg(server, test_record_type, &msg_iov, 1, &blocked, &bytes_written), + s2n_ktls_sendmsg(server->send_io_context, test_record_type, + &msg_iov, 1, &blocked, &bytes_written), S2N_ERR_IO); /* Blocked status intentionally not reset to preserve legacy s2n_send behavior */ EXPECT_EQUAL(blocked, S2N_BLOCKED_ON_WRITE); @@ -379,12 +385,14 @@ int main(int argc, char **argv) size_t bytes_written = 0; size_t iovlen_zero = 0; - EXPECT_OK(s2n_ktls_sendmsg(server, test_record_type, &msg_iov, iovlen_zero, &blocked, &bytes_written)); + EXPECT_OK(s2n_ktls_sendmsg(server->send_io_context, test_record_type, + &msg_iov, iovlen_zero, &blocked, &bytes_written)); EXPECT_EQUAL(blocked, S2N_NOT_BLOCKED); EXPECT_EQUAL(bytes_written, 0); struct iovec msg_iov_len_zero = { .iov_base = test_data, .iov_len = 0 }; - EXPECT_OK(s2n_ktls_sendmsg(server, test_record_type, &msg_iov_len_zero, 1, &blocked, &bytes_written)); + EXPECT_OK(s2n_ktls_sendmsg(server->send_io_context, test_record_type, + &msg_iov_len_zero, 1, &blocked, &bytes_written)); EXPECT_EQUAL(blocked, S2N_NOT_BLOCKED); EXPECT_EQUAL(bytes_written, 0); @@ -396,8 +404,7 @@ int main(int argc, char **argv) { /* Safety */ { - DEFER_CLEANUP(struct s2n_connection *client = s2n_connection_new(S2N_CLIENT), - s2n_connection_ptr_free); + struct s2n_test_ktls_io_stuffer ctx = { 0 }; uint8_t recv_buf[S2N_TLS_MAXIMUM_FRAGMENT_LENGTH] = { 0 }; s2n_blocked_status blocked = S2N_NOT_BLOCKED; uint8_t recv_record_type = 0; @@ -405,23 +412,23 @@ int main(int argc, char **argv) EXPECT_ERROR_WITH_ERRNO( s2n_ktls_recvmsg(NULL, &recv_record_type, recv_buf, S2N_TEST_TO_SEND, &blocked, &bytes_read), - S2N_ERR_NULL); + S2N_ERR_IO); EXPECT_ERROR_WITH_ERRNO( - s2n_ktls_recvmsg(client, NULL, recv_buf, S2N_TEST_TO_SEND, &blocked, &bytes_read), + s2n_ktls_recvmsg(&ctx, NULL, recv_buf, S2N_TEST_TO_SEND, &blocked, &bytes_read), S2N_ERR_NULL); EXPECT_ERROR_WITH_ERRNO( - s2n_ktls_recvmsg(client, &recv_record_type, NULL, S2N_TEST_TO_SEND, &blocked, &bytes_read), + s2n_ktls_recvmsg(&ctx, &recv_record_type, NULL, S2N_TEST_TO_SEND, &blocked, &bytes_read), S2N_ERR_NULL); EXPECT_ERROR_WITH_ERRNO( - s2n_ktls_recvmsg(client, &recv_record_type, recv_buf, S2N_TEST_TO_SEND, NULL, &bytes_read), + s2n_ktls_recvmsg(&ctx, &recv_record_type, recv_buf, S2N_TEST_TO_SEND, NULL, &bytes_read), S2N_ERR_NULL); EXPECT_ERROR_WITH_ERRNO( - s2n_ktls_recvmsg(client, &recv_record_type, recv_buf, S2N_TEST_TO_SEND, &blocked, NULL), + s2n_ktls_recvmsg(&ctx, &recv_record_type, recv_buf, S2N_TEST_TO_SEND, &blocked, NULL), S2N_ERR_NULL); size_t to_recv_zero = 0; EXPECT_ERROR_WITH_ERRNO( - s2n_ktls_recvmsg(client, &recv_record_type, recv_buf, to_recv_zero, &blocked, &bytes_read), + s2n_ktls_recvmsg(&ctx, &recv_record_type, recv_buf, to_recv_zero, &blocked, &bytes_read), S2N_ERR_SAFETY); }; @@ -438,13 +445,15 @@ int main(int argc, char **argv) struct iovec msg_iov = { .iov_base = test_data, .iov_len = S2N_TEST_TO_SEND }; s2n_blocked_status blocked = S2N_NOT_BLOCKED; size_t bytes_written = 0; - EXPECT_OK(s2n_ktls_sendmsg(server, test_record_type, &msg_iov, 1, &blocked, &bytes_written)); + EXPECT_OK(s2n_ktls_sendmsg(server->send_io_context, test_record_type, + &msg_iov, 1, &blocked, &bytes_written)); EXPECT_EQUAL(bytes_written, S2N_TEST_TO_SEND); uint8_t recv_buf[S2N_TLS_MAXIMUM_FRAGMENT_LENGTH] = { 0 }; uint8_t recv_record_type = 0; size_t bytes_read = 0; - EXPECT_OK(s2n_ktls_recvmsg(client, &recv_record_type, recv_buf, S2N_TEST_TO_SEND, &blocked, &bytes_read)); + EXPECT_OK(s2n_ktls_recvmsg(client->recv_io_context, &recv_record_type, + recv_buf, S2N_TEST_TO_SEND, &blocked, &bytes_read)); EXPECT_BYTEARRAY_EQUAL(test_data, recv_buf, bytes_read); EXPECT_EQUAL(bytes_read, bytes_written); @@ -470,7 +479,8 @@ int main(int argc, char **argv) /* recv should block since there is no data */ for (size_t i = 0; i < blocked_invoked_count; i++) { EXPECT_ERROR_WITH_ERRNO( - s2n_ktls_recvmsg(client, &recv_record_type, recv_buf, S2N_TEST_TO_SEND, &blocked, &bytes_read), + s2n_ktls_recvmsg(client->recv_io_context, &recv_record_type, + recv_buf, S2N_TEST_TO_SEND, &blocked, &bytes_read), S2N_ERR_IO_BLOCKED); EXPECT_EQUAL(blocked, S2N_BLOCKED_ON_READ); } @@ -478,17 +488,20 @@ int main(int argc, char **argv) /* send data to unblock */ struct iovec msg_iov = { .iov_base = test_data, .iov_len = S2N_TEST_TO_SEND }; size_t bytes_written = 0; - EXPECT_OK(s2n_ktls_sendmsg(server, test_record_type, &msg_iov, 1, &blocked, &bytes_written)); + EXPECT_OK(s2n_ktls_sendmsg(server->send_io_context, test_record_type, + &msg_iov, 1, &blocked, &bytes_written)); EXPECT_EQUAL(bytes_written, S2N_TEST_TO_SEND); - EXPECT_OK(s2n_ktls_recvmsg(client, &recv_record_type, recv_buf, S2N_TEST_TO_SEND, &blocked, &bytes_read)); + EXPECT_OK(s2n_ktls_recvmsg(client->recv_io_context, &recv_record_type, + recv_buf, S2N_TEST_TO_SEND, &blocked, &bytes_read)); EXPECT_BYTEARRAY_EQUAL(test_data, recv_buf, bytes_read); EXPECT_EQUAL(bytes_read, bytes_written); /* recv should block again since we have read all the data */ for (size_t i = 0; i < blocked_invoked_count; i++) { EXPECT_ERROR_WITH_ERRNO( - s2n_ktls_recvmsg(client, &recv_record_type, recv_buf, S2N_TEST_TO_SEND, &blocked, &bytes_read), + s2n_ktls_recvmsg(client->recv_io_context, &recv_record_type, + recv_buf, S2N_TEST_TO_SEND, &blocked, &bytes_read), S2N_ERR_IO_BLOCKED); EXPECT_EQUAL(blocked, S2N_BLOCKED_ON_READ); } @@ -511,14 +524,16 @@ int main(int argc, char **argv) io_ctx.errno_code = EWOULDBLOCK; EXPECT_ERROR_WITH_ERRNO( - s2n_ktls_recvmsg(client, &recv_record_type, recv_buf, S2N_TEST_TO_SEND, &blocked, &bytes_read), + s2n_ktls_recvmsg(client->recv_io_context, &recv_record_type, + recv_buf, S2N_TEST_TO_SEND, &blocked, &bytes_read), S2N_ERR_IO_BLOCKED); EXPECT_EQUAL(blocked, S2N_BLOCKED_ON_READ); /* cppcheck-suppress redundantAssignment */ io_ctx.errno_code = EAGAIN; EXPECT_ERROR_WITH_ERRNO( - s2n_ktls_recvmsg(client, &recv_record_type, recv_buf, S2N_TEST_TO_SEND, &blocked, &bytes_read), + s2n_ktls_recvmsg(client->recv_io_context, &recv_record_type, + recv_buf, S2N_TEST_TO_SEND, &blocked, &bytes_read), S2N_ERR_IO_BLOCKED); EXPECT_EQUAL(blocked, S2N_BLOCKED_ON_READ); @@ -539,7 +554,8 @@ int main(int argc, char **argv) uint8_t recv_record_type = 0; size_t bytes_read = 0; EXPECT_ERROR_WITH_ERRNO( - s2n_ktls_recvmsg(client, &recv_record_type, recv_buf, S2N_TEST_TO_SEND, &blocked, &bytes_read), + s2n_ktls_recvmsg(client->recv_io_context, &recv_record_type, + recv_buf, S2N_TEST_TO_SEND, &blocked, &bytes_read), S2N_ERR_IO); /* Blocked status intentionally not reset to preserve legacy s2n_send behavior */ EXPECT_EQUAL(blocked, S2N_BLOCKED_ON_READ); @@ -559,7 +575,8 @@ int main(int argc, char **argv) uint8_t recv_record_type = 0; size_t bytes_read = 0; EXPECT_ERROR_WITH_ERRNO( - s2n_ktls_recvmsg(client, &recv_record_type, recv_buf, S2N_TEST_TO_SEND, &blocked, &bytes_read), + s2n_ktls_recvmsg(client->recv_io_context, &recv_record_type, + recv_buf, S2N_TEST_TO_SEND, &blocked, &bytes_read), S2N_ERR_CLOSED); /* Blocked status intentionally not reset to preserve legacy s2n_send behavior */ EXPECT_EQUAL(blocked, S2N_BLOCKED_ON_READ); @@ -582,14 +599,16 @@ int main(int argc, char **argv) struct iovec msg_iov = { .iov_base = test_data, .iov_len = S2N_TEST_TO_SEND }; s2n_blocked_status blocked = S2N_NOT_BLOCKED; size_t bytes_written = 0; - EXPECT_OK(s2n_ktls_sendmsg(server, test_record_type, &msg_iov, 1, &blocked, &bytes_written)); + EXPECT_OK(s2n_ktls_sendmsg(server->send_io_context, test_record_type, + &msg_iov, 1, &blocked, &bytes_written)); EXPECT_EQUAL(bytes_written, S2N_TEST_TO_SEND); uint8_t recv_buf[S2N_TLS_MAXIMUM_FRAGMENT_LENGTH] = { 0 }; uint8_t recv_record_type = 0; size_t bytes_read = 0; EXPECT_ERROR_WITH_ERRNO( - s2n_ktls_recvmsg(client, &recv_record_type, recv_buf, S2N_TEST_TO_SEND, &blocked, &bytes_read), + s2n_ktls_recvmsg(client->recv_io_context, &recv_record_type, + recv_buf, S2N_TEST_TO_SEND, &blocked, &bytes_read), S2N_ERR_KTLS_BAD_CMSG); /* Blocked status intentionally not reset to preserve legacy s2n_send behavior */ EXPECT_EQUAL(blocked, S2N_BLOCKED_ON_READ); @@ -858,5 +877,106 @@ int main(int argc, char **argv) }; }; + /* Test: s2n_ktls_send_cb */ + { + /* It's safe to reuse a connection across tests because the connection + * isn't actually used by s2n_ktls_send_cb. It's just required for test + * setup methods. + */ + DEFER_CLEANUP(struct s2n_connection *conn = s2n_connection_new(S2N_SERVER), + s2n_connection_ptr_free); + EXPECT_NOT_NULL(conn); + + /* Safety */ + { + struct s2n_test_ktls_io_stuffer ctx = { 0 }; + EXPECT_FAILURE_WITH_ERRNO(s2n_ktls_send_cb(NULL, test_data, 1), S2N_ERR_IO); + EXPECT_FAILURE_WITH_ERRNO(s2n_ktls_send_cb(&ctx, NULL, 1), S2N_ERR_IO); + }; + + /* Test: Basic write succeeds */ + { + DEFER_CLEANUP(struct s2n_test_ktls_io_stuffer ctx = { 0 }, + s2n_ktls_io_stuffer_free); + EXPECT_OK(s2n_test_init_ktls_io_stuffer_send(conn, &ctx)); + + EXPECT_SUCCESS(s2n_ktls_send_cb(&ctx, test_data, sizeof(test_data))); + EXPECT_EQUAL(ctx.sendmsg_invoked_count, 1); + EXPECT_OK(s2n_test_validate_ancillary(&ctx, TLS_ALERT, sizeof(test_data))); + EXPECT_OK(s2n_test_validate_data(&ctx, test_data, sizeof(test_data))); + }; + + /* Test: Errors passed on to caller */ + { + struct s2n_test_ktls_io_fail_ctx ctx = { 0 }; + EXPECT_OK(s2n_ktls_set_sendmsg_cb(conn, s2n_test_ktls_sendmsg_fail, &ctx)); + + ctx.errno_code = 1; + EXPECT_FAILURE_WITH_ERRNO(s2n_ktls_send_cb(&ctx, test_data, sizeof(test_data)), + S2N_ERR_IO); + + ctx.errno_code = EINVAL; + EXPECT_FAILURE_WITH_ERRNO(s2n_ktls_send_cb(&ctx, test_data, sizeof(test_data)), + S2N_ERR_IO); + + ctx.errno_code = EAGAIN; + EXPECT_FAILURE_WITH_ERRNO(s2n_ktls_send_cb(&ctx, test_data, sizeof(test_data)), + S2N_ERR_IO_BLOCKED); + + ctx.errno_code = EWOULDBLOCK; + EXPECT_FAILURE_WITH_ERRNO(s2n_ktls_send_cb(&ctx, test_data, sizeof(test_data)), + S2N_ERR_IO_BLOCKED); + }; + }; + + /* Test: s2n_ktls_record_writev */ + { + const size_t to_write = 10; + + /* Safety */ + { + DEFER_CLEANUP(struct s2n_connection *conn = s2n_connection_new(S2N_SERVER), + s2n_connection_ptr_free); + EXPECT_NOT_NULL(conn); + struct iovec iov = { 0 }; + EXPECT_FAILURE_WITH_ERRNO(s2n_ktls_record_writev(NULL, 0, &iov, 1, 1, 1), + S2N_ERR_NULL); + EXPECT_FAILURE_WITH_ERRNO(s2n_ktls_record_writev(conn, 0, NULL, 1, 1, 1), + S2N_ERR_NULL); + EXPECT_FAILURE_WITH_ERRNO(s2n_ktls_record_writev(conn, 0, &iov, -1, 1, 1), + S2N_ERR_INVALID_ARGUMENT); + }; + + /* Test: Basic write succeeds */ + { + DEFER_CLEANUP(struct s2n_connection *conn = s2n_connection_new(S2N_SERVER), + s2n_connection_ptr_free); + EXPECT_NOT_NULL(conn); + struct iovec iov = { + .iov_base = test_data, + .iov_len = sizeof(test_data), + }; + EXPECT_EQUAL(s2n_ktls_record_writev(conn, TLS_ALERT, &iov, 1, 0, to_write), to_write); + EXPECT_EQUAL(conn->out.blob.allocated, to_write); + EXPECT_EQUAL(s2n_stuffer_data_available(&conn->out), to_write); + uint8_t *in_out = s2n_stuffer_raw_read(&conn->out, to_write); + EXPECT_BYTEARRAY_EQUAL(in_out, test_data, to_write); + }; + + /* Test: Only alerts currently supported */ + { + DEFER_CLEANUP(struct s2n_connection *conn = s2n_connection_new(S2N_SERVER), + s2n_connection_ptr_free); + EXPECT_NOT_NULL(conn); + struct iovec iov = { + .iov_base = test_data, + .iov_len = sizeof(test_data), + }; + EXPECT_FAILURE_WITH_ERRNO( + s2n_ktls_record_writev(conn, TLS_HANDSHAKE, &iov, 1, 0, to_write), + S2N_ERR_UNIMPLEMENTED); + }; + }; + END_TEST(); } diff --git a/tests/unit/s2n_shutdown_test.c b/tests/unit/s2n_shutdown_test.c index 91efa34020a..f5e16a37897 100644 --- a/tests/unit/s2n_shutdown_test.c +++ b/tests/unit/s2n_shutdown_test.c @@ -16,8 +16,10 @@ #include "tls/s2n_shutdown.c" #include "s2n_test.h" +#include "testlib/s2n_ktls_test_utils.h" #include "testlib/s2n_testlib.h" #include "tls/s2n_alerts.h" +#include "utils/s2n_socket.h" #define ALERT_LEN (sizeof(uint16_t)) @@ -614,7 +616,93 @@ int main(int argc, char **argv) EXPECT_TRUE(s2n_connection_check_io_status(conn, S2N_IO_CLOSED)); EXPECT_FALSE(s2n_atomic_flag_test(&conn->close_notify_received)); }; - } + + /* Test: kTLS enabled */ + { + /* Test: Successfully send alert */ + { + DEFER_CLEANUP(struct s2n_connection *conn = s2n_connection_new(S2N_SERVER), + s2n_connection_ptr_free); + EXPECT_NOT_NULL(conn); + EXPECT_OK(s2n_ktls_configure_connection(conn, S2N_KTLS_MODE_SEND)); + + DEFER_CLEANUP(struct s2n_test_ktls_io_stuffer out = { 0 }, + s2n_ktls_io_stuffer_free); + EXPECT_OK(s2n_test_init_ktls_io_stuffer_send(conn, &out)); + + s2n_blocked_status blocked = S2N_NOT_BLOCKED; + EXPECT_SUCCESS(s2n_shutdown_send(conn, &blocked)); + EXPECT_EQUAL(blocked, S2N_NOT_BLOCKED); + EXPECT_TRUE(conn->alert_sent); + EXPECT_EQUAL(out.sendmsg_invoked_count, 1); + EXPECT_OK(s2n_test_validate_ancillary(&out, TLS_ALERT, S2N_ALERT_LENGTH)); + EXPECT_OK(s2n_test_validate_data(&out, + close_notify_alert, sizeof(close_notify_alert))); + + /* Repeating the shutdown does not resend the alert */ + for (size_t i = 0; i < 5; i++) { + EXPECT_SUCCESS(s2n_shutdown_send(conn, &blocked)); + EXPECT_EQUAL(blocked, S2N_NOT_BLOCKED); + EXPECT_TRUE(conn->alert_sent); + EXPECT_EQUAL(out.sendmsg_invoked_count, 1); + } + }; + + /* Test: Successfully send alert after blocking */ + { + /* One call does the partial write, the second blocks */ + const size_t partial_write = 1; + const size_t second_write = sizeof(close_notify_alert) - partial_write; + EXPECT_TRUE(second_write > 0); + + DEFER_CLEANUP(struct s2n_connection *conn = s2n_connection_new(S2N_SERVER), + s2n_connection_ptr_free); + EXPECT_NOT_NULL(conn); + EXPECT_OK(s2n_ktls_configure_connection(conn, S2N_KTLS_MODE_SEND)); + + DEFER_CLEANUP(struct s2n_test_ktls_io_stuffer out = { 0 }, + s2n_ktls_io_stuffer_free); + EXPECT_OK(s2n_test_init_ktls_io_stuffer_send(conn, &out)); + EXPECT_SUCCESS(s2n_stuffer_free(&out.data_buffer)); + EXPECT_SUCCESS(s2n_stuffer_alloc(&out.data_buffer, partial_write)); + + /* One call does the partial write, the second blocks */ + size_t expected_calls = 2; + + /* Initial shutdown blocks */ + s2n_blocked_status blocked = S2N_NOT_BLOCKED; + EXPECT_FAILURE_WITH_ERRNO(s2n_shutdown_send(conn, &blocked), + S2N_ERR_IO_BLOCKED); + EXPECT_EQUAL(blocked, S2N_BLOCKED_ON_WRITE); + EXPECT_TRUE(conn->alert_sent); + EXPECT_EQUAL(out.sendmsg_invoked_count, expected_calls); + EXPECT_OK(s2n_test_validate_ancillary(&out, TLS_ALERT, partial_write)); + EXPECT_OK(s2n_test_validate_data(&out, close_notify_alert, partial_write)); + + /* Unblock the output stuffer */ + out.data_buffer.growable = true; + expected_calls++; + EXPECT_SUCCESS(s2n_stuffer_wipe(&out.ancillary_buffer)); + + /* Second shutdown succeeds */ + EXPECT_SUCCESS(s2n_shutdown_send(conn, &blocked)); + EXPECT_EQUAL(blocked, S2N_NOT_BLOCKED); + EXPECT_TRUE(conn->alert_sent); + EXPECT_EQUAL(out.sendmsg_invoked_count, expected_calls); + EXPECT_OK(s2n_test_validate_ancillary(&out, TLS_ALERT, second_write)); + EXPECT_OK(s2n_test_validate_data(&out, close_notify_alert, + sizeof(close_notify_alert))); + + /* Repeating the shutdown does not resend the alert */ + for (size_t i = 0; i < 5; i++) { + EXPECT_SUCCESS(s2n_shutdown_send(conn, &blocked)); + EXPECT_EQUAL(blocked, S2N_NOT_BLOCKED); + EXPECT_TRUE(conn->alert_sent); + EXPECT_EQUAL(out.sendmsg_invoked_count, expected_calls); + } + }; + }; + }; END_TEST(); } diff --git a/tls/s2n_connection.c b/tls/s2n_connection.c index 787d1a42280..ac56df34cea 100644 --- a/tls/s2n_connection.c +++ b/tls/s2n_connection.c @@ -853,11 +853,17 @@ int s2n_connection_use_corked_io(struct s2n_connection *conn) uint64_t s2n_connection_get_wire_bytes_in(struct s2n_connection *conn) { + if (conn->ktls_recv_enabled) { + return 0; + } return conn->wire_bytes_in; } uint64_t s2n_connection_get_wire_bytes_out(struct s2n_connection *conn) { + if (conn->ktls_send_enabled) { + return 0; + } return conn->wire_bytes_out; } diff --git a/tls/s2n_ktls.c b/tls/s2n_ktls.c index a1d775d65f5..7f2a06d138f 100644 --- a/tls/s2n_ktls.c +++ b/tls/s2n_ktls.c @@ -41,11 +41,6 @@ static int s2n_ktls_disabled_read(void *io_context, uint8_t *buf, uint32_t len) POSIX_BAIL(S2N_ERR_IO); } -static int s2n_ktls_disabled_write(void *io_context, const uint8_t *buf, uint32_t len) -{ - POSIX_BAIL(S2N_ERR_IO); -} - static S2N_RESULT s2n_ktls_validate(struct s2n_connection *conn, s2n_ktls_mode ktls_mode) { RESULT_ENSURE_REF(conn); @@ -244,6 +239,18 @@ static S2N_RESULT s2n_ktls_configure_socket(struct s2n_connection *conn, s2n_ktl return S2N_RESULT_OK; } +S2N_RESULT s2n_ktls_configure_connection(struct s2n_connection *conn, s2n_ktls_mode ktls_mode) +{ + if (ktls_mode == S2N_KTLS_MODE_SEND) { + conn->ktls_send_enabled = true; + conn->send = s2n_ktls_send_cb; + } else { + conn->ktls_recv_enabled = true; + conn->recv = s2n_ktls_disabled_read; + } + return S2N_RESULT_OK; +} + /* * Since kTLS is an optimization, it is possible to continue operation * by using userspace TLS if kTLS is not supported. @@ -265,10 +272,7 @@ int s2n_connection_ktls_enable_send(struct s2n_connection *conn) } POSIX_GUARD_RESULT(s2n_ktls_configure_socket(conn, S2N_KTLS_MODE_SEND)); - - conn->ktls_send_enabled = true; - /* kTLS now handles I/O for the connection */ - conn->send = s2n_ktls_disabled_write; + POSIX_GUARD_RESULT(s2n_ktls_configure_connection(conn, S2N_KTLS_MODE_SEND)); return S2N_SUCCESS; } @@ -287,10 +291,7 @@ int s2n_connection_ktls_enable_recv(struct s2n_connection *conn) } POSIX_GUARD_RESULT(s2n_ktls_configure_socket(conn, S2N_KTLS_MODE_RECV)); - - conn->ktls_recv_enabled = true; - /* kTLS now handles I/O for the connection */ - conn->recv = s2n_ktls_disabled_read; + POSIX_GUARD_RESULT(s2n_ktls_configure_connection(conn, S2N_KTLS_MODE_RECV)); return S2N_SUCCESS; } diff --git a/tls/s2n_ktls.h b/tls/s2n_ktls.h index 6bf4e90dfb3..342681f78a3 100644 --- a/tls/s2n_ktls.h +++ b/tls/s2n_ktls.h @@ -39,13 +39,16 @@ typedef enum { bool s2n_ktls_is_supported_on_platform(); S2N_RESULT s2n_ktls_get_file_descriptor(struct s2n_connection *conn, s2n_ktls_mode ktls_mode, int *fd); -S2N_RESULT s2n_ktls_sendmsg(struct s2n_connection *conn, uint8_t record_type, const struct iovec *msg_iov, +int s2n_ktls_send_cb(void *io_context, const uint8_t *buf, uint32_t len); +S2N_RESULT s2n_ktls_sendmsg(void *io_context, uint8_t record_type, const struct iovec *msg_iov, size_t msg_iovlen, s2n_blocked_status *blocked, size_t *bytes_written); -S2N_RESULT s2n_ktls_recvmsg(struct s2n_connection *conn, uint8_t *record_type, uint8_t *buf, +S2N_RESULT s2n_ktls_recvmsg(void *io_context, uint8_t *record_type, uint8_t *buf, size_t buf_len, s2n_blocked_status *blocked, size_t *bytes_read); ssize_t s2n_ktls_sendv_with_offset(struct s2n_connection *conn, const struct iovec *bufs, ssize_t count, ssize_t offs, s2n_blocked_status *blocked); +int s2n_ktls_record_writev(struct s2n_connection *conn, uint8_t content_type, + const struct iovec *in, int in_count, size_t offs, size_t to_write); /* These functions will be part of the public API. */ int s2n_connection_ktls_enable_send(struct s2n_connection *conn); @@ -61,3 +64,4 @@ S2N_RESULT s2n_ktls_set_sendmsg_cb(struct s2n_connection *conn, s2n_ktls_sendmsg void *send_ctx); S2N_RESULT s2n_ktls_set_recvmsg_cb(struct s2n_connection *conn, s2n_ktls_recvmsg_fn recv_cb, void *recv_ctx); +S2N_RESULT s2n_ktls_configure_connection(struct s2n_connection *conn, s2n_ktls_mode ktls_mode); diff --git a/tls/s2n_ktls_io.c b/tls/s2n_ktls_io.c index 3b0603030a6..a7272d68596 100644 --- a/tls/s2n_ktls_io.c +++ b/tls/s2n_ktls_io.c @@ -183,12 +183,11 @@ S2N_RESULT s2n_ktls_get_control_data(struct msghdr *msg, int cmsg_type, uint8_t return S2N_RESULT_OK; } -S2N_RESULT s2n_ktls_sendmsg(struct s2n_connection *conn, uint8_t record_type, const struct iovec *msg_iov, +S2N_RESULT s2n_ktls_sendmsg(void *io_context, uint8_t record_type, const struct iovec *msg_iov, size_t msg_iovlen, s2n_blocked_status *blocked, size_t *bytes_written) { RESULT_ENSURE_REF(bytes_written); RESULT_ENSURE_REF(blocked); - RESULT_ENSURE_REF(conn); RESULT_ENSURE(msg_iov != NULL || msg_iovlen == 0, S2N_ERR_NULL); *blocked = S2N_BLOCKED_ON_WRITE; @@ -206,7 +205,7 @@ S2N_RESULT s2n_ktls_sendmsg(struct s2n_connection *conn, uint8_t record_type, co RESULT_GUARD(s2n_ktls_set_control_data(&msg, control_data, sizeof(control_data), S2N_TLS_SET_RECORD_TYPE, record_type)); - ssize_t result = s2n_sendmsg_fn(conn->send_io_context, &msg); + ssize_t result = s2n_sendmsg_fn(io_context, &msg); if (result < 0) { if (errno == EWOULDBLOCK || errno == EAGAIN) { RESULT_BAIL(S2N_ERR_IO_BLOCKED); @@ -219,13 +218,12 @@ S2N_RESULT s2n_ktls_sendmsg(struct s2n_connection *conn, uint8_t record_type, co return S2N_RESULT_OK; } -S2N_RESULT s2n_ktls_recvmsg(struct s2n_connection *conn, uint8_t *record_type, uint8_t *buf, +S2N_RESULT s2n_ktls_recvmsg(void *io_context, uint8_t *record_type, uint8_t *buf, size_t buf_len, s2n_blocked_status *blocked, size_t *bytes_read) { RESULT_ENSURE_REF(record_type); RESULT_ENSURE_REF(bytes_read); RESULT_ENSURE_REF(blocked); - RESULT_ENSURE_REF(conn); RESULT_ENSURE_REF(buf); /* Ensure that buf_len is > 0 since trying to receive 0 bytes does not * make sense and a return value of `0` from recvmsg is treated as EOF. @@ -254,7 +252,7 @@ S2N_RESULT s2n_ktls_recvmsg(struct s2n_connection *conn, uint8_t *record_type, u msg.msg_controllen = sizeof(control_data); msg.msg_control = control_data; - ssize_t result = s2n_recvmsg_fn(conn->recv_io_context, &msg); + ssize_t result = s2n_recvmsg_fn(io_context, &msg); if (result < 0) { if (errno == EWOULDBLOCK || errno == EAGAIN) { RESULT_BAIL(S2N_ERR_IO_BLOCKED); @@ -304,6 +302,7 @@ static S2N_RESULT s2n_ktls_new_iovecs_with_offset(const struct iovec *bufs, ssize_t s2n_ktls_sendv_with_offset(struct s2n_connection *conn, const struct iovec *bufs, ssize_t count_in, ssize_t offs_in, s2n_blocked_status *blocked) { + POSIX_ENSURE_REF(conn); POSIX_ENSURE(count_in >= 0, S2N_ERR_INVALID_ARGUMENT); size_t count = count_in; POSIX_ENSURE(offs_in >= 0, S2N_ERR_INVALID_ARGUMENT); @@ -319,7 +318,55 @@ ssize_t s2n_ktls_sendv_with_offset(struct s2n_connection *conn, const struct iov } size_t bytes_written = 0; - POSIX_GUARD_RESULT(s2n_ktls_sendmsg(conn, TLS_APPLICATION_DATA, bufs, count, - blocked, &bytes_written)); + POSIX_GUARD_RESULT(s2n_ktls_sendmsg(conn->send_io_context, TLS_APPLICATION_DATA, + bufs, count, blocked, &bytes_written)); + return bytes_written; +} + +int s2n_ktls_send_cb(void *io_context, const uint8_t *buf, uint32_t len) +{ + /* For now, all control records are assumed to be alerts. + * We can set the record_type on the io_context in the future. + */ + const uint8_t record_type = TLS_ALERT; + + const struct iovec iov = { + .iov_base = (void *) (uintptr_t) buf, + .iov_len = len, + }; + s2n_blocked_status blocked = S2N_NOT_BLOCKED; + size_t bytes_written = 0; + + POSIX_GUARD_RESULT(s2n_ktls_sendmsg(io_context, record_type, &iov, 1, + &blocked, &bytes_written)); + + POSIX_ENSURE_LTE(bytes_written, len); return bytes_written; } + +int s2n_ktls_record_writev(struct s2n_connection *conn, uint8_t content_type, + const struct iovec *in, int in_count, size_t offs, size_t to_write) +{ + POSIX_ENSURE_REF(conn); + POSIX_ENSURE(in_count > 0, S2N_ERR_INVALID_ARGUMENT); + size_t count = in_count; + POSIX_ENSURE_REF(in); + + /* Currently, ktls only supports sending alerts. + * To also support handshake messages, we would need a way to track record_type. + * We could add a field to the send io context. + */ + POSIX_ENSURE(content_type == TLS_ALERT, S2N_ERR_UNIMPLEMENTED); + + /* When stuffers automatically resize, they allocate a potentially large + * chunk of memory to avoid repeated resizes. + * Since ktls only uses conn->out for control messages (alerts and eventually + * handshake messages), we expect infrequent small writes with conn->out + * freed in between. Since we're therefore more concerned with the size of + * the allocation than the frequency, use a more accurate size for each write. + */ + POSIX_GUARD(s2n_stuffer_resize_if_empty(&conn->out, to_write)); + + POSIX_GUARD(s2n_stuffer_writev_bytes(&conn->out, in, count, offs, to_write)); + return to_write; +} diff --git a/tls/s2n_record_write.c b/tls/s2n_record_write.c index 9a3ed93fd3e..dd115a3ca2b 100644 --- a/tls/s2n_record_write.c +++ b/tls/s2n_record_write.c @@ -24,6 +24,7 @@ #include "tls/s2n_cipher_suites.h" #include "tls/s2n_connection.h" #include "tls/s2n_crypto.h" +#include "tls/s2n_ktls.h" #include "tls/s2n_record.h" #include "utils/s2n_blob.h" #include "utils/s2n_random.h" @@ -247,6 +248,10 @@ static inline int s2n_record_encrypt( int s2n_record_writev(struct s2n_connection *conn, uint8_t content_type, const struct iovec *in, int in_count, size_t offs, size_t to_write) { + if (conn->ktls_send_enabled) { + return s2n_ktls_record_writev(conn, content_type, in, in_count, offs, to_write); + } + struct s2n_blob iv = { 0 }; uint8_t padding = 0; uint16_t block_size = 0; diff --git a/tls/s2n_send.c b/tls/s2n_send.c index fdd40023fa3..3cf071ac194 100644 --- a/tls/s2n_send.c +++ b/tls/s2n_send.c @@ -115,13 +115,13 @@ ssize_t s2n_sendv_with_offset_impl(struct s2n_connection *conn, const struct iov POSIX_ENSURE(s2n_connection_check_io_status(conn, S2N_IO_WRITABLE), S2N_ERR_CLOSED); POSIX_ENSURE(!s2n_connection_is_quic_enabled(conn), S2N_ERR_UNSUPPORTED_WITH_QUIC); + /* Flush any pending I/O */ + POSIX_GUARD(s2n_flush(conn, blocked)); + if (conn->ktls_send_enabled) { return s2n_ktls_sendv_with_offset(conn, bufs, count, offs, blocked); } - /* Flush any pending I/O */ - POSIX_GUARD(s2n_flush(conn, blocked)); - /* Acknowledge consumed and flushed user data as sent */ user_data_sent = conn->current_user_data_consumed;