Skip to content

Commit

Permalink
fix trailing checksum
Browse files Browse the repository at this point in the history
  • Loading branch information
sbiscigl committed Nov 12, 2024
1 parent e45f7f5 commit 061f67b
Show file tree
Hide file tree
Showing 6 changed files with 226 additions and 79 deletions.
File renamed without changes.
2 changes: 1 addition & 1 deletion .github/workflows/clang-format.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ jobs:
run: |
clang-format --version
if [ -s diff_output.patch ]; then
python3 clang-format-diff.py -p1 -style=file:.github/workflows/.clang-format < diff_output.patch > formatted_differences.patch 2> error.log || true
python3 clang-format-diff.py -p1 -style=file:.clang-format < diff_output.patch > formatted_differences.patch 2> error.log || true
if [ -s error.log ]; then
echo "Errors from clang-format-diff.py:"
cat error.log
Expand Down
103 changes: 103 additions & 0 deletions src/aws-cpp-sdk-core/include/aws/core/utils/stream/AwsChunkedStream.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
/**
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0.
*/
#pragma once
#include <aws/core/utils/HashingUtils.h>
#include <aws/core/utils/Outcome.h>
#include <aws/core/utils/crypto/Hash.h>

namespace Aws {
namespace Utils {
namespace Stream {

static const size_t AWS_DATA_BUFFER_SIZE = 65536;

template <size_t DataBufferSize = AWS_DATA_BUFFER_SIZE>
class AwsChunkedStream {
public:
AwsChunkedStream(Http::HttpRequest *request, const std::shared_ptr<Aws::IOStream> &stream)
: m_chunkingStream{Aws::MakeShared<StringStream>("AwsChunkedStream")}, m_request(request), m_stream(stream) {
assert(m_stream != nullptr);
if (m_stream == nullptr) {
AWS_LOGSTREAM_ERROR("AwsChunkedStream", "stream is null");
}
assert(request != nullptr);
if (request == nullptr) {
AWS_LOGSTREAM_ERROR("AwsChunkedStream", "request is null");
}
}

size_t BufferedRead(char *dst, size_t amountToRead) {
assert(dst != nullptr);
if (dst == nullptr) {
AWS_LOGSTREAM_ERROR("AwsChunkedStream", "dst is null");
}

// the chunk has ended and cannot be read from
if (m_chunkEnd) {
return 0;
}

// If we've read all of the underlying stream write the checksum trailing header
// the set that the chunked stream is over.
if (m_stream->eof() && !m_stream->bad() && (m_chunkingStream->eof() || m_chunkingStream->peek() == EOF)) {
return writeTrailer(dst, amountToRead);
}

// Try to read in a 64K chunk, if we cant we know the stream is over
size_t bytesRead = 0;
while (m_stream->good() && bytesRead < DataBufferSize) {
m_stream->read(&m_data[bytesRead], DataBufferSize - bytesRead);
bytesRead += static_cast<size_t>(m_stream->gcount());
}

if (bytesRead > 0) {
writeChunk(bytesRead);
}

// Read to destination buffer, return how much was read
m_chunkingStream->read(dst, amountToRead);
return static_cast<size_t>(m_chunkingStream->gcount());
}

private:
size_t writeTrailer(char *dst, size_t amountToRead) {
Aws::StringStream chunkedTrailerStream;
chunkedTrailerStream << "0\r\n";
if (m_request->GetRequestHash().second != nullptr) {
chunkedTrailerStream << "x-amz-checksum-" << m_request->GetRequestHash().first << ":"
<< HashingUtils::Base64Encode(m_request->GetRequestHash().second->GetHash().GetResult()) << "\r\n";
}
chunkedTrailerStream << "\r\n";
const auto chunkedTrailer = chunkedTrailerStream.str();
auto trailerSize = chunkedTrailer.size();
// unreferenced param for assert
AWS_UNREFERENCED_PARAM(amountToRead);
assert(amountToRead >= trailerSize);
memcpy(dst, chunkedTrailer.c_str(), trailerSize);
m_chunkEnd = true;
return trailerSize;
}

void writeChunk(size_t bytesRead) {
if (m_request->GetRequestHash().second != nullptr) {
m_request->GetRequestHash().second->Update(reinterpret_cast<unsigned char *>(m_data.GetUnderlyingData()), bytesRead);
}

if (m_chunkingStream != nullptr && !m_chunkingStream->bad()) {
*m_chunkingStream << Aws::Utils::StringUtils::ToHexString(bytesRead) << "\r\n";
m_chunkingStream->write(m_data.GetUnderlyingData(), bytesRead);
*m_chunkingStream << "\r\n";
}
}

Aws::Utils::Array<char> m_data{DataBufferSize};
std::shared_ptr<Aws::IOStream> m_chunkingStream;
bool m_chunkEnd{false};
Http::HttpRequest *m_request{nullptr};
std::shared_ptr<Aws::IOStream> m_stream;
};
} // namespace Stream
} // namespace Utils
} // namespace Aws
125 changes: 47 additions & 78 deletions src/aws-cpp-sdk-core/source/http/curl/CurlHttpClient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,29 @@
* SPDX-License-Identifier: Apache-2.0.
*/

