Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
davidh44 committed Feb 8, 2024
1 parent 37be109 commit ec9a173
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,9 @@ public void beforeTransmission(Context.BeforeTransmission context, ExecutionAttr
Map<String, List<String>> headers = sdkHttpRequest.headers();
String checksumHeaderName = "x-amz-checksum-algorithm";
if (headers.containsKey(checksumHeaderName)) {
checksumHeader = headers.get(checksumHeaderName).get(0);
List<String> checksumHeaderVals = headers.get(checksumHeaderName);
assertThat(checksumHeaderVals).hasSize(1);
checksumHeader = checksumHeaderVals.get(0);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import java.nio.ByteBuffer;
import java.nio.charset.Charset;
import java.nio.file.Files;
import java.security.MessageDigest;
import java.security.SecureRandom;
import java.util.Base64;
import java.util.List;
Expand Down Expand Up @@ -62,16 +63,14 @@ public class S3MultipartClientPutObjectIntegrationTest extends S3IntegrationTest
private static final String TEST_KEY = "testfile.dat";
private static final int OBJ_SIZE = 19 * 1024 * 1024;
private static final CapturingInterceptor CAPTURING_INTERCEPTOR = new CapturingInterceptor();
private static final byte[] CONTENT = RandomStringUtils.randomAscii(OBJ_SIZE).getBytes(Charset.defaultCharset());
private static File testFile;
private static S3AsyncClient mpuS3Client;

@BeforeAll
public static void setup() throws Exception {
S3IntegrationTestBase.setUp();
S3IntegrationTestBase.createBucket(TEST_BUCKET);
byte[] CONTENT =
RandomStringUtils.randomAscii(OBJ_SIZE).getBytes(Charset.defaultCharset());

testFile = File.createTempFile("SplittingPublisherTest", UUID.randomUUID().toString());
Files.write(testFile.toPath(), CONTENT);
mpuS3Client = S3AsyncClient
Expand Down Expand Up @@ -186,6 +185,45 @@ void putObject_withSSECAndChecksum_objectSentCorrectly() throws Exception {
assertThat(ChecksumUtils.computeCheckSum(objContent)).isEqualTo(expectedSum);
}

@Test
void putObject_withUserSpecifiedChecksumValue_objectSentCorrectly() throws Exception {
String sha1Val = calculateSHA1AsString();
AsyncRequestBody body = AsyncRequestBody.fromFile(testFile.toPath());
mpuS3Client.putObject(r -> r.bucket(TEST_BUCKET)
.key(TEST_KEY)
.checksumSHA1(sha1Val),
body).join();

assertThat(CAPTURING_INTERCEPTOR.headers.get("x-amz-checksum-sha1")).contains(sha1Val);
assertThat(CAPTURING_INTERCEPTOR.checksumHeader).isNull();

ResponseInputStream<GetObjectResponse> objContent =
S3IntegrationTestBase.s3.getObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY),
ResponseTransformer.toInputStream());

assertThat(objContent.response().contentLength()).isEqualTo(testFile.length());
byte[] expectedSum = ChecksumUtils.computeCheckSum(Files.newInputStream(testFile.toPath()));
assertThat(ChecksumUtils.computeCheckSum(objContent)).isEqualTo(expectedSum);
}

@Test
void putObject_withUserSpecifiedChecksumTypeOtherThanCrc32_shouldHonorChecksum() {
AsyncRequestBody body = AsyncRequestBody.fromFile(testFile.toPath());
mpuS3Client.putObject(r -> r.bucket(TEST_BUCKET)
.key(TEST_KEY)
.checksumAlgorithm(ChecksumAlgorithm.SHA1),
body).join();

assertThat(CAPTURING_INTERCEPTOR.checksumHeader).isEqualTo("SHA1");
}

private static String calculateSHA1AsString() throws Exception {
MessageDigest md = MessageDigest.getInstance("SHA-1");
md.update(CONTENT);
byte[] checksum = md.digest();
return Base64.getEncoder().encodeToString(checksum);
}

private static byte[] generateSecretKey() {
KeyGenerator generator;
try {
Expand All @@ -198,16 +236,17 @@ private static byte[] generateSecretKey() {
}

private static final class CapturingInterceptor implements ExecutionInterceptor {
private String checksumHeader;
String checksumHeader;
Map<String, List<String>> headers;
@Override
public void beforeTransmission(Context.BeforeTransmission context, ExecutionAttributes executionAttributes) {
SdkHttpRequest sdkHttpRequest = context.httpRequest();
Map<String, List<String>> headers = sdkHttpRequest.headers();
headers = sdkHttpRequest.headers();
String checksumHeaderName = "x-amz-sdk-checksum-algorithm";
if (headers.containsKey(checksumHeaderName)) {
checksumHeader = headers.get(checksumHeaderName).get(0);

System.out.println(headers);
List<String> checksumHeaderVals = headers.get(checksumHeaderName);
assertThat(checksumHeaderVals).hasSize(1);
checksumHeader = checksumHeaderVals.get(0);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

package software.amazon.awssdk.services.s3.internal.multipart;

import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
Expand Down Expand Up @@ -47,13 +48,16 @@
public final class SdkPojoConversionUtils {
private static final Logger log = Logger.loggerFor(SdkPojoConversionUtils.class);

private static final HashSet<String> PUT_OBJECT_REQUEST_TO_UPLOAD_PART_FIELDS_TO_IGNORE =
new HashSet<>(Arrays.asList("ChecksumSHA1", "ChecksumSHA256", "ContentMD5", "ChecksumCRC32C", "ChecksumCRC32"));

private SdkPojoConversionUtils() {
}

public static UploadPartRequest toUploadPartRequest(PutObjectRequest putObjectRequest, int partNumber, String uploadId) {

UploadPartRequest.Builder builder = UploadPartRequest.builder();
setSdkFields(builder, putObjectRequest);
setSdkFields(builder, putObjectRequest, PUT_OBJECT_REQUEST_TO_UPLOAD_PART_FIELDS_TO_IGNORE);
return builder.uploadId(uploadId).partNumber(partNumber).build();
}

Expand Down

0 comments on commit ec9a173

Please sign in to comment.