diff --git a/pom.xml b/pom.xml
index 1f7679a34d04..79de46b6356d 100644
--- a/pom.xml
+++ b/pom.xml
@@ -119,7 +119,7 @@
2.2.21
1.15
1.29
- 0.29.1
+ 0.29.2
5.10.0
diff --git a/services/s3/src/it/java/software/amazon/awssdk/services/s3/crt/ChecksumIntegrationTest.java b/services/s3/src/it/java/software/amazon/awssdk/services/s3/crt/CrtChecksumIntegrationTest.java
similarity index 76%
rename from services/s3/src/it/java/software/amazon/awssdk/services/s3/crt/ChecksumIntegrationTest.java
rename to services/s3/src/it/java/software/amazon/awssdk/services/s3/crt/CrtChecksumIntegrationTest.java
index 5dca87203369..d5123f340b17 100644
--- a/services/s3/src/it/java/software/amazon/awssdk/services/s3/crt/ChecksumIntegrationTest.java
+++ b/services/s3/src/it/java/software/amazon/awssdk/services/s3/crt/CrtChecksumIntegrationTest.java
@@ -22,12 +22,11 @@
import java.nio.file.Files;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
-import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import software.amazon.awssdk.core.ResponseBytes;
import software.amazon.awssdk.core.async.AsyncRequestBody;
import software.amazon.awssdk.core.async.AsyncResponseTransformer;
-import software.amazon.awssdk.crt.CrtResource;
+import software.amazon.awssdk.core.checksums.Algorithm;
import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.S3IntegrationTestBase;
import software.amazon.awssdk.services.s3.internal.crt.S3CrtAsyncClient;
@@ -37,15 +36,20 @@
import software.amazon.awssdk.services.s3.model.PutObjectTaggingResponse;
import software.amazon.awssdk.services.s3.model.Tag;
import software.amazon.awssdk.services.s3.model.Tagging;
+import software.amazon.awssdk.services.s3.utils.ChecksumUtils;
import software.amazon.awssdk.testutils.RandomTempFile;
import software.amazon.awssdk.testutils.service.AwsTestBase;
-public class ChecksumIntegrationTest extends S3IntegrationTestBase {
- private static final String TEST_BUCKET = temporaryBucketName(ChecksumIntegrationTest.class);
+public class CrtChecksumIntegrationTest extends S3IntegrationTestBase {
+ private static final String TEST_BUCKET = temporaryBucketName(CrtChecksumIntegrationTest.class);
private static final String TEST_KEY = "10mib_file.dat";
private static final int OBJ_SIZE = 10 * 1024 * 1024;
private static RandomTempFile testFile;
+
+ private static String testFileSha1;
+ private static String testFileCrc32;
+
private static S3AsyncClient s3Crt;
@BeforeAll
@@ -54,10 +58,15 @@ public static void setup() throws Exception {
S3IntegrationTestBase.createBucket(TEST_BUCKET);
testFile = new RandomTempFile(TEST_KEY, OBJ_SIZE);
+ testFileSha1 = ChecksumUtils.calculatedChecksum(testFile.toPath(), Algorithm.SHA1);
+ testFileCrc32 = ChecksumUtils.calculatedChecksum(testFile.toPath(), Algorithm.CRC32);
s3Crt = S3CrtAsyncClient.builder()
.credentialsProvider(AwsTestBase.CREDENTIALS_PROVIDER_CHAIN)
.region(S3IntegrationTestBase.DEFAULT_REGION)
+ // make sure we don't do a multipart upload, it will mess with validation against the precomputed
+ // checksums above
+ .thresholdInBytes(2L * OBJ_SIZE)
.build();
}
@@ -72,37 +81,33 @@ public static void teardown() throws IOException {
@Test
void noChecksumCustomization_crc32ShouldBeUsed() {
AsyncRequestBody body = AsyncRequestBody.fromFile(testFile.toPath());
- PutObjectResponse putObjectResponse =
- s3Crt.putObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY), body).join();
- assertThat(putObjectResponse).isNotNull();
+ s3Crt.putObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY), body).join();
ResponseBytes getObjectResponseResponseBytes =
- s3Crt.getObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY).partNumber(1), AsyncResponseTransformer.toBytes()).join();
+ s3Crt.getObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY), AsyncResponseTransformer.toBytes()).join();
String getObjectChecksum = getObjectResponseResponseBytes.response().checksumCRC32();
- assertThat(getObjectChecksum).isNotNull();
+ assertThat(getObjectChecksum).isEqualTo(testFileCrc32);
}
@Test
void putObject_checksumProvidedInRequest_shouldTakePrecendence() {
AsyncRequestBody body = AsyncRequestBody.fromFile(testFile.toPath());
- PutObjectResponse putObjectResponse =
- s3Crt.putObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY).checksumAlgorithm(ChecksumAlgorithm.SHA1), body).join();
- assertThat(putObjectResponse).isNotNull();
+ s3Crt.putObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY).checksumAlgorithm(ChecksumAlgorithm.SHA1), body).join();
ResponseBytes getObjectResponseResponseBytes =
- s3Crt.getObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY).partNumber(1), AsyncResponseTransformer.toBytes()).join();
+ s3Crt.getObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY), AsyncResponseTransformer.toBytes()).join();
String getObjectChecksum = getObjectResponseResponseBytes.response().checksumSHA1();
- assertThat(getObjectChecksum).isNotNull();
+ assertThat(getObjectChecksum).isEqualTo(testFileSha1);
}
@Test
void checksumDisabled_shouldNotPerformChecksumValidationByDefault() {
try (S3AsyncClient s3Crt = S3CrtAsyncClient.builder()
- .credentialsProvider(AwsTestBase.CREDENTIALS_PROVIDER_CHAIN)
- .region(S3IntegrationTestBase.DEFAULT_REGION)
- .checksumValidationEnabled(Boolean.FALSE)
- .build()) {
+ .credentialsProvider(AwsTestBase.CREDENTIALS_PROVIDER_CHAIN)
+ .region(S3IntegrationTestBase.DEFAULT_REGION)
+ .checksumValidationEnabled(Boolean.FALSE)
+ .build()) {
AsyncRequestBody body = AsyncRequestBody.fromFile(testFile.toPath());
PutObjectResponse putObjectResponse =
s3Crt.putObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY), body).join();
diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crt/S3CrtResponseHandlerAdapter.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crt/S3CrtResponseHandlerAdapter.java
index b9979e1d5628..de04329326a5 100644
--- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crt/S3CrtResponseHandlerAdapter.java
+++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crt/S3CrtResponseHandlerAdapter.java
@@ -45,11 +45,13 @@ public final class S3CrtResponseHandlerAdapter implements S3MetaRequestResponseH
private final SimplePublisher responsePublisher = new SimplePublisher<>();
- private final SdkHttpResponse.Builder respBuilder = SdkHttpResponse.builder();
+ private final SdkHttpResponse.Builder initialHeadersResponse = SdkHttpResponse.builder();
private volatile S3MetaRequest metaRequest;
private final PublisherListener progressListener;
+ private volatile boolean responseHandlingInitiated;
+
public S3CrtResponseHandlerAdapter(CompletableFuture executeFuture,
SdkAsyncHttpResponseHandler responseHandler,
PublisherListener progressListener) {
@@ -60,17 +62,17 @@ public S3CrtResponseHandlerAdapter(CompletableFuture executeFuture,
@Override
public void onResponseHeaders(int statusCode, HttpHeader[] headers) {
- for (HttpHeader h : headers) {
- respBuilder.appendHeader(h.getName(), h.getValue());
- }
-
- respBuilder.statusCode(statusCode);
- responseHandler.onHeaders(respBuilder.build());
- responseHandler.onStream(responsePublisher);
+ // Note, we cannot call responseHandler.onHeaders() here because the response status code and headers may not represent
+ // whether the request has succeeded or not (e.g. if this is for a HeadObject call that CRT calls under the hood). We
+ // need to rely on onResponseBody/onFinished being called to determine this.
+ populateSdkHttpResponse(initialHeadersResponse, statusCode, headers);
}
@Override
public int onResponseBody(ByteBuffer bodyBytesIn, long objectRangeStart, long objectRangeEnd) {
+ // See reasoning in onResponseHeaders for why we call this here and not there.
+ initiateResponseHandling(initialHeadersResponse.build());
+
if (bodyBytesIn == null) {
failResponseHandlerAndFuture(new IllegalStateException("ByteBuffer delivered is null"));
return 0;
@@ -98,6 +100,10 @@ public void onFinished(S3FinishedResponseContext context) {
if (crtCode != CRT.AWS_CRT_SUCCESS) {
handleError(context);
} else {
+ // onResponseBody() is not invoked for responses with no content, so we may not have invoked
+ // SdkAsyncHttpResponseHandler#onHeaders yet.
+ // See also reasoning in onResponseHeaders for why we call this here and not there.
+ initiateResponseHandling(initialHeadersResponse.build());
onSuccessfulResponseComplete();
}
}
@@ -127,10 +133,14 @@ public void cancelRequest() {
private void handleError(S3FinishedResponseContext context) {
int crtCode = context.getErrorCode();
+ HttpHeader[] headers = context.getErrorHeaders();
int responseStatus = context.getResponseStatus();
byte[] errorPayload = context.getErrorPayload();
if (isErrorResponse(responseStatus) && errorPayload != null) {
+ SdkHttpResponse.Builder errorResponse = populateSdkHttpResponse(SdkHttpResponse.builder(),
+ responseStatus, headers);
+ initiateResponseHandling(errorResponse.build());
onErrorResponseComplete(errorPayload);
} else {
Throwable cause = context.getCause();
@@ -142,6 +152,14 @@ private void handleError(S3FinishedResponseContext context) {
}
}
+ private void initiateResponseHandling(SdkHttpResponse response) {
+ if (!responseHandlingInitiated) {
+ responseHandlingInitiated = true;
+ responseHandler.onHeaders(response);
+ responseHandler.onStream(responsePublisher);
+ }
+ }
+
private void onErrorResponseComplete(byte[] errorPayload) {
responsePublisher.send(ByteBuffer.wrap(errorPayload))
.thenRun(responsePublisher::complete)
@@ -176,6 +194,17 @@ public void onProgress(S3MetaRequestProgress progress) {
this.progressListener.subscriberOnNext(progress);
}
+ private static SdkHttpResponse.Builder populateSdkHttpResponse(SdkHttpResponse.Builder respBuilder,
+ int statusCode, HttpHeader[] headers) {
+ if (headers != null) {
+ for (HttpHeader h : headers) {
+ respBuilder.appendHeader(h.getName(), h.getValue());
+ }
+ }
+ respBuilder.statusCode(statusCode);
+ return respBuilder;
+ }
+
private static class NoOpPublisherListener implements PublisherListener {
}
}
diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/crt/CrtDownloadErrorTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/crt/CrtDownloadErrorTest.java
new file mode 100644
index 000000000000..df1d717d866e
--- /dev/null
+++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/crt/CrtDownloadErrorTest.java
@@ -0,0 +1,145 @@
+/*
+ * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License").
+ * You may not use this file except in compliance with the License.
+ * A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0
+ *
+ * or in the "license" file accompanying this file. This file is distributed
+ * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
+ * express or implied. See the License for the specific language governing
+ * permissions and limitations under the License.
+ */
+
+package software.amazon.awssdk.services.s3.crt;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
+import com.github.tomakehurst.wiremock.WireMockServer;
+import com.github.tomakehurst.wiremock.client.WireMock;
+import com.github.tomakehurst.wiremock.core.WireMockConfiguration;
+import java.net.URI;
+import java.nio.charset.StandardCharsets;
+import org.junit.jupiter.api.AfterAll;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.Test;
+import software.amazon.awssdk.core.async.AsyncResponseTransformer;
+import software.amazon.awssdk.crt.Log;
+import software.amazon.awssdk.regions.Region;
+import software.amazon.awssdk.services.s3.S3AsyncClient;
+import software.amazon.awssdk.services.s3.model.S3Exception;
+
+public class CrtDownloadErrorTest {
+ private static final String BUCKET = "my-bucket";
+ private static final String KEY = "my-key";
+ private static final WireMockServer WM = new WireMockServer(WireMockConfiguration.wireMockConfig().dynamicPort());
+ private S3AsyncClient s3;
+
+ @BeforeAll
+ public static void setup() {
+ WM.start();
+ // Execute this statement before constructing the SDK service client.
+ Log.initLoggingToStdout(Log.LogLevel.Trace);
+ }
+
+ @AfterAll
+ public static void teardown() {
+ WM.stop();
+ }
+
+ @AfterEach
+ public void methodTeardown() {
+ if (s3 != null) {
+ s3.close();
+ }
+ s3 = null;
+ }
+
+ @Test
+ public void getObject_headObjectOk_getObjectThrows_operationThrows() {
+ s3 = S3AsyncClient.crtBuilder()
+ .endpointOverride(URI.create("http://localhost:" + WM.port()))
+ .forcePathStyle(true)
+ .region(Region.US_EAST_1)
+ .build();
+
+ String path = String.format("/%s/%s", BUCKET, KEY);
+
+ WM.stubFor(WireMock.head(WireMock.urlPathEqualTo(path))
+ .willReturn(WireMock.aResponse()
+ .withStatus(200)
+ .withHeader("ETag", "etag")
+ .withHeader("Content-Length", "5")));
+
+ String errorContent = ""
+ + "\n"
+ + " AccessDenied
\n"
+ + " User does not have permission\n"
+ + " request-id\n"
+ + " host-id\n"
+ + "";
+ WM.stubFor(WireMock.get(WireMock.urlPathEqualTo(path))
+ .willReturn(WireMock.aResponse()
+ .withStatus(403)
+ .withBody(errorContent)));
+
+ assertThatThrownBy(s3.getObject(r -> r.bucket(BUCKET).key(KEY), AsyncResponseTransformer.toBytes())::join)
+ .hasCauseInstanceOf(S3Exception.class)
+ .hasMessageContaining("User does not have permission")
+ .hasMessageContaining("Status Code: 403");
+ }
+
+ @Test
+ public void getObject_headObjectOk_getObjectOk_operationSucceeds() {
+ s3 = S3AsyncClient.crtBuilder()
+ .endpointOverride(URI.create("http://localhost:" + WM.port()))
+ .forcePathStyle(true)
+ .region(Region.US_EAST_1)
+ .build();
+
+ String path = String.format("/%s/%s", BUCKET, KEY);
+
+ byte[] content = "hello".getBytes(StandardCharsets.UTF_8);
+
+ WM.stubFor(WireMock.head(WireMock.urlPathEqualTo(path))
+ .willReturn(WireMock.aResponse()
+ .withStatus(200)
+ .withHeader("ETag", "etag")
+ .withHeader("Content-Length", Integer.toString(content.length))));
+ WM.stubFor(WireMock.get(WireMock.urlPathEqualTo(path))
+ .willReturn(WireMock.aResponse()
+ .withStatus(200)
+ .withHeader("Content-Type", "text/plain")
+ .withBody(content)));
+
+ String objectContent = s3.getObject(r -> r.bucket(BUCKET).key(KEY), AsyncResponseTransformer.toBytes())
+ .join()
+ .asUtf8String();
+
+ assertThat(objectContent.getBytes(StandardCharsets.UTF_8)).isEqualTo(content);
+ }
+
+ @Test
+ public void getObject_headObjectThrows_operationThrows() {
+ s3 = S3AsyncClient.crtBuilder()
+ .endpointOverride(URI.create("http://localhost:" + WM.port()))
+ .forcePathStyle(true)
+ .region(Region.US_EAST_1)
+ .build();
+
+ String path = String.format("/%s/%s", BUCKET, KEY);
+
+
+ WM.stubFor(WireMock.head(WireMock.urlPathEqualTo(path))
+ .willReturn(WireMock.aResponse()
+ .withStatus(403)));
+
+ assertThatThrownBy(s3.getObject(r -> r.bucket(BUCKET).key(KEY), AsyncResponseTransformer.toBytes())::join)
+ .hasCauseInstanceOf(S3Exception.class)
+ .hasMessageContaining("Status Code: 403");
+ }
+}
diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/crt/S3CrtResponseHandlerAdapterTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/crt/S3CrtResponseHandlerAdapterTest.java
index bd1aea915ef2..dbd86d3be6d8 100644
--- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/crt/S3CrtResponseHandlerAdapterTest.java
+++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/crt/S3CrtResponseHandlerAdapterTest.java
@@ -19,6 +19,7 @@
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@@ -30,7 +31,6 @@
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
-import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.junit.MockitoJUnitRunner;
@@ -59,7 +59,7 @@ public class S3CrtResponseHandlerAdapterTest {
@Before
public void setup() {
future = new CompletableFuture<>();
- sdkResponseHandler = new TestResponseHandler();
+ sdkResponseHandler = spy(new TestResponseHandler());
responseHandlerAdapter = new S3CrtResponseHandlerAdapter(future,
sdkResponseHandler,
null);
@@ -75,17 +75,20 @@ public void successfulResponse_shouldCompleteFutureSuccessfully() throws Excepti
int statusCode = 200;
responseHandlerAdapter.onResponseHeaders(statusCode, httpHeaders);
+ stubOnResponseBody();
+
+ responseHandlerAdapter.onFinished(stubResponseContext(0, 0, null));
+ future.get(5, TimeUnit.SECONDS);
+
SdkHttpResponse actualSdkHttpResponse = sdkResponseHandler.sdkHttpResponse;
assertThat(actualSdkHttpResponse.statusCode()).isEqualTo(statusCode);
assertThat(actualSdkHttpResponse.firstMatchingHeader("foo")).contains("1");
assertThat(actualSdkHttpResponse.firstMatchingHeader("bar")).contains("2");
- stubOnResponseBody();
- responseHandlerAdapter.onFinished(stubResponseContext(0, 0, null));
- future.get(5, TimeUnit.SECONDS);
assertThat(future).isCompleted();
verify(s3MetaRequest, times(2)).incrementReadWindow(11L);
verify(s3MetaRequest).close();
+ verify(sdkResponseHandler).onHeaders(any(SdkHttpResponse.class));
}
@Test
@@ -103,6 +106,7 @@ public void nullByteBuffer_shouldCompleteFutureExceptionally() {
+ "null");
assertThat(future).isCompletedExceptionally();
verify(s3MetaRequest).close();
+ verify(sdkResponseHandler).onHeaders(any(SdkHttpResponse.class));
}
@Test
@@ -110,17 +114,17 @@ public void errorResponse_shouldCompleteFutureSuccessfully() {
int statusCode = 400;
responseHandlerAdapter.onResponseHeaders(statusCode, new HttpHeader[0]);
- SdkHttpResponse actualSdkHttpResponse = sdkResponseHandler.sdkHttpResponse;
- assertThat(actualSdkHttpResponse.statusCode()).isEqualTo(400);
- assertThat(actualSdkHttpResponse.headers()).isEmpty();
-
byte[] errorPayload = "errorResponse".getBytes(StandardCharsets.UTF_8);
stubOnResponseBody();
-
responseHandlerAdapter.onFinished(stubResponseContext(1, statusCode, errorPayload));
+ SdkHttpResponse actualSdkHttpResponse = sdkResponseHandler.sdkHttpResponse;
+ assertThat(actualSdkHttpResponse.statusCode()).isEqualTo(400);
+ assertThat(actualSdkHttpResponse.headers()).isEmpty();
+
assertThat(future).isCompleted();
verify(s3MetaRequest).close();
+ verify(sdkResponseHandler).onHeaders(any(SdkHttpResponse.class));
}
@Test
@@ -164,7 +168,7 @@ private void stubOnResponseBody() {
responseHandlerAdapter.onResponseBody(ByteBuffer.wrap("helloworld2".getBytes()), 1, 2);
}
- private static final class TestResponseHandler implements SdkAsyncHttpResponseHandler {
+ private static class TestResponseHandler implements SdkAsyncHttpResponseHandler {
private SdkHttpResponse sdkHttpResponse;
private Throwable error;
@Override
diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/utils/ChecksumUtils.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/utils/ChecksumUtils.java
index 22a313f4186e..0c4664072b70 100644
--- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/utils/ChecksumUtils.java
+++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/utils/ChecksumUtils.java
@@ -15,17 +15,15 @@
package software.amazon.awssdk.services.s3.utils;
-import java.io.File;
import java.io.IOException;
import java.io.InputStream;
-import java.io.PrintWriter;
-import java.io.RandomAccessFile;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
+import java.nio.file.Files;
+import java.nio.file.Path;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.List;
-import java.util.Random;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import software.amazon.awssdk.core.checksums.Algorithm;
@@ -88,6 +86,20 @@ public static String calculatedChecksum(String contentString, Algorithm algorith
return BinaryUtils.toBase64(sdkChecksum.getChecksumBytes());
}
+ public static String calculatedChecksum(Path path, Algorithm algorithm) {
+ SdkChecksum sdkChecksum = SdkChecksum.forAlgorithm(algorithm);
+ try (InputStream is = Files.newInputStream(path)) {
+ byte[] buffer = new byte[4096];
+ int read;
+ while ((read = is.read(buffer)) != -1) {
+ sdkChecksum.update(buffer, 0, read);
+ }
+ return BinaryUtils.toBase64(sdkChecksum.getChecksumBytes());
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
public static String createDataOfSize(int dataSize, char contentCharacter) {
return IntStream.range(0, dataSize).mapToObj(i -> String.valueOf(contentCharacter)).collect(Collectors.joining());
}