From c11dcf3c1d7d34012aeee8e29f2e3120b9816d53 Mon Sep 17 00:00:00 2001 From: Andrey Semashev Date: Fri, 25 Sep 2020 19:33:26 +0300 Subject: [PATCH] Changed stream list to arrays and optimized stream lookup with SSE2. The stream list in the SRTP context is now implemented with two arrays: an array of SSRCs and an array of pointers to the streams corresponding to the SSRCs. The streams no longer form a single linked list. Stream lookup by SSRC is now performed over the array of SSRCs, which is considerably faster because it is more cache-friendly. Additionally, the lookup is optimized for SSE2, which provides an additional massive speedup with many streams in the list. Although the lookup still has linear complexity, its absolute times are reduced and with tens to hundreds elements are lower or comparable with a typical rb-tree equivalent. Expected speedup of SSE2 version over the previous implementation: SSRCs speedup (scalar) speedup (SSE2) 1 0.39x 0.22x 3 0.57x 0.23x 5 0.69x 0.62x 10 0.77x 1.43x 20 0.86x 2.38x 30 0.87x 3.44x 50 1.13x 6.21x 100 1.25x 8.51x 200 1.30x 9.83x At small numbers of SSRCs the new algorithm is somewhat slower, but given that the absolute and relative times of the lookup are very small, that slowdown is not very significant. --- include/srtp_priv.h | 17 +- srtp/srtp.c | 474 ++++++++++++++++++++++++++++++++------------ test/srtp_driver.c | 7 +- 3 files changed, 366 insertions(+), 132 deletions(-) diff --git a/include/srtp_priv.h b/include/srtp_priv.h index 48dc65c7d..f30f5bbac 100644 --- a/include/srtp_priv.h +++ b/include/srtp_priv.h @@ -149,14 +149,27 @@ typedef struct srtp_stream_ctx_t_ { int *enc_xtn_hdr; int enc_xtn_hdr_count; uint32_t pending_roc; - struct srtp_stream_ctx_t_ *next; /* linked list of streams */ } strp_stream_ctx_t_; +/* + * An srtp_stream_list_t is a list of streams searchable by SSRC. + * + * Pointers to streams and their respective SSRCs are stored in two arrays, + * where the stream pointer and the SSRC at the same index correspond to each + * other. + */ +typedef struct srtp_stream_list_t { + srtp_stream_ctx_t **streams; + uint32_t *ssrcs; + uint32_t size; + uint32_t capacity; +} srtp_stream_list_t; + /* * an srtp_ctx_t holds a stream list and a service description */ typedef struct srtp_ctx_t_ { - struct srtp_stream_ctx_t_ *stream_list; /* linked list of streams */ + srtp_stream_list_t stream_list; /* list of streams */ struct srtp_stream_ctx_t_ *stream_template; /* act as template for other */ /* streams */ void *user_data; /* user custom data */ diff --git a/srtp/srtp.c b/srtp/srtp.c index dbb099095..81f29429a 100644 --- a/srtp/srtp.c +++ b/srtp/srtp.c @@ -59,6 +59,7 @@ #include "aes_icm_ext.h" #endif +#include #include #ifdef HAVE_NETINET_IN_H #include @@ -66,6 +67,13 @@ #include #endif +#if defined(__SSE2__) +#include +#if defined(_MSC_VER) +#include +#endif +#endif + /* the debug module for srtp */ srtp_debug_module_t mod_srtp = { 0, /* debugging is off by default */ @@ -549,11 +557,246 @@ srtp_err_status_t srtp_stream_clone(const srtp_stream_ctx_t *stream_template, str->enc_xtn_hdr = stream_template->enc_xtn_hdr; str->enc_xtn_hdr_count = stream_template->enc_xtn_hdr_count; - /* defensive coding */ - str->next = NULL; return srtp_err_status_ok; } +/* + * Initializes an empty list of streams + */ +static inline void srtp_stream_list_init(srtp_stream_list_t *streams) +{ + memset(streams, 0, sizeof(*streams)); +} + +/* + * Returns an index of the stream corresponding to ssrc, + * or >= streams->size if no stream exists for that ssrc. + */ +static uint32_t srtp_stream_list_find(const srtp_stream_list_t *streams, + uint32_t ssrc) +{ +#if defined(__SSE2__) + const uint32_t *const ssrcs = streams->ssrcs; + const __m128i mm_ssrc = _mm_set1_epi32(ssrc); + uint32_t pos = 0u, n = (streams->size + 7u) & ~(uint32_t)(7u); + for (uint32_t m = n & ~(uint32_t)(15u); pos < m; pos += 16u) { + __m128i mm1 = _mm_loadu_si128((const __m128i *)(ssrcs + pos)); + __m128i mm2 = _mm_loadu_si128((const __m128i *)(ssrcs + pos + 4u)); + __m128i mm3 = _mm_loadu_si128((const __m128i *)(ssrcs + pos + 8u)); + __m128i mm4 = _mm_loadu_si128((const __m128i *)(ssrcs + pos + 12u)); + mm1 = _mm_cmpeq_epi32(mm1, mm_ssrc); + mm2 = _mm_cmpeq_epi32(mm2, mm_ssrc); + mm3 = _mm_cmpeq_epi32(mm3, mm_ssrc); + mm4 = _mm_cmpeq_epi32(mm4, mm_ssrc); + mm1 = _mm_packs_epi32(mm1, mm2); + mm3 = _mm_packs_epi32(mm3, mm4); + mm1 = _mm_packs_epi16(mm1, mm3); + uint32_t mask = _mm_movemask_epi8(mm1); + if (mask) { +#if defined(_MSC_VER) + unsigned long bit_pos; + _BitScanForward(&bit_pos, mask); + pos += bit_pos; +#else + pos += __builtin_ctz(mask); +#endif + + goto done; + } + } + + if (pos < n) { + __m128i mm1 = _mm_loadu_si128((const __m128i *)(ssrcs + pos)); + __m128i mm2 = _mm_loadu_si128((const __m128i *)(ssrcs + pos + 4u)); + mm1 = _mm_cmpeq_epi32(mm1, mm_ssrc); + mm2 = _mm_cmpeq_epi32(mm2, mm_ssrc); + mm1 = _mm_packs_epi32(mm1, mm2); + + uint32_t mask = _mm_movemask_epi8(mm1); + if (mask) { +#if defined(_MSC_VER) + unsigned long bit_pos; + _BitScanForward(&bit_pos, mask); + pos += bit_pos / 2u; +#else + pos += __builtin_ctz(mask) / 2u; +#endif + goto done; + } + + pos += 8u; + } + +done: + return pos; +#else + /* walk down list until ssrc is found */ + uint32_t pos = 0, n = streams->size; + for (; pos < n; ++pos) { + if (streams->ssrcs[pos] == ssrc) + break; + } + + return pos; +#endif +} + +/* + * Reserves storage to be able to store at least the specified number + * of elements. + */ +static srtp_err_status_t srtp_stream_list_reserve(srtp_stream_list_t *streams, + uint32_t new_capacity) +{ + if (new_capacity > streams->capacity) { + uint32_t *ssrcs; + srtp_stream_ctx_t **stream_ptrs; + + if (new_capacity > (UINT32_MAX - 15u)) + return srtp_err_status_alloc_fail; + + new_capacity = (new_capacity + 15u) & ~((uint32_t)15u); + + ssrcs = (uint32_t *)srtp_crypto_alloc((size_t)new_capacity * + sizeof(uint32_t)); + if (!ssrcs) + return srtp_err_status_alloc_fail; + stream_ptrs = (srtp_stream_ctx_t **)srtp_crypto_alloc( + (size_t)new_capacity * sizeof(srtp_stream_ctx_t *)); + if (!stream_ptrs) { + srtp_crypto_free(ssrcs); + return srtp_err_status_alloc_fail; + } + + if (streams->size > 0u) { + memcpy(ssrcs, streams->ssrcs, + (size_t)streams->size * sizeof(uint32_t)); + memcpy(stream_ptrs, streams->streams, + (size_t)streams->size * sizeof(srtp_stream_ctx_t *)); + } + + srtp_crypto_free(streams->ssrcs); + srtp_crypto_free(streams->streams); + streams->streams = stream_ptrs; + streams->ssrcs = ssrcs; + + streams->capacity = new_capacity; + } + + return srtp_err_status_ok; +} + +/* + * Inserts a new stream at the end of the list. The newly added stream + * and its SSRC must not be already present in the list. + */ +static srtp_err_status_t srtp_stream_list_push_back(srtp_stream_list_t *streams, + srtp_stream_ctx_t *stream) +{ + uint32_t pos; + srtp_err_status_t status = + srtp_stream_list_reserve(streams, streams->size + 1u); + if (status) + return status; + pos = streams->size++; + streams->ssrcs[pos] = stream->ssrc; + streams->streams[pos] = stream; + + return srtp_err_status_ok; +} + +/* + * Erases an element at the given position and deallocates the stream. + */ +static srtp_err_status_t srtp_stream_list_erase( + srtp_stream_list_t *streams, + uint32_t pos, + const srtp_stream_ctx_t *stream_template) +{ + uint32_t tail_size, last_pos; + srtp_err_status_t status = + srtp_stream_dealloc(streams->streams[pos], stream_template); + if (status) + return status; + + last_pos = --streams->size; + tail_size = last_pos - pos; + if (tail_size > 0u) { + memmove(streams->streams + pos, streams->streams + pos + 1, + (size_t)tail_size * sizeof(*streams->streams)); + memmove(streams->ssrcs + pos, streams->ssrcs + pos + 1, + (size_t)tail_size * sizeof(*streams->ssrcs)); + } + + streams->streams[last_pos] = NULL; + streams->ssrcs[last_pos] = 0u; + + return srtp_err_status_ok; +} + +/* + * Clears the list of streams + */ +static srtp_err_status_t srtp_stream_list_clear( + srtp_stream_list_t *streams, + const srtp_stream_ctx_t *stream_template) +{ + srtp_err_status_t status = srtp_err_status_ok; + uint32_t pos = 0u, n = streams->size; + uint32_t count_left; + + /* + * we take a conservative deallocation strategy - if we encounter an + * error deallocating a stream, then we stop trying to deallocate + * memory and just return an error + */ + for (; pos < n; ++pos) { + status = srtp_stream_dealloc(streams->streams[pos], stream_template); + if (status) + break; + } + + count_left = n - pos; + if (count_left) { + /* move the elements we failed to deallocate to the beginning of the + * list */ + memmove(streams->streams, streams->streams + pos, + (size_t)count_left * sizeof(*streams->streams)); + memmove(streams->ssrcs, streams->ssrcs + pos, + (size_t)count_left * sizeof(*streams->ssrcs)); + } + + memset(streams->streams + count_left, 0, + (size_t)pos * sizeof(*streams->streams)); + memset(streams->ssrcs + count_left, 0, + (size_t)pos * sizeof(*streams->ssrcs)); + streams->size = count_left; + + return status; +} + +/* + * Clears and deallocates memory allocated for the list of streams. + * Does not deallocate the list structure itself. + */ +static srtp_err_status_t srtp_stream_list_destroy( + srtp_stream_list_t *streams, + const srtp_stream_ctx_t *stream_template) +{ + srtp_err_status_t status = srtp_stream_list_clear(streams, stream_template); + + if (status == srtp_err_status_ok && streams->capacity > 0u) { + srtp_crypto_free(streams->streams); + streams->streams = NULL; + srtp_crypto_free(streams->ssrcs); + streams->ssrcs = NULL; + streams->size = 0u; + streams->capacity = 0u; + } + + return status; +} + /* * key derivation functions, internal to libSRTP * @@ -2045,9 +2288,12 @@ static srtp_err_status_t srtp_unprotect_aead(srtp_ctx_t *ctx, return status; } - /* add new stream to the head of the stream_list */ - new_stream->next = ctx->stream_list; - ctx->stream_list = new_stream; + /* add new stream to the stream_list */ + status = srtp_stream_list_push_back(&ctx->stream_list, new_stream); + if (status) { + srtp_stream_dealloc(new_stream, ctx->stream_template); + return status; + } /* set stream (the pointer used in this function) */ stream = new_stream; @@ -2136,9 +2382,12 @@ srtp_err_status_t srtp_protect_mki(srtp_ctx_t *ctx, if (status) return status; - /* add new stream to the head of the stream_list */ - new_stream->next = ctx->stream_list; - ctx->stream_list = new_stream; + /* add new stream to the stream_list */ + status = srtp_stream_list_push_back(&ctx->stream_list, new_stream); + if (status) { + srtp_stream_dealloc(new_stream, ctx->stream_template); + return status; + } /* set direction to outbound */ new_stream->direction = dir_srtp_sender; @@ -2735,9 +2984,12 @@ srtp_err_status_t srtp_unprotect_mki(srtp_ctx_t *ctx, if (status) return status; - /* add new stream to the head of the stream_list */ - new_stream->next = ctx->stream_list; - ctx->stream_list = new_stream; + /* add new stream to the stream_list */ + status = srtp_stream_list_push_back(&ctx->stream_list, new_stream); + if (status) { + srtp_stream_dealloc(new_stream, ctx->stream_template); + return status; + } /* set stream (the pointer used in this function) */ stream = new_stream; @@ -2804,40 +3056,19 @@ srtp_err_status_t srtp_shutdown() srtp_stream_ctx_t *srtp_get_stream(srtp_t srtp, uint32_t ssrc) { - srtp_stream_ctx_t *stream; - - /* walk down list until ssrc is found */ - stream = srtp->stream_list; - while (stream != NULL) { - if (stream->ssrc == ssrc) - return stream; - stream = stream->next; - } - - /* we haven't found our ssrc, so return a null */ - return NULL; + srtp_stream_ctx_t *stream = NULL; + uint32_t pos = srtp_stream_list_find(&srtp->stream_list, ssrc); + if (pos < srtp->stream_list.size) + stream = srtp->stream_list.streams[pos]; + return stream; } srtp_err_status_t srtp_dealloc(srtp_t session) { - srtp_stream_ctx_t *stream; - srtp_err_status_t status; - - /* - * we take a conservative deallocation strategy - if we encounter an - * error deallocating a stream, then we stop trying to deallocate - * memory and just return an error - */ - - /* walk list of streams, deallocating as we go */ - stream = session->stream_list; - while (stream != NULL) { - srtp_stream_t next = stream->next; - status = srtp_stream_dealloc(stream, session->stream_template); - if (status) - return status; - stream = next; - } + srtp_err_status_t status = srtp_stream_list_destroy( + &session->stream_list, session->stream_template); + if (status) + return status; /* deallocate stream template, if there is one */ if (session->stream_template != NULL) { @@ -2906,8 +3137,11 @@ srtp_err_status_t srtp_add_stream(srtp_t session, const srtp_policy_t *policy) session->stream_template->direction = dir_srtp_receiver; break; case (ssrc_specific): - tmp->next = session->stream_list; - session->stream_list = tmp; + status = srtp_stream_list_push_back(&session->stream_list, tmp); + if (status) { + srtp_stream_dealloc(tmp, NULL); + return status; + } break; case (ssrc_undefined): default: @@ -2943,8 +3177,8 @@ srtp_err_status_t srtp_create(srtp_t *session, /* handle for session */ * loop over elements in the policy list, allocating and * initializing a stream for each element */ + srtp_stream_list_init(&ctx->stream_list); ctx->stream_template = NULL; - ctx->stream_list = NULL; ctx->user_data = NULL; while (policy != NULL) { stat = srtp_add_stream(ctx, policy); @@ -2964,35 +3198,20 @@ srtp_err_status_t srtp_create(srtp_t *session, /* handle for session */ srtp_err_status_t srtp_remove_stream(srtp_t session, uint32_t ssrc) { - srtp_stream_ctx_t *stream, *last_stream; - srtp_err_status_t status; + uint32_t pos; /* sanity check arguments */ if (session == NULL) return srtp_err_status_bad_param; /* find stream in list; complain if not found */ - last_stream = stream = session->stream_list; - while ((stream != NULL) && (ssrc != stream->ssrc)) { - last_stream = stream; - stream = stream->next; - } - if (stream == NULL) + pos = srtp_stream_list_find(&session->stream_list, ssrc); + if (pos >= session->stream_list.size) return srtp_err_status_no_ctx; /* remove stream from the list */ - if (last_stream == stream) - /* stream was first in list */ - session->stream_list = stream->next; - else - last_stream->next = stream->next; - - /* deallocate the stream */ - status = srtp_stream_dealloc(stream, session->stream_template); - if (status) - return status; - - return srtp_err_status_ok; + return srtp_stream_list_erase(&session->stream_list, pos, + session->stream_template); } srtp_err_status_t srtp_update(srtp_t session, const srtp_policy_t *policy) @@ -3027,7 +3246,7 @@ static srtp_err_status_t update_template_streams(srtp_t session, { srtp_err_status_t status; srtp_stream_t new_stream_template; - srtp_stream_t new_stream_list = NULL; + srtp_stream_list_t new_stream_list; status = srtp_valid_policy(policy); if (status != srtp_err_status_ok) { @@ -3051,76 +3270,74 @@ static srtp_err_status_t update_template_streams(srtp_t session, return status; } + srtp_stream_list_init(&new_stream_list); + /* for all old templated streams */ - for (;;) { - srtp_stream_t stream; + for (uint32_t pos = 0u, n = session->stream_list.size; pos < n; ++pos) { + srtp_stream_t old_stream, new_stream; uint32_t ssrc; srtp_xtd_seq_num_t old_index; srtp_rdb_t old_rtcp_rdb; - stream = session->stream_list; - while ((stream != NULL) && - (stream->session_keys[0].rtp_auth != - session->stream_template->session_keys[0].rtp_auth)) { - stream = stream->next; - } - if (stream == NULL) { - /* no more templated streams */ - break; + old_stream = session->stream_list.streams[pos]; + if (old_stream->session_keys[0].rtp_auth != + session->stream_template->session_keys[0].rtp_auth) { + continue; } - /* save old extendard seq */ - ssrc = stream->ssrc; - old_index = stream->rtp_rdbx.index; - old_rtcp_rdb = stream->rtcp_rdb; + /* save old extended seq */ + ssrc = old_stream->ssrc; + old_index = old_stream->rtp_rdbx.index; + old_rtcp_rdb = old_stream->rtcp_rdb; - /* remove stream */ - status = srtp_remove_stream(session, ssrc); + /* allocate and initialize a new stream */ + status = srtp_stream_clone(new_stream_template, ssrc, &new_stream); if (status) { - /* free new allocations */ - while (new_stream_list != NULL) { - srtp_stream_t next = new_stream_list->next; - srtp_stream_dealloc(new_stream_list, new_stream_template); - new_stream_list = next; - } + srtp_stream_list_destroy(&new_stream_list, new_stream_template); srtp_stream_dealloc(new_stream_template, NULL); return status; } - /* allocate and initialize a new stream */ - status = srtp_stream_clone(new_stream_template, ssrc, &stream); + /* restore old extended seq */ + new_stream->rtp_rdbx.index = old_index; + new_stream->rtcp_rdb = old_rtcp_rdb; + + /* insert into the new stream list */ + status = srtp_stream_list_push_back(&new_stream_list, new_stream); if (status) { - /* free new allocations */ - while (new_stream_list != NULL) { - srtp_stream_t next = new_stream_list->next; - srtp_stream_dealloc(new_stream_list, new_stream_template); - new_stream_list = next; - } + srtp_stream_list_destroy(&new_stream_list, new_stream_template); srtp_stream_dealloc(new_stream_template, NULL); return status; } - /* add new stream to the head of the new_stream_list */ - stream->next = new_stream_list; - new_stream_list = stream; + /* + * Repurpose the ssrc array to store positions of the stream + * in the original list to speed up merging the updated streams + * into the original list. Otherwise we would have to perform + * lookup by ssrc for each updated stream. + */ + new_stream_list.ssrcs[new_stream_list.size - 1u] = pos; + } - /* restore old extended seq */ - stream->rtp_rdbx.index = old_index; - stream->rtcp_rdb = old_rtcp_rdb; + /* swap updated streams with the ones in the original list */ + for (uint32_t pos = 0u, n = new_stream_list.size; pos < n; ++pos) { + srtp_stream_t old_stream, new_stream; + uint32_t old_pos = new_stream_list.ssrcs[pos]; + + new_stream = new_stream_list.streams[pos]; + old_stream = session->stream_list.streams[old_pos]; + session->stream_list.streams[old_pos] = new_stream; + new_stream_list.streams[pos] = old_stream; + new_stream_list.ssrcs[pos] = old_stream->ssrc; } + + /* dealloc old streams */ + srtp_stream_list_destroy(&new_stream_list, session->stream_template); /* dealloc old template */ srtp_stream_dealloc(session->stream_template, NULL); /* set new template */ session->stream_template = new_stream_template; - /* add new list */ - if (new_stream_list) { - srtp_stream_t tail = new_stream_list; - while (tail->next) { - tail = tail->next; - } - tail->next = session->stream_list; - session->stream_list = new_stream_list; - } + return status; } @@ -3888,9 +4105,12 @@ static srtp_err_status_t srtp_unprotect_rtcp_aead( return status; } - /* add new stream to the head of the stream_list */ - new_stream->next = ctx->stream_list; - ctx->stream_list = new_stream; + /* add new stream to the stream_list */ + status = srtp_stream_list_push_back(&ctx->stream_list, new_stream); + if (status) { + srtp_stream_dealloc(new_stream, ctx->stream_template); + return status; + } /* set stream (the pointer used in this function) */ stream = new_stream; @@ -3954,9 +4174,12 @@ srtp_err_status_t srtp_protect_rtcp_mki(srtp_t ctx, if (status) return status; - /* add new stream to the head of the stream_list */ - new_stream->next = ctx->stream_list; - ctx->stream_list = new_stream; + /* add new stream to the stream_list */ + status = srtp_stream_list_push_back(&ctx->stream_list, new_stream); + if (status) { + srtp_stream_dealloc(new_stream, ctx->stream_template); + return status; + } /* set stream (the pointer used in this function) */ stream = new_stream; @@ -4401,9 +4624,12 @@ srtp_err_status_t srtp_unprotect_rtcp_mki(srtp_t ctx, if (status) return status; - /* add new stream to the head of the stream_list */ - new_stream->next = ctx->stream_list; - ctx->stream_list = new_stream; + /* add new stream to the stream_list */ + status = srtp_stream_list_push_back(&ctx->stream_list, new_stream); + if (status) { + srtp_stream_dealloc(new_stream, ctx->stream_template); + return status; + } /* set stream (the pointer used in this function) */ stream = new_stream; @@ -4596,7 +4822,7 @@ srtp_err_status_t get_protect_trailer_length(srtp_t session, return srtp_err_status_bad_param; } - if (session->stream_template == NULL && session->stream_list == NULL) { + if (session->stream_template == NULL && session->stream_list.size == 0u) { return srtp_err_status_bad_param; } @@ -4609,10 +4835,9 @@ srtp_err_status_t get_protect_trailer_length(srtp_t session, length); } - stream = session->stream_list; - - while (stream != NULL) { + for (uint32_t pos = 0u, n = session->stream_list.size; pos < n; ++pos) { uint32_t temp_length; + stream = session->stream_list.streams[pos]; if (stream_get_protect_trailer_length(stream, is_rtp, use_mki, mki_index, &temp_length) == srtp_err_status_ok) { @@ -4620,7 +4845,6 @@ srtp_err_status_t get_protect_trailer_length(srtp_t session, *length = temp_length; } } - stream = stream->next; } return srtp_err_status_ok; diff --git a/test/srtp_driver.c b/test/srtp_driver.c index db6d3f48e..030446761 100644 --- a/test/srtp_driver.c +++ b/test/srtp_driver.c @@ -1517,8 +1517,8 @@ srtp_err_status_t srtp_session_print_policy(srtp_t srtp) } /* loop over streams in session, printing the policy of each */ - stream = srtp->stream_list; - while (stream != NULL) { + for (uint32_t pos = 0u, n = srtp->stream_list.size; pos < n; ++pos) { + stream = srtp->stream_list.streams[pos]; if (stream->rtp_services > sec_serv_conf_and_auth) { return srtp_err_status_bad_param; } @@ -1555,9 +1555,6 @@ srtp_err_status_t srtp_session_print_policy(srtp_t srtp) } else { printf("none\n"); } - - /* advance to next stream in the list */ - stream = stream->next; } return srtp_err_status_ok; }