Skip to content

Commit

Permalink
Fix ChecksumIntegrationTest (#4798)
Browse files Browse the repository at this point in the history
* Fix ChecksumIntegrationTest

- Some tests specificy a part number, but CRT may do a range get under the hood.
  S3 will throw an error if both a range and part number are specified. This is
an issue that needs to be fixed in CRT, but part number is not required in this
test, so removing it.

 - Rename test file to CrtCheckIntegrationTest so it gets added to CRT test
   suite

* Revert "Revert "Wait until response body or error body received to process request (#4786)""

This reverts commit 045bcc4.
  • Loading branch information
dagnir committed Dec 22, 2023
1 parent 3568ec6 commit aa51b12
Show file tree
Hide file tree
Showing 6 changed files with 237 additions and 42 deletions.
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@
<rxjava.version>2.2.21</rxjava.version>
<commons-codec.verion>1.15</commons-codec.verion>
<jmh.version>1.29</jmh.version>
<awscrt.version>0.29.1</awscrt.version>
<awscrt.version>0.29.2</awscrt.version>

<!--Test dependencies -->
<junit5.version>5.10.0</junit5.version>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,11 @@
import java.nio.file.Files;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import software.amazon.awssdk.core.ResponseBytes;
import software.amazon.awssdk.core.async.AsyncRequestBody;
import software.amazon.awssdk.core.async.AsyncResponseTransformer;
import software.amazon.awssdk.crt.CrtResource;
import software.amazon.awssdk.core.checksums.Algorithm;
import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.S3IntegrationTestBase;
import software.amazon.awssdk.services.s3.internal.crt.S3CrtAsyncClient;
Expand All @@ -37,15 +36,20 @@
import software.amazon.awssdk.services.s3.model.PutObjectTaggingResponse;
import software.amazon.awssdk.services.s3.model.Tag;
import software.amazon.awssdk.services.s3.model.Tagging;
import software.amazon.awssdk.services.s3.utils.ChecksumUtils;
import software.amazon.awssdk.testutils.RandomTempFile;
import software.amazon.awssdk.testutils.service.AwsTestBase;

public class ChecksumIntegrationTest extends S3IntegrationTestBase {
private static final String TEST_BUCKET = temporaryBucketName(ChecksumIntegrationTest.class);
public class CrtChecksumIntegrationTest extends S3IntegrationTestBase {
private static final String TEST_BUCKET = temporaryBucketName(CrtChecksumIntegrationTest.class);
private static final String TEST_KEY = "10mib_file.dat";
private static final int OBJ_SIZE = 10 * 1024 * 1024;

private static RandomTempFile testFile;

private static String testFileSha1;
private static String testFileCrc32;

private static S3AsyncClient s3Crt;

@BeforeAll
Expand All @@ -54,10 +58,15 @@ public static void setup() throws Exception {
S3IntegrationTestBase.createBucket(TEST_BUCKET);

testFile = new RandomTempFile(TEST_KEY, OBJ_SIZE);
testFileSha1 = ChecksumUtils.calculatedChecksum(testFile.toPath(), Algorithm.SHA1);
testFileCrc32 = ChecksumUtils.calculatedChecksum(testFile.toPath(), Algorithm.CRC32);

s3Crt = S3CrtAsyncClient.builder()
.credentialsProvider(AwsTestBase.CREDENTIALS_PROVIDER_CHAIN)
.region(S3IntegrationTestBase.DEFAULT_REGION)
// make sure we don't do a multipart upload, it will mess with validation against the precomputed
// checksums above
.thresholdInBytes(2L * OBJ_SIZE)
.build();
}

Expand All @@ -72,37 +81,33 @@ public static void teardown() throws IOException {
@Test
void noChecksumCustomization_crc32ShouldBeUsed() {
AsyncRequestBody body = AsyncRequestBody.fromFile(testFile.toPath());
PutObjectResponse putObjectResponse =
s3Crt.putObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY), body).join();
assertThat(putObjectResponse).isNotNull();
s3Crt.putObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY), body).join();

ResponseBytes<GetObjectResponse> getObjectResponseResponseBytes =
s3Crt.getObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY).partNumber(1), AsyncResponseTransformer.toBytes()).join();
s3Crt.getObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY), AsyncResponseTransformer.toBytes()).join();
String getObjectChecksum = getObjectResponseResponseBytes.response().checksumCRC32();
assertThat(getObjectChecksum).isNotNull();
assertThat(getObjectChecksum).isEqualTo(testFileCrc32);
}

@Test
void putObject_checksumProvidedInRequest_shouldTakePrecendence() {
AsyncRequestBody body = AsyncRequestBody.fromFile(testFile.toPath());
PutObjectResponse putObjectResponse =
s3Crt.putObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY).checksumAlgorithm(ChecksumAlgorithm.SHA1), body).join();
assertThat(putObjectResponse).isNotNull();
s3Crt.putObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY).checksumAlgorithm(ChecksumAlgorithm.SHA1), body).join();

