diff --git a/pom.xml b/pom.xml index 1f7679a34d04..79de46b6356d 100644 --- a/pom.xml +++ b/pom.xml @@ -119,7 +119,7 @@ 2.2.21 1.15 1.29 - 0.29.1 + 0.29.2 5.10.0 diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crt/S3CrtResponseHandlerAdapter.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crt/S3CrtResponseHandlerAdapter.java index b9979e1d5628..de04329326a5 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crt/S3CrtResponseHandlerAdapter.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/crt/S3CrtResponseHandlerAdapter.java @@ -45,11 +45,13 @@ public final class S3CrtResponseHandlerAdapter implements S3MetaRequestResponseH private final SimplePublisher responsePublisher = new SimplePublisher<>(); - private final SdkHttpResponse.Builder respBuilder = SdkHttpResponse.builder(); + private final SdkHttpResponse.Builder initialHeadersResponse = SdkHttpResponse.builder(); private volatile S3MetaRequest metaRequest; private final PublisherListener progressListener; + private volatile boolean responseHandlingInitiated; + public S3CrtResponseHandlerAdapter(CompletableFuture executeFuture, SdkAsyncHttpResponseHandler responseHandler, PublisherListener progressListener) { @@ -60,17 +62,17 @@ public S3CrtResponseHandlerAdapter(CompletableFuture executeFuture, @Override public void onResponseHeaders(int statusCode, HttpHeader[] headers) { - for (HttpHeader h : headers) { - respBuilder.appendHeader(h.getName(), h.getValue()); - } - - respBuilder.statusCode(statusCode); - responseHandler.onHeaders(respBuilder.build()); - responseHandler.onStream(responsePublisher); + // Note, we cannot call responseHandler.onHeaders() here because the response status code and headers may not represent + // whether the request has succeeded or not (e.g. if this is for a HeadObject call that CRT calls under the hood). We + // need to rely on onResponseBody/onFinished being called to determine this. + populateSdkHttpResponse(initialHeadersResponse, statusCode, headers); } @Override public int onResponseBody(ByteBuffer bodyBytesIn, long objectRangeStart, long objectRangeEnd) { + // See reasoning in onResponseHeaders for why we call this here and not there. + initiateResponseHandling(initialHeadersResponse.build()); + if (bodyBytesIn == null) { failResponseHandlerAndFuture(new IllegalStateException("ByteBuffer delivered is null")); return 0; @@ -98,6 +100,10 @@ public void onFinished(S3FinishedResponseContext context) { if (crtCode != CRT.AWS_CRT_SUCCESS) { handleError(context); } else { + // onResponseBody() is not invoked for responses with no content, so we may not have invoked + // SdkAsyncHttpResponseHandler#onHeaders yet. + // See also reasoning in onResponseHeaders for why we call this here and not there. + initiateResponseHandling(initialHeadersResponse.build()); onSuccessfulResponseComplete(); } } @@ -127,10 +133,14 @@ public void cancelRequest() { private void handleError(S3FinishedResponseContext context) { int crtCode = context.getErrorCode(); + HttpHeader[] headers = context.getErrorHeaders(); int responseStatus = context.getResponseStatus(); byte[] errorPayload = context.getErrorPayload(); if (isErrorResponse(responseStatus) && errorPayload != null) { + SdkHttpResponse.Builder errorResponse = populateSdkHttpResponse(SdkHttpResponse.builder(), + responseStatus, headers); + initiateResponseHandling(errorResponse.build()); onErrorResponseComplete(errorPayload); } else { Throwable cause = context.getCause(); @@ -142,6 +152,14 @@ private void handleError(S3FinishedResponseContext context) { } } + private void initiateResponseHandling(SdkHttpResponse response) { + if (!responseHandlingInitiated) { + responseHandlingInitiated = true; + responseHandler.onHeaders(response); + responseHandler.onStream(responsePublisher); + } + } + private void onErrorResponseComplete(byte[] errorPayload) { responsePublisher.send(ByteBuffer.wrap(errorPayload)) .thenRun(responsePublisher::complete) @@ -176,6 +194,17 @@ public void onProgress(S3MetaRequestProgress progress) { this.progressListener.subscriberOnNext(progress); } + private static SdkHttpResponse.Builder populateSdkHttpResponse(SdkHttpResponse.Builder respBuilder, + int statusCode, HttpHeader[] headers) { + if (headers != null) { + for (HttpHeader h : headers) { + respBuilder.appendHeader(h.getName(), h.getValue()); + } + } + respBuilder.statusCode(statusCode); + return respBuilder; + } + private static class NoOpPublisherListener implements PublisherListener { } } diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/crt/CrtDownloadErrorTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/crt/CrtDownloadErrorTest.java new file mode 100644 index 000000000000..df1d717d866e --- /dev/null +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/crt/CrtDownloadErrorTest.java @@ -0,0 +1,145 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.crt; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import com.github.tomakehurst.wiremock.WireMockServer; +import com.github.tomakehurst.wiremock.client.WireMock; +import com.github.tomakehurst.wiremock.core.WireMockConfiguration; +import java.net.URI; +import java.nio.charset.StandardCharsets; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import software.amazon.awssdk.core.async.AsyncResponseTransformer; +import software.amazon.awssdk.crt.Log; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.model.S3Exception; + +public class CrtDownloadErrorTest { + private static final String BUCKET = "my-bucket"; + private static final String KEY = "my-key"; + private static final WireMockServer WM = new WireMockServer(WireMockConfiguration.wireMockConfig().dynamicPort()); + private S3AsyncClient s3; + + @BeforeAll + public static void setup() { + WM.start(); + // Execute this statement before constructing the SDK service client. + Log.initLoggingToStdout(Log.LogLevel.Trace); + } + + @AfterAll + public static void teardown() { + WM.stop(); + } + + @AfterEach + public void methodTeardown() { + if (s3 != null) { + s3.close(); + } + s3 = null; + } + + @Test + public void getObject_headObjectOk_getObjectThrows_operationThrows() { + s3 = S3AsyncClient.crtBuilder() + .endpointOverride(URI.create("http://localhost:" + WM.port())) + .forcePathStyle(true) + .region(Region.US_EAST_1) + .build(); + + String path = String.format("/%s/%s", BUCKET, KEY); + + WM.stubFor(WireMock.head(WireMock.urlPathEqualTo(path)) + .willReturn(WireMock.aResponse() + .withStatus(200) + .withHeader("ETag", "etag") + .withHeader("Content-Length", "5"))); + + String errorContent = "" + + "\n" + + " AccessDenied\n" + + " User does not have permission\n" + + " request-id\n" + + " host-id\n" + + ""; + WM.stubFor(WireMock.get(WireMock.urlPathEqualTo(path)) + .willReturn(WireMock.aResponse() + .withStatus(403) + .withBody(errorContent))); + + assertThatThrownBy(s3.getObject(r -> r.bucket(BUCKET).key(KEY), AsyncResponseTransformer.toBytes())::join) + .hasCauseInstanceOf(S3Exception.class) + .hasMessageContaining("User does not have permission") + .hasMessageContaining("Status Code: 403"); + } + + @Test + public void getObject_headObjectOk_getObjectOk_operationSucceeds() { + s3 = S3AsyncClient.crtBuilder() + .endpointOverride(URI.create("http://localhost:" + WM.port())) + .forcePathStyle(true) + .region(Region.US_EAST_1) + .build(); + + String path = String.format("/%s/%s", BUCKET, KEY); + + byte[] content = "hello".getBytes(StandardCharsets.UTF_8); + + WM.stubFor(WireMock.head(WireMock.urlPathEqualTo(path)) + .willReturn(WireMock.aResponse() + .withStatus(200) + .withHeader("ETag", "etag") + .withHeader("Content-Length", Integer.toString(content.length)))); + WM.stubFor(WireMock.get(WireMock.urlPathEqualTo(path)) + .willReturn(WireMock.aResponse() + .withStatus(200) + .withHeader("Content-Type", "text/plain") + .withBody(content))); + + String objectContent = s3.getObject(r -> r.bucket(BUCKET).key(KEY), AsyncResponseTransformer.toBytes()) + .join() + .asUtf8String(); + + assertThat(objectContent.getBytes(StandardCharsets.UTF_8)).isEqualTo(content); + } + + @Test + public void getObject_headObjectThrows_operationThrows() { + s3 = S3AsyncClient.crtBuilder() + .endpointOverride(URI.create("http://localhost:" + WM.port())) + .forcePathStyle(true) + .region(Region.US_EAST_1) + .build(); + + String path = String.format("/%s/%s", BUCKET, KEY); + + + WM.stubFor(WireMock.head(WireMock.urlPathEqualTo(path)) + .willReturn(WireMock.aResponse() + .withStatus(403))); + + assertThatThrownBy(s3.getObject(r -> r.bucket(BUCKET).key(KEY), AsyncResponseTransformer.toBytes())::join) + .hasCauseInstanceOf(S3Exception.class) + .hasMessageContaining("Status Code: 403"); + } +} diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/crt/S3CrtResponseHandlerAdapterTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/crt/S3CrtResponseHandlerAdapterTest.java index bd1aea915ef2..dbd86d3be6d8 100644 --- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/crt/S3CrtResponseHandlerAdapterTest.java +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/crt/S3CrtResponseHandlerAdapterTest.java @@ -19,6 +19,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -30,7 +31,6 @@ import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; -import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.Mockito; import org.mockito.junit.MockitoJUnitRunner; @@ -59,7 +59,7 @@ public class S3CrtResponseHandlerAdapterTest { @Before public void setup() { future = new CompletableFuture<>(); - sdkResponseHandler = new TestResponseHandler(); + sdkResponseHandler = spy(new TestResponseHandler()); responseHandlerAdapter = new S3CrtResponseHandlerAdapter(future, sdkResponseHandler, null); @@ -75,17 +75,20 @@ public void successfulResponse_shouldCompleteFutureSuccessfully() throws Excepti int statusCode = 200; responseHandlerAdapter.onResponseHeaders(statusCode, httpHeaders); + stubOnResponseBody(); + + responseHandlerAdapter.onFinished(stubResponseContext(0, 0, null)); + future.get(5, TimeUnit.SECONDS); + SdkHttpResponse actualSdkHttpResponse = sdkResponseHandler.sdkHttpResponse; assertThat(actualSdkHttpResponse.statusCode()).isEqualTo(statusCode); assertThat(actualSdkHttpResponse.firstMatchingHeader("foo")).contains("1"); assertThat(actualSdkHttpResponse.firstMatchingHeader("bar")).contains("2"); - stubOnResponseBody(); - responseHandlerAdapter.onFinished(stubResponseContext(0, 0, null)); - future.get(5, TimeUnit.SECONDS); assertThat(future).isCompleted(); verify(s3MetaRequest, times(2)).incrementReadWindow(11L); verify(s3MetaRequest).close(); + verify(sdkResponseHandler).onHeaders(any(SdkHttpResponse.class)); } @Test @@ -103,6 +106,7 @@ public void nullByteBuffer_shouldCompleteFutureExceptionally() { + "null"); assertThat(future).isCompletedExceptionally(); verify(s3MetaRequest).close(); + verify(sdkResponseHandler).onHeaders(any(SdkHttpResponse.class)); } @Test @@ -110,17 +114,17 @@ public void errorResponse_shouldCompleteFutureSuccessfully() { int statusCode = 400; responseHandlerAdapter.onResponseHeaders(statusCode, new HttpHeader[0]); - SdkHttpResponse actualSdkHttpResponse = sdkResponseHandler.sdkHttpResponse; - assertThat(actualSdkHttpResponse.statusCode()).isEqualTo(400); - assertThat(actualSdkHttpResponse.headers()).isEmpty(); - byte[] errorPayload = "errorResponse".getBytes(StandardCharsets.UTF_8); stubOnResponseBody(); - responseHandlerAdapter.onFinished(stubResponseContext(1, statusCode, errorPayload)); + SdkHttpResponse actualSdkHttpResponse = sdkResponseHandler.sdkHttpResponse; + assertThat(actualSdkHttpResponse.statusCode()).isEqualTo(400); + assertThat(actualSdkHttpResponse.headers()).isEmpty(); + assertThat(future).isCompleted(); verify(s3MetaRequest).close(); + verify(sdkResponseHandler).onHeaders(any(SdkHttpResponse.class)); } @Test @@ -164,7 +168,7 @@ private void stubOnResponseBody() { responseHandlerAdapter.onResponseBody(ByteBuffer.wrap("helloworld2".getBytes()), 1, 2); } - private static final class TestResponseHandler implements SdkAsyncHttpResponseHandler { + private static class TestResponseHandler implements SdkAsyncHttpResponseHandler { private SdkHttpResponse sdkHttpResponse; private Throwable error; @Override