diff --git a/core/auth/pom.xml b/core/auth/pom.xml
index 64d46909f109..57e87fea7b68 100644
--- a/core/auth/pom.xml
+++ b/core/auth/pom.xml
@@ -93,6 +93,11 @@
http-auth-spi
${awsjavasdk.version}
+
+ software.amazon.awssdk
+ checksums-spi
+ ${awsjavasdk.version}
+
software.amazon.eventstream
eventstream
diff --git a/core/auth/src/main/java/software/amazon/awssdk/auth/signer/internal/chunkedencoding/AwsSignedChunkedEncodingInputStream.java b/core/auth/src/main/java/software/amazon/awssdk/auth/signer/internal/chunkedencoding/AwsSignedChunkedEncodingInputStream.java
index 3174eb7c6caa..a7772fe9f54e 100644
--- a/core/auth/src/main/java/software/amazon/awssdk/auth/signer/internal/chunkedencoding/AwsSignedChunkedEncodingInputStream.java
+++ b/core/auth/src/main/java/software/amazon/awssdk/auth/signer/internal/chunkedencoding/AwsSignedChunkedEncodingInputStream.java
@@ -19,11 +19,13 @@
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import software.amazon.awssdk.annotations.SdkInternalApi;
+import software.amazon.awssdk.checksums.spi.ChecksumAlgorithm;
import software.amazon.awssdk.core.checksums.Algorithm;
import software.amazon.awssdk.core.checksums.SdkChecksum;
import software.amazon.awssdk.core.exception.SdkClientException;
import software.amazon.awssdk.core.internal.chunked.AwsChunkedEncodingConfig;
import software.amazon.awssdk.core.internal.io.AwsChunkedEncodingInputStream;
+import software.amazon.awssdk.http.auth.spi.signer.PayloadChecksumStore;
import software.amazon.awssdk.utils.BinaryUtils;
/**
@@ -60,12 +62,15 @@ public final class AwsSignedChunkedEncodingInputStream extends AwsChunkedEncodin
* @param config The configuration allows the user to customize chunk size and buffer size.
* See {@link AwsChunkedEncodingConfig} for default values.
*/
- private AwsSignedChunkedEncodingInputStream(InputStream in, SdkChecksum sdkChecksum,
+ private AwsSignedChunkedEncodingInputStream(InputStream in,
+ ChecksumAlgorithm checksumAlgorithm,
+ SdkChecksum sdkChecksum,
+ PayloadChecksumStore checksumStore,
String checksumHeaderForTrailer,
String headerSignature,
AwsChunkSigner chunkSigner,
AwsChunkedEncodingConfig config) {
- super(in, sdkChecksum, checksumHeaderForTrailer, config);
+ super(in, checksumAlgorithm, sdkChecksum, checksumStore, checksumHeaderForTrailer, config);
this.chunkSigner = chunkSigner;
this.previousChunkSignature = headerSignature;
this.headerSignature = headerSignature;
@@ -103,9 +108,14 @@ public Builder awsChunkSigner(AwsChunkSigner awsChunkSigner) {
public AwsSignedChunkedEncodingInputStream build() {
- return new AwsSignedChunkedEncodingInputStream(this.inputStream, this.sdkChecksum, this.checksumHeaderForTrailer,
+ return new AwsSignedChunkedEncodingInputStream(this.inputStream,
+ this.checksumAlgorithm,
+ this.sdkChecksum,
+ this.checksumStore,
+ this.checksumHeaderForTrailer,
this.headerSignature,
- this.awsChunkSigner, this.awsChunkedEncodingConfig);
+ this.awsChunkSigner,
+ this.awsChunkedEncodingConfig);
}
}
diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/checksums/LegacyPayloadChecksumCache.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/checksums/LegacyPayloadChecksumCache.java
deleted file mode 100644
index 7f3f5ed93b4b..000000000000
--- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/checksums/LegacyPayloadChecksumCache.java
+++ /dev/null
@@ -1,37 +0,0 @@
-/*
- * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
- *
- * Licensed under the Apache License, Version 2.0 (the "License").
- * You may not use this file except in compliance with the License.
- * A copy of the License is located at
- *
- * http://aws.amazon.com/apache2.0
- *
- * or in the "license" file accompanying this file. This file is distributed
- * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
- * express or implied. See the License for the specific language governing
- * permissions and limitations under the License.
- */
-
-package software.amazon.awssdk.core.internal.checksums;
-
-import java.util.concurrent.ConcurrentHashMap;
-import software.amazon.awssdk.annotations.SdkInternalApi;
-import software.amazon.awssdk.core.checksums.Algorithm;
-
-/**
- * Cache for storing computed payload checksums. Only to be used in the legacy signing paths.
- */
-@SdkInternalApi
-@SuppressWarnings("deprecation")
-public class LegacyPayloadChecksumCache {
- private final ConcurrentHashMap cache = new ConcurrentHashMap<>();
-
- public byte[] putChecksumValue(Algorithm algorithm, byte[] checksumValue) {
- return cache.put(algorithm, checksumValue);
- }
-
- public byte[] getChecksumValue(Algorithm algorithm) {
- return cache.get(algorithm);
- }
-}
diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/HttpChecksumStage.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/HttpChecksumStage.java
index b2925f55c3d3..ca4b4d8f7f2c 100644
--- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/HttpChecksumStage.java
+++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/HttpChecksumStage.java
@@ -77,8 +77,9 @@ public HttpChecksumStage(ClientType clientType) {
public SdkHttpFullRequest.Builder execute(SdkHttpFullRequest.Builder request, RequestExecutionContext context)
throws Exception {
+ ensurePayloadChecksumStorePresent(context.executionAttributes());
+
if (sraSigningEnabled(context)) {
- ensurePayloadChecksumStorePresent(context.executionAttributes());
return sraChecksum(request, context);
}
@@ -87,9 +88,10 @@ public SdkHttpFullRequest.Builder execute(SdkHttpFullRequest.Builder request, Re
private SdkHttpFullRequest.Builder legacyChecksum(SdkHttpFullRequest.Builder request, RequestExecutionContext context) {
ChecksumSpecs resolvedChecksumSpecs = getResolvedChecksumSpecs(context.executionAttributes());
+ PayloadChecksumStore checksumStore = getPayloadChecksumStore(context.executionAttributes());
if (md5ChecksumRequired(request, context)) {
- addMd5ChecksumInHeader(request);
+ addMd5ChecksumInHeader(request, checksumStore);
return request;
}
@@ -99,7 +101,7 @@ private SdkHttpFullRequest.Builder legacyChecksum(SdkHttpFullRequest.Builder req
}
if (flexibleChecksumInHeaderRequired(context, resolvedChecksumSpecs)) {
- addFlexibleChecksumInHeader(request, context, resolvedChecksumSpecs);
+ addFlexibleChecksumInHeader(request, context, resolvedChecksumSpecs, checksumStore);
return request;
}
@@ -174,10 +176,14 @@ private boolean md5ChecksumRequired(SdkHttpFullRequest.Builder request, RequestE
* request body to use that buffered content. We obviously don't want to do that for giant streams, so we haven't opted to do
* that yet.
*/
- private void addMd5ChecksumInHeader(SdkHttpFullRequest.Builder request) {
+ private void addMd5ChecksumInHeader(SdkHttpFullRequest.Builder request, PayloadChecksumStore checksumStore) {
try {
- String payloadMd5 = Md5Utils.md5AsBase64(request.contentStreamProvider().newStream());
- request.putHeader(Header.CONTENT_MD5, payloadMd5);
+ byte[] payloadMd5 = checksumStore.getChecksumValue(DefaultChecksumAlgorithm.MD5);
+ if (payloadMd5 == null) {
+ payloadMd5 = Md5Utils.computeMD5Hash(request.contentStreamProvider().newStream());
+ checksumStore.putChecksumValue(DefaultChecksumAlgorithm.MD5, payloadMd5);
+ }
+ request.putHeader(Header.CONTENT_MD5, BinaryUtils.toBase64(payloadMd5));
} catch (IOException e) {
throw new UncheckedIOException(e);
}
@@ -237,7 +243,11 @@ private void addFlexibleChecksumInTrailer(SdkHttpFullRequest.Builder request, Re
int chunkSize = 0;
if (clientType == ClientType.SYNC) {
- request.contentStreamProvider(new ChecksumCalculatingStreamProvider(request.contentStreamProvider(), checksumSpecs));
+ request.contentStreamProvider(
+ new ChecksumCalculatingStreamProvider(request.contentStreamProvider(),
+ checksumSpecs,
+ getPayloadChecksumStore(context.executionAttributes())
+ ));
originalContentLength =
context.executionContext().interceptorContext().requestBody().get().optionalContentLength().orElse(0L);
chunkSize = DEFAULT_CHUNK_SIZE;
@@ -311,13 +321,19 @@ private boolean flexibleChecksumInHeaderRequired(RequestExecutionContext context
* that yet.
*/
private void addFlexibleChecksumInHeader(SdkHttpFullRequest.Builder request, RequestExecutionContext context,
- ChecksumSpecs checksumSpecs) {
+ ChecksumSpecs checksumSpecs, PayloadChecksumStore checksumStore) {
try {
Algorithm legacyAlgorithm = checksumSpecs.algorithm();
- String payloadChecksum = BinaryUtils.toBase64(HttpChecksumUtils.computeChecksum(
- context.executionContext().interceptorContext().requestBody().get().contentStreamProvider().newStream(),
- legacyAlgorithm));
- request.putHeader(checksumSpecs.headerName(), payloadChecksum);
+ ChecksumAlgorithm newAlgorithm = HttpChecksumUtils.toNewChecksumAlgorithm(legacyAlgorithm);
+ byte[] payloadChecksum = checksumStore.getChecksumValue(newAlgorithm);
+ if (payloadChecksum == null) {
+ payloadChecksum = HttpChecksumUtils.computeChecksum(
+ context.executionContext().interceptorContext().requestBody().get().contentStreamProvider().newStream(),
+ legacyAlgorithm);
+ checksumStore.putChecksumValue(newAlgorithm, payloadChecksum);
+ }
+ String headerValue = BinaryUtils.toBase64(payloadChecksum);
+ request.putHeader(checksumSpecs.headerName(), headerValue);
} catch (IOException e) {
throw new UncheckedIOException(e);
}
@@ -339,16 +355,21 @@ static final class ChecksumCalculatingStreamProvider implements ContentStreamPro
private final ContentStreamProvider underlyingInputStreamProvider;
private final String checksumHeaderForTrailer;
private final ChecksumSpecs checksumSpecs;
+ private final PayloadChecksumStore checksumStore;
private InputStream currentStream;
+ private final ChecksumAlgorithm checksumAlgorithm;
private software.amazon.awssdk.core.checksums.SdkChecksum sdkChecksum;
ChecksumCalculatingStreamProvider(ContentStreamProvider underlyingInputStreamProvider,
- ChecksumSpecs checksumSpecs) {
+ ChecksumSpecs checksumSpecs,
+ PayloadChecksumStore checksumStore) {
this.underlyingInputStreamProvider = underlyingInputStreamProvider;
this.sdkChecksum = software.amazon.awssdk.core.checksums.SdkChecksum.forAlgorithm(
checksumSpecs.algorithm());
+ this.checksumAlgorithm = HttpChecksumUtils.toNewChecksumAlgorithm(checksumSpecs.algorithm());
this.checksumHeaderForTrailer = checksumSpecs.headerName();
this.checksumSpecs = checksumSpecs;
+ this.checksumStore = checksumStore;
}
@Override
@@ -356,7 +377,9 @@ public InputStream newStream() {
closeCurrentStream();
currentStream = AwsUnsignedChunkedEncodingInputStream.builder()
.inputStream(underlyingInputStreamProvider.newStream())
+ .checksumAlgorithm(checksumAlgorithm)
.sdkChecksum(sdkChecksum)
+ .checksumStore(checksumStore)
.checksumHeaderForTrailer(checksumHeaderForTrailer)
.build();
return currentStream;
diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/io/AwsChunkedEncodingInputStream.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/io/AwsChunkedEncodingInputStream.java
index ec4870f5e686..0b98608a67e7 100644
--- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/io/AwsChunkedEncodingInputStream.java
+++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/io/AwsChunkedEncodingInputStream.java
@@ -20,8 +20,11 @@
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import software.amazon.awssdk.annotations.SdkInternalApi;
+import software.amazon.awssdk.checksums.spi.ChecksumAlgorithm;
import software.amazon.awssdk.core.checksums.SdkChecksum;
+import software.amazon.awssdk.core.internal.checksums.NoOpPayloadChecksumStore;
import software.amazon.awssdk.core.internal.chunked.AwsChunkedEncodingConfig;
+import software.amazon.awssdk.http.auth.spi.signer.PayloadChecksumStore;
import software.amazon.awssdk.utils.Validate;
/**
@@ -45,7 +48,9 @@ public abstract class AwsChunkedEncodingInputStream extends AwsChunkedInputStrea
protected boolean isTrailingTerminated = true;
private final int chunkSize;
private final int maxBufferSize;
+ private final ChecksumAlgorithm checksumAlgorithm;
private final SdkChecksum sdkChecksum;
+ private final PayloadChecksumStore checksumStore;
private boolean isLastTrailingCrlf;
/**
@@ -58,7 +63,10 @@ public abstract class AwsChunkedEncodingInputStream extends AwsChunkedInputStrea
* See {@link AwsChunkedEncodingConfig} for default values.
*/
protected AwsChunkedEncodingInputStream(InputStream in,
- SdkChecksum sdkChecksum, String checksumHeaderForTrailer,
+ ChecksumAlgorithm checksumAlgorithm,
+ SdkChecksum sdkChecksum,
+ PayloadChecksumStore checksumStore,
+ String checksumHeaderForTrailer,
AwsChunkedEncodingConfig config) {
AwsChunkedEncodingConfig awsChunkedEncodingConfig = config == null ? AwsChunkedEncodingConfig.create() : config;
@@ -78,14 +86,18 @@ protected AwsChunkedEncodingInputStream(InputStream in,
if (maxBufferSize < chunkSize) {
throw new IllegalArgumentException("Max buffer size should not be less than chunk size");
}
+ this.checksumAlgorithm = checksumAlgorithm;
this.sdkChecksum = sdkChecksum;
+ this.checksumStore = checksumStore == null ? NoOpPayloadChecksumStore.create() : checksumStore;
this.checksumHeaderForTrailer = checksumHeaderForTrailer;
}
protected abstract static class Builder {
protected InputStream inputStream;
+ protected ChecksumAlgorithm checksumAlgorithm;
protected SdkChecksum sdkChecksum;
+ protected PayloadChecksumStore checksumStore;
protected String checksumHeaderForTrailer;
protected AwsChunkedEncodingConfig awsChunkedEncodingConfig;
@@ -110,6 +122,11 @@ public T awsChunkedEncodingConfig(AwsChunkedEncodingConfig awsChunkedEncodingCon
return (T) this;
}
+ public T checksumAlgorithm(ChecksumAlgorithm checksumAlgorithm) {
+ this.checksumAlgorithm = checksumAlgorithm;
+ return (T) this;
+ }
+
/**
*
* @param sdkChecksum Instance of SdkChecksum, this can be null if we do not want to calculate Checksum
@@ -120,6 +137,11 @@ public T sdkChecksum(SdkChecksum sdkChecksum) {
return (T) this;
}
+ public T checksumStore(PayloadChecksumStore checksumStore) {
+ this.checksumStore = checksumStore;
+ return (T) this;
+ }
+
/**
*
* @param checksumHeaderForTrailer String value of Trailer header where checksum will be updated.
@@ -166,7 +188,11 @@ private boolean setUpTrailingChunks() {
return true;
}
if (calculatedChecksum == null) {
- calculatedChecksum = sdkChecksum.getChecksumBytes();
+ calculatedChecksum = checksumStore.getChecksumValue(checksumAlgorithm);
+ if (calculatedChecksum == null) {
+ calculatedChecksum = sdkChecksum.getChecksumBytes();
+ checksumStore.putChecksumValue(checksumAlgorithm, calculatedChecksum);
+ }
currentChunkIterator = new ChunkContentIterator(createChecksumChunkHeader());
return false;
} else if (!isLastTrailingCrlf) {
diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/io/AwsUnsignedChunkedEncodingInputStream.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/io/AwsUnsignedChunkedEncodingInputStream.java
index 4c7f46a248cf..14b1cab5bf6c 100644
--- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/io/AwsUnsignedChunkedEncodingInputStream.java
+++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/io/AwsUnsignedChunkedEncodingInputStream.java
@@ -18,9 +18,11 @@
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import software.amazon.awssdk.annotations.SdkInternalApi;
+import software.amazon.awssdk.checksums.spi.ChecksumAlgorithm;
import software.amazon.awssdk.core.checksums.SdkChecksum;
import software.amazon.awssdk.core.exception.SdkClientException;
import software.amazon.awssdk.core.internal.chunked.AwsChunkedEncodingConfig;
+import software.amazon.awssdk.http.auth.spi.signer.PayloadChecksumStore;
import software.amazon.awssdk.utils.BinaryUtils;
/**
@@ -29,10 +31,13 @@
@SdkInternalApi
public class AwsUnsignedChunkedEncodingInputStream extends AwsChunkedEncodingInputStream {
- private AwsUnsignedChunkedEncodingInputStream(InputStream in, AwsChunkedEncodingConfig awsChunkedEncodingConfig,
+ private AwsUnsignedChunkedEncodingInputStream(InputStream in,
+ AwsChunkedEncodingConfig awsChunkedEncodingConfig,
+ ChecksumAlgorithm checksumAlgorithm,
SdkChecksum sdkChecksum,
+ PayloadChecksumStore checksumStore,
String checksumHeaderForTrailer) {
- super(in, sdkChecksum, checksumHeaderForTrailer, awsChunkedEncodingConfig);
+ super(in, checksumAlgorithm, sdkChecksum, checksumStore, checksumHeaderForTrailer, awsChunkedEncodingConfig);
}
public static Builder builder() {
@@ -85,8 +90,12 @@ protected byte[] createChecksumChunkHeader() {
public static final class Builder extends AwsChunkedEncodingInputStream.Builder {
public AwsUnsignedChunkedEncodingInputStream build() {
return new AwsUnsignedChunkedEncodingInputStream(
- this.inputStream, this.awsChunkedEncodingConfig,
- this.sdkChecksum, this.checksumHeaderForTrailer);
+ this.inputStream,
+ this.awsChunkedEncodingConfig,
+ this.checksumAlgorithm,
+ this.sdkChecksum,
+ this.checksumStore,
+ this.checksumHeaderForTrailer);
}
}
}
diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/pipeline/stages/HttpChecksumStageNonSraTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/pipeline/stages/HttpChecksumStageNonSraTest.java
index 0e1b9efe07e1..a099e91ab6ce 100644
--- a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/pipeline/stages/HttpChecksumStageNonSraTest.java
+++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/pipeline/stages/HttpChecksumStageNonSraTest.java
@@ -23,6 +23,7 @@
import static software.amazon.awssdk.http.Header.CONTENT_LENGTH;
import static software.amazon.awssdk.http.Header.CONTENT_MD5;
+import java.nio.charset.StandardCharsets;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.junit.MockitoJUnitRunner;
@@ -41,6 +42,10 @@
import software.amazon.awssdk.core.signer.NoOpSigner;
import software.amazon.awssdk.core.sync.RequestBody;
import software.amazon.awssdk.http.SdkHttpFullRequest;
+import software.amazon.awssdk.http.SdkHttpMethod;
+import software.amazon.awssdk.http.auth.spi.signer.PayloadChecksumStore;
+import software.amazon.awssdk.utils.BinaryUtils;
+import software.amazon.awssdk.utils.IoUtils;
import utils.ValidSdkObjects;
@RunWith(MockitoJUnitRunner.class)
@@ -48,6 +53,8 @@ public class HttpChecksumStageNonSraTest {
private static final String CHECKSUM_SPECS_HEADER = "x-amz-checksum-sha256";
private static final RequestBody REQUEST_BODY = RequestBody.fromString("TestBody");
private static final AsyncRequestBody ASYNC_REQUEST_BODY = AsyncRequestBody.fromString("TestBody");
+ private static final String PAYLOAD_CHECKSUM_SHA256 = "/T5YuTxNWthvWXg+TJMwl60XKcAnLMrrOZe/jA9Y+eI=";
+
private final HttpChecksumStage syncStage = new HttpChecksumStage(ClientType.SYNC);
private final HttpChecksumStage asyncStage = new HttpChecksumStage(ClientType.ASYNC);
@@ -69,6 +76,40 @@ public void sync_md5Required_addsMd5Checksum_doesNotAddFlexibleChecksums() throw
assertThat(requestBuilder.firstMatchingHeader(CHECKSUM_SPECS_HEADER)).isEmpty();
}
+ @Test
+ public void sync_md5Required_checksumValueInStore_usesExistingValue() throws Exception {
+ SdkHttpFullRequest.Builder requestBuilder = createHttpRequestBuilder();
+ boolean isAsyncStreaming = false;
+ RequestExecutionContext ctx = md5RequiredRequestContext(isAsyncStreaming);
+
+ byte[] checksumValue = "my-md5".getBytes(StandardCharsets.UTF_8);
+ PayloadChecksumStore store = PayloadChecksumStore.create();
+ store.putChecksumValue(DefaultChecksumAlgorithm.MD5, checksumValue);
+
+ ctx.executionAttributes().putAttribute(SdkInternalExecutionAttribute.CHECKSUM_STORE, store);
+
+ syncStage.execute(requestBuilder, ctx);
+
+ assertThat(requestBuilder.headers().get(CONTENT_MD5)).containsExactly(BinaryUtils.toBase64(checksumValue));
+ }
+
+ @Test
+ public void sync_md5Required_checksumStoreEmpty_storesComputedMd5() throws Exception {
+ SdkHttpFullRequest.Builder requestBuilder = createHttpRequestBuilder();
+ boolean isAsyncStreaming = false;
+ RequestExecutionContext ctx = md5RequiredRequestContext(isAsyncStreaming);
+
+ PayloadChecksumStore store = PayloadChecksumStore.create();
+ ctx.executionAttributes().putAttribute(SdkInternalExecutionAttribute.CHECKSUM_STORE, store);
+
+ syncStage.execute(requestBuilder, ctx);
+
+ String expectedChecksum = "9dzKaiLL99all2ZyHa76RA==";
+
+ assertThat(requestBuilder.headers().get(CONTENT_MD5)).containsExactly(expectedChecksum);
+ assertThat(store.getChecksumValue(DefaultChecksumAlgorithm.MD5)).isEqualTo(BinaryUtils.fromBase64(expectedChecksum));
+ }
+
@Test
public void async_nonStreaming_md5Required_addsMd5Checksum_doesNotAddFlexibleChecksums() throws Exception {
SdkHttpFullRequest.Builder requestBuilder = createHttpRequestBuilder();
@@ -119,6 +160,40 @@ public void syncWithCustomSigner_flexibleChecksumInTrailerRequired_addsFlexibleC
assertThat(requestBuilder.firstMatchingHeader(CHECKSUM_SPECS_HEADER)).isEmpty();
}
+ @Test
+ public void syncWithCustomSigner_flexibleChecksumInTrailerRequired_storeEmpty_storesComputedValue() throws Exception {
+ SdkHttpFullRequest.Builder requestBuilder = createHttpRequestBuilder();
+
+ RequestExecutionContext ctx = noOpSignerRequestContext(ClientType.SYNC);
+
+ PayloadChecksumStore store = PayloadChecksumStore.create();
+ ctx.executionAttributes().putAttribute(SdkInternalExecutionAttribute.CHECKSUM_STORE, store);
+
+ syncStage.execute(requestBuilder, ctx);
+
+ String content = IoUtils.toUtf8String(requestBuilder.build().contentStreamProvider().get().newStream());
+ assertThat(getTrailingChecksum(content)).isEqualTo(String.format("%s:%s", CHECKSUM_SPECS_HEADER, PAYLOAD_CHECKSUM_SHA256));
+ assertThat(store.getChecksumValue(DefaultChecksumAlgorithm.SHA256)).isEqualTo(BinaryUtils.fromBase64(PAYLOAD_CHECKSUM_SHA256));
+ }
+
+ @Test
+ public void syncWithCustomSigner_flexibleChecksumInTrailerRequired_checksumValueInStore_usesExistingValue() throws Exception {
+ SdkHttpFullRequest.Builder requestBuilder = createHttpRequestBuilder();
+
+ RequestExecutionContext ctx = noOpSignerRequestContext(ClientType.SYNC);
+ PayloadChecksumStore store = PayloadChecksumStore.create();
+ byte[] checksumValue = "my-sha256".getBytes(StandardCharsets.UTF_8);
+ store.putChecksumValue(DefaultChecksumAlgorithm.SHA256, checksumValue);
+
+ ctx.executionAttributes().putAttribute(SdkInternalExecutionAttribute.CHECKSUM_STORE, store);
+
+ syncStage.execute(requestBuilder, ctx);
+
+ String content = IoUtils.toUtf8String(requestBuilder.build().contentStreamProvider().get().newStream());
+ assertThat(getTrailingChecksum(content)).isEqualTo(String.format("%s:%s", CHECKSUM_SPECS_HEADER,
+ BinaryUtils.toBase64(checksumValue)));
+ }
+
@Test
public void asyncWithCustomSigner_flexibleChecksumInTrailerRequired_addsFlexibleChecksumInTrailer() throws Exception {
SdkHttpFullRequest.Builder requestBuilder = createHttpRequestBuilder();
@@ -181,7 +256,7 @@ public void sync_flexibleChecksumInHeaderRequired_addsFlexibleChecksumInHeader_d
syncStage.execute(requestBuilder, ctx);
- assertThat(requestBuilder.headers().get(CHECKSUM_SPECS_HEADER)).containsExactly("/T5YuTxNWthvWXg+TJMwl60XKcAnLMrrOZe/jA9Y+eI=");
+ assertThat(requestBuilder.headers().get(CHECKSUM_SPECS_HEADER)).containsExactly(PAYLOAD_CHECKSUM_SHA256);
assertThat(requestBuilder.firstMatchingHeader(HEADER_FOR_TRAILER_REFERENCE)).isEmpty();
assertThat(requestBuilder.firstMatchingHeader("Content-encoding")).isEmpty();
@@ -191,6 +266,44 @@ public void sync_flexibleChecksumInHeaderRequired_addsFlexibleChecksumInHeader_d
assertThat(requestBuilder.firstMatchingHeader(CONTENT_MD5)).isEmpty();
}
+ @Test
+ public void sync_flexibleChecksumInHeaderRequired_checksumValueInStore_usesExistingValue() throws Exception {
+ SdkHttpFullRequest.Builder requestBuilder = createHttpRequestBuilder();
+ boolean isStreaming = false;
+
+ RequestExecutionContext ctx = syncFlexibleChecksumRequiredRequestContext(isStreaming);
+
+ byte[] checksumValue = "my-sha256".getBytes(StandardCharsets.UTF_8);
+ PayloadChecksumStore store = PayloadChecksumStore.create();
+ // Test context uses SHA-256 as the flexible checksum
+ store.putChecksumValue(DefaultChecksumAlgorithm.SHA256, checksumValue);
+
+ ctx.executionAttributes().putAttribute(SdkInternalExecutionAttribute.CHECKSUM_STORE, store);
+
+ syncStage.execute(requestBuilder, ctx);
+
+ assertThat(requestBuilder.headers().get(CHECKSUM_SPECS_HEADER)).containsExactly(BinaryUtils.toBase64(checksumValue));
+
+ }
+
+ @Test
+ public void sync_flexibleChecksumInHeaderRequired_checksumStoreEmpty_storesComputedSha256() throws Exception {
+ SdkHttpFullRequest.Builder requestBuilder = createHttpRequestBuilder();
+ boolean isStreaming = false;
+
+ RequestExecutionContext ctx = syncFlexibleChecksumRequiredRequestContext(isStreaming);
+
+ PayloadChecksumStore store = PayloadChecksumStore.create();
+
+ ctx.executionAttributes().putAttribute(SdkInternalExecutionAttribute.CHECKSUM_STORE, store);
+
+ syncStage.execute(requestBuilder, ctx);
+
+ assertThat(requestBuilder.headers().get(CHECKSUM_SPECS_HEADER)).containsExactly(PAYLOAD_CHECKSUM_SHA256);
+ // Test context uses SHA-256 as the flexible checksum
+ assertThat(store.getChecksumValue(DefaultChecksumAlgorithm.SHA256)).isEqualTo(BinaryUtils.fromBase64(PAYLOAD_CHECKSUM_SHA256));
+ }
+
@Test
public void async_flexibleChecksumInHeaderRequired_addsFlexibleChecksumInHeader_doesNotAddMd5ChecksumAndFlexibleChecksumInTrailer() throws Exception {
SdkHttpFullRequest.Builder requestBuilder = createHttpRequestBuilder();
@@ -199,7 +312,7 @@ public void async_flexibleChecksumInHeaderRequired_addsFlexibleChecksumInHeader_
asyncStage.execute(requestBuilder, ctx);
- assertThat(requestBuilder.headers().get(CHECKSUM_SPECS_HEADER)).containsExactly("/T5YuTxNWthvWXg+TJMwl60XKcAnLMrrOZe/jA9Y+eI=");
+ assertThat(requestBuilder.headers().get(CHECKSUM_SPECS_HEADER)).containsExactly(PAYLOAD_CHECKSUM_SHA256);
assertThat(requestBuilder.firstMatchingHeader(HEADER_FOR_TRAILER_REFERENCE)).isEmpty();
assertThat(requestBuilder.firstMatchingHeader("Content-encoding")).isEmpty();
@@ -209,8 +322,21 @@ public void async_flexibleChecksumInHeaderRequired_addsFlexibleChecksumInHeader_
assertThat(requestBuilder.firstMatchingHeader(CONTENT_MD5)).isEmpty();
}
+ private static String getTrailingChecksum(String payload) {
+ for (String line : payload.split("\r\n")) {
+ if (line.startsWith("x-amz-checksum")) {
+ return line;
+ }
+ }
+ return null;
+ }
+
private SdkHttpFullRequest.Builder createHttpRequestBuilder() {
- return SdkHttpFullRequest.builder().contentStreamProvider(REQUEST_BODY.contentStreamProvider());
+ return SdkHttpFullRequest.builder()
+ .method(SdkHttpMethod.GET)
+ .protocol("https")
+ .host("sdk.aws")
+ .contentStreamProvider(REQUEST_BODY.contentStreamProvider());
}
private RequestExecutionContext md5RequiredRequestContext(boolean isAsyncStreaming) {
diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/checksums/ChecksumReuseTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/checksums/ChecksumReuseTest.java
index 4158945ef302..80f3009734d6 100644
--- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/checksums/ChecksumReuseTest.java
+++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/checksums/ChecksumReuseTest.java
@@ -37,6 +37,7 @@
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
+import software.amazon.awssdk.auth.signer.AwsS3V4Signer;
import software.amazon.awssdk.core.async.AsyncRequestBody;
import software.amazon.awssdk.core.checksums.RequestChecksumCalculation;
import software.amazon.awssdk.core.exception.SdkException;
@@ -109,6 +110,35 @@ public void putObject_serverResponds500_usesSameChecksumOnRetries() {
assertAllTrailingChecksumsMatch(httpClient.requestPayloads);
}
+ @Test
+ public void putObject_nonSra_serverResponds500_usesSameChecksumOnRetries() {
+ MockHttpClient httpClient = new MockHttpClient();
+
+ S3Client s3 = S3Client.builder()
+ .region(Region.US_WEST_2)
+ .credentialsProvider(CREDENTIALS_PROVIDER)
+ .requestChecksumCalculation(RequestChecksumCalculation.WHEN_SUPPORTED)
+ .httpClient(httpClient)
+ .overrideConfiguration(o -> o.retryStrategy(StandardRetryStrategy.builder()
+ .maxAttempts(4)
+ .backoffStrategy(BackoffStrategy.retryImmediately())
+ .build()))
+ .build();
+
+ RequestBody requestBody = RequestBody.fromInputStream(new RandomInputStream(), 4096);
+
+ assertThatThrownBy(() -> s3.putObject(r -> r.bucket(BUCKET)
+ .key(KEY)
+ .checksumAlgorithm(ChecksumAlgorithm.CRC32)
+ .overrideConfiguration(o -> o.signer(AwsS3V4Signer.create())),
+ requestBody))
+ .isInstanceOf(S3Exception.class)
+ // Ensure we actually retried
+ .matches(e -> ((SdkException) e).numAttempts() == 4);
+
+ assertAllTrailingChecksumsMatch(httpClient.requestPayloads);
+ }
+
@Test
void asyncPutObject_serverResponds500_usesSameChecksumOnRetries() {
MockAsyncHttpClient httpClient = new MockAsyncHttpClient();
@@ -138,6 +168,39 @@ void asyncPutObject_serverResponds500_usesSameChecksumOnRetries() {
assertAllTrailingChecksumsMatch(httpClient.requestPayloads);
}
+ @Test
+ void asyncPutObject_nonSra_serverResponds500_usesSameChecksumOnRetries() {
+ MockAsyncHttpClient httpClient = new MockAsyncHttpClient();
+
+ S3AsyncClient s3 = S3AsyncClient.builder()
+ .region(Region.US_WEST_2)
+ .credentialsProvider(CREDENTIALS_PROVIDER)
+ .requestChecksumCalculation(RequestChecksumCalculation.WHEN_SUPPORTED)
+ .httpClient(httpClient)
+ .overrideConfiguration(o -> o.retryStrategy(StandardRetryStrategy.builder()
+ .maxAttempts(4)
+ .backoffStrategy(BackoffStrategy.retryImmediately())
+ .build()))
+ .build();
+
+ AsyncRequestBody requestBody = AsyncRequestBody.fromInputStream(new RandomInputStream(),
+ 4096L,
+ executorService);
+
+ CompletableFuture responseFuture =
+ s3.putObject(r -> r.bucket(BUCKET)
+ .key(KEY)
+ .checksumAlgorithm(ChecksumAlgorithm.CRC32)
+ .overrideConfiguration(o -> o.signer(AwsS3V4Signer.create())),
+ requestBody);
+
+ assertThatThrownBy(responseFuture::join)
+ .hasCauseInstanceOf(S3Exception.class)
+ .matches(e -> ((SdkException) e.getCause()).numAttempts() == 4);
+
+ assertAllTrailingChecksumsMatch(httpClient.requestPayloads);
+ }
+
private void assertAllTrailingChecksumsMatch(List requestPayloads) {
List trailingChecksumHeaders = new ArrayList<>();