ResponseBytes<GetObjectResponse> getObjectResponseResponseBytes =
s3Crt.getObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY).partNumber(1), AsyncResponseTransformer.toBytes()).join();
s3Crt.getObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY), AsyncResponseTransformer.toBytes()).join();
String getObjectChecksum = getObjectResponseResponseBytes.response().checksumSHA1();
assertThat(getObjectChecksum).isNotNull();
assertThat(getObjectChecksum).isEqualTo(testFileSha1);
}

@Test
void checksumDisabled_shouldNotPerformChecksumValidationByDefault() {

try (S3AsyncClient s3Crt = S3CrtAsyncClient.builder()
.credentialsProvider(AwsTestBase.CREDENTIALS_PROVIDER_CHAIN)
.region(S3IntegrationTestBase.DEFAULT_REGION)
.checksumValidationEnabled(Boolean.FALSE)
.build()) {
.credentialsProvider(AwsTestBase.CREDENTIALS_PROVIDER_CHAIN)
.region(S3IntegrationTestBase.DEFAULT_REGION)
.checksumValidationEnabled(Boolean.FALSE)
.build()) {
AsyncRequestBody body = AsyncRequestBody.fromFile(testFile.toPath());
PutObjectResponse putObjectResponse =
s3Crt.putObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY), body).join();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,13 @@ public final class S3CrtResponseHandlerAdapter implements S3MetaRequestResponseH

private final SimplePublisher<ByteBuffer> responsePublisher = new SimplePublisher<>();

private final SdkHttpResponse.Builder respBuilder = SdkHttpResponse.builder();
private final SdkHttpResponse.Builder initialHeadersResponse = SdkHttpResponse.builder();
private volatile S3MetaRequest metaRequest;

private final PublisherListener<S3MetaRequestProgress> progressListener;

private volatile boolean responseHandlingInitiated;

