Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ChecksumIntegrationTest #4798

Merged
merged 2 commits into from
Dec 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@
<rxjava.version>2.2.21</rxjava.version>
<commons-codec.verion>1.15</commons-codec.verion>
<jmh.version>1.29</jmh.version>
<awscrt.version>0.29.1</awscrt.version>
<awscrt.version>0.29.2</awscrt.version>

<!--Test dependencies -->
<junit5.version>5.10.0</junit5.version>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,11 @@
import java.nio.file.Files;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import software.amazon.awssdk.core.ResponseBytes;
import software.amazon.awssdk.core.async.AsyncRequestBody;
import software.amazon.awssdk.core.async.AsyncResponseTransformer;
import software.amazon.awssdk.crt.CrtResource;
import software.amazon.awssdk.core.checksums.Algorithm;
import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.S3IntegrationTestBase;
import software.amazon.awssdk.services.s3.internal.crt.S3CrtAsyncClient;
Expand All @@ -37,15 +36,20 @@
import software.amazon.awssdk.services.s3.model.PutObjectTaggingResponse;
import software.amazon.awssdk.services.s3.model.Tag;
import software.amazon.awssdk.services.s3.model.Tagging;
import software.amazon.awssdk.services.s3.utils.ChecksumUtils;
import software.amazon.awssdk.testutils.RandomTempFile;
import software.amazon.awssdk.testutils.service.AwsTestBase;

public class ChecksumIntegrationTest extends S3IntegrationTestBase {
private static final String TEST_BUCKET = temporaryBucketName(ChecksumIntegrationTest.class);
public class CrtChecksumIntegrationTest extends S3IntegrationTestBase {
private static final String TEST_BUCKET = temporaryBucketName(CrtChecksumIntegrationTest.class);
private static final String TEST_KEY = "10mib_file.dat";
private static final int OBJ_SIZE = 10 * 1024 * 1024;

private static RandomTempFile testFile;

private static String testFileSha1;
private static String testFileCrc32;

private static S3AsyncClient s3Crt;

@BeforeAll
Expand All @@ -54,10 +58,15 @@ public static void setup() throws Exception {
S3IntegrationTestBase.createBucket(TEST_BUCKET);

testFile = new RandomTempFile(TEST_KEY, OBJ_SIZE);
testFileSha1 = ChecksumUtils.calculatedChecksum(testFile.toPath(), Algorithm.SHA1);
testFileCrc32 = ChecksumUtils.calculatedChecksum(testFile.toPath(), Algorithm.CRC32);

s3Crt = S3CrtAsyncClient.builder()
.credentialsProvider(AwsTestBase.CREDENTIALS_PROVIDER_CHAIN)
.region(S3IntegrationTestBase.DEFAULT_REGION)
// make sure we don't do a multipart upload, it will mess with validation against the precomputed
// checksums above
.thresholdInBytes(2L * OBJ_SIZE)
.build();
}

Expand All @@ -72,37 +81,33 @@ public static void teardown() throws IOException {
@Test
void noChecksumCustomization_crc32ShouldBeUsed() {
AsyncRequestBody body = AsyncRequestBody.fromFile(testFile.toPath());
PutObjectResponse putObjectResponse =
s3Crt.putObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY), body).join();
assertThat(putObjectResponse).isNotNull();
s3Crt.putObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY), body).join();

ResponseBytes<GetObjectResponse> getObjectResponseResponseBytes =
s3Crt.getObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY).partNumber(1), AsyncResponseTransformer.toBytes()).join();
s3Crt.getObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY), AsyncResponseTransformer.toBytes()).join();
String getObjectChecksum = getObjectResponseResponseBytes.response().checksumCRC32();
assertThat(getObjectChecksum).isNotNull();
assertThat(getObjectChecksum).isEqualTo(testFileCrc32);
}

@Test
void putObject_checksumProvidedInRequest_shouldTakePrecendence() {
AsyncRequestBody body = AsyncRequestBody.fromFile(testFile.toPath());
PutObjectResponse putObjectResponse =
s3Crt.putObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY).checksumAlgorithm(ChecksumAlgorithm.SHA1), body).join();
assertThat(putObjectResponse).isNotNull();
s3Crt.putObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY).checksumAlgorithm(ChecksumAlgorithm.SHA1), body).join();

ResponseBytes<GetObjectResponse> getObjectResponseResponseBytes =
s3Crt.getObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY).partNumber(1), AsyncResponseTransformer.toBytes()).join();
s3Crt.getObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY), AsyncResponseTransformer.toBytes()).join();
String getObjectChecksum = getObjectResponseResponseBytes.response().checksumSHA1();
assertThat(getObjectChecksum).isNotNull();
assertThat(getObjectChecksum).isEqualTo(testFileSha1);
}

