Skip to content

Commit

Permalink
Adds a client builder option to disable the default MD5 checksum vali… (
Browse files Browse the repository at this point in the history
#4729)

* Adds a client builder option to disable the default MD5 checksum validation

* Changing name of validation parameter and inverting the boolean
  • Loading branch information
cenedhryn authored Dec 2, 2023
1 parent 0e3b14d commit 018582b
Show file tree
Hide file tree
Showing 8 changed files with 205 additions and 7 deletions.
6 changes: 6 additions & 0 deletions .changes/next-release/feature-AmazonSQS-79a6b63.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"type": "feature",
"category": "Amazon SQS",
"contributor": "",
"description": "Adds a client builder option to disable the default MD5 checksum validation for SendMessage, ReceiveMessage and SendMessageBatch"
}
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,9 @@ private boolean shouldGenerateClientEndpointTests() {

private boolean hasClientContextParams() {
Map<String, ClientContextParam> clientContextParams = model.getClientContextParams();
return clientContextParams != null && !clientContextParams.isEmpty();
Map<String, ClientContextParam> customClientContextParams = model.getCustomizationConfig().getCustomClientContextParams();
return (clientContextParams != null && !clientContextParams.isEmpty()) ||
(customClientContextParams != null && !customClientContextParams.isEmpty());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,11 @@ public TypeSpec poetSpec() {

b.addMethod(ctor());

model.getClientContextParams().forEach((n, m) -> {
b.addField(paramDeclaration(n, m));
});
if (model.getClientContextParams() != null) {
model.getClientContextParams().forEach((n, m) -> {
b.addField(paramDeclaration(n, m));
});
}

if (model.getCustomizationConfig() != null && model.getCustomizationConfig().getCustomClientContextParams() != null) {
model.getCustomizationConfig().getCustomClientContextParams().forEach((n, m) -> {
Expand Down
6 changes: 6 additions & 0 deletions services/sqs/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -85,5 +85,11 @@
<artifactId>http-auth-aws</artifactId>
<version>${awsjavasdk.version}</version>
</dependency>
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>service-test-utils</artifactId>
<version>${awsjavasdk.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
</project>
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
import software.amazon.awssdk.core.interceptor.Context;
import software.amazon.awssdk.core.interceptor.ExecutionAttributes;
import software.amazon.awssdk.core.interceptor.ExecutionInterceptor;
import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute;
import software.amazon.awssdk.services.sqs.endpoints.SqsClientContextParams;
import software.amazon.awssdk.services.sqs.model.Message;
import software.amazon.awssdk.services.sqs.model.MessageAttributeValue;
import software.amazon.awssdk.services.sqs.model.ReceiveMessageRequest;
Expand All @@ -41,6 +43,7 @@
import software.amazon.awssdk.services.sqs.model.SendMessageBatchResultEntry;
import software.amazon.awssdk.services.sqs.model.SendMessageRequest;
import software.amazon.awssdk.services.sqs.model.SendMessageResponse;
import software.amazon.awssdk.utils.AttributeMap;
import software.amazon.awssdk.utils.BinaryUtils;
import software.amazon.awssdk.utils.Logger;
import software.amazon.awssdk.utils.Md5Utils;
Expand Down Expand Up @@ -77,7 +80,8 @@ public final class MessageMD5ChecksumInterceptor implements ExecutionInterceptor
public void afterExecution(Context.AfterExecution context, ExecutionAttributes executionAttributes) {
SdkResponse response = context.response();
SdkRequest originalRequest = context.request();
if (response != null) {

if (response != null && validateMessageMD5Enabled(executionAttributes)) {
if (originalRequest instanceof SendMessageRequest) {
SendMessageRequest sendMessageRequest = (SendMessageRequest) originalRequest;
SendMessageResponse sendMessageResult = (SendMessageResponse) response;
Expand All @@ -95,6 +99,12 @@ public void afterExecution(Context.AfterExecution context, ExecutionAttributes e
}
}

private static boolean validateMessageMD5Enabled(ExecutionAttributes executionAttributes) {
AttributeMap clientContextParams = executionAttributes.getAttribute(SdkInternalExecutionAttribute.CLIENT_CONTEXT_PARAMS);
Boolean enableMd5Validation = clientContextParams.get(SqsClientContextParams.CHECKSUM_VALIDATION_ENABLED);
return enableMd5Validation == null || enableMd5Validation;
}

/**
* Throw an exception if the MD5 checksums returned in the SendMessageResponse do not match the
* client-side calculation based on the original message in the SendMessageRequest.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,11 @@
],
"interceptors": [
"software.amazon.awssdk.services.sqs.internal.MessageMD5ChecksumInterceptor"
]
],
"customClientContextParams":{
"checksumValidationEnabled":{
"documentation":"Enable message MD5 checksum validation.<p>Checksum validation for messages defaults to true. Only set to false if required, for instance if your cryptographic library does not support MD5.<p>Supported operations are SendMessage, ReceiveMessage and SendMessageBatch.",
"type":"boolean"
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import software.amazon.awssdk.core.exception.SdkClientException;
import software.amazon.awssdk.core.interceptor.ExecutionAttributes;
import software.amazon.awssdk.core.interceptor.InterceptorContext;
import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute;
import software.amazon.awssdk.services.sqs.endpoints.SqsClientContextParams;
import software.amazon.awssdk.services.sqs.internal.MessageMD5ChecksumInterceptor;
import software.amazon.awssdk.services.sqs.model.Message;
import software.amazon.awssdk.services.sqs.model.MessageAttributeValue;
Expand All @@ -37,6 +39,7 @@
import software.amazon.awssdk.services.sqs.model.SendMessageBatchResultEntry;
import software.amazon.awssdk.services.sqs.model.SendMessageRequest;
import software.amazon.awssdk.services.sqs.model.SendMessageResponse;
import software.amazon.awssdk.utils.AttributeMap;

/**
* Verifies the functionality of {@link MessageMD5ChecksumInterceptor}.
Expand Down Expand Up @@ -236,11 +239,13 @@ private void assertFailure(SdkRequest request, SdkResponse response) {
}

private void callInterceptor(SdkRequest request, SdkResponse response) {
ExecutionAttributes executionAttributes = new ExecutionAttributes();
executionAttributes.putAttribute(SdkInternalExecutionAttribute.CLIENT_CONTEXT_PARAMS, AttributeMap.builder().build());
new MessageMD5ChecksumInterceptor().afterExecution(InterceptorContext.builder()
.request(request)
.response(response)
.build(),
new ExecutionAttributes());
executionAttributes);
}

private String messageBody() {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://aws.amazon.com/apache2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

package software.amazon.awssdk.services.sqs;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;

import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
import software.amazon.awssdk.http.AbortableInputStream;
import software.amazon.awssdk.http.HttpExecuteResponse;
import software.amazon.awssdk.http.SdkHttpResponse;
import software.amazon.awssdk.services.sqs.internal.MessageMD5ChecksumInterceptor;
import software.amazon.awssdk.services.sqs.model.MessageAttributeValue;
import software.amazon.awssdk.services.sqs.model.SendMessageResponse;
import software.amazon.awssdk.testutils.service.http.MockAsyncHttpClient;
import software.amazon.awssdk.testutils.service.http.MockSyncHttpClient;
import software.amazon.awssdk.utils.StringInputStream;

/**
* Verifies that the logic in {@link MessageMD5ChecksumInterceptor} can be disabled, which is needed for use cases like
* FIPS cryptography libraries that don't have MD5 support. SendMessage is used as the test API, but the flow is the
* same for ReceiveMessage and SendMessageBatch.
*/
public class MessageMD5ChecksumValidationDisableTest {
private static final AwsBasicCredentials CLIENT_CREDENTIALS = AwsBasicCredentials.create("ca", "cs");
private static final String MESSAGE_ID = "0f433476-621e-4638-811a-112d2c2e41d7";

private MockAsyncHttpClient asyncHttpClient;
private MockSyncHttpClient syncHttpClient;

@BeforeEach
public void setupClient() {
asyncHttpClient = new MockAsyncHttpClient();
syncHttpClient = new MockSyncHttpClient();
}

@AfterEach
public void cleanup() {
asyncHttpClient.reset();
syncHttpClient.reset();
}

@Test
public void md5ValidationEnabled_default_md5InResponse_Works() {
asyncHttpClient.stubResponses(responseWithMd5());
SqsAsyncClient client = SqsAsyncClient.builder()
.credentialsProvider(StaticCredentialsProvider.create(CLIENT_CREDENTIALS))
.httpClient(asyncHttpClient)
.build();

SendMessageResponse sendMessageResponse =
client.sendMessage(r -> r.messageBody(messageBody()).messageAttributes(createAttributeValues())).join();

assertThat(sendMessageResponse.messageId()).isEqualTo(MESSAGE_ID);
}

@Test
public void md5ValidationEnabled_default_noMd5InResponse_throwsException() {
asyncHttpClient.stubResponses(responseWithoutMd5());
SqsAsyncClient client = SqsAsyncClient.builder()
.credentialsProvider(StaticCredentialsProvider.create(CLIENT_CREDENTIALS))
.httpClient(asyncHttpClient)
.build();

assertThatThrownBy(() -> client.sendMessage(r -> r.messageBody(messageBody())
.messageAttributes(createAttributeValues()))
.join())
.hasMessageContaining("MD5 returned by SQS does not match the calculation on the original request");
}

@Test
public void md5ValidationDisabled_md5InResponse_Works() {
asyncHttpClient.stubResponses(responseWithMd5());
SqsAsyncClient client = SqsAsyncClient.builder()
.credentialsProvider(StaticCredentialsProvider.create(CLIENT_CREDENTIALS))
.httpClient(asyncHttpClient)
.checksumValidationEnabled(false)
.build();

SendMessageResponse sendMessageResponse =
client.sendMessage(r -> r.messageBody(messageBody()).messageAttributes(createAttributeValues())).join();

assertThat(sendMessageResponse.messageId()).isEqualTo(MESSAGE_ID);
}

@Test
public void md5ValidationDisabled_noMd5InResponse_Works() {
asyncHttpClient.stubResponses(responseWithoutMd5());
SqsAsyncClient client = SqsAsyncClient.builder()
.credentialsProvider(StaticCredentialsProvider.create(CLIENT_CREDENTIALS))
.httpClient(asyncHttpClient)
.checksumValidationEnabled(false)
.build();

SendMessageResponse sendMessageResponse =
client.sendMessage(r -> r.messageBody(messageBody()).messageAttributes(createAttributeValues())).join();

assertThat(sendMessageResponse.messageId()).isEqualTo(MESSAGE_ID);
}

@Test
public void sync_md5ValidationDisabled_noMd5InResponse_Works() {
syncHttpClient.stubResponses(responseWithoutMd5());
SqsClient client = SqsClient.builder()
.credentialsProvider(StaticCredentialsProvider.create(CLIENT_CREDENTIALS))
.httpClient(syncHttpClient)
.checksumValidationEnabled(false)
.build();

SendMessageResponse sendMessageResponse =
client.sendMessage(r -> r.messageBody(messageBody()).messageAttributes(createAttributeValues()));

assertThat(sendMessageResponse.messageId()).isEqualTo(MESSAGE_ID);
}

private static String messageBody() {
return "Body";
}

private static HttpExecuteResponse responseWithMd5() {
return HttpExecuteResponse.builder().response(SdkHttpResponse.builder().statusCode(200).build()).responseBody(
AbortableInputStream.create(new StringInputStream(
"{\"MD5OfMessageAttributes\":\"43eeb333d10515533e317490584ea243\","
+ "\"MD5OfMessageBody\":\"ac101b32dda4448cf13a93fe283dddd8\","
+ "\"MessageId\":\"" + MESSAGE_ID + "\"} ")))
.build();
}

private static HttpExecuteResponse responseWithoutMd5() {
return HttpExecuteResponse.builder().response(SdkHttpResponse.builder().statusCode(200).build())
.responseBody(AbortableInputStream
.create(new StringInputStream("{\"MessageId\":\"" + MESSAGE_ID + "\"} ")))
.build();
}

protected static Map<String, MessageAttributeValue> createAttributeValues() {
Map<String, MessageAttributeValue> attrs = new HashMap();
attrs.put("attribute-1", MessageAttributeValue.builder().dataType("String").stringValue("tmp").build());
return Collections.unmodifiableMap(attrs);
}
}

0 comments on commit 018582b

Please sign in to comment.