Skip to content

Commit

Permalink
Add checksum param to transfer manager
Browse files Browse the repository at this point in the history
  • Loading branch information
sbiscigl committed Jul 31, 2023
1 parent c84dab7 commit 8a9550f
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ namespace Aws
bool IsLastPart() { return m_lastPart; }
void SetLastPart() { m_lastPart = true; }

Aws::String GetChecksum() const { return m_checksum; };
void SetChecksum(const Aws::String& checksum) { m_checksum = checksum; }
private:

int m_partId;
Expand All @@ -90,6 +92,7 @@ namespace Aws
std::atomic<Aws::IOStream *> m_downloadPartStream;
std::atomic<unsigned char*> m_downloadBuffer;
bool m_lastPart;
Aws::String m_checksum;
};

using PartPointer = std::shared_ptr< PartState >;
Expand Down
17 changes: 16 additions & 1 deletion src/aws-cpp-sdk-transfer/include/aws/transfer/TransferManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <aws/s3/model/GetObjectRequest.h>
#include <aws/s3/model/CreateMultipartUploadRequest.h>
#include <aws/s3/model/UploadPartRequest.h>
#include <aws/s3/model/CompletedPart.h>
#include <aws/core/utils/threading/Executor.h>
#include <aws/core/utils/memory/stl/AWSStreamFwd.h>
#include <aws/core/utils/ResourceManager.h>
Expand Down Expand Up @@ -54,7 +55,7 @@ namespace Aws
/**
* When true, TransferManager will calculate the MD5 digest of the content being uploaded.
* The digest is sent to S3 via an HTTP header enabling the service to perform integrity checks.
* This option is disabled by default.
* This option is disabled by default. Defer to checksumAlgorithm to use other checksum algorithms.
*/
bool computeContentMD5;
/**
Expand Down Expand Up @@ -116,6 +117,13 @@ namespace Aws
* key/val of map entries will be key/val of query strings.
*/
Aws::Map<Aws::String, Aws::String> customizedAccessLogTag;

/**
* Set the Checksum Algorithm for the transfer manager to use for multipart
* upload. Defaults to CRC32. Will be overwritten to use MD5 if computeContentMD5
* is set to true.
*/
Aws::S3::Model::ChecksumAlgorithm checksumAlgorithm = S3::Model::ChecksumAlgorithm::CRC32;
};

/**
Expand Down Expand Up @@ -328,6 +336,13 @@ namespace Aws
void AddTask(std::shared_ptr<TransferHandle> handle);
void RemoveTask(const std::shared_ptr<TransferHandle>& handle);

/**
* Sets the checksum on a Completed Part based on the state, and the algorithm selected.
* @param state The state of the completed part as tracker by the transfer manager.
* @param part The completed part of the MPU.
*/
void SetChecksumForAlgorithm(const std::shared_ptr<PartState> state, Aws::S3::Model::CompletedPart &part);

static Aws::String DetermineFilePath(const Aws::String& directory, const Aws::String& prefix, const Aws::String& keyName);

Aws::Utils::ExclusiveOwnershipResourceManager<unsigned char*> m_bufferManager;
Expand Down
98 changes: 78 additions & 20 deletions src/aws-cpp-sdk-transfer/source/transfer/TransferManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
#include <aws/core/utils/stream/PreallocatedStreamBuf.h>
#include <aws/core/utils/memory/stl/AWSStringStream.h>
#include <aws/core/utils/HashingUtils.h>
#include <aws/core/utils/FileSystemUtils.h>
#include <aws/core/platform/FileSystem.h>
#include <aws/s3/S3Client.h>
#include <aws/s3/model/HeadObjectRequest.h>
Expand Down Expand Up @@ -360,12 +359,16 @@ namespace Aws

