diff --git a/include/srtp_priv.h b/include/srtp_priv.h index 48dc65c7d..f30f5bbac 100644 --- a/include/srtp_priv.h +++ b/include/srtp_priv.h @@ -149,14 +149,27 @@ typedef struct srtp_stream_ctx_t_ { int *enc_xtn_hdr; int enc_xtn_hdr_count; uint32_t pending_roc; - struct srtp_stream_ctx_t_ *next; /* linked list of streams */ } strp_stream_ctx_t_; +/* + * An srtp_stream_list_t is a list of streams searchable by SSRC. + * + * Pointers to streams and their respective SSRCs are stored in two arrays, + * where the stream pointer and the SSRC at the same index correspond to each + * other. + */ +typedef struct srtp_stream_list_t { + srtp_stream_ctx_t **streams; + uint32_t *ssrcs; + uint32_t size; + uint32_t capacity; +} srtp_stream_list_t; + /* * an srtp_ctx_t holds a stream list and a service description */ typedef struct srtp_ctx_t_ { - struct srtp_stream_ctx_t_ *stream_list; /* linked list of streams */ + srtp_stream_list_t stream_list; /* list of streams */ struct srtp_stream_ctx_t_ *stream_template; /* act as template for other */ /* streams */ void *user_data; /* user custom data */ diff --git a/srtp/srtp.c b/srtp/srtp.c index dbb099095..81f29429a 100644 --- a/srtp/srtp.c +++ b/srtp/srtp.c @@ -59,6 +59,7 @@ #include "aes_icm_ext.h" #endif +#include #include #ifdef HAVE_NETINET_IN_H #include @@ -66,6 +67,13 @@ #include #endif +#if defined(__SSE2__) +#include +#if defined(_MSC_VER) +#include +#endif +#endif + /* the debug module for srtp */ srtp_debug_module_t mod_srtp = { 0, /* debugging is off by default */ @@ -549,11 +557,246 @@ srtp_err_status_t srtp_stream_clone(const srtp_stream_ctx_t *stream_template, str->enc_xtn_hdr = stream_template->enc_xtn_hdr; str->enc_xtn_hdr_count = stream_template->enc_xtn_hdr_count; - /* defensive coding */ - str->next = NULL; return srtp_err_status_ok; } +/* + * Initializes an empty list of streams + */ +static inline void srtp_stream_list_init(srtp_stream_list_t *streams) +{ + memset(streams, 0, sizeof(*streams)); +} + +/* + * Returns an index of the stream corresponding to ssrc, + * or >= streams->size if no stream exists for that ssrc. + */ +static uint32_t srtp_stream_list_find(const srtp_stream_list_t *streams, + uint32_t ssrc) +{ +#if defined(__SSE2__) + const uint32_t *const ssrcs = streams->ssrcs; + const __m128i mm_ssrc = _mm_set1_epi32(ssrc); + uint32_t pos = 0u, n = (streams->size + 7u) & ~(uint32_t)(7u); + for (uint32_t m = n & ~(uint32_t)(15u); pos < m; pos += 16u) { + __m128i mm1 = _mm_loadu_si128((const __m128i *)(ssrcs + pos)); + __m128i mm2 = _mm_loadu_si128((const __m128i *)(ssrcs + pos + 4u)); + __m128i mm3 = _mm_loadu_si128((const __m128i *)(ssrcs + pos + 8u)); + __m128i mm4 = _mm_loadu_si128((const __m128i *)(ssrcs + pos + 12u)); + mm1 = _mm_cmpeq_epi32(mm1, mm_ssrc); + mm2 = _mm_cmpeq_epi32(mm2, mm_ssrc); + mm3 = _mm_cmpeq_epi32(mm3, mm_ssrc); + mm4 = _mm_cmpeq_epi32(mm4, mm_ssrc); + mm1 = _mm_packs_epi32(mm1, mm2); + mm3 = _mm_packs_epi32(mm3, mm4); + mm1 = _mm_packs_epi16(mm1, mm3); + uint32_t mask = _mm_movemask_epi8(mm1); + if (mask) { +#if defined(_MSC_VER) + unsigned long bit_pos; + _BitScanForward(&bit_pos, mask); + pos += bit_pos; +#else + pos += __builtin_ctz(mask); +#endif + + goto done; + } + } + + if (pos < n) { + __m128i mm1 = _mm_loadu_si128((const __m128i *)(ssrcs + pos)); + __m128i mm2 = _mm_loadu_si128((const __m128i *)(ssrcs + pos + 4u)); + mm1 = _mm_cmpeq_epi32(mm1, mm_ssrc); + mm2 = _mm_cmpeq_epi32(mm2, mm_ssrc); + mm1 = _mm_packs_epi32(mm1, mm2); + + uint32_t mask = _mm_movemask_epi8(mm1); + if (mask) { +#if defined(_MSC_VER) + unsigned long bit_pos; + _BitScanForward(&bit_pos, mask); + pos += bit_pos / 2u; +#else + pos += __builtin_ctz(mask) / 2u; +#endif + goto done; + } + + pos += 8u; + } + +done: + return pos; +#else + /* walk down list until ssrc is found */ + uint32_t pos = 0, n = streams->size; + for (; pos < n; ++pos) { + if (streams->ssrcs[pos] == ssrc) + break; + } + + return pos; +#endif +} + +/* + * Reserves storage to be able to store at least the specified number + * of elements. + */ +static srtp_err_status_t srtp_stream_list_reserve(srtp_stream_list_t *streams, + uint32_t new_capacity) +{ + if (new_capacity > streams->capacity) { + uint32_t *ssrcs; + srtp_stream_ctx_t **stream_ptrs; + + if (new_capacity > (UINT32_MAX - 15u)) + return srtp_err_status_alloc_fail; + + new_capacity = (new_capacity + 15u) & ~((uint32_t)15u); + + ssrcs = (uint32_t *)srtp_crypto_alloc((size_t)new_capacity * + sizeof(uint32_t)); + if (!ssrcs) + return srtp_err_status_alloc_fail; + stream_ptrs = (srtp_stream_ctx_t **)srtp_crypto_alloc( + (size_t)new_capacity * sizeof(srtp_stream_ctx_t *)); + if (!stream_ptrs) { + srtp_crypto_free(ssrcs); + return srtp_err_status_alloc_fail; + } + + if (streams->size > 0u) { + memcpy(ssrcs, streams->ssrcs, + (size_t)streams->size * sizeof(uint32_t)); + memcpy(stream_ptrs, streams->streams, + (size_t)streams->size * sizeof(srtp_stream_ctx_t *)); + } + + srtp_crypto_free(streams->ssrcs); + srtp_crypto_free(streams->streams); + streams->streams = stream_ptrs; + streams->ssrcs = ssrcs; + + streams->capacity = new_capacity; + } + + return srtp_err_status_ok; +} + +/* + * Inserts a new stream at the end of the list. The newly added stream + * and its SSRC must not be already present in the list. + */ +static srtp_err_status_t srtp_stream_list_push_back(srtp_stream_list_t *streams, + srtp_stream_ctx_t *stream) +{ + uint32_t pos; + srtp_err_status_t status = + srtp_stream_list_reserve(streams, streams->size + 1u); + if (status) + return status; + pos = streams->size++; + streams->ssrcs[pos] = stream->ssrc; + streams->streams[pos] = stream; + + return srtp_err_status_ok; +} + +/* + * Erases an element at the given position and deallocates the stream. + */ +static srtp_err_status_t srtp_stream_list_erase( + srtp_stream_list_t *streams, + uint32_t pos, + const srtp_stream_ctx_t *stream_template) +{ + uint32_t tail_size, last_pos; + srtp_err_status_t status = + srtp_stream_dealloc(streams->streams[pos], stream_template); + if (status) + return status; + + last_pos = --streams->size; + tail_size = last_pos - pos; + if (tail_size > 0u) { + memmove(streams->streams + pos, streams->streams + pos + 1, + (size_t)tail_size * sizeof(*streams->streams)); + memmove(streams->ssrcs + pos, streams->ssrcs + pos + 1, + (size_t)tail_size * sizeof(*streams->ssrcs)); + } + + streams->streams[last_pos] = NULL; + streams->ssrcs[last_pos] = 0u; + + return srtp_err_status_ok; +} + +/* + * Clears the list of streams + */ +static srtp_err_status_t srtp_stream_list_clear( + srtp_stream_list_t *streams, + const srtp_stream_ctx_t *stream_template) +{ + srtp_err_status_t status = srtp_err_status_ok; + uint32_t pos = 0u, n = streams->size; + uint32_t count_left; + + /* + * we take a conservative deallocation strategy - if we encounter an + * error deallocating a stream, then we stop trying to deallocate + * memory and just return an error + */ + for (; pos < n; ++pos) { + status = srtp_stream_dealloc(streams->streams[pos], stream_template); + if (status) + break; + } + + count_left = n - pos; + if (count_left) { + /* move the elements we failed to deallocate to the beginning of the + * list */ + memmove(streams->streams, streams->streams + pos, + (size_t)count_left * sizeof(*streams->streams)); + memmove(streams->ssrcs, streams->ssrcs + pos, + (size_t)count_left * sizeof(*streams->ssrcs)); + } + + memset(streams->streams + count_left, 0, + (size_t)pos * sizeof(*streams->streams)); + memset(streams->ssrcs + count_left, 0, + (size_t)pos * sizeof(*streams->ssrcs)); + streams->size = count_left; + + return status; +} + +/* + * Clears and deallocates memory allocated for the list of streams. + * Does not deallocate the list structure itself. + */ +static srtp_err_status_t srtp_stream_list_destroy( + srtp_stream_list_t *streams, + const srtp_stream_ctx_t *stream_template) +{ + srtp_err_status_t status = srtp_stream_list_clear(streams, stream_template); + + if (status == srtp_err_status_ok && streams->capacity > 0u) { + srtp_crypto_free(streams->streams); + streams->streams = NULL; + srtp_crypto_free(streams->ssrcs); + streams->ssrcs = NULL; + streams->size = 0u; + streams->capacity = 0u; + } + + return status; +} + /* * key derivation functions, internal to libSRTP * @@ -2045,9 +2288,12 @@ static srtp_err_status_t srtp_unprotect_aead(srtp_ctx_t *ctx, return status; } - /* add new stream to the head of the stream_list */ - new_stream->next = ctx->stream_list; - ctx->stream_list = new_stream; + /* add new stream to the stream_list */ + status = srtp_stream_list_push_back(&ctx->stream_list, new_stream); + if (status) { + srtp_stream_dealloc(new_stream, ctx->stream_template); + return status; + } /* set stream (the pointer used in this function) */ stream = new_stream; @@ -2136,9 +2382,12 @@ srtp_err_status_t srtp_protect_mki(srtp_ctx_t *ctx, if (status) return status; - /* add new stream to the head of the stream_list */ - new_stream->next = ctx->stream_list; - ctx->stream_list = new_stream; + /* add new stream to the stream_list */ + status = srtp_stream_list_push_back(&ctx->stream_list, new_stream); + if (status) { + srtp_stream_dealloc(new_stream, ctx->stream_template); + return status; + } /* set direction to outbound */ new_stream->direction = dir_srtp_sender; @@ -2735,9 +2984,12 @@ srtp_err_status_t srtp_unprotect_mki(srtp_ctx_t *ctx, if (status) return status; - /* add new stream to the head of the stream_list */ - new_stream->next = ctx->stream_list; - ctx->stream_list = new_stream; + /* add new stream to the stream_list */ + status = srtp_stream_list_push_back(&ctx->stream_list, new_stream); + if (status) { + srtp_stream_dealloc(new_stream, ctx->stream_template); + return status; + } /* set stream (the pointer used in this function) */ stream = new_stream; @@ -2804,40 +3056,19 @@ srtp_err_status_t srtp_shutdown() srtp_stream_ctx_t *srtp_get_stream(srtp_t srtp, uint32_t ssrc) { - srtp_stream_ctx_t *stream; - - /* walk down list until ssrc is found */ - stream = srtp->stream_list; - while (stream != NULL) { - if (stream->ssrc == ssrc) - return stream; - stream = stream->next; - } - - /* we haven't found our ssrc, so return a null */ - return NULL; + srtp_stream_ctx_t *stream = NULL; + uint32_t pos = srtp_stream_list_find(&srtp->stream_list, ssrc); + if (pos < srtp->stream_list.size) + stream = srtp->stream_list.streams[pos]; + return stream; } srtp_err_status_t srtp_dealloc(srtp_t session) { - srtp_stream_ctx_t *stream; - srtp_err_status_t status; - - /* - * we take a conservative deallocation strategy - if we encounter an - * error deallocating a stream, then we stop trying to deallocate - * memory and just return an error - */ - - /* walk list of streams, deallocating as we go */ - stream = session->stream_list; - while (stream != NULL) { - srtp_stream_t next = stream->next; - status = srtp_stream_dealloc(stream, session->stream_template); - if (status) - return status; - stream = next; - } + srtp_err_status_t status = srtp_stream_list_destroy( + &session->stream_list, session->stream_template); + if (status) + return status; /* deallocate stream template, if there is one */ if (session->stream_template != NULL) { @@ -2906,8 +3137,11 @@ srtp_err_status_t srtp_add_stream(srtp_t session, const srtp_policy_t *policy) session->stream_template->direction = dir_srtp_receiver; break; case (ssrc_specific): - tmp->next = session->stream_list; - session->stream_list = tmp; + status = srtp_stream_list_push_back(&session->stream_list, tmp); + if (status) { + srtp_stream_dealloc(tmp, NULL); + return status; + } break; case (ssrc_undefined): default: @@ -2943,8 +3177,8 @@ srtp_err_status_t srtp_create(srtp_t *session, /* handle for session */ * loop over elements in the policy list, allocating and * initializing a stream for each element */ + srtp_stream_list_init(&ctx->stream_list); ctx->stream_template = NULL; - ctx->stream_list = NULL; ctx->user_data = NULL; while (policy != NULL) { stat = srtp_add_stream(ctx, policy); @@ -2964,35 +3198,20 @@ srtp_err_status_t srtp_create(srtp_t *session, /* handle for session */ srtp_err_status_t srtp_remove_stream(srtp_t session, uint32_t ssrc) { - srtp_stream_ctx_t *stream, *last_stream; - srtp_err_status_t status; + uint32_t pos; /* sanity check arguments */ if (session == NULL) return srtp_err_status_bad_param; /* find stream in list; complain if not found */ - last_stream = stream = session->stream_list; - while ((stream != NULL) && (ssrc != stream->ssrc)) { - last_stream = stream; - stream = stream->next; - } - if (stream == NULL) + pos = srtp_stream_list_find(&session->stream_list, ssrc); + if (pos >= session->stream_list.size) return srtp_err_status_no_ctx; /* remove stream from the list */ - if (last_stream == stream) - /* stream was first in list */ - session->stream_list = stream->next; - else - last_stream->next = stream->next; - - /* deallocate the stream */ - status = srtp_stream_dealloc(stream, session->stream_template); - if (status) - return status; - - return srtp_err_status_ok; + return srtp_stream_list_erase(&session->stream_list, pos, + session->stream_template); } srtp_err_status_t srtp_update(srtp_t session, const srtp_policy_t *policy) @@ -3027,7 +3246,7 @@ static srtp_err_status_t update_template_streams(srtp_t session, { srtp_err_status_t status; srtp_stream_t new_stream_template; - srtp_stream_t new_stream_list = NULL; + srtp_stream_list_t new_stream_list; status = srtp_valid_policy(policy); if (status != srtp_err_status_ok) { @@ -3051,76 +3270,74 @@ static srtp_err_status_t update_template_streams(srtp_t session, return status; } + srtp_stream_list_init(&new_stream_list); + /* for all old templated streams */ - for (;;) { - srtp_stream_t stream; + for (uint32_t pos = 0u, n = session->stream_list.size; pos < n; ++pos) { + srtp_stream_t old_stream, new_stream; uint32_t ssrc; srtp_xtd_seq_num_t old_index; srtp_rdb_t old_rtcp_rdb; - stream = session->stream_list; - while ((stream != NULL) && - (stream->session_keys[0].rtp_auth != - session->stream_template->session_keys[0].rtp_auth)) { - stream = stream->next; - } - if (stream == NULL) { - /* no more templated streams */ - break; + old_stream = session->stream_list.streams[pos]; + if (old_stream->session_keys[0].rtp_auth != + session->stream_template->session_keys[0].rtp_auth) { + continue; } - /* save old extendard seq */ - ssrc = stream->ssrc; - old_index = stream->rtp_rdbx.index; - old_rtcp_rdb = stream->rtcp_rdb; + /* save old extended seq */ + ssrc = old_stream->ssrc; + old_index = old_stream->rtp_rdbx.index; + old_rtcp_rdb = old_stream->rtcp_rdb; - /* remove stream */ - status = srtp_remove_stream(session, ssrc); + /* allocate and initialize a new stream */ + status = srtp_stream_clone(new_stream_template, ssrc, &new_stream); if (status) { - /* free new allocations */ - while (new_stream_list != NULL) { - srtp_stream_t next = new_stream_list->next; - srtp_stream_dealloc(new_stream_list, new_stream_template); - new_stream_list = next; - } + srtp_stream_list_destroy(&new_stream_list, new_stream_template); srtp_stream_dealloc(new_stream_template, NULL); return status; } - /* allocate and initialize a new stream */ - status = srtp_stream_clone(new_stream_template, ssrc, &stream); + /* restore old extended seq */ + new_stream->rtp_rdbx.index = old_index; + new_stream->rtcp_rdb = old_rtcp_rdb; + + /* insert into the new stream list */ + status = srtp_stream_list_push_back(&new_stream_list, new_stream); if (status) { - /* free new allocations */ - while (new_stream_list != NULL) { - srtp_stream_t next = new_stream_list->next; - srtp_stream_dealloc(new_stream_list, new_stream_template); - new_stream_list = next; - } + srtp_stream_list_destroy(&new_stream_list, new_stream_template); srtp_stream_dealloc(new_stream_template, NULL); return status; } - /* add new stream to the head of the new_stream_list */ - stream->next = new_stream_list; - new_stream_list = stream; + /* + * Repurpose the ssrc array to store positions of the stream + * in the original list to speed up merging the updated streams + * into the original list. Otherwise we would have to perform + * lookup by ssrc for each updated stream. + */ + new_stream_list.ssrcs[new_stream_list.size - 1u] = pos; + } - /* restore old extended seq */ - stream->rtp_rdbx.index = old_index; - stream->rtcp_rdb = old_rtcp_rdb; + /* swap updated streams with the ones in the original list */ + for (uint32_t pos = 0u, n = new_stream_list.size; pos < n; ++pos) { + srtp_stream_t old_stream, new_stream; + uint32_t old_pos = new_stream_list.ssrcs[pos]; + + new_stream = new_stream_list.streams[pos]; + old_stream = session->stream_list.streams[old_pos]; + session->stream_list.streams[old_pos] = new_stream; + new_stream_list.streams[pos] = old_stream; + new_stream_list.ssrcs[pos] = old_stream->ssrc; } + + /* dealloc old streams */ + srtp_stream_list_destroy(&new_stream_list, session->stream_template); /* dealloc old template */ srtp_stream_dealloc(session->stream_template, NULL); /* set new template */ session->stream_template = new_stream_template; - /* add new list */ - if (new_stream_list) { - srtp_stream_t tail = new_stream_list; - while (tail->next) { - tail = tail->next; - } - tail->next = session->stream_list; - session->stream_list = new_stream_list; - } + return status; } @@ -3888,9 +4105,12 @@ static srtp_err_status_t srtp_unprotect_rtcp_aead( return status; } - /* add new stream to the head of the stream_list */ - new_stream->next = ctx->stream_list; - ctx->stream_list = new_stream; + /* add new stream to the stream_list */ + status = srtp_stream_list_push_back(&ctx->stream_list, new_stream); + if (status) { + srtp_stream_dealloc(new_stream, ctx->stream_template); + return status; + } /* set stream (the pointer used in this function) */ stream = new_stream; @@ -3954,9 +4174,12 @@ srtp_err_status_t srtp_protect_rtcp_mki(srtp_t ctx, if (status) return status; - /* add new stream to the head of the stream_list */ - new_stream->next = ctx->stream_list; - ctx->stream_list = new_stream; + /* add new stream to the stream_list */ + status = srtp_stream_list_push_back(&ctx->stream_list, new_stream); + if (status) { + srtp_stream_dealloc(new_stream, ctx->stream_template); + return status; + } /* set stream (the pointer used in this function) */ stream = new_stream; @@ -4401,9 +4624,12 @@ srtp_err_status_t srtp_unprotect_rtcp_mki(srtp_t ctx, if (status) return status; - /* add new stream to the head of the stream_list */ - new_stream->next = ctx->stream_list; - ctx->stream_list = new_stream; + /* add new stream to the stream_list */ + status = srtp_stream_list_push_back(&ctx->stream_list, new_stream); + if (status) { + srtp_stream_dealloc(new_stream, ctx->stream_template); + return status; + } /* set stream (the pointer used in this function) */ stream = new_stream; @@ -4596,7 +4822,7 @@ srtp_err_status_t get_protect_trailer_length(srtp_t session, return srtp_err_status_bad_param; } - if (session->stream_template == NULL && session->stream_list == NULL) { + if (session->stream_template == NULL && session->stream_list.size == 0u) { return srtp_err_status_bad_param; } @@ -4609,10 +4835,9 @@ srtp_err_status_t get_protect_trailer_length(srtp_t session, length); } - stream = session->stream_list; - - while (stream != NULL) { + for (uint32_t pos = 0u, n = session->stream_list.size; pos < n; ++pos) { uint32_t temp_length; + stream = session->stream_list.streams[pos]; if (stream_get_protect_trailer_length(stream, is_rtp, use_mki, mki_index, &temp_length) == srtp_err_status_ok) { @@ -4620,7 +4845,6 @@ srtp_err_status_t get_protect_trailer_length(srtp_t session, *length = temp_length; } } - stream = stream->next; } return srtp_err_status_ok; diff --git a/test/srtp_driver.c b/test/srtp_driver.c index db6d3f48e..030446761 100644 --- a/test/srtp_driver.c +++ b/test/srtp_driver.c @@ -1517,8 +1517,8 @@ srtp_err_status_t srtp_session_print_policy(srtp_t srtp) } /* loop over streams in session, printing the policy of each */ - stream = srtp->stream_list; - while (stream != NULL) { + for (uint32_t pos = 0u, n = srtp->stream_list.size; pos < n; ++pos) { + stream = srtp->stream_list.streams[pos]; if (stream->rtp_services > sec_serv_conf_and_auth) { return srtp_err_status_bad_param; } @@ -1555,9 +1555,6 @@ srtp_err_status_t srtp_session_print_policy(srtp_t srtp) } else { printf("none\n"); } - - /* advance to next stream in the list */ - stream = stream->next; } return srtp_err_status_ok; }