Skip to content

Commit

Permalink
Refactor send/receive part 1
Browse files Browse the repository at this point in the history
The send/receive code is complex and difficult to follow. This is a first small attempt to refactor the code so that it's a little easier to read.

Signed-off-by: Steven Bellock <sbellock@nvidia.com>
  • Loading branch information
steven-bellock authored and jyao1 committed Nov 4, 2024
1 parent 5dc95d2 commit 803849f
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 96 deletions.
115 changes: 47 additions & 68 deletions library/spdm_requester_lib/libspdm_req_send_receive.c
Original file line number Diff line number Diff line change
Expand Up @@ -27,39 +27,37 @@ libspdm_return_t libspdm_send_request(void *spdm_context, const uint32_t *sessio
LIBSPDM_DEBUG((LIBSPDM_DEBUG_INFO,
"libspdm_send_spdm_request[%x] msg %s(0x%x), size (0x%zx): \n",
(session_id != NULL) ? *session_id : 0x0,
libspdm_get_code_str(((spdm_message_header_t *)request)->
request_response_code),
((spdm_message_header_t *)request)->request_response_code,
request_size));
libspdm_get_code_str(((spdm_message_header_t *)request)->request_response_code),
((spdm_message_header_t *)request)->request_response_code, request_size));
LIBSPDM_INTERNAL_DUMP_HEX(request, request_size);

transport_header_size = context->local_context.capability.transport_header_size;
libspdm_get_scratch_buffer(context, (void**) &scratch_buffer, &scratch_buffer_size);
libspdm_get_sender_buffer(context, (void**) &sender_buffer, &sender_buffer_size);
libspdm_get_scratch_buffer(context, (void **)&scratch_buffer, &scratch_buffer_size);
libspdm_get_sender_buffer(context, (void **)&sender_buffer, &sender_buffer_size);

