From d7832d86aeeed822657371422bf1569bc1c20f68 Mon Sep 17 00:00:00 2001 From: Dongie Agnir Date: Wed, 24 Sep 2025 13:13:44 -0700 Subject: [PATCH 1/2] Reuse checksums in legacy signing codepath This commit adds support for reusing calculated payload checksums over retries in the legacy (i.e. non-SRA) signing codepaths. --- core/auth/pom.xml | 5 + .../AwsSignedChunkedEncodingInputStream.java | 18 ++- .../checksums/LegacyPayloadChecksumCache.java | 37 ----- .../pipeline/stages/HttpChecksumStage.java | 49 +++++-- .../io/AwsChunkedEncodingInputStream.java | 30 +++- ...AwsUnsignedChunkedEncodingInputStream.java | 17 ++- .../stages/HttpChecksumStageNonSraTest.java | 132 +++++++++++++++++- 7 files changed, 225 insertions(+), 63 deletions(-) delete mode 100644 core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/checksums/LegacyPayloadChecksumCache.java diff --git a/core/auth/pom.xml b/core/auth/pom.xml index 64d46909f109..57e87fea7b68 100644 --- a/core/auth/pom.xml +++ b/core/auth/pom.xml @@ -93,6 +93,11 @@ http-auth-spi ${awsjavasdk.version} + + software.amazon.awssdk + checksums-spi + ${awsjavasdk.version} + software.amazon.eventstream eventstream diff --git a/core/auth/src/main/java/software/amazon/awssdk/auth/signer/internal/chunkedencoding/AwsSignedChunkedEncodingInputStream.java b/core/auth/src/main/java/software/amazon/awssdk/auth/signer/internal/chunkedencoding/AwsSignedChunkedEncodingInputStream.java index 3174eb7c6caa..a7772fe9f54e 100644 --- a/core/auth/src/main/java/software/amazon/awssdk/auth/signer/internal/chunkedencoding/AwsSignedChunkedEncodingInputStream.java +++ b/core/auth/src/main/java/software/amazon/awssdk/auth/signer/internal/chunkedencoding/AwsSignedChunkedEncodingInputStream.java @@ -19,11 +19,13 @@ import java.io.InputStream; import java.nio.charset.StandardCharsets; import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.checksums.spi.ChecksumAlgorithm; import software.amazon.awssdk.core.checksums.Algorithm; import software.amazon.awssdk.core.checksums.SdkChecksum; import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.internal.chunked.AwsChunkedEncodingConfig; import software.amazon.awssdk.core.internal.io.AwsChunkedEncodingInputStream; +import software.amazon.awssdk.http.auth.spi.signer.PayloadChecksumStore; import software.amazon.awssdk.utils.BinaryUtils; /** @@ -60,12 +62,15 @@ public final class AwsSignedChunkedEncodingInputStream extends AwsChunkedEncodin * @param config The configuration allows the user to customize chunk size and buffer size. * See {@link AwsChunkedEncodingConfig} for default values. */ - private AwsSignedChunkedEncodingInputStream(InputStream in, SdkChecksum sdkChecksum, + private AwsSignedChunkedEncodingInputStream(InputStream in, + ChecksumAlgorithm checksumAlgorithm, + SdkChecksum sdkChecksum, + PayloadChecksumStore checksumStore, String checksumHeaderForTrailer, String headerSignature, AwsChunkSigner chunkSigner, AwsChunkedEncodingConfig config) { - super(in, sdkChecksum, checksumHeaderForTrailer, config); + super(in, checksumAlgorithm, sdkChecksum, checksumStore, checksumHeaderForTrailer, config); this.chunkSigner = chunkSigner; this.previousChunkSignature = headerSignature; this.headerSignature = headerSignature; @@ -103,9 +108,14 @@ public Builder awsChunkSigner(AwsChunkSigner awsChunkSigner) { public AwsSignedChunkedEncodingInputStream build() { - return new AwsSignedChunkedEncodingInputStream(this.inputStream, this.sdkChecksum, this.checksumHeaderForTrailer, + return new AwsSignedChunkedEncodingInputStream(this.inputStream, + this.checksumAlgorithm, + this.sdkChecksum, + this.checksumStore, + this.checksumHeaderForTrailer, this.headerSignature, - this.awsChunkSigner, this.awsChunkedEncodingConfig); + this.awsChunkSigner, + this.awsChunkedEncodingConfig); } } diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/checksums/LegacyPayloadChecksumCache.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/checksums/LegacyPayloadChecksumCache.java deleted file mode 100644 index 7f3f5ed93b4b..000000000000 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/checksums/LegacyPayloadChecksumCache.java +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at - * - * http://aws.amazon.com/apache2.0 - * - * or in the "license" file accompanying this file. This file is distributed - * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either - * express or implied. See the License for the specific language governing - * permissions and limitations under the License. - */ - -package software.amazon.awssdk.core.internal.checksums; - -import java.util.concurrent.ConcurrentHashMap; -import software.amazon.awssdk.annotations.SdkInternalApi; -import software.amazon.awssdk.core.checksums.Algorithm; - -/** - * Cache for storing computed payload checksums. Only to be used in the legacy signing paths. - */ -@SdkInternalApi -@SuppressWarnings("deprecation") -public class LegacyPayloadChecksumCache { - private final ConcurrentHashMap cache = new ConcurrentHashMap<>(); - - public byte[] putChecksumValue(Algorithm algorithm, byte[] checksumValue) { - return cache.put(algorithm, checksumValue); - } - - public byte[] getChecksumValue(Algorithm algorithm) { - return cache.get(algorithm); - } -} diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/HttpChecksumStage.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/HttpChecksumStage.java index b2925f55c3d3..ca4b4d8f7f2c 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/HttpChecksumStage.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/http/pipeline/stages/HttpChecksumStage.java @@ -77,8 +77,9 @@ public HttpChecksumStage(ClientType clientType) { public SdkHttpFullRequest.Builder execute(SdkHttpFullRequest.Builder request, RequestExecutionContext context) throws Exception { + ensurePayloadChecksumStorePresent(context.executionAttributes()); + if (sraSigningEnabled(context)) { - ensurePayloadChecksumStorePresent(context.executionAttributes()); return sraChecksum(request, context); } @@ -87,9 +88,10 @@ public SdkHttpFullRequest.Builder execute(SdkHttpFullRequest.Builder request, Re private SdkHttpFullRequest.Builder legacyChecksum(SdkHttpFullRequest.Builder request, RequestExecutionContext context) { ChecksumSpecs resolvedChecksumSpecs = getResolvedChecksumSpecs(context.executionAttributes()); + PayloadChecksumStore checksumStore = getPayloadChecksumStore(context.executionAttributes()); if (md5ChecksumRequired(request, context)) { - addMd5ChecksumInHeader(request); + addMd5ChecksumInHeader(request, checksumStore); return request; } @@ -99,7 +101,7 @@ private SdkHttpFullRequest.Builder legacyChecksum(SdkHttpFullRequest.Builder req } if (flexibleChecksumInHeaderRequired(context, resolvedChecksumSpecs)) { - addFlexibleChecksumInHeader(request, context, resolvedChecksumSpecs); + addFlexibleChecksumInHeader(request, context, resolvedChecksumSpecs, checksumStore); return request; } @@ -174,10 +176,14 @@ private boolean md5ChecksumRequired(SdkHttpFullRequest.Builder request, RequestE * request body to use that buffered content. We obviously don't want to do that for giant streams, so we haven't opted to do * that yet. */ - private void addMd5ChecksumInHeader(SdkHttpFullRequest.Builder request) { + private void addMd5ChecksumInHeader(SdkHttpFullRequest.Builder request, PayloadChecksumStore checksumStore) { try { - String payloadMd5 = Md5Utils.md5AsBase64(request.contentStreamProvider().newStream()); - request.putHeader(Header.CONTENT_MD5, payloadMd5); + byte[] payloadMd5 = checksumStore.getChecksumValue(DefaultChecksumAlgorithm.MD5); + if (payloadMd5 == null) { + payloadMd5 = Md5Utils.computeMD5Hash(request.contentStreamProvider().newStream()); + checksumStore.putChecksumValue(DefaultChecksumAlgorithm.MD5, payloadMd5); + } + request.putHeader(Header.CONTENT_MD5, BinaryUtils.toBase64(payloadMd5)); } catch (IOException e) { throw new UncheckedIOException(e); } @@ -237,7 +243,11 @@ private void addFlexibleChecksumInTrailer(SdkHttpFullRequest.Builder request, Re int chunkSize = 0; if (clientType == ClientType.SYNC) { - request.contentStreamProvider(new ChecksumCalculatingStreamProvider(request.contentStreamProvider(), checksumSpecs)); + request.contentStreamProvider( + new ChecksumCalculatingStreamProvider(request.contentStreamProvider(), + checksumSpecs, + getPayloadChecksumStore(context.executionAttributes()) + )); originalContentLength = context.executionContext().interceptorContext().requestBody().get().optionalContentLength().orElse(0L); chunkSize = DEFAULT_CHUNK_SIZE; @@ -311,13 +321,19 @@ private boolean flexibleChecksumInHeaderRequired(RequestExecutionContext context * that yet. */ private void addFlexibleChecksumInHeader(SdkHttpFullRequest.Builder request, RequestExecutionContext context, - ChecksumSpecs checksumSpecs) { + ChecksumSpecs checksumSpecs, PayloadChecksumStore checksumStore) { try { Algorithm legacyAlgorithm = checksumSpecs.algorithm(); - String payloadChecksum = BinaryUtils.toBase64(HttpChecksumUtils.computeChecksum( - context.executionContext().interceptorContext().requestBody().get().contentStreamProvider().newStream(), - legacyAlgorithm)); - request.putHeader(checksumSpecs.headerName(), payloadChecksum); + ChecksumAlgorithm newAlgorithm = HttpChecksumUtils.toNewChecksumAlgorithm(legacyAlgorithm); + byte[] payloadChecksum = checksumStore.getChecksumValue(newAlgorithm); + if (payloadChecksum == null) { + payloadChecksum = HttpChecksumUtils.computeChecksum( + context.executionContext().interceptorContext().requestBody().get().contentStreamProvider().newStream(), + legacyAlgorithm); + checksumStore.putChecksumValue(newAlgorithm, payloadChecksum); + } + String headerValue = BinaryUtils.toBase64(payloadChecksum); + request.putHeader(checksumSpecs.headerName(), headerValue); } catch (IOException e) { throw new UncheckedIOException(e); } @@ -339,16 +355,21 @@ static final class ChecksumCalculatingStreamProvider implements ContentStreamPro private final ContentStreamProvider underlyingInputStreamProvider; private final String checksumHeaderForTrailer; private final ChecksumSpecs checksumSpecs; + private final PayloadChecksumStore checksumStore; private InputStream currentStream; + private final ChecksumAlgorithm checksumAlgorithm; private software.amazon.awssdk.core.checksums.SdkChecksum sdkChecksum; ChecksumCalculatingStreamProvider(ContentStreamProvider underlyingInputStreamProvider, - ChecksumSpecs checksumSpecs) { + ChecksumSpecs checksumSpecs, + PayloadChecksumStore checksumStore) { this.underlyingInputStreamProvider = underlyingInputStreamProvider; this.sdkChecksum = software.amazon.awssdk.core.checksums.SdkChecksum.forAlgorithm( checksumSpecs.algorithm()); + this.checksumAlgorithm = HttpChecksumUtils.toNewChecksumAlgorithm(checksumSpecs.algorithm()); this.checksumHeaderForTrailer = checksumSpecs.headerName(); this.checksumSpecs = checksumSpecs; + this.checksumStore = checksumStore; } @Override @@ -356,7 +377,9 @@ public InputStream newStream() { closeCurrentStream(); currentStream = AwsUnsignedChunkedEncodingInputStream.builder() .inputStream(underlyingInputStreamProvider.newStream()) + .checksumAlgorithm(checksumAlgorithm) .sdkChecksum(sdkChecksum) + .checksumStore(checksumStore) .checksumHeaderForTrailer(checksumHeaderForTrailer) .build(); return currentStream; diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/io/AwsChunkedEncodingInputStream.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/io/AwsChunkedEncodingInputStream.java index ec4870f5e686..0b98608a67e7 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/io/AwsChunkedEncodingInputStream.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/io/AwsChunkedEncodingInputStream.java @@ -20,8 +20,11 @@ import java.nio.charset.StandardCharsets; import java.util.Arrays; import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.checksums.spi.ChecksumAlgorithm; import software.amazon.awssdk.core.checksums.SdkChecksum; +import software.amazon.awssdk.core.internal.checksums.NoOpPayloadChecksumStore; import software.amazon.awssdk.core.internal.chunked.AwsChunkedEncodingConfig; +import software.amazon.awssdk.http.auth.spi.signer.PayloadChecksumStore; import software.amazon.awssdk.utils.Validate; /** @@ -45,7 +48,9 @@ public abstract class AwsChunkedEncodingInputStream extends AwsChunkedInputStrea protected boolean isTrailingTerminated = true; private final int chunkSize; private final int maxBufferSize; + private final ChecksumAlgorithm checksumAlgorithm; private final SdkChecksum sdkChecksum; + private final PayloadChecksumStore checksumStore; private boolean isLastTrailingCrlf; /** @@ -58,7 +63,10 @@ public abstract class AwsChunkedEncodingInputStream extends AwsChunkedInputStrea * See {@link AwsChunkedEncodingConfig} for default values. */ protected AwsChunkedEncodingInputStream(InputStream in, - SdkChecksum sdkChecksum, String checksumHeaderForTrailer, + ChecksumAlgorithm checksumAlgorithm, + SdkChecksum sdkChecksum, + PayloadChecksumStore checksumStore, + String checksumHeaderForTrailer, AwsChunkedEncodingConfig config) { AwsChunkedEncodingConfig awsChunkedEncodingConfig = config == null ? AwsChunkedEncodingConfig.create() : config; @@ -78,14 +86,18 @@ protected AwsChunkedEncodingInputStream(InputStream in, if (maxBufferSize < chunkSize) { throw new IllegalArgumentException("Max buffer size should not be less than chunk size"); } + this.checksumAlgorithm = checksumAlgorithm; this.sdkChecksum = sdkChecksum; + this.checksumStore = checksumStore == null ? NoOpPayloadChecksumStore.create() : checksumStore; this.checksumHeaderForTrailer = checksumHeaderForTrailer; } protected abstract static class Builder { protected InputStream inputStream; + protected ChecksumAlgorithm checksumAlgorithm; protected SdkChecksum sdkChecksum; + protected PayloadChecksumStore checksumStore; protected String checksumHeaderForTrailer; protected AwsChunkedEncodingConfig awsChunkedEncodingConfig; @@ -110,6 +122,11 @@ public T awsChunkedEncodingConfig(AwsChunkedEncodingConfig awsChunkedEncodingCon return (T) this; } + public T checksumAlgorithm(ChecksumAlgorithm checksumAlgorithm) { + this.checksumAlgorithm = checksumAlgorithm; + return (T) this; + } + /** * * @param sdkChecksum Instance of SdkChecksum, this can be null if we do not want to calculate Checksum @@ -120,6 +137,11 @@ public T sdkChecksum(SdkChecksum sdkChecksum) { return (T) this; } + public T checksumStore(PayloadChecksumStore checksumStore) { + this.checksumStore = checksumStore; + return (T) this; + } + /** * * @param checksumHeaderForTrailer String value of Trailer header where checksum will be updated. @@ -166,7 +188,11 @@ private boolean setUpTrailingChunks() { return true; } if (calculatedChecksum == null) { - calculatedChecksum = sdkChecksum.getChecksumBytes(); + calculatedChecksum = checksumStore.getChecksumValue(checksumAlgorithm); + if (calculatedChecksum == null) { + calculatedChecksum = sdkChecksum.getChecksumBytes(); + checksumStore.putChecksumValue(checksumAlgorithm, calculatedChecksum); + } currentChunkIterator = new ChunkContentIterator(createChecksumChunkHeader()); return false; } else if (!isLastTrailingCrlf) { diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/io/AwsUnsignedChunkedEncodingInputStream.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/io/AwsUnsignedChunkedEncodingInputStream.java index 4c7f46a248cf..14b1cab5bf6c 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/io/AwsUnsignedChunkedEncodingInputStream.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/io/AwsUnsignedChunkedEncodingInputStream.java @@ -18,9 +18,11 @@ import java.io.InputStream; import java.nio.charset.StandardCharsets; import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.checksums.spi.ChecksumAlgorithm; import software.amazon.awssdk.core.checksums.SdkChecksum; import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.internal.chunked.AwsChunkedEncodingConfig; +import software.amazon.awssdk.http.auth.spi.signer.PayloadChecksumStore; import software.amazon.awssdk.utils.BinaryUtils; /** @@ -29,10 +31,13 @@ @SdkInternalApi public class AwsUnsignedChunkedEncodingInputStream extends AwsChunkedEncodingInputStream { - private AwsUnsignedChunkedEncodingInputStream(InputStream in, AwsChunkedEncodingConfig awsChunkedEncodingConfig, + private AwsUnsignedChunkedEncodingInputStream(InputStream in, + AwsChunkedEncodingConfig awsChunkedEncodingConfig, + ChecksumAlgorithm checksumAlgorithm, SdkChecksum sdkChecksum, + PayloadChecksumStore checksumStore, String checksumHeaderForTrailer) { - super(in, sdkChecksum, checksumHeaderForTrailer, awsChunkedEncodingConfig); + super(in, checksumAlgorithm, sdkChecksum, checksumStore, checksumHeaderForTrailer, awsChunkedEncodingConfig); } public static Builder builder() { @@ -85,8 +90,12 @@ protected byte[] createChecksumChunkHeader() { public static final class Builder extends AwsChunkedEncodingInputStream.Builder { public AwsUnsignedChunkedEncodingInputStream build() { return new AwsUnsignedChunkedEncodingInputStream( - this.inputStream, this.awsChunkedEncodingConfig, - this.sdkChecksum, this.checksumHeaderForTrailer); + this.inputStream, + this.awsChunkedEncodingConfig, + this.checksumAlgorithm, + this.sdkChecksum, + this.checksumStore, + this.checksumHeaderForTrailer); } } } diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/pipeline/stages/HttpChecksumStageNonSraTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/pipeline/stages/HttpChecksumStageNonSraTest.java index 0e1b9efe07e1..a099e91ab6ce 100644 --- a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/pipeline/stages/HttpChecksumStageNonSraTest.java +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/http/pipeline/stages/HttpChecksumStageNonSraTest.java @@ -23,6 +23,7 @@ import static software.amazon.awssdk.http.Header.CONTENT_LENGTH; import static software.amazon.awssdk.http.Header.CONTENT_MD5; +import java.nio.charset.StandardCharsets; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.junit.MockitoJUnitRunner; @@ -41,6 +42,10 @@ import software.amazon.awssdk.core.signer.NoOpSigner; import software.amazon.awssdk.core.sync.RequestBody; import software.amazon.awssdk.http.SdkHttpFullRequest; +import software.amazon.awssdk.http.SdkHttpMethod; +import software.amazon.awssdk.http.auth.spi.signer.PayloadChecksumStore; +import software.amazon.awssdk.utils.BinaryUtils; +import software.amazon.awssdk.utils.IoUtils; import utils.ValidSdkObjects; @RunWith(MockitoJUnitRunner.class) @@ -48,6 +53,8 @@ public class HttpChecksumStageNonSraTest { private static final String CHECKSUM_SPECS_HEADER = "x-amz-checksum-sha256"; private static final RequestBody REQUEST_BODY = RequestBody.fromString("TestBody"); private static final AsyncRequestBody ASYNC_REQUEST_BODY = AsyncRequestBody.fromString("TestBody"); + private static final String PAYLOAD_CHECKSUM_SHA256 = "/T5YuTxNWthvWXg+TJMwl60XKcAnLMrrOZe/jA9Y+eI="; + private final HttpChecksumStage syncStage = new HttpChecksumStage(ClientType.SYNC); private final HttpChecksumStage asyncStage = new HttpChecksumStage(ClientType.ASYNC); @@ -69,6 +76,40 @@ public void sync_md5Required_addsMd5Checksum_doesNotAddFlexibleChecksums() throw assertThat(requestBuilder.firstMatchingHeader(CHECKSUM_SPECS_HEADER)).isEmpty(); } + @Test + public void sync_md5Required_checksumValueInStore_usesExistingValue() throws Exception { + SdkHttpFullRequest.Builder requestBuilder = createHttpRequestBuilder(); + boolean isAsyncStreaming = false; + RequestExecutionContext ctx = md5RequiredRequestContext(isAsyncStreaming); + + byte[] checksumValue = "my-md5".getBytes(StandardCharsets.UTF_8); + PayloadChecksumStore store = PayloadChecksumStore.create(); + store.putChecksumValue(DefaultChecksumAlgorithm.MD5, checksumValue); + + ctx.executionAttributes().putAttribute(SdkInternalExecutionAttribute.CHECKSUM_STORE, store); + + syncStage.execute(requestBuilder, ctx); + + assertThat(requestBuilder.headers().get(CONTENT_MD5)).containsExactly(BinaryUtils.toBase64(checksumValue)); + } + + @Test + public void sync_md5Required_checksumStoreEmpty_storesComputedMd5() throws Exception { + SdkHttpFullRequest.Builder requestBuilder = createHttpRequestBuilder(); + boolean isAsyncStreaming = false; + RequestExecutionContext ctx = md5RequiredRequestContext(isAsyncStreaming); + + PayloadChecksumStore store = PayloadChecksumStore.create(); + ctx.executionAttributes().putAttribute(SdkInternalExecutionAttribute.CHECKSUM_STORE, store); + + syncStage.execute(requestBuilder, ctx); + + String expectedChecksum = "9dzKaiLL99all2ZyHa76RA=="; + + assertThat(requestBuilder.headers().get(CONTENT_MD5)).containsExactly(expectedChecksum); + assertThat(store.getChecksumValue(DefaultChecksumAlgorithm.MD5)).isEqualTo(BinaryUtils.fromBase64(expectedChecksum)); + } + @Test public void async_nonStreaming_md5Required_addsMd5Checksum_doesNotAddFlexibleChecksums() throws Exception { SdkHttpFullRequest.Builder requestBuilder = createHttpRequestBuilder(); @@ -119,6 +160,40 @@ public void syncWithCustomSigner_flexibleChecksumInTrailerRequired_addsFlexibleC assertThat(requestBuilder.firstMatchingHeader(CHECKSUM_SPECS_HEADER)).isEmpty(); } + @Test + public void syncWithCustomSigner_flexibleChecksumInTrailerRequired_storeEmpty_storesComputedValue() throws Exception { + SdkHttpFullRequest.Builder requestBuilder = createHttpRequestBuilder(); + + RequestExecutionContext ctx = noOpSignerRequestContext(ClientType.SYNC); + + PayloadChecksumStore store = PayloadChecksumStore.create(); + ctx.executionAttributes().putAttribute(SdkInternalExecutionAttribute.CHECKSUM_STORE, store); + + syncStage.execute(requestBuilder, ctx); + + String content = IoUtils.toUtf8String(requestBuilder.build().contentStreamProvider().get().newStream()); + assertThat(getTrailingChecksum(content)).isEqualTo(String.format("%s:%s", CHECKSUM_SPECS_HEADER, PAYLOAD_CHECKSUM_SHA256)); + assertThat(store.getChecksumValue(DefaultChecksumAlgorithm.SHA256)).isEqualTo(BinaryUtils.fromBase64(PAYLOAD_CHECKSUM_SHA256)); + } + + @Test + public void syncWithCustomSigner_flexibleChecksumInTrailerRequired_checksumValueInStore_usesExistingValue() throws Exception { + SdkHttpFullRequest.Builder requestBuilder = createHttpRequestBuilder(); + + RequestExecutionContext ctx = noOpSignerRequestContext(ClientType.SYNC); + PayloadChecksumStore store = PayloadChecksumStore.create(); + byte[] checksumValue = "my-sha256".getBytes(StandardCharsets.UTF_8); + store.putChecksumValue(DefaultChecksumAlgorithm.SHA256, checksumValue); + + ctx.executionAttributes().putAttribute(SdkInternalExecutionAttribute.CHECKSUM_STORE, store); + + syncStage.execute(requestBuilder, ctx); + + String content = IoUtils.toUtf8String(requestBuilder.build().contentStreamProvider().get().newStream()); + assertThat(getTrailingChecksum(content)).isEqualTo(String.format("%s:%s", CHECKSUM_SPECS_HEADER, + BinaryUtils.toBase64(checksumValue))); + } + @Test public void asyncWithCustomSigner_flexibleChecksumInTrailerRequired_addsFlexibleChecksumInTrailer() throws Exception { SdkHttpFullRequest.Builder requestBuilder = createHttpRequestBuilder(); @@ -181,7 +256,7 @@ public void sync_flexibleChecksumInHeaderRequired_addsFlexibleChecksumInHeader_d syncStage.execute(requestBuilder, ctx); - assertThat(requestBuilder.headers().get(CHECKSUM_SPECS_HEADER)).containsExactly("/T5YuTxNWthvWXg+TJMwl60XKcAnLMrrOZe/jA9Y+eI="); + assertThat(requestBuilder.headers().get(CHECKSUM_SPECS_HEADER)).containsExactly(PAYLOAD_CHECKSUM_SHA256); assertThat(requestBuilder.firstMatchingHeader(HEADER_FOR_TRAILER_REFERENCE)).isEmpty(); assertThat(requestBuilder.firstMatchingHeader("Content-encoding")).isEmpty(); @@ -191,6 +266,44 @@ public void sync_flexibleChecksumInHeaderRequired_addsFlexibleChecksumInHeader_d assertThat(requestBuilder.firstMatchingHeader(CONTENT_MD5)).isEmpty(); } + @Test + public void sync_flexibleChecksumInHeaderRequired_checksumValueInStore_usesExistingValue() throws Exception { + SdkHttpFullRequest.Builder requestBuilder = createHttpRequestBuilder(); + boolean isStreaming = false; + + RequestExecutionContext ctx = syncFlexibleChecksumRequiredRequestContext(isStreaming); + + byte[] checksumValue = "my-sha256".getBytes(StandardCharsets.UTF_8); + PayloadChecksumStore store = PayloadChecksumStore.create(); + // Test context uses SHA-256 as the flexible checksum + store.putChecksumValue(DefaultChecksumAlgorithm.SHA256, checksumValue); + + ctx.executionAttributes().putAttribute(SdkInternalExecutionAttribute.CHECKSUM_STORE, store); + + syncStage.execute(requestBuilder, ctx); + + assertThat(requestBuilder.headers().get(CHECKSUM_SPECS_HEADER)).containsExactly(BinaryUtils.toBase64(checksumValue)); + + } + + @Test + public void sync_flexibleChecksumInHeaderRequired_checksumStoreEmpty_storesComputedSha256() throws Exception { + SdkHttpFullRequest.Builder requestBuilder = createHttpRequestBuilder(); + boolean isStreaming = false; + + RequestExecutionContext ctx = syncFlexibleChecksumRequiredRequestContext(isStreaming); + + PayloadChecksumStore store = PayloadChecksumStore.create(); + + ctx.executionAttributes().putAttribute(SdkInternalExecutionAttribute.CHECKSUM_STORE, store); + + syncStage.execute(requestBuilder, ctx); + + assertThat(requestBuilder.headers().get(CHECKSUM_SPECS_HEADER)).containsExactly(PAYLOAD_CHECKSUM_SHA256); + // Test context uses SHA-256 as the flexible checksum + assertThat(store.getChecksumValue(DefaultChecksumAlgorithm.SHA256)).isEqualTo(BinaryUtils.fromBase64(PAYLOAD_CHECKSUM_SHA256)); + } + @Test public void async_flexibleChecksumInHeaderRequired_addsFlexibleChecksumInHeader_doesNotAddMd5ChecksumAndFlexibleChecksumInTrailer() throws Exception { SdkHttpFullRequest.Builder requestBuilder = createHttpRequestBuilder(); @@ -199,7 +312,7 @@ public void async_flexibleChecksumInHeaderRequired_addsFlexibleChecksumInHeader_ asyncStage.execute(requestBuilder, ctx); - assertThat(requestBuilder.headers().get(CHECKSUM_SPECS_HEADER)).containsExactly("/T5YuTxNWthvWXg+TJMwl60XKcAnLMrrOZe/jA9Y+eI="); + assertThat(requestBuilder.headers().get(CHECKSUM_SPECS_HEADER)).containsExactly(PAYLOAD_CHECKSUM_SHA256); assertThat(requestBuilder.firstMatchingHeader(HEADER_FOR_TRAILER_REFERENCE)).isEmpty(); assertThat(requestBuilder.firstMatchingHeader("Content-encoding")).isEmpty(); @@ -209,8 +322,21 @@ public void async_flexibleChecksumInHeaderRequired_addsFlexibleChecksumInHeader_ assertThat(requestBuilder.firstMatchingHeader(CONTENT_MD5)).isEmpty(); } + private static String getTrailingChecksum(String payload) { + for (String line : payload.split("\r\n")) { + if (line.startsWith("x-amz-checksum")) { + return line; + } + } + return null; + } + private SdkHttpFullRequest.Builder createHttpRequestBuilder() { - return SdkHttpFullRequest.builder().contentStreamProvider(REQUEST_BODY.contentStreamProvider()); + return SdkHttpFullRequest.builder() + .method(SdkHttpMethod.GET) + .protocol("https") + .host("sdk.aws") + .contentStreamProvider(REQUEST_BODY.contentStreamProvider()); } private RequestExecutionContext md5RequiredRequestContext(boolean isAsyncStreaming) { From 0f3dd8aab9ca744914e2a19b9c55c77d28b09aff Mon Sep 17 00:00:00 2001 From: Dongie Agnir Date: Thu, 25 Sep 2025 15:30:36 -0700 Subject: [PATCH 2/2] S3 testing with non-SRA --- .../s3/checksums/ChecksumReuseTest.java | 63 +++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/checksums/ChecksumReuseTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/checksums/ChecksumReuseTest.java index 4158945ef302..80f3009734d6 100644 --- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/checksums/ChecksumReuseTest.java +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/checksums/ChecksumReuseTest.java @@ -37,6 +37,7 @@ import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; +import software.amazon.awssdk.auth.signer.AwsS3V4Signer; import software.amazon.awssdk.core.async.AsyncRequestBody; import software.amazon.awssdk.core.checksums.RequestChecksumCalculation; import software.amazon.awssdk.core.exception.SdkException; @@ -109,6 +110,35 @@ public void putObject_serverResponds500_usesSameChecksumOnRetries() { assertAllTrailingChecksumsMatch(httpClient.requestPayloads); } + @Test + public void putObject_nonSra_serverResponds500_usesSameChecksumOnRetries() { + MockHttpClient httpClient = new MockHttpClient(); + + S3Client s3 = S3Client.builder() + .region(Region.US_WEST_2) + .credentialsProvider(CREDENTIALS_PROVIDER) + .requestChecksumCalculation(RequestChecksumCalculation.WHEN_SUPPORTED) + .httpClient(httpClient) + .overrideConfiguration(o -> o.retryStrategy(StandardRetryStrategy.builder() + .maxAttempts(4) + .backoffStrategy(BackoffStrategy.retryImmediately()) + .build())) + .build(); + + RequestBody requestBody = RequestBody.fromInputStream(new RandomInputStream(), 4096); + + assertThatThrownBy(() -> s3.putObject(r -> r.bucket(BUCKET) + .key(KEY) + .checksumAlgorithm(ChecksumAlgorithm.CRC32) + .overrideConfiguration(o -> o.signer(AwsS3V4Signer.create())), + requestBody)) + .isInstanceOf(S3Exception.class) + // Ensure we actually retried + .matches(e -> ((SdkException) e).numAttempts() == 4); + + assertAllTrailingChecksumsMatch(httpClient.requestPayloads); + } + @Test void asyncPutObject_serverResponds500_usesSameChecksumOnRetries() { MockAsyncHttpClient httpClient = new MockAsyncHttpClient(); @@ -138,6 +168,39 @@ void asyncPutObject_serverResponds500_usesSameChecksumOnRetries() { assertAllTrailingChecksumsMatch(httpClient.requestPayloads); } + @Test + void asyncPutObject_nonSra_serverResponds500_usesSameChecksumOnRetries() { + MockAsyncHttpClient httpClient = new MockAsyncHttpClient(); + + S3AsyncClient s3 = S3AsyncClient.builder() + .region(Region.US_WEST_2) + .credentialsProvider(CREDENTIALS_PROVIDER) + .requestChecksumCalculation(RequestChecksumCalculation.WHEN_SUPPORTED) + .httpClient(httpClient) + .overrideConfiguration(o -> o.retryStrategy(StandardRetryStrategy.builder() + .maxAttempts(4) + .backoffStrategy(BackoffStrategy.retryImmediately()) + .build())) + .build(); + + AsyncRequestBody requestBody = AsyncRequestBody.fromInputStream(new RandomInputStream(), + 4096L, + executorService); + + CompletableFuture responseFuture = + s3.putObject(r -> r.bucket(BUCKET) + .key(KEY) + .checksumAlgorithm(ChecksumAlgorithm.CRC32) + .overrideConfiguration(o -> o.signer(AwsS3V4Signer.create())), + requestBody); + + assertThatThrownBy(responseFuture::join) + .hasCauseInstanceOf(S3Exception.class) + .matches(e -> ((SdkException) e.getCause()).numAttempts() == 4); + + assertAllTrailingChecksumsMatch(httpClient.requestPayloads); + } + private void assertAllTrailingChecksumsMatch(List requestPayloads) { List trailingChecksumHeaders = new ArrayList<>();