From 061f67bf89ce6fe428252c482c3daeeb7617cab8 Mon Sep 17 00:00:00 2001 From: sbiscigl Date: Thu, 7 Nov 2024 17:50:35 -0500 Subject: [PATCH] fix trailing checksum --- .../workflows/.clang-format => .clang-format | 0 .github/workflows/clang-format.yml | 2 +- .../aws/core/utils/stream/AwsChunkedStream.h | 103 +++++++++++++++ .../source/http/curl/CurlHttpClient.cpp | 125 +++++++----------- .../utils/stream/AwsChunkedStreamTest.cpp | 41 ++++++ .../BucketAndObjectOperationTest.cpp | 34 +++++ 6 files changed, 226 insertions(+), 79 deletions(-) rename .github/workflows/.clang-format => .clang-format (100%) create mode 100644 src/aws-cpp-sdk-core/include/aws/core/utils/stream/AwsChunkedStream.h create mode 100644 tests/aws-cpp-sdk-core-tests/utils/stream/AwsChunkedStreamTest.cpp diff --git a/.github/workflows/.clang-format b/.clang-format similarity index 100% rename from .github/workflows/.clang-format rename to .clang-format diff --git a/.github/workflows/clang-format.yml b/.github/workflows/clang-format.yml index a4de9d5c79b..03c5af5f0f3 100644 --- a/.github/workflows/clang-format.yml +++ b/.github/workflows/clang-format.yml @@ -56,7 +56,7 @@ jobs: run: | clang-format --version if [ -s diff_output.patch ]; then - python3 clang-format-diff.py -p1 -style=file:.github/workflows/.clang-format < diff_output.patch > formatted_differences.patch 2> error.log || true + python3 clang-format-diff.py -p1 -style=file:.clang-format < diff_output.patch > formatted_differences.patch 2> error.log || true if [ -s error.log ]; then echo "Errors from clang-format-diff.py:" cat error.log diff --git a/src/aws-cpp-sdk-core/include/aws/core/utils/stream/AwsChunkedStream.h b/src/aws-cpp-sdk-core/include/aws/core/utils/stream/AwsChunkedStream.h new file mode 100644 index 00000000000..1eb6d44c14a --- /dev/null +++ b/src/aws-cpp-sdk-core/include/aws/core/utils/stream/AwsChunkedStream.h @@ -0,0 +1,103 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ +#pragma once +#include +#include +#include + +namespace Aws { +namespace Utils { +namespace Stream { + +static const size_t AWS_DATA_BUFFER_SIZE = 65536; + +template +class AwsChunkedStream { + public: + AwsChunkedStream(Http::HttpRequest *request, const std::shared_ptr &stream) + : m_chunkingStream{Aws::MakeShared("AwsChunkedStream")}, m_request(request), m_stream(stream) { + assert(m_stream != nullptr); + if (m_stream == nullptr) { + AWS_LOGSTREAM_ERROR("AwsChunkedStream", "stream is null"); + } + assert(request != nullptr); + if (request == nullptr) { + AWS_LOGSTREAM_ERROR("AwsChunkedStream", "request is null"); + } + } + + size_t BufferedRead(char *dst, size_t amountToRead) { + assert(dst != nullptr); + if (dst == nullptr) { + AWS_LOGSTREAM_ERROR("AwsChunkedStream", "dst is null"); + } + + // the chunk has ended and cannot be read from + if (m_chunkEnd) { + return 0; + } + + // If we've read all of the underlying stream write the checksum trailing header + // the set that the chunked stream is over. + if (m_stream->eof() && !m_stream->bad() && (m_chunkingStream->eof() || m_chunkingStream->peek() == EOF)) { + return writeTrailer(dst, amountToRead); + } + + // Try to read in a 64K chunk, if we cant we know the stream is over + size_t bytesRead = 0; + while (m_stream->good() && bytesRead < DataBufferSize) { + m_stream->read(&m_data[bytesRead], DataBufferSize - bytesRead); + bytesRead += static_cast(m_stream->gcount()); + } + + if (bytesRead > 0) { + writeChunk(bytesRead); + } + + // Read to destination buffer, return how much was read + m_chunkingStream->read(dst, amountToRead); + return static_cast(m_chunkingStream->gcount()); + } + + private: + size_t writeTrailer(char *dst, size_t amountToRead) { + Aws::StringStream chunkedTrailerStream; + chunkedTrailerStream << "0\r\n"; + if (m_request->GetRequestHash().second != nullptr) { + chunkedTrailerStream << "x-amz-checksum-" << m_request->GetRequestHash().first << ":" + << HashingUtils::Base64Encode(m_request->GetRequestHash().second->GetHash().GetResult()) << "\r\n"; + } + chunkedTrailerStream << "\r\n"; + const auto chunkedTrailer = chunkedTrailerStream.str(); + auto trailerSize = chunkedTrailer.size(); + // unreferenced param for assert + AWS_UNREFERENCED_PARAM(amountToRead); + assert(amountToRead >= trailerSize); + memcpy(dst, chunkedTrailer.c_str(), trailerSize); + m_chunkEnd = true; + return trailerSize; + } + + void writeChunk(size_t bytesRead) { + if (m_request->GetRequestHash().second != nullptr) { + m_request->GetRequestHash().second->Update(reinterpret_cast(m_data.GetUnderlyingData()), bytesRead); + } + + if (m_chunkingStream != nullptr && !m_chunkingStream->bad()) { + *m_chunkingStream << Aws::Utils::StringUtils::ToHexString(bytesRead) << "\r\n"; + m_chunkingStream->write(m_data.GetUnderlyingData(), bytesRead); + *m_chunkingStream << "\r\n"; + } + } + + Aws::Utils::Array m_data{DataBufferSize}; + std::shared_ptr m_chunkingStream; + bool m_chunkEnd{false}; + Http::HttpRequest *m_request{nullptr}; + std::shared_ptr m_stream; +}; +} // namespace Stream +} // namespace Utils +} // namespace Aws diff --git a/src/aws-cpp-sdk-core/source/http/curl/CurlHttpClient.cpp b/src/aws-cpp-sdk-core/source/http/curl/CurlHttpClient.cpp index 9ceefe03794..b03e4070ff1 100644 --- a/src/aws-cpp-sdk-core/source/http/curl/CurlHttpClient.cpp +++ b/src/aws-cpp-sdk-core/source/http/curl/CurlHttpClient.cpp @@ -3,27 +3,29 @@ * SPDX-License-Identifier: Apache-2.0. */ -#include #include +#include #include -#include +#include +#include #include +#include +#include +#include #include #include -#include -#include -#include -#include -#include +#include + #include +#include #include - using namespace Aws::Client; using namespace Aws::Http; using namespace Aws::Http::Standard; using namespace Aws::Utils; using namespace Aws::Utils::Logging; +using namespace Aws::Utils::Stream; using namespace Aws::Monitoring; #ifdef USE_AWS_MEMORY_MANAGEMENT @@ -144,25 +146,28 @@ struct CurlWriteCallbackContext int64_t m_numBytesResponseReceived; }; +static const char* CURL_HTTP_CLIENT_TAG = "CurlHttpClient"; + struct CurlReadCallbackContext { - CurlReadCallbackContext(const CurlHttpClient* client, CURL* curlHandle, HttpRequest* request, Aws::Utils::RateLimits::RateLimiterInterface* limiter) : - m_client(client), + CurlReadCallbackContext(const CurlHttpClient* client, CURL* curlHandle, HttpRequest* request, + Aws::Utils::RateLimits::RateLimiterInterface* limiter, + std::shared_ptr> chunkedStream = nullptr) + : m_client(client), m_curlHandle(curlHandle), m_rateLimiter(limiter), m_request(request), - m_chunkEnd(false) - {} - - const CurlHttpClient* m_client; - CURL* m_curlHandle; - Aws::Utils::RateLimits::RateLimiterInterface* m_rateLimiter; - HttpRequest* m_request; - bool m_chunkEnd; + m_chunkEnd(false), + m_chunkedStream{std::move(chunkedStream)} {} + + const CurlHttpClient* m_client; + CURL* m_curlHandle; + Aws::Utils::RateLimits::RateLimiterInterface* m_rateLimiter; + HttpRequest* m_request; + bool m_chunkEnd; + std::shared_ptr> m_chunkedStream; }; -static const char* CURL_HTTP_CLIENT_TAG = "CurlHttpClient"; - static int64_t GetContentLengthFromHeader(CURL* connectionHandle, bool& hasContentLength) { #if LIBCURL_VERSION_NUM >= 0x073700 // 7.55.0 @@ -293,67 +298,24 @@ static size_t ReadBody(char* ptr, size_t size, size_t nmemb, void* userdata, boo size_t amountToRead = size * nmemb; bool isAwsChunked = request->HasHeader(Aws::Http::CONTENT_ENCODING_HEADER) && request->GetHeaderValue(Aws::Http::CONTENT_ENCODING_HEADER) == Aws::Http::AWS_CHUNKED_VALUE; - // aws-chunk = hex(chunk-size) + CRLF + chunk-data + CRLF - // Needs to reserve bytes of sizeof(hex(chunk-size)) + sizeof(CRLF) + sizeof(CRLF) - if (isAwsChunked) - { - Aws::String amountToReadHexString = Aws::Utils::StringUtils::ToHexString(amountToRead); - amountToRead -= (amountToReadHexString.size() + 4); - } if (ioStream != nullptr && amountToRead > 0) { size_t amountRead = 0; - if (isStreaming) - { - if (!ioStream->eof() && ioStream->peek() != EOF) - { - amountRead = (size_t) ioStream->readsome(ptr, amountToRead); - } - if (amountRead == 0 && !ioStream->eof()) - { - return CURL_READFUNC_PAUSE; - } - } - else - { - ioStream->read(ptr, amountToRead); - amountRead = static_cast(ioStream->gcount()); - } - - if (isAwsChunked) - { - if (amountRead > 0) - { - if (request->GetRequestHash().second != nullptr) - { - request->GetRequestHash().second->Update(reinterpret_cast(ptr), amountRead); - } - - Aws::String hex = Aws::Utils::StringUtils::ToHexString(amountRead); - memmove(ptr + hex.size() + 2, ptr, amountRead); - memmove(ptr + hex.size() + 2 + amountRead, "\r\n", 2); - memmove(ptr, hex.c_str(), hex.size()); - memmove(ptr + hex.size(), "\r\n", 2); - amountRead += hex.size() + 4; - } - else if (!context->m_chunkEnd) - { - Aws::StringStream chunkedTrailer; - chunkedTrailer << "0\r\n"; - if (request->GetRequestHash().second != nullptr) - { - chunkedTrailer << "x-amz-checksum-" - << request->GetRequestHash().first - << ":" - << HashingUtils::Base64Encode(request->GetRequestHash().second->GetHash().GetResult()) - << "\r\n"; - } - chunkedTrailer << "\r\n"; - amountRead = chunkedTrailer.str().size(); - memcpy(ptr, chunkedTrailer.str().c_str(), amountRead); - context->m_chunkEnd = true; - } + if (isStreaming) { + if (!ioStream->eof() && ioStream->peek() != EOF) { + amountRead = (size_t)ioStream->readsome(ptr, amountToRead); + } + if (amountRead == 0 && !ioStream->eof()) { + return CURL_READFUNC_PAUSE; + } + } else if (isAwsChunked && context->m_chunkedStream != nullptr) { + AWS_LOGSTREAM_ERROR(CURL_HTTP_CLIENT_TAG, "Called with size: " << amountToRead); + amountRead = context->m_chunkedStream->BufferedRead(ptr, amountToRead); + AWS_LOGSTREAM_ERROR(CURL_HTTP_CLIENT_TAG, "read: " << amountRead); + } else { + ioStream->read(ptr, amountToRead); + amountRead = static_cast(ioStream->gcount()); } auto& sentHandler = request->GetDataSentEventHandler(); @@ -724,7 +686,14 @@ std::shared_ptr CurlHttpClient::MakeRequest(const std::shared_ptr< } CurlWriteCallbackContext writeContext(this, request.get(), response.get(), readLimiter); - CurlReadCallbackContext readContext(this, connectionHandle, request.get(), writeLimiter); + + const auto readContext = [this, &connectionHandle, &request, &writeLimiter]() -> CurlReadCallbackContext { + if (request->GetContentBody() != nullptr) { + auto chunkedBodyPtr = Aws::MakeShared>(CURL_HTTP_CLIENT_TAG, request.get(), request->GetContentBody()); + return {this, connectionHandle, request.get(), writeLimiter, std::move(chunkedBodyPtr)}; + } + return {this, connectionHandle, request.get(), writeLimiter}; + }(); SetOptCodeForHttpMethod(connectionHandle, request); diff --git a/tests/aws-cpp-sdk-core-tests/utils/stream/AwsChunkedStreamTest.cpp b/tests/aws-cpp-sdk-core-tests/utils/stream/AwsChunkedStreamTest.cpp new file mode 100644 index 00000000000..075b46e4558 --- /dev/null +++ b/tests/aws-cpp-sdk-core-tests/utils/stream/AwsChunkedStreamTest.cpp @@ -0,0 +1,41 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ +#include +#include +#include +#include + +using namespace Aws; +using namespace Aws::Http::Standard; +using namespace Aws::Utils::Stream; +using namespace Aws::Utils::Crypto; + +class AwsChunkedStreamTest : public Aws::Testing::AwsCppSdkGTestSuite {}; + +const char* TEST_LOG_TAG = "AWS_CHUNKED_STREAM_TEST"; + +TEST_F(AwsChunkedStreamTest, ChunkedStreamShouldWork) { + StandardHttpRequest request{"www.elda.com/will", Http::HttpMethod::HTTP_GET}; + auto requestHash = Aws::MakeShared(TEST_LOG_TAG); + request.SetRequestHash("crc32", requestHash); + std::shared_ptr inputStream = Aws::MakeShared(TEST_LOG_TAG, "1234567890123456789012345"); + AwsChunkedStream<10> chunkedStream{&request, inputStream}; + Aws::Utils::Array outputBuffer{100}; + Aws::StringStream output; + size_t readIterations{4}; + size_t bufferOffset{0}; + while (readIterations > 0) { + bufferOffset = chunkedStream.BufferedRead(outputBuffer.GetUnderlyingData(), 10); + std::copy(outputBuffer.GetUnderlyingData(), outputBuffer.GetUnderlyingData() + bufferOffset, std::ostream_iterator(output)); + readIterations--; + } + // Read trailing checksum that is greater than 10 chars + bufferOffset = chunkedStream.BufferedRead(outputBuffer.GetUnderlyingData(), 40); + EXPECT_EQ(36ul, bufferOffset); + std::copy(outputBuffer.GetUnderlyingData(), outputBuffer.GetUnderlyingData() + bufferOffset, std::ostream_iterator(output)); + const auto encodedStr = output.str(); + auto expectedStreamWithChecksum = "A\r\n1234567890\r\nA\r\n1234567890\r\n5\r\n12345\r\n0\r\nx-amz-checksum-crc32:78DeVw==\r\n\r\n"; + EXPECT_EQ(expectedStreamWithChecksum, encodedStr); +} diff --git a/tests/aws-cpp-sdk-s3-integration-tests/BucketAndObjectOperationTest.cpp b/tests/aws-cpp-sdk-s3-integration-tests/BucketAndObjectOperationTest.cpp index ae22f9c0271..dbc087b9b4f 100644 --- a/tests/aws-cpp-sdk-s3-integration-tests/BucketAndObjectOperationTest.cpp +++ b/tests/aws-cpp-sdk-s3-integration-tests/BucketAndObjectOperationTest.cpp @@ -2518,4 +2518,38 @@ namespace } } } + + TEST_F(BucketAndObjectOperationTest, PutObjectChecksumWithGuarunteedChunkedObject) { + struct ChecksumTestCase { + std::function chucksumRequestMutator; + String body; + }; + + const String fullBucketName = CalculateBucketName(BASE_CHECKSUMS_BUCKET_NAME.c_str()); + SCOPED_TRACE(Aws::String("FullBucketName ") + fullBucketName); + CreateBucketRequest createBucketRequest; + createBucketRequest.SetBucket(fullBucketName); + createBucketRequest.SetACL(BucketCannedACL::private_); + CreateBucketOutcome createBucketOutcome = CreateBucket(createBucketRequest); + AWS_ASSERT_SUCCESS(createBucketOutcome); + + Vector testCases{ + {[](PutObjectRequest request) -> PutObjectRequest { return request.WithChecksumAlgorithm(ChecksumAlgorithm::CRC32); }, + Aws::String(1024 * 1024, 'e')}, + {[](PutObjectRequest request) -> PutObjectRequest { return request.WithChecksumAlgorithm(ChecksumAlgorithm::CRC32C); }, + Aws::String(1024 * 1024, 'l')}, + {[](PutObjectRequest request) -> PutObjectRequest { return request.WithChecksumAlgorithm(ChecksumAlgorithm::SHA1); }, + Aws::String(1024 * 1024, 'd')}, + {[](PutObjectRequest request) -> PutObjectRequest { return request.WithChecksumAlgorithm(ChecksumAlgorithm::SHA256); }, + Aws::String(1024 * 1024, 'a')}}; + + for (const auto& testCase : testCases) { + auto request = testCase.chucksumRequestMutator(PutObjectRequest().WithBucket(fullBucketName).WithKey("Metaphor")); + std::shared_ptr body = + Aws::MakeShared(ALLOCATION_TAG, testCase.body, std::ios_base::in | std::ios_base::binary); + request.SetBody(body); + const auto response = Client->PutObject(request); + EXPECT_TRUE(response.IsSuccess()); + } + } }