From 8a9550f1db04b33b3606602ba181d68377f763df Mon Sep 17 00:00:00 2001 From: sbiscigl Date: Wed, 26 Jul 2023 15:31:27 -0400 Subject: [PATCH] Add checksum param to transfer manager --- .../include/aws/transfer/TransferHandle.h | 3 + .../include/aws/transfer/TransferManager.h | 17 +++- .../source/transfer/TransferManager.cpp | 98 +++++++++++++++---- .../TransferTests.cpp | 34 +++---- 4 files changed, 114 insertions(+), 38 deletions(-) diff --git a/src/aws-cpp-sdk-transfer/include/aws/transfer/TransferHandle.h b/src/aws-cpp-sdk-transfer/include/aws/transfer/TransferHandle.h index 12dbfc5d2a1..f32a179502d 100644 --- a/src/aws-cpp-sdk-transfer/include/aws/transfer/TransferHandle.h +++ b/src/aws-cpp-sdk-transfer/include/aws/transfer/TransferHandle.h @@ -77,6 +77,8 @@ namespace Aws bool IsLastPart() { return m_lastPart; } void SetLastPart() { m_lastPart = true; } + Aws::String GetChecksum() const { return m_checksum; }; + void SetChecksum(const Aws::String& checksum) { m_checksum = checksum; } private: int m_partId; @@ -90,6 +92,7 @@ namespace Aws std::atomic m_downloadPartStream; std::atomic m_downloadBuffer; bool m_lastPart; + Aws::String m_checksum; }; using PartPointer = std::shared_ptr< PartState >; diff --git a/src/aws-cpp-sdk-transfer/include/aws/transfer/TransferManager.h b/src/aws-cpp-sdk-transfer/include/aws/transfer/TransferManager.h index 2a0d8e962bc..74c6417e0d0 100644 --- a/src/aws-cpp-sdk-transfer/include/aws/transfer/TransferManager.h +++ b/src/aws-cpp-sdk-transfer/include/aws/transfer/TransferManager.h @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -54,7 +55,7 @@ namespace Aws /** * When true, TransferManager will calculate the MD5 digest of the content being uploaded. * The digest is sent to S3 via an HTTP header enabling the service to perform integrity checks. - * This option is disabled by default. + * This option is disabled by default. Defer to checksumAlgorithm to use other checksum algorithms. */ bool computeContentMD5; /** @@ -116,6 +117,13 @@ namespace Aws * key/val of map entries will be key/val of query strings. */ Aws::Map customizedAccessLogTag; + + /** + * Set the Checksum Algorithm for the transfer manager to use for multipart + * upload. Defaults to CRC32. Will be overwritten to use MD5 if computeContentMD5 + * is set to true. + */ + Aws::S3::Model::ChecksumAlgorithm checksumAlgorithm = S3::Model::ChecksumAlgorithm::CRC32; }; /** @@ -328,6 +336,13 @@ namespace Aws void AddTask(std::shared_ptr handle); void RemoveTask(const std::shared_ptr& handle); + /** + * Sets the checksum on a Completed Part based on the state, and the algorithm selected. + * @param state The state of the completed part as tracker by the transfer manager. + * @param part The completed part of the MPU. + */ + void SetChecksumForAlgorithm(const std::shared_ptr state, Aws::S3::Model::CompletedPart &part); + static Aws::String DetermineFilePath(const Aws::String& directory, const Aws::String& prefix, const Aws::String& keyName); Aws::Utils::ExclusiveOwnershipResourceManager m_bufferManager; diff --git a/src/aws-cpp-sdk-transfer/source/transfer/TransferManager.cpp b/src/aws-cpp-sdk-transfer/source/transfer/TransferManager.cpp index 30073d0057b..4c1d4799203 100644 --- a/src/aws-cpp-sdk-transfer/source/transfer/TransferManager.cpp +++ b/src/aws-cpp-sdk-transfer/source/transfer/TransferManager.cpp @@ -9,7 +9,6 @@ #include #include #include -#include #include #include #include @@ -360,12 +359,16 @@ namespace Aws if (!isRetry) { - Aws::S3::Model::CreateMultipartUploadRequest createMultipartRequest = m_transferConfig.createMultipartUploadTemplate; - createMultipartRequest.SetCustomizedAccessLogTag(m_transferConfig.customizedAccessLogTag); - createMultipartRequest.WithBucket(handle->GetBucketName()); - createMultipartRequest.WithContentType(handle->GetContentType()); - createMultipartRequest.WithKey(handle->GetKey()); - createMultipartRequest.WithMetadata(handle->GetMetadata()); + Aws::S3::Model::CreateMultipartUploadRequest createMultipartRequest = m_transferConfig.createMultipartUploadTemplate + .WithCustomizedAccessLogTag(m_transferConfig.customizedAccessLogTag) + .WithBucket(handle->GetBucketName()) + .WithContentType(handle->GetContentType()) + .WithKey(handle->GetKey()) + .WithMetadata(handle->GetMetadata()); + + if (!m_transferConfig.computeContentMD5) { + createMultipartRequest.SetChecksumAlgorithm(m_transferConfig.checksumAlgorithm); + } auto createMultipartResponse = m_transferConfig.s3Client->CreateMultipartUpload(createMultipartRequest); if (createMultipartResponse.IsSuccess()) @@ -439,17 +442,22 @@ namespace Aws auto self = shared_from_this(); // keep transfer manager alive until all callbacks are finished. PartPointer partPtr = partsIter->second; - Aws::S3::Model::UploadPartRequest uploadPartRequest = m_transferConfig.uploadPartTemplate; - uploadPartRequest.SetCustomizedAccessLogTag(m_transferConfig.customizedAccessLogTag); - uploadPartRequest.SetContinueRequestHandler([handle](const Aws::Http::HttpRequest*) { return handle->ShouldContinue(); }); - uploadPartRequest.SetDataSentEventHandler([self, handle, partPtr](const Aws::Http::HttpRequest*, long long amount){ partPtr->OnDataTransferred(amount, handle); self->TriggerUploadProgressCallback(handle); }); - uploadPartRequest.SetRequestRetryHandler([partPtr](const AmazonWebServiceRequest&){ partPtr->Reset(); }); - uploadPartRequest.WithBucket(handle->GetBucketName()) + Aws::S3::Model::UploadPartRequest uploadPartRequest = m_transferConfig.uploadPartTemplate + .WithCustomizedAccessLogTag(m_transferConfig.customizedAccessLogTag) + .WithBucket(handle->GetBucketName()) .WithContentLength(static_cast(lengthToWrite)) .WithKey(handle->GetKey()) .WithPartNumber(partsIter->first) .WithUploadId(handle->GetMultiPartId()); + if (!m_transferConfig.computeContentMD5) { + uploadPartRequest.SetChecksumAlgorithm(m_transferConfig.checksumAlgorithm); + } + + uploadPartRequest.SetContinueRequestHandler([handle](const Aws::Http::HttpRequest*) { return handle->ShouldContinue(); }); + uploadPartRequest.SetDataSentEventHandler([self, handle, partPtr](const Aws::Http::HttpRequest*, long long amount){ partPtr->OnDataTransferred(amount, handle); self->TriggerUploadProgressCallback(handle); }); + uploadPartRequest.SetRequestRetryHandler([partPtr](const AmazonWebServiceRequest&){ partPtr->Reset(); }); + handle->AddPendingPart(partsIter->second); uploadPartRequest.SetBody(preallocatedStreamReader); @@ -457,6 +465,7 @@ namespace Aws if (m_transferConfig.computeContentMD5) { uploadPartRequest.SetContentMD5(Aws::Utils::HashingUtils::Base64Encode(Aws::Utils::HashingUtils::CalculateMD5(*preallocatedStreamReader))); } + auto asyncContext = Aws::MakeShared(CLASS_TAG); asyncContext->handle = handle; asyncContext->partState = partsIter->second; @@ -515,14 +524,16 @@ namespace Aws handle->AddPendingPart(partState); TriggerTransferStatusUpdatedCallback(handle); - auto putObjectRequest = m_transferConfig.putObjectTemplate; - putObjectRequest.SetCustomizedAccessLogTag(m_transferConfig.customizedAccessLogTag); - putObjectRequest.SetContinueRequestHandler([handle](const Aws::Http::HttpRequest*) { return handle->ShouldContinue(); }); - putObjectRequest.WithBucket(handle->GetBucketName()) + auto putObjectRequest = m_transferConfig.putObjectTemplate + .WithChecksumAlgorithm(m_transferConfig.checksumAlgorithm) + .WithBucket(handle->GetBucketName()) .WithKey(handle->GetKey()) .WithContentLength(static_cast(handle->GetBytesTotalSize())) .WithMetadata(handle->GetMetadata()); + putObjectRequest.SetCustomizedAccessLogTag(m_transferConfig.customizedAccessLogTag); + putObjectRequest.SetContinueRequestHandler([handle](const Aws::Http::HttpRequest*) { return handle->ShouldContinue(); }); + putObjectRequest.SetContentType(handle->GetContentType()); auto buffer = m_bufferManager.Acquire(); @@ -586,6 +597,26 @@ namespace Aws { if (handle->ShouldContinue()) { + partState->SetChecksum([&]() -> Aws::String { + if (m_transferConfig.checksumAlgorithm == S3::Model::ChecksumAlgorithm::CRC32) + { + return outcome.GetResult().GetChecksumCRC32(); + } + else if (m_transferConfig.checksumAlgorithm == S3::Model::ChecksumAlgorithm::CRC32C) + { + return outcome.GetResult().GetChecksumCRC32C(); + } + else if (m_transferConfig.checksumAlgorithm == S3::Model::ChecksumAlgorithm::SHA1) + { + return outcome.GetResult().GetChecksumSHA1(); + } + else if (m_transferConfig.checksumAlgorithm == S3::Model::ChecksumAlgorithm::SHA256) + { + return outcome.GetResult().GetChecksumSHA256(); + } + //Return empty checksum for not set. + return ""; + }()); handle->ChangePartToCompleted(partState, outcome.GetResult().GetETag()); AWS_LOGSTREAM_DEBUG(CLASS_TAG, "Transfer handle [" << handle->GetId() << " successfully uploaded Part: [" << partState->GetPartId() << "] to Bucket: [" @@ -624,15 +655,20 @@ namespace Aws if (pendingParts.size() == 0 && queuedParts.size() == 0 && handle->LockForCompletion()) { - if (failedParts.size() == 0 && handle->GetBytesTransferred() == handle->GetBytesTotalSize()) + if (failedParts.size() == 0 && (handle->GetBytesTransferred() >= handle->GetBytesTotalSize())) { Aws::S3::Model::CompletedMultipartUpload completedUpload; for (auto& part : handle->GetCompletedParts()) { - Aws::S3::Model::CompletedPart completedPart; - completedPart.WithPartNumber(part.first) + auto completedPart = Aws::S3::Model::CompletedPart() + .WithPartNumber(part.first) .WithETag(part.second->GetETag()); + + if(!m_transferConfig.computeContentMD5) { + SetChecksumForAlgorithm(part.second, completedPart); + } + completedUpload.AddParts(completedPart); } @@ -1419,5 +1455,27 @@ namespace Aws auto handle = CreateUploadFileHandle(fileStream.get(), bucketName, keyName, contentType, metadata, context, fileName); return SubmitUpload(handle); } + + void TransferManager::SetChecksumForAlgorithm(const std::shared_ptr state, + Aws::S3::Model::CompletedPart &part) + { + if (m_transferConfig.checksumAlgorithm == S3::Model::ChecksumAlgorithm::CRC32) + { + part.SetChecksumCRC32(state->GetChecksum()); + } + else if (m_transferConfig.checksumAlgorithm == S3::Model::ChecksumAlgorithm::CRC32C) + { + part.SetChecksumCRC32C(state->GetChecksum()); + } + else if (m_transferConfig.checksumAlgorithm == S3::Model::ChecksumAlgorithm::SHA1) + { + part.SetChecksumSHA1(state->GetChecksum()); + } + else if (m_transferConfig.checksumAlgorithm == S3::Model::ChecksumAlgorithm::SHA256) + { + part.SetChecksumSHA256(state->GetChecksum()); + } + // Set no checksum on part if none is specified + } } } diff --git a/tests/aws-cpp-sdk-transfer-tests/TransferTests.cpp b/tests/aws-cpp-sdk-transfer-tests/TransferTests.cpp index e17442ed3d4..990194184d8 100644 --- a/tests/aws-cpp-sdk-transfer-tests/TransferTests.cpp +++ b/tests/aws-cpp-sdk-transfer-tests/TransferTests.cpp @@ -866,7 +866,7 @@ TEST_P(TransferTests, TransferManager_EmptyFileTest) uint64_t fileSize = requestPtr->GetBytesTotalSize(); ASSERT_EQ(0u, fileSize); - ASSERT_EQ(fileSize, requestPtr->GetBytesTransferred()); + ASSERT_LE(fileSize, requestPtr->GetBytesTransferred()); ASSERT_TRUE(WaitForObjectToPropagate(GetTestBucketName(), RandomFileName.c_str())); @@ -944,7 +944,7 @@ TEST_P(TransferTests, TransferManager_SmallTest) uint64_t fileSize = requestPtr->GetBytesTotalSize(); ASSERT_EQ(fileSize, (SMALL_TEST_SIZE / testStrLen * testStrLen)); - ASSERT_EQ(fileSize, requestPtr->GetBytesTransferred()); + ASSERT_LE(fileSize, requestPtr->GetBytesTransferred()); ASSERT_TRUE(WaitForObjectToPropagate(GetTestBucketName(), RandomFileName.c_str())); @@ -983,7 +983,7 @@ TEST_P(TransferTests, TransferManager_ContentTest) uint64_t fileSize = requestPtr->GetBytesTotalSize(); ASSERT_EQ(fileSize, strlen(CONTENT_TEST_FILE_TEXT)); - ASSERT_EQ(fileSize, requestPtr->GetBytesTransferred()); + ASSERT_LE(fileSize, requestPtr->GetBytesTransferred()); ASSERT_TRUE(WaitForObjectToPropagate(GetTestBucketName(), CONTENT_FILE_KEY)); @@ -1153,7 +1153,7 @@ TEST_P(TransferTests, TransferManager_MediumTest) uint64_t fileSize = requestPtr->GetBytesTotalSize(); ASSERT_EQ(fileSize, MEDIUM_TEST_SIZE / testStrLen * testStrLen); - ASSERT_EQ(fileSize, requestPtr->GetBytesTransferred()); + ASSERT_LE(fileSize, requestPtr->GetBytesTransferred()); ASSERT_TRUE(WaitForObjectToPropagate(GetTestBucketName(), RandomFileName.c_str())); @@ -1201,7 +1201,7 @@ TEST_P(TransferTests, TransferManager_BigTest) uint64_t fileSize = requestPtr->GetBytesTotalSize(); ASSERT_EQ(fileSize, BIG_TEST_SIZE / testStrLen * testStrLen); - ASSERT_EQ(fileSize, requestPtr->GetBytesTransferred()); + ASSERT_LE(fileSize, requestPtr->GetBytesTransferred()); ASSERT_TRUE(WaitForObjectToPropagate(GetTestBucketName(), BIG_FILE_KEY)); @@ -1285,7 +1285,7 @@ TEST_P(TransferTests, TransferManager_MultipartTestWithStreamOffset) uint64_t fileSize = requestPtr->GetBytesTotalSize(); ASSERT_EQ(fileSize, BIG_TEST_SIZE / testStrLen * testStrLen - inputOffset); - ASSERT_EQ(fileSize, requestPtr->GetBytesTransferred()); + ASSERT_LE(fileSize, requestPtr->GetBytesTransferred()); ASSERT_TRUE(WaitForObjectToPropagate(GetTestBucketName(), BIG_FILE_KEY)); @@ -1335,7 +1335,7 @@ TEST_P(TransferTests, TransferManager_UnicodeFileNameTest) uint64_t fileSize = requestPtr->GetBytesTotalSize(); ASSERT_EQ(fileSize, MEDIUM_TEST_SIZE / testStrLen * testStrLen); - ASSERT_EQ(fileSize, requestPtr->GetBytesTransferred()); + ASSERT_LE(fileSize, requestPtr->GetBytesTransferred()); ASSERT_TRUE(WaitForObjectToPropagate(GetTestBucketName(), UNICODE_FILE_KEY)); @@ -1438,7 +1438,7 @@ TEST_P(TransferTests, TransferManager_CancelAndRetryUploadTest) ASSERT_TRUE(completedPartsStayedCompletedDuringRetry); ASSERT_STREQ("text/plain", requestPtr->GetContentType().c_str()); - ASSERT_EQ(fileSize, requestPtr->GetBytesTransferred()); + ASSERT_LE(fileSize, requestPtr->GetBytesTransferred()); listMultipartOutcome = m_s3Clients[GetParam()]->ListMultipartUploads(listMultipartRequest); @@ -1545,7 +1545,7 @@ TEST_P(TransferTests, TransferManager_AbortAndRetryUploadTest) ASSERT_EQ(30u, requestPtr->GetCompletedParts().size()); ASSERT_TRUE(completionCheckDone); ASSERT_FALSE(completedPartsStayedCompletedDuringRetry); - ASSERT_EQ(fileSize, requestPtr->GetBytesTransferred()); + ASSERT_LE(fileSize, requestPtr->GetBytesTransferred()); ASSERT_TRUE(WaitForObjectToPropagate(GetTestBucketName(), CANCEL_FILE_KEY)); @@ -1581,7 +1581,7 @@ TEST_P(TransferTests, TransferManager_MultiPartContentTest) ASSERT_EQ(TransferStatus::COMPLETED, requestPtr->GetStatus()); ASSERT_EQ(PARTS_IN_MEDIUM_TEST, requestPtr->GetCompletedParts().size()); // > 1 part - ASSERT_EQ(requestPtr->GetBytesTotalSize(), requestPtr->GetBytesTransferred()); + ASSERT_LE(requestPtr->GetBytesTotalSize(), requestPtr->GetBytesTransferred()); VerifyUploadedFile(*transferManager, multiPartContentFileName, @@ -1634,7 +1634,7 @@ TEST_P(TransferTests, TransferManager_MultiPartStreamableByteTest) ASSERT_EQ(TransferStatus::COMPLETED, requestPtr->GetStatus()); ASSERT_EQ(PARTS_IN_MEDIUM_TEST, requestPtr->GetCompletedParts().size()); // > 1 part - ASSERT_EQ(requestPtr->GetBytesTotalSize(), requestPtr->GetBytesTransferred()); + ASSERT_LE(requestPtr->GetBytesTotalSize(), requestPtr->GetBytesTransferred()); VerifyUploadedFile(*transferManager, multiPartContentFileName, @@ -1671,7 +1671,7 @@ TEST_P(TransferTests, TransferManager_SinglePartUploadWithMetadataTest) requestPtr->WaitUntilFinished(); ASSERT_EQ(TransferStatus::COMPLETED, requestPtr->GetStatus()); - ASSERT_EQ(requestPtr->GetBytesTotalSize(), requestPtr->GetBytesTransferred()); + ASSERT_LE(requestPtr->GetBytesTotalSize(), requestPtr->GetBytesTransferred()); ASSERT_TRUE(WaitForObjectToPropagate(GetTestBucketName(), RandomFileName.c_str())); @@ -1723,7 +1723,7 @@ TEST_P(TransferTests, MultipartUploadWithMetadataTest) requestPtr->WaitUntilFinished(); } ASSERT_EQ(TransferStatus::COMPLETED, requestPtr->GetStatus()); - ASSERT_EQ(requestPtr->GetBytesTotalSize(), requestPtr->GetBytesTransferred()); + ASSERT_LE(requestPtr->GetBytesTotalSize(), requestPtr->GetBytesTransferred()); ASSERT_TRUE(WaitForObjectToPropagate(GetTestBucketName(), RandomFileName.c_str())); @@ -1871,7 +1871,7 @@ TEST_P(TransferTests, TransferManager_CancelAndRetryDownloadTest) ASSERT_TRUE(completedPartsStayedCompletedDuringRetry); ASSERT_STREQ("text/plain", requestPtr->GetContentType().c_str()); - ASSERT_EQ(requestPtr->GetBytesTotalSize(), requestPtr->GetBytesTransferred()); + ASSERT_LE(requestPtr->GetBytesTotalSize(), requestPtr->GetBytesTransferred()); ASSERT_TRUE(AreFilesSame(downloadFileName, cancelTestFileName)); } @@ -1901,7 +1901,7 @@ TEST_P(TransferTests, TransferManager_SinglePartUploadWithComputeContentMd5Test) requestPtr->WaitUntilFinished(); ASSERT_FALSE(requestPtr->IsMultipart()); ASSERT_EQ(TransferStatus::COMPLETED, requestPtr->GetStatus()); - ASSERT_EQ(requestPtr->GetBytesTotalSize(), requestPtr->GetBytesTransferred()); + ASSERT_LE(requestPtr->GetBytesTotalSize(), requestPtr->GetBytesTransferred()); ASSERT_TRUE(WaitForObjectToPropagate(GetTestBucketName(), RandomFileName.c_str())); @@ -1956,7 +1956,7 @@ TEST_P(TransferTests, MultipartUploadWithComputeContentMd5Test) } ASSERT_TRUE(requestPtr->IsMultipart()); ASSERT_EQ(TransferStatus::COMPLETED, requestPtr->GetStatus()); - ASSERT_EQ(requestPtr->GetBytesTotalSize(), requestPtr->GetBytesTransferred()); + ASSERT_LE(requestPtr->GetBytesTotalSize(), requestPtr->GetBytesTransferred()); ASSERT_TRUE(WaitForObjectToPropagate(GetTestBucketName(), RandomFileName.c_str())); @@ -2022,7 +2022,7 @@ TEST_P(TransferTests, TransferManager_TemplatesTest) uint64_t fileSize = requestPtr->GetBytesTotalSize(); ASSERT_EQ(fileSize, MEDIUM_TEST_SIZE / testStrLen * testStrLen); - ASSERT_EQ(fileSize, requestPtr->GetBytesTransferred()); + ASSERT_LE(fileSize, requestPtr->GetBytesTransferred()); ASSERT_TRUE(WaitForObjectToPropagate(GetTestBucketName(), RandomFileName.c_str()));