@Test
void checksumDisabled_shouldNotPerformChecksumValidationByDefault() {

try (S3AsyncClient s3Crt = S3CrtAsyncClient.builder()
.credentialsProvider(AwsTestBase.CREDENTIALS_PROVIDER_CHAIN)
.region(S3IntegrationTestBase.DEFAULT_REGION)
.checksumValidationEnabled(Boolean.FALSE)
.build()) {
.credentialsProvider(AwsTestBase.CREDENTIALS_PROVIDER_CHAIN)
.region(S3IntegrationTestBase.DEFAULT_REGION)
.checksumValidationEnabled(Boolean.FALSE)
.build()) {
AsyncRequestBody body = AsyncRequestBody.fromFile(testFile.toPath());
PutObjectResponse putObjectResponse =
s3Crt.putObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY), body).join();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,13 @@ public final class S3CrtResponseHandlerAdapter implements S3MetaRequestResponseH

private final SimplePublisher<ByteBuffer> responsePublisher = new SimplePublisher<>();

private final SdkHttpResponse.Builder respBuilder = SdkHttpResponse.builder();
private final SdkHttpResponse.Builder initialHeadersResponse = SdkHttpResponse.builder();
private volatile S3MetaRequest metaRequest;

private final PublisherListener<S3MetaRequestProgress> progressListener;

private volatile boolean responseHandlingInitiated;