/* This is a problem because original code assumes request is in the sender buffer,
* when it can really be using the scratch space for chunking.
* Did not want to modify ally request handlers to pass this information,
* Did not want to modify all request handlers to pass this information,
* so just making the determination here by examining scratch/sender buffers.
* This may be something that should be refactored in the future. */
#if LIBSPDM_ENABLE_CAPABILITY_CHUNK_CAP
if ((uint8_t*) request >= sender_buffer &&
(uint8_t*)request < sender_buffer + sender_buffer_size) {
if ((uint8_t *)request >= sender_buffer &&
(uint8_t *)request < sender_buffer + sender_buffer_size) {
message = sender_buffer;
message_size = sender_buffer_size;
} else {
if ((uint8_t*)request >=
if ((uint8_t *)request >=
scratch_buffer + libspdm_get_scratch_buffer_sender_receiver_offset(spdm_context)
&& (uint8_t*)request <
&& (uint8_t *)request <
scratch_buffer + libspdm_get_scratch_buffer_sender_receiver_offset(spdm_context)
+ libspdm_get_scratch_buffer_sender_receiver_capacity(spdm_context)) {
message = scratch_buffer +
libspdm_get_scratch_buffer_sender_receiver_offset(spdm_context);
message_size = libspdm_get_scratch_buffer_sender_receiver_capacity(spdm_context);
} else if ((uint8_t*)request >=
} else if ((uint8_t *)request >=
scratch_buffer +
libspdm_get_scratch_buffer_large_sender_receiver_offset(spdm_context)
&& (uint8_t*)request <
&& (uint8_t *)request <
scratch_buffer +
libspdm_get_scratch_buffer_large_sender_receiver_offset(spdm_context) +
libspdm_get_scratch_buffer_large_sender_receiver_capacity(spdm_context)) {
Expand All @@ -86,21 +84,19 @@ libspdm_return_t libspdm_send_request(void *spdm_context, const uint32_t *sessio
/* backup it to last_spdm_request, because the caller wants to compare it with response */
if (((const spdm_message_header_t *)request)->request_response_code != SPDM_RESPOND_IF_READY
&& ((const spdm_message_header_t *)request)->request_response_code != SPDM_CHUNK_GET
&& ((const spdm_message_header_t*) request)->request_response_code != SPDM_CHUNK_SEND) {
&& ((const spdm_message_header_t *)request)->request_response_code != SPDM_CHUNK_SEND) {
libspdm_copy_mem (context->last_spdm_request,
libspdm_get_scratch_buffer_last_spdm_request_capacity(context),
request,
request_size
);
request_size);
context->last_spdm_request_size = request_size;
}

status = context->transport_encode_message(
context, session_id, is_app_message, true, request_size,
request, &message_size, (void **)&message);
if (LIBSPDM_STATUS_IS_ERROR(status)) {
LIBSPDM_DEBUG((LIBSPDM_DEBUG_INFO, "transport_encode_message status - %xu\n",
status));
LIBSPDM_DEBUG((LIBSPDM_DEBUG_INFO, "transport_encode_message status - %xu\n", status));
if ((session_id != NULL) &&
((status == LIBSPDM_STATUS_SEQUENCE_NUMBER_OVERFLOW) ||
(status == LIBSPDM_STATUS_CRYPTO_ERROR))) {
Expand All @@ -110,9 +106,8 @@ libspdm_return_t libspdm_send_request(void *spdm_context, const uint32_t *sessio
}

timeout = context->local_context.capability.rtt;
status = context->send_message(context, message_size, message, timeout);

status = context->send_message(context, message_size, message,
timeout);
if (LIBSPDM_STATUS_IS_ERROR(status)) {
LIBSPDM_DEBUG((LIBSPDM_DEBUG_INFO, "libspdm_send_spdm_request[%x] status - %xu\n",
(session_id != NULL) ? *session_id : 0x0, status));
Expand Down Expand Up @@ -148,14 +143,12 @@ libspdm_return_t libspdm_receive_response(void *spdm_context, const uint32_t *se
timeout = context->local_context.capability.rtt +
((uint64_t)1 << context->connection_info.capability.ct_exponent);
} else {
timeout = context->local_context.capability.rtt +
context->local_context.capability.st1;
timeout = context->local_context.capability.rtt + context->local_context.capability.st1;
}

message = *response;
message_size = *response_size;
status = context->receive_message(context, &message_size,
(void **)&message, timeout);
status = context->receive_message(context, &message_size, (void **)&message, timeout);
if (LIBSPDM_STATUS_IS_ERROR(status)) {
LIBSPDM_DEBUG((LIBSPDM_DEBUG_INFO,
"libspdm_receive_spdm_response[%x] status - %xu\n",
Expand Down Expand Up @@ -232,16 +225,14 @@ libspdm_return_t libspdm_receive_response(void *spdm_context, const uint32_t *se
if (*message_session_id != *session_id) {
LIBSPDM_DEBUG((LIBSPDM_DEBUG_INFO,
"libspdm_receive_spdm_response[%x] GetSessionId - %x\n",
(session_id != NULL) ? *session_id : 0x0,
*message_session_id));
(session_id != NULL) ? *session_id : 0x0, *message_session_id));
goto error;
}
} else {
if (message_session_id != NULL) {
LIBSPDM_DEBUG((LIBSPDM_DEBUG_INFO,
"libspdm_receive_spdm_response[%x] GetSessionId - %x\n",
(session_id != NULL) ? *session_id : 0x0,
*message_session_id));
(session_id != NULL) ? *session_id : 0x0, *message_session_id));
goto error;
}
}
Expand Down Expand Up @@ -342,7 +333,7 @@ libspdm_return_t libspdm_handle_large_request(
/* now we can get sender buffer */
transport_header_size = spdm_context->local_context.capability.transport_header_size;

libspdm_get_scratch_buffer(spdm_context, (void**) &scratch_buffer, &scratch_buffer_size);
libspdm_get_scratch_buffer(spdm_context, (void **)&scratch_buffer, &scratch_buffer_size);

/* Temporary send/receive buffers for chunking are in the scratch space */
message = scratch_buffer + libspdm_get_scratch_buffer_sender_receiver_offset(spdm_context);
Expand Down Expand Up @@ -374,7 +365,7 @@ libspdm_return_t libspdm_handle_large_request(

do {
LIBSPDM_ASSERT(send_info->large_message_capacity >= transport_header_size);
spdm_request = (spdm_chunk_send_request_t*) ((uint8_t*) message + transport_header_size);
spdm_request = (spdm_chunk_send_request_t *)((uint8_t *)message + transport_header_size);
spdm_request_size = message_size - transport_header_size;

spdm_request->header.spdm_version = libspdm_get_connection_version(spdm_context);
Expand All @@ -383,36 +374,33 @@ libspdm_return_t libspdm_handle_large_request(
spdm_request->header.param2 = send_info->chunk_handle;
spdm_request->chunk_seq_no = send_info->chunk_seq_no;
spdm_request->reserved = 0;
chunk_ptr = (uint8_t*) (spdm_request + 1);

if (min_data_transfer_size
- sizeof(spdm_chunk_send_request_t)
< (send_info->large_message_size - send_info->chunk_bytes_transferred)) {
chunk_ptr = (uint8_t *)(spdm_request + 1);

copy_size = min_data_transfer_size
- sizeof(spdm_chunk_send_request_t);
if ((min_data_transfer_size - sizeof(spdm_chunk_send_request_t)) <
(send_info->large_message_size - send_info->chunk_bytes_transferred)) {
copy_size = min_data_transfer_size - sizeof(spdm_chunk_send_request_t);
} else {
copy_size = (send_info->large_message_size - send_info->chunk_bytes_transferred);
}

if (send_info->chunk_seq_no == 0) {
*(uint32_t*) (spdm_request + 1) = (uint32_t) send_info->large_message_size;
*(uint32_t *)(spdm_request + 1) = (uint32_t)send_info->large_message_size;
chunk_ptr += sizeof(uint32_t);
copy_size -= sizeof(uint32_t);
}

spdm_request->chunk_size = (uint32_t) copy_size;
spdm_request->chunk_size = (uint32_t)copy_size;

libspdm_copy_mem(
chunk_ptr, spdm_request_size - ((uint8_t*) spdm_request - (uint8_t*) message),
(uint8_t*)send_info->large_message + send_info->chunk_bytes_transferred, copy_size);
chunk_ptr, spdm_request_size - ((uint8_t *)spdm_request - (uint8_t *)message),
(uint8_t *)send_info->large_message + send_info->chunk_bytes_transferred, copy_size);

send_info->chunk_bytes_transferred += copy_size;
if (send_info->chunk_bytes_transferred >= send_info->large_message_size) {
spdm_request->header.param1 |= SPDM_CHUNK_SEND_REQUEST_ATTRIBUTE_LAST_CHUNK;
}

spdm_request_size = (chunk_ptr + copy_size) - (uint8_t*)spdm_request;
spdm_request_size = (chunk_ptr + copy_size) - (uint8_t *)spdm_request;
status = libspdm_send_request(
spdm_context, session_id, false,
spdm_request_size, spdm_request);
Expand All @@ -433,7 +421,7 @@ libspdm_return_t libspdm_handle_large_request(
if (LIBSPDM_STATUS_IS_ERROR(status)) {
break;
}
spdm_response = (void*) (response);
spdm_response = (void *)(response);

if (response_size < sizeof(spdm_message_header_t)) {
status = LIBSPDM_STATUS_INVALID_MSG_SIZE;
Expand Down Expand Up @@ -477,7 +465,7 @@ libspdm_return_t libspdm_handle_large_request(
if (spdm_response->header.param1
& SPDM_CHUNK_SEND_ACK_RESPONSE_ATTRIBUTE_EARLY_ERROR_DETECTED) {

spdm_error = (spdm_error_response_t *) (spdm_response + 1);
spdm_error = (spdm_error_response_t *)(spdm_response + 1);
if (response_size < (sizeof(spdm_chunk_send_ack_response_t) +
sizeof(spdm_error_response_t))) {
status = LIBSPDM_STATUS_INVALID_MSG_SIZE;
Expand All @@ -500,7 +488,7 @@ libspdm_return_t libspdm_handle_large_request(
libspdm_copy_mem(
send_info->large_message,
send_info->large_message_capacity,
(uint8_t*) (spdm_response + 1),
(uint8_t *)(spdm_response + 1),
response_size - sizeof(spdm_chunk_send_ack_response_t));

send_info->large_message_size =
Expand All @@ -518,11 +506,10 @@ libspdm_return_t libspdm_handle_large_request(
break;
}

chunk_ptr = (uint8_t*) (spdm_response + 1);
chunk_ptr = (uint8_t *)(spdm_response + 1);
send_info->chunk_seq_no++;

if (send_info->chunk_bytes_transferred >= send_info->large_message_size) {

/* All bytes have been transferred. Store response in scratch buffer
* to be read by libspdm_receive_spdm_response */
libspdm_copy_mem(
Expand All @@ -538,7 +525,6 @@ libspdm_return_t libspdm_handle_large_request(
&& send_info->chunk_bytes_transferred < send_info->large_message_size);

if (LIBSPDM_STATUS_IS_ERROR(status)) {

send_info->chunk_in_use = false;
send_info->chunk_handle++; /* Implicit wrap-around*/
send_info->chunk_seq_no = 0;
Expand Down Expand Up @@ -583,16 +569,14 @@ libspdm_return_t libspdm_send_spdm_request(libspdm_context_t *spdm_context,
spdm_context, true,
SPDM_GET_CAPABILITIES_REQUEST_FLAGS_HANDSHAKE_IN_THE_CLEAR_CAP,
SPDM_GET_CAPABILITIES_RESPONSE_FLAGS_HANDSHAKE_IN_THE_CLEAR_CAP)) {
session_info = libspdm_get_session_info_via_session_id(
spdm_context, *session_id);
session_info = libspdm_get_session_info_via_session_id(spdm_context, *session_id);
LIBSPDM_ASSERT(session_info != NULL);
if (session_info == NULL) {
return LIBSPDM_STATUS_INVALID_STATE_LOCAL;
}
session_state = libspdm_secured_message_get_session_state(
session_info->secured_message_context);
if ((session_state == LIBSPDM_SESSION_STATE_HANDSHAKING) &&
!session_info->use_psk) {
if ((session_state == LIBSPDM_SESSION_STATE_HANDSHAKING) && !session_info->use_psk) {
session_id = NULL;
}
}
Expand All @@ -612,8 +596,8 @@ libspdm_return_t libspdm_send_spdm_request(libspdm_context_t *spdm_context,

/* large SPDM message is the SPDM message whose size is greater than the DataTransferSize of the receiving
* SPDM endpoint or greater than the transmit buffer size of the sending SPDM endpoint */
if (((const spdm_message_header_t*) request)->request_response_code != SPDM_GET_VERSION
&& ((const spdm_message_header_t*) request)->request_response_code != SPDM_GET_CAPABILITIES
if (((const spdm_message_header_t *)request)->request_response_code != SPDM_GET_VERSION
&& ((const spdm_message_header_t *)request)->request_response_code != SPDM_GET_CAPABILITIES
&& ((spdm_context->connection_info.capability.data_transfer_size != 0 &&
request_size > spdm_context->connection_info.capability.data_transfer_size) ||
(spdm_context->local_context.capability.sender_data_transfer_size != 0 &&
Expand All @@ -626,9 +610,9 @@ libspdm_return_t libspdm_send_spdm_request(libspdm_context_t *spdm_context,
* so that it can compare last_spdm_request's fields with response fields
* Therefore the request must be copied to last_spdm_request here. */

if (((const spdm_message_header_t*) request)->request_response_code != SPDM_RESPOND_IF_READY
&& ((const spdm_message_header_t*) request)->request_response_code != SPDM_CHUNK_GET
&& ((const spdm_message_header_t*) request)->request_response_code != SPDM_CHUNK_SEND) {
if (((const spdm_message_header_t *)request)->request_response_code != SPDM_RESPOND_IF_READY
&& ((const spdm_message_header_t *)request)->request_response_code != SPDM_CHUNK_GET
&& ((const spdm_message_header_t *)request)->request_response_code != SPDM_CHUNK_SEND) {
libspdm_copy_mem(
spdm_context->last_spdm_request,
libspdm_get_scratch_buffer_last_spdm_request_capacity(spdm_context),
Expand Down Expand Up @@ -675,16 +659,14 @@ libspdm_return_t libspdm_receive_spdm_response(libspdm_context_t *spdm_context,
spdm_context, true,
SPDM_GET_CAPABILITIES_REQUEST_FLAGS_HANDSHAKE_IN_THE_CLEAR_CAP,
SPDM_GET_CAPABILITIES_RESPONSE_FLAGS_HANDSHAKE_IN_THE_CLEAR_CAP)) {
session_info = libspdm_get_session_info_via_session_id(
spdm_context, *session_id);
session_info = libspdm_get_session_info_via_session_id(spdm_context, *session_id);
LIBSPDM_ASSERT(session_info != NULL);
if (session_info == NULL) {
return LIBSPDM_STATUS_INVALID_STATE_LOCAL;
}
session_state = libspdm_secured_message_get_session_state(
session_info->secured_message_context);
if ((session_state == LIBSPDM_SESSION_STATE_HANDSHAKING) &&
!session_info->use_psk) {
if ((session_state == LIBSPDM_SESSION_STATE_HANDSHAKING) && !session_info->use_psk) {
session_id = NULL;
}
}
Expand All @@ -701,7 +683,6 @@ libspdm_return_t libspdm_receive_spdm_response(libspdm_context_t *spdm_context,

/* This response may either be an actual response or ERROR_LARGE_RESPONSE,
* the latter which should be handled in the large response handler. */

send_info->chunk_in_use = false;
send_info->chunk_handle++; /* Implicit wrap-around*/
send_info->chunk_seq_no = 0;
Expand All @@ -712,14 +693,13 @@ libspdm_return_t libspdm_receive_spdm_response(libspdm_context_t *spdm_context,
status = LIBSPDM_STATUS_SUCCESS;
} else {
response_capacity = *response_size;
status = libspdm_receive_response(spdm_context, session_id, false,
response_size, response);
status = libspdm_receive_response(spdm_context, session_id, false, response_size, response);
if (LIBSPDM_STATUS_IS_ERROR(status)) {
goto receive_done;
}
}

spdm_response = (spdm_message_header_t*) (*response);
spdm_response = (spdm_message_header_t *)(*response);

if (*response_size < sizeof(spdm_message_header_t)) {
status = LIBSPDM_STATUS_INVALID_MSG_SIZE;
Expand All @@ -730,7 +710,7 @@ libspdm_return_t libspdm_receive_spdm_response(libspdm_context_t *spdm_context,
&& spdm_response->param1 == SPDM_ERROR_CODE_LARGE_RESPONSE) {
status = libspdm_handle_error_large_response(
spdm_context, session_id,
response_size, (void*) spdm_response, response_capacity);
response_size, (void *)spdm_response, response_capacity);

if (LIBSPDM_STATUS_IS_ERROR(status)) {
goto receive_done;
Expand All @@ -744,8 +724,7 @@ libspdm_return_t libspdm_receive_spdm_response(libspdm_context_t *spdm_context,
/* Per the spec, SPDM_VERSION and SPDM_CAPABILITIES shall not be chunked
* and should be an unexpected error. */
if (spdm_response->request_response_code == SPDM_VERSION ||
spdm_response->request_response_code == SPDM_CAPABILITIES
) {
spdm_response->request_response_code == SPDM_CAPABILITIES) {
status = LIBSPDM_STATUS_INVALID_MSG_FIELD;
goto receive_done;
}
Expand Down
Loading

0 comments on commit 803849f

Please sign in to comment.