#include <aws/core/http/curl/CurlHttpClient.h>
#include <aws/core/http/HttpRequest.h>
#include <aws/core/http/curl/CurlHttpClient.h>
#include <aws/core/http/standard/StandardHttpResponse.h>
#include <aws/core/utils/StringUtils.h>
#include <aws/core/monitoring/HttpClientMetrics.h>
#include <aws/core/utils/DateTime.h>
#include <aws/core/utils/HashingUtils.h>
#include <aws/core/utils/Outcome.h>
#include <aws/core/utils/StringUtils.h>
#include <aws/core/utils/crypto/Hash.h>
#include <aws/core/utils/logging/LogMacros.h>
#include <aws/core/utils/ratelimiter/RateLimiterInterface.h>
#include <aws/core/utils/DateTime.h>
#include <aws/core/utils/crypto/Hash.h>
#include <aws/core/utils/Outcome.h>
#include <aws/core/monitoring/HttpClientMetrics.h>
#include <cassert>
#include <aws/core/utils/stream/AwsChunkedStream.h>

#include <algorithm>
#include <cassert>
#include <thread>


using namespace Aws::Client;
using namespace Aws::Http;
using namespace Aws::Http::Standard;
using namespace Aws::Utils;
using namespace Aws::Utils::Logging;
using namespace Aws::Utils::Stream;
using namespace Aws::Monitoring;

#ifdef USE_AWS_MEMORY_MANAGEMENT
Expand Down Expand Up @@ -144,25 +146,28 @@ struct CurlWriteCallbackContext
int64_t m_numBytesResponseReceived;
};

static const char* CURL_HTTP_CLIENT_TAG = "CurlHttpClient";

struct CurlReadCallbackContext
{
CurlReadCallbackContext(const CurlHttpClient* client, CURL* curlHandle, HttpRequest* request, Aws::Utils::RateLimits::RateLimiterInterface* limiter) :
m_client(client),
CurlReadCallbackContext(const CurlHttpClient* client, CURL* curlHandle, HttpRequest* request,
Aws::Utils::RateLimits::RateLimiterInterface* limiter,
std::shared_ptr<AwsChunkedStream<>> chunkedStream = nullptr)
: m_client(client),
m_curlHandle(curlHandle),
m_rateLimiter(limiter),
m_request(request),
m_chunkEnd(false)
{}

const CurlHttpClient* m_client;
CURL* m_curlHandle;
Aws::Utils::RateLimits::RateLimiterInterface* m_rateLimiter;
HttpRequest* m_request;
bool m_chunkEnd;
m_chunkEnd(false),
m_chunkedStream{std::move(chunkedStream)} {}

const CurlHttpClient* m_client;
CURL* m_curlHandle;
Aws::Utils::RateLimits::RateLimiterInterface* m_rateLimiter;
HttpRequest* m_request;
bool m_chunkEnd;
std::shared_ptr<Stream::AwsChunkedStream<>> m_chunkedStream;
};

static const char* CURL_HTTP_CLIENT_TAG = "CurlHttpClient";