public S3CrtResponseHandlerAdapter(CompletableFuture<Void> executeFuture,
SdkAsyncHttpResponseHandler responseHandler,
PublisherListener<S3MetaRequestProgress> progressListener) {
Expand All @@ -60,17 +62,17 @@ public S3CrtResponseHandlerAdapter(CompletableFuture<Void> executeFuture,

@Override
public void onResponseHeaders(int statusCode, HttpHeader[] headers) {
for (HttpHeader h : headers) {
respBuilder.appendHeader(h.getName(), h.getValue());
}

respBuilder.statusCode(statusCode);
responseHandler.onHeaders(respBuilder.build());
responseHandler.onStream(responsePublisher);
// Note, we cannot call responseHandler.onHeaders() here because the response status code and headers may not represent
// whether the request has succeeded or not (e.g. if this is for a HeadObject call that CRT calls under the hood). We
// need to rely on onResponseBody/onFinished being called to determine this.
populateSdkHttpResponse(initialHeadersResponse, statusCode, headers);
}

@Override
public int onResponseBody(ByteBuffer bodyBytesIn, long objectRangeStart, long objectRangeEnd) {
// See reasoning in onResponseHeaders for why we call this here and not there.
initiateResponseHandling(initialHeadersResponse.build());

if (bodyBytesIn == null) {
failResponseHandlerAndFuture(new IllegalStateException("ByteBuffer delivered is null"));
return 0;
Expand Down Expand Up @@ -98,6 +100,10 @@ public void onFinished(S3FinishedResponseContext context) {
if (crtCode != CRT.AWS_CRT_SUCCESS) {
handleError(context);
} else {
// onResponseBody() is not invoked for responses with no content, so we may not have invoked
// SdkAsyncHttpResponseHandler#onHeaders yet.
// See also reasoning in onResponseHeaders for why we call this here and not there.
initiateResponseHandling(initialHeadersResponse.build());
onSuccessfulResponseComplete();
}
}
Expand Down Expand Up @@ -127,10 +133,14 @@ public void cancelRequest() {

private void handleError(S3FinishedResponseContext context) {
int crtCode = context.getErrorCode();
HttpHeader[] headers = context.getErrorHeaders();
int responseStatus = context.getResponseStatus();
byte[] errorPayload = context.getErrorPayload();

if (isErrorResponse(responseStatus) && errorPayload != null) {
SdkHttpResponse.Builder errorResponse = populateSdkHttpResponse(SdkHttpResponse.builder(),
responseStatus, headers);
initiateResponseHandling(errorResponse.build());
onErrorResponseComplete(errorPayload);
} else {
Throwable cause = context.getCause();
Expand All @@ -142,6 +152,14 @@ private void handleError(S3FinishedResponseContext context) {
}
}

private void initiateResponseHandling(SdkHttpResponse response) {
if (!responseHandlingInitiated) {
responseHandlingInitiated = true;
responseHandler.onHeaders(response);
responseHandler.onStream(responsePublisher);
}
}

private void onErrorResponseComplete(byte[] errorPayload) {
responsePublisher.send(ByteBuffer.wrap(errorPayload))
.thenRun(responsePublisher::complete)
Expand Down Expand Up @@ -176,6 +194,17 @@ public void onProgress(S3MetaRequestProgress progress) {
this.progressListener.subscriberOnNext(progress);
}

private static SdkHttpResponse.Builder populateSdkHttpResponse(SdkHttpResponse.Builder respBuilder,
int statusCode, HttpHeader[] headers) {
if (headers != null) {
for (HttpHeader h : headers) {
respBuilder.appendHeader(h.getName(), h.getValue());
}
}
respBuilder.statusCode(statusCode);
return respBuilder;
}

private static class NoOpPublisherListener implements PublisherListener<S3MetaRequestProgress> {
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://aws.amazon.com/apache2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

package software.amazon.awssdk.services.s3.crt;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

import com.github.tomakehurst.wiremock.WireMockServer;
import com.github.tomakehurst.wiremock.client.WireMock;
import com.github.tomakehurst.wiremock.core.WireMockConfiguration;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import software.amazon.awssdk.core.async.AsyncResponseTransformer;
import software.amazon.awssdk.crt.Log;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.model.S3Exception;

public class CrtDownloadErrorTest {
private static final String BUCKET = "my-bucket";
private static final String KEY = "my-key";
private static final WireMockServer WM = new WireMockServer(WireMockConfiguration.wireMockConfig().dynamicPort());
private S3AsyncClient s3;

@BeforeAll
public static void setup() {
WM.start();
// Execute this statement before constructing the SDK service client.
Log.initLoggingToStdout(Log.LogLevel.Trace);
}

@AfterAll
public static void teardown() {
WM.stop();
}

@AfterEach
public void methodTeardown() {
if (s3 != null) {
s3.close();
}
s3 = null;
}

@Test
public void getObject_headObjectOk_getObjectThrows_operationThrows() {
s3 = S3AsyncClient.crtBuilder()
.endpointOverride(URI.create("http://localhost:" + WM.port()))
.forcePathStyle(true)
.region(Region.US_EAST_1)
.build();

String path = String.format("/%s/%s", BUCKET, KEY);

WM.stubFor(WireMock.head(WireMock.urlPathEqualTo(path))
.willReturn(WireMock.aResponse()
.withStatus(200)
.withHeader("ETag", "etag")
.withHeader("Content-Length", "5")));

String errorContent = ""
+ "<Error>\n"
+ " <Code>AccessDenied</Code>\n"
+ " <Message>User does not have permission</Message>\n"
+ " <RequestId>request-id</RequestId>\n"
+ " <HostId>host-id</HostId>\n"
+ "</Error>";
WM.stubFor(WireMock.get(WireMock.urlPathEqualTo(path))
.willReturn(WireMock.aResponse()
.withStatus(403)
.withBody(errorContent)));

assertThatThrownBy(s3.getObject(r -> r.bucket(BUCKET).key(KEY), AsyncResponseTransformer.toBytes())::join)
.hasCauseInstanceOf(S3Exception.class)
.hasMessageContaining("User does not have permission")
.hasMessageContaining("Status Code: 403");
}

@Test
public void getObject_headObjectOk_getObjectOk_operationSucceeds() {
s3 = S3AsyncClient.crtBuilder()
.endpointOverride(URI.create("http://localhost:" + WM.port()))
.forcePathStyle(true)
.region(Region.US_EAST_1)
.build();

String path = String.format("/%s/%s", BUCKET, KEY);

byte[] content = "hello".getBytes(StandardCharsets.UTF_8);

WM.stubFor(WireMock.head(WireMock.urlPathEqualTo(path))
.willReturn(WireMock.aResponse()
.withStatus(200)
.withHeader("ETag", "etag")
.withHeader("Content-Length", Integer.toString(content.length))));
WM.stubFor(WireMock.get(WireMock.urlPathEqualTo(path))
.willReturn(WireMock.aResponse()
.withStatus(200)
.withHeader("Content-Type", "text/plain")
.withBody(content)));

String objectContent = s3.getObject(r -> r.bucket(BUCKET).key(KEY), AsyncResponseTransformer.toBytes())
.join()
.asUtf8String();

assertThat(objectContent.getBytes(StandardCharsets.UTF_8)).isEqualTo(content);
}

@Test
public void getObject_headObjectThrows_operationThrows() {
s3 = S3AsyncClient.crtBuilder()
.endpointOverride(URI.create("http://localhost:" + WM.port()))
.forcePathStyle(true)
.region(Region.US_EAST_1)
.build();

String path = String.format("/%s/%s", BUCKET, KEY);


WM.stubFor(WireMock.head(WireMock.urlPathEqualTo(path))
.willReturn(WireMock.aResponse()
.withStatus(403)));

assertThatThrownBy(s3.getObject(r -> r.bucket(BUCKET).key(KEY), AsyncResponseTransformer.toBytes())::join)
.hasCauseInstanceOf(S3Exception.class)
.hasMessageContaining("Status Code: 403");
}
}
Loading
Loading