Skip to content

Commit

Permalink
Revert "Revert "Wait until response body or error body received to pr…
Browse files Browse the repository at this point in the history
…ocess request (#4786)""

This reverts commit 045bcc4.
  • Loading branch information
dagnir committed Dec 21, 2023
1 parent 0d06311 commit 694cef6
Show file tree
Hide file tree
Showing 4 changed files with 198 additions and 20 deletions.
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@
<rxjava.version>2.2.21</rxjava.version>
<commons-codec.verion>1.15</commons-codec.verion>
<jmh.version>1.29</jmh.version>
<awscrt.version>0.29.1</awscrt.version>
<awscrt.version>0.29.2</awscrt.version>

<!--Test dependencies -->
<junit5.version>5.10.0</junit5.version>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,13 @@ public final class S3CrtResponseHandlerAdapter implements S3MetaRequestResponseH

private final SimplePublisher<ByteBuffer> responsePublisher = new SimplePublisher<>();

private final SdkHttpResponse.Builder respBuilder = SdkHttpResponse.builder();
private final SdkHttpResponse.Builder initialHeadersResponse = SdkHttpResponse.builder();
private volatile S3MetaRequest metaRequest;

private final PublisherListener<S3MetaRequestProgress> progressListener;

private volatile boolean responseHandlingInitiated;

public S3CrtResponseHandlerAdapter(CompletableFuture<Void> executeFuture,
SdkAsyncHttpResponseHandler responseHandler,
PublisherListener<S3MetaRequestProgress> progressListener) {
Expand All @@ -60,17 +62,17 @@ public S3CrtResponseHandlerAdapter(CompletableFuture<Void> executeFuture,

@Override
public void onResponseHeaders(int statusCode, HttpHeader[] headers) {
for (HttpHeader h : headers) {
respBuilder.appendHeader(h.getName(), h.getValue());
}

respBuilder.statusCode(statusCode);
responseHandler.onHeaders(respBuilder.build());
responseHandler.onStream(responsePublisher);
// Note, we cannot call responseHandler.onHeaders() here because the response status code and headers may not represent
// whether the request has succeeded or not (e.g. if this is for a HeadObject call that CRT calls under the hood). We
// need to rely on onResponseBody/onFinished being called to determine this.
populateSdkHttpResponse(initialHeadersResponse, statusCode, headers);
}

@Override
public int onResponseBody(ByteBuffer bodyBytesIn, long objectRangeStart, long objectRangeEnd) {
// See reasoning in onResponseHeaders for why we call this here and not there.
initiateResponseHandling(initialHeadersResponse.build());

if (bodyBytesIn == null) {
failResponseHandlerAndFuture(new IllegalStateException("ByteBuffer delivered is null"));
return 0;
Expand Down Expand Up @@ -98,6 +100,10 @@ public void onFinished(S3FinishedResponseContext context) {
if (crtCode != CRT.AWS_CRT_SUCCESS) {
handleError(context);
} else {
// onResponseBody() is not invoked for responses with no content, so we may not have invoked
// SdkAsyncHttpResponseHandler#onHeaders yet.
// See also reasoning in onResponseHeaders for why we call this here and not there.
initiateResponseHandling(initialHeadersResponse.build());
onSuccessfulResponseComplete();
}
}
Expand Down Expand Up @@ -127,10 +133,14 @@ public void cancelRequest() {

private void handleError(S3FinishedResponseContext context) {
int crtCode = context.getErrorCode();
HttpHeader[] headers = context.getErrorHeaders();
int responseStatus = context.getResponseStatus();
byte[] errorPayload = context.getErrorPayload();

if (isErrorResponse(responseStatus) && errorPayload != null) {
SdkHttpResponse.Builder errorResponse = populateSdkHttpResponse(SdkHttpResponse.builder(),
responseStatus, headers);
initiateResponseHandling(errorResponse.build());
onErrorResponseComplete(errorPayload);
} else {
Throwable cause = context.getCause();
Expand All @@ -142,6 +152,14 @@ private void handleError(S3FinishedResponseContext context) {
}
}

private void initiateResponseHandling(SdkHttpResponse response) {
if (!responseHandlingInitiated) {
responseHandlingInitiated = true;
responseHandler.onHeaders(response);
responseHandler.onStream(responsePublisher);
}
}

private void onErrorResponseComplete(byte[] errorPayload) {
responsePublisher.send(ByteBuffer.wrap(errorPayload))
.thenRun(responsePublisher::complete)
Expand Down Expand Up @@ -176,6 +194,17 @@ public void onProgress(S3MetaRequestProgress progress) {
this.progressListener.subscriberOnNext(progress);
}

private static SdkHttpResponse.Builder populateSdkHttpResponse(SdkHttpResponse.Builder respBuilder,
int statusCode, HttpHeader[] headers) {
if (headers != null) {
for (HttpHeader h : headers) {
respBuilder.appendHeader(h.getName(), h.getValue());
}
}
respBuilder.statusCode(statusCode);
return respBuilder;
}

private static class NoOpPublisherListener implements PublisherListener<S3MetaRequestProgress> {
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://aws.amazon.com/apache2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

package software.amazon.awssdk.services.s3.crt;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

import com.github.tomakehurst.wiremock.WireMockServer;
import com.github.tomakehurst.wiremock.client.WireMock;
import com.github.tomakehurst.wiremock.core.WireMockConfiguration;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import software.amazon.awssdk.core.async.AsyncResponseTransformer;
import software.amazon.awssdk.crt.Log;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.model.S3Exception;

public class CrtDownloadErrorTest {
private static final String BUCKET = "my-bucket";
private static final String KEY = "my-key";
private static final WireMockServer WM = new WireMockServer(WireMockConfiguration.wireMockConfig().dynamicPort());
private S3AsyncClient s3;

@BeforeAll
public static void setup() {
WM.start();
// Execute this statement before constructing the SDK service client.
Log.initLoggingToStdout(Log.LogLevel.Trace);
}

@AfterAll
public static void teardown() {
WM.stop();
}

@AfterEach
public void methodTeardown() {
if (s3 != null) {
s3.close();
}
s3 = null;
}

@Test
public void getObject_headObjectOk_getObjectThrows_operationThrows() {
s3 = S3AsyncClient.crtBuilder()
.endpointOverride(URI.create("http://localhost:" + WM.port()))
.forcePathStyle(true)
.region(Region.US_EAST_1)
.build();

String path = String.format("/%s/%s", BUCKET, KEY);

WM.stubFor(WireMock.head(WireMock.urlPathEqualTo(path))
.willReturn(WireMock.aResponse()
.withStatus(200)
.withHeader("ETag", "etag")
.withHeader("Content-Length", "5")));

String errorContent = ""
+ "<Error>\n"
+ " <Code>AccessDenied</Code>\n"
+ " <Message>User does not have permission</Message>\n"
+ " <RequestId>request-id</RequestId>\n"
+ " <HostId>host-id</HostId>\n"
+ "</Error>";
WM.stubFor(WireMock.get(WireMock.urlPathEqualTo(path))
.willReturn(WireMock.aResponse()
.withStatus(403)
.withBody(errorContent)));

assertThatThrownBy(s3.getObject(r -> r.bucket(BUCKET).key(KEY), AsyncResponseTransformer.toBytes())::join)
.hasCauseInstanceOf(S3Exception.class)
.hasMessageContaining("User does not have permission")
.hasMessageContaining("Status Code: 403");
}

@Test
public void getObject_headObjectOk_getObjectOk_operationSucceeds() {
s3 = S3AsyncClient.crtBuilder()
.endpointOverride(URI.create("http://localhost:" + WM.port()))
.forcePathStyle(true)
.region(Region.US_EAST_1)
.build();

String path = String.format("/%s/%s", BUCKET, KEY);

byte[] content = "hello".getBytes(StandardCharsets.UTF_8);

WM.stubFor(WireMock.head(WireMock.urlPathEqualTo(path))
.willReturn(WireMock.aResponse()
.withStatus(200)
.withHeader("ETag", "etag")
.withHeader("Content-Length", Integer.toString(content.length))));
WM.stubFor(WireMock.get(WireMock.urlPathEqualTo(path))
.willReturn(WireMock.aResponse()
.withStatus(200)
.withHeader("Content-Type", "text/plain")
.withBody(content)));

String objectContent = s3.getObject(r -> r.bucket(BUCKET).key(KEY), AsyncResponseTransformer.toBytes())
.join()
.asUtf8String();

assertThat(objectContent.getBytes(StandardCharsets.UTF_8)).isEqualTo(content);
}

@Test
public void getObject_headObjectThrows_operationThrows() {
s3 = S3AsyncClient.crtBuilder()
.endpointOverride(URI.create("http://localhost:" + WM.port()))
.forcePathStyle(true)
.region(Region.US_EAST_1)
.build();

String path = String.format("/%s/%s", BUCKET, KEY);


WM.stubFor(WireMock.head(WireMock.urlPathEqualTo(path))
.willReturn(WireMock.aResponse()
.withStatus(403)));

assertThatThrownBy(s3.getObject(r -> r.bucket(BUCKET).key(KEY), AsyncResponseTransformer.toBytes())::join)
.hasCauseInstanceOf(S3Exception.class)
.hasMessageContaining("Status Code: 403");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
Expand All @@ -30,7 +31,6 @@
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.junit.MockitoJUnitRunner;
Expand Down Expand Up @@ -59,7 +59,7 @@ public class S3CrtResponseHandlerAdapterTest {
@Before
public void setup() {
future = new CompletableFuture<>();
sdkResponseHandler = new TestResponseHandler();
sdkResponseHandler = spy(new TestResponseHandler());
responseHandlerAdapter = new S3CrtResponseHandlerAdapter(future,
sdkResponseHandler,
null);
Expand All @@ -75,17 +75,20 @@ public void successfulResponse_shouldCompleteFutureSuccessfully() throws Excepti
int statusCode = 200;
responseHandlerAdapter.onResponseHeaders(statusCode, httpHeaders);

stubOnResponseBody();

responseHandlerAdapter.onFinished(stubResponseContext(0, 0, null));
future.get(5, TimeUnit.SECONDS);

SdkHttpResponse actualSdkHttpResponse = sdkResponseHandler.sdkHttpResponse;
assertThat(actualSdkHttpResponse.statusCode()).isEqualTo(statusCode);
assertThat(actualSdkHttpResponse.firstMatchingHeader("foo")).contains("1");
assertThat(actualSdkHttpResponse.firstMatchingHeader("bar")).contains("2");
stubOnResponseBody();

responseHandlerAdapter.onFinished(stubResponseContext(0, 0, null));
future.get(5, TimeUnit.SECONDS);
assertThat(future).isCompleted();
verify(s3MetaRequest, times(2)).incrementReadWindow(11L);
verify(s3MetaRequest).close();
verify(sdkResponseHandler).onHeaders(any(SdkHttpResponse.class));
}

@Test
Expand All @@ -103,24 +106,25 @@ public void nullByteBuffer_shouldCompleteFutureExceptionally() {
+ "null");
assertThat(future).isCompletedExceptionally();
verify(s3MetaRequest).close();
verify(sdkResponseHandler).onHeaders(any(SdkHttpResponse.class));
}

@Test
public void errorResponse_shouldCompleteFutureSuccessfully() {
int statusCode = 400;
responseHandlerAdapter.onResponseHeaders(statusCode, new HttpHeader[0]);

SdkHttpResponse actualSdkHttpResponse = sdkResponseHandler.sdkHttpResponse;
assertThat(actualSdkHttpResponse.statusCode()).isEqualTo(400);
assertThat(actualSdkHttpResponse.headers()).isEmpty();

byte[] errorPayload = "errorResponse".getBytes(StandardCharsets.UTF_8);
stubOnResponseBody();

responseHandlerAdapter.onFinished(stubResponseContext(1, statusCode, errorPayload));

SdkHttpResponse actualSdkHttpResponse = sdkResponseHandler.sdkHttpResponse;
assertThat(actualSdkHttpResponse.statusCode()).isEqualTo(400);
assertThat(actualSdkHttpResponse.headers()).isEmpty();

assertThat(future).isCompleted();
verify(s3MetaRequest).close();
verify(sdkResponseHandler).onHeaders(any(SdkHttpResponse.class));
}

@Test
Expand Down Expand Up @@ -164,7 +168,7 @@ private void stubOnResponseBody() {
responseHandlerAdapter.onResponseBody(ByteBuffer.wrap("helloworld2".getBytes()), 1, 2);
}

private static final class TestResponseHandler implements SdkAsyncHttpResponseHandler {
private static class TestResponseHandler implements SdkAsyncHttpResponseHandler {
private SdkHttpResponse sdkHttpResponse;
private Throwable error;
@Override
Expand Down

0 comments on commit 694cef6

Please sign in to comment.