static int64_t GetContentLengthFromHeader(CURL* connectionHandle,
bool& hasContentLength) {
#if LIBCURL_VERSION_NUM >= 0x073700 // 7.55.0
Expand Down Expand Up @@ -293,67 +298,24 @@ static size_t ReadBody(char* ptr, size_t size, size_t nmemb, void* userdata, boo
size_t amountToRead = size * nmemb;
bool isAwsChunked = request->HasHeader(Aws::Http::CONTENT_ENCODING_HEADER) &&
request->GetHeaderValue(Aws::Http::CONTENT_ENCODING_HEADER) == Aws::Http::AWS_CHUNKED_VALUE;
// aws-chunk = hex(chunk-size) + CRLF + chunk-data + CRLF
// Needs to reserve bytes of sizeof(hex(chunk-size)) + sizeof(CRLF) + sizeof(CRLF)
if (isAwsChunked)
{
Aws::String amountToReadHexString = Aws::Utils::StringUtils::ToHexString(amountToRead);
amountToRead -= (amountToReadHexString.size() + 4);
}

if (ioStream != nullptr && amountToRead > 0)
{
size_t amountRead = 0;
if (isStreaming)
{
if (!ioStream->eof() && ioStream->peek() != EOF)
{
amountRead = (size_t) ioStream->readsome(ptr, amountToRead);
}
if (amountRead == 0 && !ioStream->eof())
{
return CURL_READFUNC_PAUSE;
}
}
else
{
ioStream->read(ptr, amountToRead);
amountRead = static_cast<size_t>(ioStream->gcount());
}

if (isAwsChunked)
{
if (amountRead > 0)
{
if (request->GetRequestHash().second != nullptr)
{
request->GetRequestHash().second->Update(reinterpret_cast<unsigned char*>(ptr), amountRead);
}

Aws::String hex = Aws::Utils::StringUtils::ToHexString(amountRead);
memmove(ptr + hex.size() + 2, ptr, amountRead);
memmove(ptr + hex.size() + 2 + amountRead, "\r\n", 2);
memmove(ptr, hex.c_str(), hex.size());
memmove(ptr + hex.size(), "\r\n", 2);
amountRead += hex.size() + 4;
}
else if (!context->m_chunkEnd)
{
Aws::StringStream chunkedTrailer;
chunkedTrailer << "0\r\n";
if (request->GetRequestHash().second != nullptr)
{
chunkedTrailer << "x-amz-checksum-"
<< request->GetRequestHash().first
<< ":"
<< HashingUtils::Base64Encode(request->GetRequestHash().second->GetHash().GetResult())
<< "\r\n";
}
chunkedTrailer << "\r\n";
amountRead = chunkedTrailer.str().size();
memcpy(ptr, chunkedTrailer.str().c_str(), amountRead);
context->m_chunkEnd = true;
}
if (isStreaming) {
if (!ioStream->eof() && ioStream->peek() != EOF) {
amountRead = (size_t)ioStream->readsome(ptr, amountToRead);
}
if (amountRead == 0 && !ioStream->eof()) {
return CURL_READFUNC_PAUSE;
}
} else if (isAwsChunked && context->m_chunkedStream != nullptr) {
AWS_LOGSTREAM_ERROR(CURL_HTTP_CLIENT_TAG, "Called with size: " << amountToRead);
amountRead = context->m_chunkedStream->BufferedRead(ptr, amountToRead);
AWS_LOGSTREAM_ERROR(CURL_HTTP_CLIENT_TAG, "read: " << amountRead);
} else {
ioStream->read(ptr, amountToRead);
amountRead = static_cast<size_t>(ioStream->gcount());
}

auto& sentHandler = request->GetDataSentEventHandler();
Expand Down Expand Up @@ -724,7 +686,14 @@ std::shared_ptr<HttpResponse> CurlHttpClient::MakeRequest(const std::shared_ptr<
}

CurlWriteCallbackContext writeContext(this, request.get(), response.get(), readLimiter);
CurlReadCallbackContext readContext(this, connectionHandle, request.get(), writeLimiter);

const auto readContext = [this, &connectionHandle, &request, &writeLimiter]() -> CurlReadCallbackContext {
if (request->GetContentBody() != nullptr) {
auto chunkedBodyPtr = Aws::MakeShared<AwsChunkedStream<>>(CURL_HTTP_CLIENT_TAG, request.get(), request->GetContentBody());
return {this, connectionHandle, request.get(), writeLimiter, std::move(chunkedBodyPtr)};
}
return {this, connectionHandle, request.get(), writeLimiter};
}();

SetOptCodeForHttpMethod(connectionHandle, request);

Expand Down
41 changes: 41 additions & 0 deletions tests/aws-cpp-sdk-core-tests/utils/stream/AwsChunkedStreamTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/**
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0.
*/
#include <aws/core/http/standard/StandardHttpRequest.h>
#include <aws/core/utils/crypto/CRC32.h>
#include <aws/core/utils/stream/AwsChunkedStream.h>
#include <aws/testing/AwsCppSdkGTestSuite.h>

using namespace Aws;
using namespace Aws::Http::Standard;
using namespace Aws::Utils::Stream;
using namespace Aws::Utils::Crypto;

class AwsChunkedStreamTest : public Aws::Testing::AwsCppSdkGTestSuite {};

const char* TEST_LOG_TAG = "AWS_CHUNKED_STREAM_TEST";

TEST_F(AwsChunkedStreamTest, ChunkedStreamShouldWork) {
StandardHttpRequest request{"www.elda.com/will", Http::HttpMethod::HTTP_GET};
auto requestHash = Aws::MakeShared<CRC32>(TEST_LOG_TAG);
request.SetRequestHash("crc32", requestHash);
std::shared_ptr<IOStream> inputStream = Aws::MakeShared<StringStream>(TEST_LOG_TAG, "1234567890123456789012345");
AwsChunkedStream<10> chunkedStream{&request, inputStream};
Aws::Utils::Array<char> outputBuffer{100};
Aws::StringStream output;
size_t readIterations{4};
size_t bufferOffset{0};
while (readIterations > 0) {
bufferOffset = chunkedStream.BufferedRead(outputBuffer.GetUnderlyingData(), 10);
std::copy(outputBuffer.GetUnderlyingData(), outputBuffer.GetUnderlyingData() + bufferOffset, std::ostream_iterator<char>(output));
readIterations--;
}
// Read trailing checksum that is greater than 10 chars
bufferOffset = chunkedStream.BufferedRead(outputBuffer.GetUnderlyingData(), 40);
EXPECT_EQ(36ul, bufferOffset);
std::copy(outputBuffer.GetUnderlyingData(), outputBuffer.GetUnderlyingData() + bufferOffset, std::ostream_iterator<char>(output));
const auto encodedStr = output.str();
auto expectedStreamWithChecksum = "A\r\n1234567890\r\nA\r\n1234567890\r\n5\r\n12345\r\n0\r\nx-amz-checksum-crc32:78DeVw==\r\n\r\n";
EXPECT_EQ(expectedStreamWithChecksum, encodedStr);
}
Original file line number Diff line number Diff line change
Expand Up @@ -2518,4 +2518,38 @@ namespace
}
}
}

TEST_F(BucketAndObjectOperationTest, PutObjectChecksumWithGuarunteedChunkedObject) {
struct ChecksumTestCase {
std::function<PutObjectRequest(PutObjectRequest)> chucksumRequestMutator;
String body;
};

const String fullBucketName = CalculateBucketName(BASE_CHECKSUMS_BUCKET_NAME.c_str());
SCOPED_TRACE(Aws::String("FullBucketName ") + fullBucketName);
CreateBucketRequest createBucketRequest;
createBucketRequest.SetBucket(fullBucketName);
createBucketRequest.SetACL(BucketCannedACL::private_);
CreateBucketOutcome createBucketOutcome = CreateBucket(createBucketRequest);
AWS_ASSERT_SUCCESS(createBucketOutcome);

Vector<ChecksumTestCase> testCases{
{[](PutObjectRequest request) -> PutObjectRequest { return request.WithChecksumAlgorithm(ChecksumAlgorithm::CRC32); },
Aws::String(1024 * 1024, 'e')},
{[](PutObjectRequest request) -> PutObjectRequest { return request.WithChecksumAlgorithm(ChecksumAlgorithm::CRC32C); },
Aws::String(1024 * 1024, 'l')},
{[](PutObjectRequest request) -> PutObjectRequest { return request.WithChecksumAlgorithm(ChecksumAlgorithm::SHA1); },
Aws::String(1024 * 1024, 'd')},
{[](PutObjectRequest request) -> PutObjectRequest { return request.WithChecksumAlgorithm(ChecksumAlgorithm::SHA256); },
Aws::String(1024 * 1024, 'a')}};

for (const auto& testCase : testCases) {
auto request = testCase.chucksumRequestMutator(PutObjectRequest().WithBucket(fullBucketName).WithKey("Metaphor"));
std::shared_ptr<IOStream> body =
Aws::MakeShared<StringStream>(ALLOCATION_TAG, testCase.body, std::ios_base::in | std::ios_base::binary);
request.SetBody(body);
const auto response = Client->PutObject(request);
EXPECT_TRUE(response.IsSuccess());
}
}
}

0 comments on commit 061f67b

Please sign in to comment.