public S3CrtResponseHandlerAdapter(CompletableFuture<Void> executeFuture,
SdkAsyncHttpResponseHandler responseHandler,
PublisherListener<S3MetaRequestProgress> progressListener) {
Expand All @@ -60,17 +62,17 @@ public S3CrtResponseHandlerAdapter(CompletableFuture<Void> executeFuture,

@Override
public void onResponseHeaders(int statusCode, HttpHeader[] headers) {
for (HttpHeader h : headers) {
respBuilder.appendHeader(h.getName(), h.getValue());
}

respBuilder.statusCode(statusCode);
responseHandler.onHeaders(respBuilder.build());
responseHandler.onStream(responsePublisher);
// Note, we cannot call responseHandler.onHeaders() here because the response status code and headers may not represent
// whether the request has succeeded or not (e.g. if this is for a HeadObject call that CRT calls under the hood). We
// need to rely on onResponseBody/onFinished being called to determine this.
populateSdkHttpResponse(initialHeadersResponse, statusCode, headers);
}

@Override
public int onResponseBody(ByteBuffer bodyBytesIn, long objectRangeStart, long objectRangeEnd) {
// See reasoning in onResponseHeaders for why we call this here and not there.
initiateResponseHandling(initialHeadersResponse.build());

if (bodyBytesIn == null) {
failResponseHandlerAndFuture(new IllegalStateException("ByteBuffer delivered is null"));
return 0;
Expand Down Expand Up @@ -98,6 +100,10 @@ public void onFinished(S3FinishedResponseContext context) {
if (crtCode != CRT.AWS_CRT_SUCCESS) {
handleError(context);
} else {
// onResponseBody() is not invoked for responses with no content, so we may not have invoked
// SdkAsyncHttpResponseHandler#onHeaders yet.
// See also reasoning in onResponseHeaders for why we call this here and not there.
initiateResponseHandling(initialHeadersResponse.build());
onSuccessfulResponseComplete();
}
}
Expand Down Expand Up @@ -127,10 +133,14 @@ public void cancelRequest() {

private void handleError(S3FinishedResponseContext context) {
int crtCode = context.getErrorCode();
HttpHeader[] headers = context.getErrorHeaders();
int responseStatus = context.getResponseStatus();
byte[] errorPayload = context.getErrorPayload();

if (isErrorResponse(responseStatus) && errorPayload != null) {
SdkHttpResponse.Builder errorResponse = populateSdkHttpResponse(SdkHttpResponse.builder(),
responseStatus, headers);
initiateResponseHandling(errorResponse.build());
onErrorResponseComplete(errorPayload);
} else {
Throwable cause = context.getCause();
Expand All @@ -142,6 +152,14 @@ private void handleError(S3FinishedResponseContext context) {
}
}

private void initiateResponseHandling(SdkHttpResponse response) {
if (!responseHandlingInitiated) {
responseHandlingInitiated = true;
responseHandler.onHeaders(response);
responseHandler.onStream(responsePublisher);
}
}

private void onErrorResponseComplete(byte[] errorPayload) {
responsePublisher.send(ByteBuffer.wrap(errorPayload))
.thenRun(responsePublisher::complete)
Expand Down Expand Up @@ -176,6 +194,17 @@ public void onProgress(S3MetaRequestProgress progress) {
this.progressListener.subscriberOnNext(progress);
}

private static SdkHttpResponse.Builder populateSdkHttpResponse(SdkHttpResponse.Builder respBuilder,
int statusCode, HttpHeader[] headers) {
if (headers != null) {
for (HttpHeader h : headers) {
respBuilder.appendHeader(h.getName(), h.getValue());
}
}
respBuilder.statusCode(statusCode);
return respBuilder;
}

private static class NoOpPublisherListener implements PublisherListener<S3MetaRequestProgress> {
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://aws.amazon.com/apache2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

package software.amazon.awssdk.services.s3.crt;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

import com.github.tomakehurst.wiremock.WireMockServer;
import com.github.tomakehurst.wiremock.client.WireMock;
import com.github.tomakehurst.wiremock.core.WireMockConfiguration;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import software.amazon.awssdk.core.async.AsyncResponseTransformer;
import software.amazon.awssdk.crt.Log;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.model.S3Exception;

public class CrtDownloadErrorTest {
private static final String BUCKET = "my-bucket";
private static final String KEY = "my-key";
private static final WireMockServer WM = new WireMockServer(WireMockConfiguration.wireMockConfig().dynamicPort());
private S3AsyncClient s3;

@BeforeAll
public static void setup() {
WM.start();
// Execute this statement before constructing the SDK service client.
Log.initLoggingToStdout(Log.LogLevel.Trace);
}

@AfterAll
public static void teardown() {
WM.stop();
}

@AfterEach
public void methodTeardown() {
if (s3 != null) {
s3.close();
}
s3 = null;
}

@Test
public void getObject_headObjectOk_getObjectThrows_operationThrows() {
s3 = S3AsyncClient.crtBuilder()
.endpointOverride(URI.create("http://localhost:" + WM.port()))
.forcePathStyle(true)
.region(Region.US_EAST_1)
.build();

String path = String.format("/%s/%s", BUCKET, KEY);

WM.stubFor(WireMock.head(WireMock.urlPathEqualTo(path))
.willReturn(WireMock.aResponse()
.withStatus(200)
.withHeader("ETag", "etag")
.withHeader("Content-Length", "5")));

String errorContent = ""
+ "<Error>\n"
+ " <Code>AccessDenied</Code>\n"
+ " <Message>User does not have permission</Message>\n"
+ " <RequestId>request-id</RequestId>\n"
+ " <HostId>host-id</HostId>\n"
+ "</Error>";
WM.stubFor(WireMock.get(WireMock.urlPathEqualTo(path))
.willReturn(WireMock.aResponse()
.withStatus(403)
.withBody(errorContent)));

assertThatThrownBy(s3.getObject(r -> r.bucket(BUCKET).key(KEY), AsyncResponseTransformer.toBytes())::join)
.hasCauseInstanceOf(S3Exception.class)
.hasMessageContaining("User does not have permission")
.hasMessageContaining("Status Code: 403");
}

@Test
public void getObject_headObjectOk_getObjectOk_operationSucceeds() {
s3 = S3AsyncClient.crtBuilder()
.endpointOverride(URI.create("http://localhost:" + WM.port()))
.forcePathStyle(true)
.region(Region.US_EAST_1)
.build();

String path = String.format("/%s/%s", BUCKET, KEY);

byte[] content = "hello".getBytes(StandardCharsets.UTF_8);

WM.stubFor(WireMock.head(WireMock.urlPathEqualTo(path))
.willReturn(WireMock.aResponse()
.withStatus(200)
.withHeader("ETag", "etag")
.withHeader("Content-Length", Integer.toString(content.length))));
WM.stubFor(WireMock.get(WireMock.urlPathEqualTo(path))
.willReturn(WireMock.aResponse()
.withStatus(200)
.withHeader("Content-Type", "text/plain")
.withBody(content)));

String objectContent = s3.getObject(r -> r.bucket(BUCKET).key(KEY), AsyncResponseTransformer.toBytes())
.join()
.asUtf8String();

assertThat(objectContent.getBytes(StandardCharsets.UTF_8)).isEqualTo(content);
}

@Test
public void getObject_headObjectThrows_operationThrows() {
s3 = S3AsyncClient.crtBuilder()
.endpointOverride(URI.create("http://localhost:" + WM.port()))
.forcePathStyle(true)
.region(Region.US_EAST_1)
.build();

String path = String.format("/%s/%s", BUCKET, KEY);


WM.stubFor(WireMock.head(WireMock.urlPathEqualTo(path))
.willReturn(WireMock.aResponse()
.withStatus(403)));

assertThatThrownBy(s3.getObject(r -> r.bucket(BUCKET).key(KEY), AsyncResponseTransformer.toBytes())::join)
.hasCauseInstanceOf(S3Exception.class)
.hasMessageContaining("Status Code: 403");
}
}
Loading

0 comments on commit aa51b12

Please sign in to comment.