if (!isRetry)
{
Aws::S3::Model::CreateMultipartUploadRequest createMultipartRequest = m_transferConfig.createMultipartUploadTemplate;
createMultipartRequest.SetCustomizedAccessLogTag(m_transferConfig.customizedAccessLogTag);
createMultipartRequest.WithBucket(handle->GetBucketName());
createMultipartRequest.WithContentType(handle->GetContentType());
createMultipartRequest.WithKey(handle->GetKey());
createMultipartRequest.WithMetadata(handle->GetMetadata());
Aws::S3::Model::CreateMultipartUploadRequest createMultipartRequest = m_transferConfig.createMultipartUploadTemplate
.WithCustomizedAccessLogTag(m_transferConfig.customizedAccessLogTag)
.WithBucket(handle->GetBucketName())
.WithContentType(handle->GetContentType())
.WithKey(handle->GetKey())
.WithMetadata(handle->GetMetadata());

if (!m_transferConfig.computeContentMD5) {
createMultipartRequest.SetChecksumAlgorithm(m_transferConfig.checksumAlgorithm);
}

auto createMultipartResponse = m_transferConfig.s3Client->CreateMultipartUpload(createMultipartRequest);
if (createMultipartResponse.IsSuccess())
Expand Down Expand Up @@ -439,24 +442,30 @@ namespace Aws

auto self = shared_from_this(); // keep transfer manager alive until all callbacks are finished.
PartPointer partPtr = partsIter->second;
Aws::S3::Model::UploadPartRequest uploadPartRequest = m_transferConfig.uploadPartTemplate;
uploadPartRequest.SetCustomizedAccessLogTag(m_transferConfig.customizedAccessLogTag);
uploadPartRequest.SetContinueRequestHandler([handle](const Aws::Http::HttpRequest*) { return handle->ShouldContinue(); });
uploadPartRequest.SetDataSentEventHandler([self, handle, partPtr](const Aws::Http::HttpRequest*, long long amount){ partPtr->OnDataTransferred(amount, handle); self->TriggerUploadProgressCallback(handle); });
uploadPartRequest.SetRequestRetryHandler([partPtr](const AmazonWebServiceRequest&){ partPtr->Reset(); });
uploadPartRequest.WithBucket(handle->GetBucketName())
Aws::S3::Model::UploadPartRequest uploadPartRequest = m_transferConfig.uploadPartTemplate
.WithCustomizedAccessLogTag(m_transferConfig.customizedAccessLogTag)
.WithBucket(handle->GetBucketName())
.WithContentLength(static_cast<long long>(lengthToWrite))
.WithKey(handle->GetKey())
.WithPartNumber(partsIter->first)
.WithUploadId(handle->GetMultiPartId());

if (!m_transferConfig.computeContentMD5) {
uploadPartRequest.SetChecksumAlgorithm(m_transferConfig.checksumAlgorithm);
}

uploadPartRequest.SetContinueRequestHandler([handle](const Aws::Http::HttpRequest*) { return handle->ShouldContinue(); });
uploadPartRequest.SetDataSentEventHandler([self, handle, partPtr](const Aws::Http::HttpRequest*, long long amount){ partPtr->OnDataTransferred(amount, handle); self->TriggerUploadProgressCallback(handle); });
uploadPartRequest.SetRequestRetryHandler([partPtr](const AmazonWebServiceRequest&){ partPtr->Reset(); });

handle->AddPendingPart(partsIter->second);

uploadPartRequest.SetBody(preallocatedStreamReader);
uploadPartRequest.SetContentType(handle->GetContentType());
if (m_transferConfig.computeContentMD5) {
uploadPartRequest.SetContentMD5(Aws::Utils::HashingUtils::Base64Encode(Aws::Utils::HashingUtils::CalculateMD5(*preallocatedStreamReader)));
}

auto asyncContext = Aws::MakeShared<TransferHandleAsyncContext>(CLASS_TAG);
asyncContext->handle = handle;
asyncContext->partState = partsIter->second;
Expand Down Expand Up @@ -515,14 +524,16 @@ namespace Aws
handle->AddPendingPart(partState);
TriggerTransferStatusUpdatedCallback(handle);

auto putObjectRequest = m_transferConfig.putObjectTemplate;
putObjectRequest.SetCustomizedAccessLogTag(m_transferConfig.customizedAccessLogTag);
putObjectRequest.SetContinueRequestHandler([handle](const Aws::Http::HttpRequest*) { return handle->ShouldContinue(); });
putObjectRequest.WithBucket(handle->GetBucketName())
auto putObjectRequest = m_transferConfig.putObjectTemplate
.WithChecksumAlgorithm(m_transferConfig.checksumAlgorithm)
.WithBucket(handle->GetBucketName())
.WithKey(handle->GetKey())
.WithContentLength(static_cast<long long>(handle->GetBytesTotalSize()))
.WithMetadata(handle->GetMetadata());

putObjectRequest.SetCustomizedAccessLogTag(m_transferConfig.customizedAccessLogTag);
putObjectRequest.SetContinueRequestHandler([handle](const Aws::Http::HttpRequest*) { return handle->ShouldContinue(); });

putObjectRequest.SetContentType(handle->GetContentType());

auto buffer = m_bufferManager.Acquire();
Expand Down Expand Up @@ -586,6 +597,26 @@ namespace Aws
{
if (handle->ShouldContinue())
{
partState->SetChecksum([&]() -> Aws::String {
if (m_transferConfig.checksumAlgorithm == S3::Model::ChecksumAlgorithm::CRC32)
{
return outcome.GetResult().GetChecksumCRC32();
}
else if (m_transferConfig.checksumAlgorithm == S3::Model::ChecksumAlgorithm::CRC32C)
{
return outcome.GetResult().GetChecksumCRC32C();
}
else if (m_transferConfig.checksumAlgorithm == S3::Model::ChecksumAlgorithm::SHA1)
{
return outcome.GetResult().GetChecksumSHA1();
}
else if (m_transferConfig.checksumAlgorithm == S3::Model::ChecksumAlgorithm::SHA256)
{
return outcome.GetResult().GetChecksumSHA256();
}
//Return empty checksum for not set.
return "";
}());
handle->ChangePartToCompleted(partState, outcome.GetResult().GetETag());
AWS_LOGSTREAM_DEBUG(CLASS_TAG, "Transfer handle [" << handle->GetId()
<< " successfully uploaded Part: [" << partState->GetPartId() << "] to Bucket: ["
Expand Down Expand Up @@ -624,15 +655,20 @@ namespace Aws

if (pendingParts.size() == 0 && queuedParts.size() == 0 && handle->LockForCompletion())
{
if (failedParts.size() == 0 && handle->GetBytesTransferred() == handle->GetBytesTotalSize())
if (failedParts.size() == 0 && (handle->GetBytesTransferred() >= handle->GetBytesTotalSize()))
{
Aws::S3::Model::CompletedMultipartUpload completedUpload;

for (auto& part : handle->GetCompletedParts())
{
Aws::S3::Model::CompletedPart completedPart;
completedPart.WithPartNumber(part.first)
auto completedPart = Aws::S3::Model::CompletedPart()
.WithPartNumber(part.first)
.WithETag(part.second->GetETag());

if(!m_transferConfig.computeContentMD5) {
SetChecksumForAlgorithm(part.second, completedPart);
}

completedUpload.AddParts(completedPart);
}

Expand Down Expand Up @@ -1419,5 +1455,27 @@ namespace Aws
auto handle = CreateUploadFileHandle(fileStream.get(), bucketName, keyName, contentType, metadata, context, fileName);
return SubmitUpload(handle);
}

void TransferManager::SetChecksumForAlgorithm(const std::shared_ptr<Aws::Transfer::PartState> state,
Aws::S3::Model::CompletedPart &part)
{
if (m_transferConfig.checksumAlgorithm == S3::Model::ChecksumAlgorithm::CRC32)
{
part.SetChecksumCRC32(state->GetChecksum());
}
else if (m_transferConfig.checksumAlgorithm == S3::Model::ChecksumAlgorithm::CRC32C)
{
part.SetChecksumCRC32C(state->GetChecksum());
}
else if (m_transferConfig.checksumAlgorithm == S3::Model::ChecksumAlgorithm::SHA1)
{
part.SetChecksumSHA1(state->GetChecksum());
}
else if (m_transferConfig.checksumAlgorithm == S3::Model::ChecksumAlgorithm::SHA256)
{
part.SetChecksumSHA256(state->GetChecksum());
}
// Set no checksum on part if none is specified
}
}
}
Loading

0 comments on commit 8a9550f

Please sign in to comment.