From 7c3111da533d48f2717a636b9ecf4208ebf3f378 Mon Sep 17 00:00:00 2001 From: SergeyRyabinin Date: Mon, 13 May 2024 21:50:11 +0000 Subject: [PATCH] Avoid excessive reply buffer copy in WinHTTP --- .../utils/stream/StreamBufProtectedWriter.h | 117 ++++++++++++++++++ .../source/http/curl/CurlHttpClient.cpp | 4 +- .../source/http/windows/WinSyncHttpClient.cpp | 48 +++---- tests/benchmark/benchmark.cpp | 7 ++ tests/benchmark/include/Configuration.h | 2 + tests/benchmark/src/Configuration.cpp | 12 +- tests/benchmark/src/service/S3GetObject.cpp | 26 +++- tests/benchmark/src/service/S3PutObject.cpp | 22 +++- tests/benchmark/src/service/S3Utils.cpp | 14 ++- 9 files changed, 215 insertions(+), 37 deletions(-) create mode 100644 src/aws-cpp-sdk-core/include/aws/core/utils/stream/StreamBufProtectedWriter.h diff --git a/src/aws-cpp-sdk-core/include/aws/core/utils/stream/StreamBufProtectedWriter.h b/src/aws-cpp-sdk-core/include/aws/core/utils/stream/StreamBufProtectedWriter.h new file mode 100644 index 00000000000..81e624c0d06 --- /dev/null +++ b/src/aws-cpp-sdk-core/include/aws/core/utils/stream/StreamBufProtectedWriter.h @@ -0,0 +1,117 @@ + +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#pragma once + +#include +#include +#include +#include + +namespace Aws +{ + namespace Utils + { + namespace Stream + { + /** + * This is a wrapper to perform a hack to write directly to the put area of the underlying streambuf + */ + class AWS_CORE_API StreamBufProtectedWriter : public std::streambuf + { + public: + StreamBufProtectedWriter() = delete; + + using WriterFunc = std::function; + + static uint64_t WriteToBuffer(Aws::IOStream& ioStream, const WriterFunc& writerFunc) + { + uint64_t totalRead = 0; + + while (true) + { + StreamBufProtectedWriter* pBufferCasted = static_cast(ioStream.rdbuf()); + bool bufferPresent = pBufferCasted && pBufferCasted->pptr() && (pBufferCasted->pptr() < pBufferCasted->epptr()); + uint64_t read = 0; + if (bufferPresent) + { + // have access to underlying put ptr. + read = WriteDirectlyToPtr(pBufferCasted, writerFunc); + } + else + { + // can't access underlying buffer, stream buffer maybe be customized to not use put ptr. + // or underlying put buffer is simply not initialized yet. + read = WriteWithHelperBuffer(ioStream, writerFunc); + } + if (!read) + { + return totalRead; + } + totalRead += read; + + if (pBufferCasted && pBufferCasted->pptr() && (pBufferCasted->pptr() >= pBufferCasted->epptr())) + { + if(!ForceOverflow(pBufferCasted, writerFunc)) + { + return totalRead; + } else { + totalRead++; + } + } + } + return totalRead; + } + protected: + static size_t ForceOverflow(StreamBufProtectedWriter* pBuffer, const WriterFunc& writerFunc) + { + char dstChar; + uint64_t read = 0; + if (writerFunc(&dstChar, 1, read) && read > 0) + { + pBuffer->overflow(dstChar); + return 1; + } + return 0; + } + + static uint64_t WriteWithHelperBuffer(Aws::IOStream& ioStream, const WriterFunc& writerFunc) + { + uint64_t read = 0; + char tmpBuf[1024]; + uint64_t tmpBufSz = sizeof(tmpBuf); + + if(writerFunc(tmpBuf, tmpBufSz, read) && read > 0) + { + ioStream.write(tmpBuf, read); + if (ioStream.fail()) { + AWS_LOGSTREAM_ERROR("StreamBufProtectedWriter", "Failed to write " << tmpBufSz + << " (eof: " << ioStream.eof() << ", bad: " << ioStream.bad() << ")"); + return 0; + } + } + + return read; + } + + static uint64_t WriteDirectlyToPtr(StreamBufProtectedWriter* pBuffer, const WriterFunc& writerFunc) + { + auto dstBegin = pBuffer->pptr(); + uint64_t dstSz = pBuffer->epptr() - dstBegin; + uint64_t read = 0; + std::cout << "dst size = " << dstSz << "\n"; + if(writerFunc(dstBegin, dstSz, read) && read > 0) + { + assert(read <= dstSz); + pBuffer->pbump((int) read); + } + + return read; + } + }; + } + } +} diff --git a/src/aws-cpp-sdk-core/source/http/curl/CurlHttpClient.cpp b/src/aws-cpp-sdk-core/source/http/curl/CurlHttpClient.cpp index fb7b702650a..2b1527207e3 100644 --- a/src/aws-cpp-sdk-core/source/http/curl/CurlHttpClient.cpp +++ b/src/aws-cpp-sdk-core/source/http/curl/CurlHttpClient.cpp @@ -209,7 +209,7 @@ static size_t WriteData(char* ptr, size_t size, size_t nmemb, void* userdata) return 0; } - size_t cur = response->GetResponseBody().tellp(); + auto cur = response->GetResponseBody().tellp(); if (response->GetResponseBody().fail()) { const auto& ref = response->GetResponseBody(); AWS_LOGSTREAM_ERROR(CURL_HTTP_CLIENT_TAG, "Unable to query response output position (eof: " @@ -302,7 +302,7 @@ static size_t ReadBody(char* ptr, size_t size, size_t nmemb, void* userdata, boo { if (!ioStream->eof() && ioStream->peek() != EOF) { - amountRead = ioStream->readsome(ptr, amountToRead); + amountRead = (size_t) ioStream->readsome(ptr, amountToRead); } if (amountRead == 0 && !ioStream->eof()) { diff --git a/src/aws-cpp-sdk-core/source/http/windows/WinSyncHttpClient.cpp b/src/aws-cpp-sdk-core/source/http/windows/WinSyncHttpClient.cpp index 1cec38b129b..76caeafa7d9 100644 --- a/src/aws-cpp-sdk-core/source/http/windows/WinSyncHttpClient.cpp +++ b/src/aws-cpp-sdk-core/source/http/windows/WinSyncHttpClient.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -245,43 +246,42 @@ bool WinSyncHttpClient::BuildSuccessResponse(const std::shared_ptr& if (request->GetMethod() != HttpMethod::HTTP_HEAD) { - char body[1024]; - uint64_t bodySize = sizeof(body); - int64_t numBytesResponseReceived = 0; - read = 0; - bool success = ContinueRequest(*request); - while (DoReadData(hHttpRequest, body, bodySize, read) && read > 0 && success) - { - response->GetResponseBody().write(body, read); - if (read > 0) + auto writerFunc = + [this, hHttpRequest, &request, readLimiter, &response](char* dst, uint64_t dstSz, uint64_t& read) -> bool{ + bool success = DoReadData(hHttpRequest, dst, dstSz, read); + if (success) { - for (const auto& hashIterator : request->GetResponseValidationHashes()) - { - hashIterator.second->Update(reinterpret_cast(body), static_cast(read)); - } - numBytesResponseReceived += read; - if (readLimiter != nullptr) + if (read > 0) { - readLimiter->ApplyAndPayForCost(read); - } - auto& receivedHandler = request->GetDataReceivedEventHandler(); - if (receivedHandler) - { - receivedHandler(request.get(), response.get(), (long long)read); + for (const auto& hashIterator : request->GetResponseValidationHashes()) + { + hashIterator.second->Update(reinterpret_cast(dst), static_cast(read)); + } + if (readLimiter != nullptr) + { + readLimiter->ApplyAndPayForCost(read); + } + auto& receivedHandler = request->GetDataReceivedEventHandler(); + if (receivedHandler) + { + receivedHandler(request.get(), response.get(), (long long)read); + } } } + return success; + }; + uint64_t numBytesResponseReceived = Aws::Utils::Stream::StreamBufProtectedWriter::WriteToBuffer(response->GetResponseBody(), writerFunc); - success = success && ContinueRequest(*request) && IsRequestProcessingEnabled(); - } + success = success && ContinueRequest(*request) && IsRequestProcessingEnabled(); if (success && response->HasHeader(Aws::Http::CONTENT_LENGTH_HEADER)) { const Aws::String& contentLength = response->GetHeader(Aws::Http::CONTENT_LENGTH_HEADER); AWS_LOGSTREAM_TRACE(GetLogTag(), "Response content-length header: " << contentLength); AWS_LOGSTREAM_TRACE(GetLogTag(), "Response body length: " << numBytesResponseReceived); - if (StringUtils::ConvertToInt64(contentLength.c_str()) != numBytesResponseReceived) + if ((uint64_t) StringUtils::ConvertToInt64(contentLength.c_str()) != numBytesResponseReceived) { success = false; response->SetClientErrorType(CoreErrors::NETWORK_CONNECTION); diff --git a/tests/benchmark/benchmark.cpp b/tests/benchmark/benchmark.cpp index f7dcffc97aa..cb929bfcbdd 100644 --- a/tests/benchmark/benchmark.cpp +++ b/tests/benchmark/benchmark.cpp @@ -10,6 +10,13 @@ #include int main(int argc, char *argv[]) { + if (1 == argc || + 2 == argc && (std::string(argv[1]) == "-h" || std::string(argv[1]) == "--help" )) + { + Benchmark::Configuration::PrintHelp(); + return 0; + } + Aws::SDKOptions options; Aws::InitAPI(options); { diff --git a/tests/benchmark/include/Configuration.h b/tests/benchmark/include/Configuration.h index d7c8c1f92f5..d47e5eb09ce 100644 --- a/tests/benchmark/include/Configuration.h +++ b/tests/benchmark/include/Configuration.h @@ -11,12 +11,14 @@ namespace Benchmark { std::string service; std::string api; long durationMillis; + size_t maxRepeats; bool shouldReportToCloudWatch; std::map dimensions; }; class Configuration { public: + static void PrintHelp(); static Configuration FromArgs(int argc, char *argv[]); inline RunConfiguration GetConfiguration() const { return this->runConfiguration; } private: diff --git a/tests/benchmark/src/Configuration.cpp b/tests/benchmark/src/Configuration.cpp index 19cf8dc098f..2b17001f7a4 100644 --- a/tests/benchmark/src/Configuration.cpp +++ b/tests/benchmark/src/Configuration.cpp @@ -7,15 +7,25 @@ #include #include #include +#include Benchmark::Configuration::Configuration(Benchmark::RunConfiguration runConfiguration) : runConfiguration(std::move(runConfiguration)) {} +void Benchmark::Configuration::PrintHelp() { + + std::cout << "usage: benchmark [--service SERVICE] [--api API] [--durationMillis DURATION_MS] [--withMetrics KEY1:VAL1,KEY2:VAL2]\n"; + std::cout << "\n"; + std::cout << "example: benchmark --service s3 --api PutObject --durationMillis 1000\n"; + std::cout << "example: benchmark --service s3 --api PutObject --durationMillis 1000 --dimensions BucketType:S3Express\n"; +} + Benchmark::Configuration Benchmark::Configuration::FromArgs(int argc, char *argv[]) { return Benchmark::Configuration({ Benchmark::Configuration::GetCmdOption(argv, argv + argc, "--service"), Benchmark::Configuration::GetCmdOption(argv, argv + argc, "--api"), std::stol(Benchmark::Configuration::GetCmdOption(argv, argv + argc, "--durationMillis")), + std::stoul(Benchmark::Configuration::GetCmdOption(argv, argv + argc, "--maxRepeats")), Benchmark::Configuration::CmdOptionExists(argv, argv + argc, "--withMetrics"), Benchmark::Configuration::GetCmdOptions(argv, argv + argc, "--dimensions") }); @@ -31,7 +41,7 @@ char *Benchmark::Configuration::GetCmdOption(char **begin, char **end, const std std::map Benchmark::Configuration::GetCmdOptions(char **begin, char **end, const std::string &option) { char **itr = std::find(begin, end, option); - auto value = ++itr; + auto value = itr != end ? ++itr : itr; if (itr != end && value != end) { //check to make sure the next entry is not another arg std::string nextArg(*value); diff --git a/tests/benchmark/src/service/S3GetObject.cpp b/tests/benchmark/src/service/S3GetObject.cpp index ed3d059d29b..cd59b69b265 100644 --- a/tests/benchmark/src/service/S3GetObject.cpp +++ b/tests/benchmark/src/service/S3GetObject.cpp @@ -40,7 +40,19 @@ Benchmark::TestFunction Benchmark::S3GetObject::CreateTestFunction() { metricsEmitter->EmitMetricForOp("CreateBucket", S3Utils::getMetricDimensions(dimensions, {{"Service", "S3"}, {"Operation", "CreateBucket"}}), [&]() -> bool { - auto response = s3->CreateBucket(CreateBucketRequest().WithBucket(bucketName)); + auto request = CreateBucketRequest() + .WithBucket(bucketName); + if (dimensions.find("BucketType") != dimensions.end() && dimensions.at("BucketType") == "S3Express") { + request.WithCreateBucketConfiguration(CreateBucketConfiguration() + .WithLocation(LocationInfo() + .WithType(LocationType::AvailabilityZone) + .WithName("use1-az2")) + .WithBucket(BucketInfo() + .WithType(BucketType::Directory) + .WithDataRedundancy(DataRedundancy::SingleAvailabilityZone))); + } + + auto response = s3->CreateBucket(request); if (!response.IsSuccess()) { std::cout << "Create Bucket Failed With: " << response.GetError().GetMessage() @@ -64,7 +76,7 @@ Benchmark::TestFunction Benchmark::S3GetObject::CreateTestFunction() { if (!response.IsSuccess()) { std::cout << "Put Object Failed With: " << response.GetError().GetMessage() - << "\n";; + << "\n"; } return response.IsSuccess(); }); @@ -72,8 +84,10 @@ Benchmark::TestFunction Benchmark::S3GetObject::CreateTestFunction() { // Run GetObject requests const auto timeToEnd = duration_cast(steady_clock::now().time_since_epoch()).count() + configuration.GetConfiguration().durationMillis; + size_t counter = 0; + size_t maxRepeats = configuration.GetConfiguration().maxRepeats; auto getObjectRequest = GetObjectRequest().WithBucket(bucketName).WithKey(testObjectKey); - while (duration_cast(steady_clock::now().time_since_epoch()).count() < timeToEnd) { ; + while (duration_cast(steady_clock::now().time_since_epoch()).count() < timeToEnd) { metricsEmitter->EmitMetricForOp( "GetObject", S3Utils::getMetricDimensions(dimensions, {{"Service", "S3"}, {"Operation", "GetObject"}}), @@ -82,10 +96,14 @@ Benchmark::TestFunction Benchmark::S3GetObject::CreateTestFunction() { if (!response.IsSuccess()) { std::cout << "Get Object Failed With: " << response.GetError().GetMessage() - << "\n";; + << "\n"; } return response.IsSuccess(); }); + counter++; + if (maxRepeats && counter == maxRepeats) { + break; + } } // Clean up diff --git a/tests/benchmark/src/service/S3PutObject.cpp b/tests/benchmark/src/service/S3PutObject.cpp index 12d3a21c257..b48df4219e3 100644 --- a/tests/benchmark/src/service/S3PutObject.cpp +++ b/tests/benchmark/src/service/S3PutObject.cpp @@ -42,7 +42,19 @@ Benchmark::TestFunction Benchmark::S3PutObject::CreateTestFunction() { metricsEmitter->EmitMetricForOp("CreateBucket", S3Utils::getMetricDimensions(dimensions, {{"Service", "S3"}, {"Operation", "CreateBucket"}}), [&]() -> bool { - auto response = s3->CreateBucket(CreateBucketRequest().WithBucket(bucketName)); + auto request = CreateBucketRequest() + .WithBucket(bucketName); + if (dimensions.find("BucketType") != dimensions.end() && dimensions.at("BucketType") == "S3Express") { + request.WithCreateBucketConfiguration(CreateBucketConfiguration() + .WithLocation(LocationInfo() + .WithType(LocationType::AvailabilityZone) + .WithName("use1-az2")) + .WithBucket(BucketInfo() + .WithType(BucketType::Directory) + .WithDataRedundancy(DataRedundancy::SingleAvailabilityZone))); + } + + auto response = s3->CreateBucket(request); if (!response.IsSuccess()) { std::cout << "Create Bucket Failed With: " << response.GetError().GetMessage() @@ -56,6 +68,8 @@ Benchmark::TestFunction Benchmark::S3PutObject::CreateTestFunction() { std::vector keysToDelete; const auto timeToEnd = duration_cast(steady_clock::now().time_since_epoch()).count() + configuration.GetConfiguration().durationMillis; + size_t counter = 0; + size_t maxRepeats = configuration.GetConfiguration().maxRepeats; while (duration_cast(steady_clock::now().time_since_epoch()).count() < timeToEnd) { auto key = RandomString(8); keysToDelete.push_back(key); @@ -71,10 +85,14 @@ Benchmark::TestFunction Benchmark::S3PutObject::CreateTestFunction() { if (!response.IsSuccess()) { std::cout << "Put Failed With: " << response.GetError().GetMessage() - << "\n";; + << "\n"; } return response.IsSuccess(); }); + counter++; + if (maxRepeats && counter == maxRepeats) { + break; + } } // Clean up diff --git a/tests/benchmark/src/service/S3Utils.cpp b/tests/benchmark/src/service/S3Utils.cpp index d7ae246a66a..c2e05e9a3b0 100644 --- a/tests/benchmark/src/service/S3Utils.cpp +++ b/tests/benchmark/src/service/S3Utils.cpp @@ -10,8 +10,7 @@ using S3ClientConfiguration = Aws::S3Crt::ClientConfiguration; using S3Client = S3CrtClient; #else #include -using namespace Aws::S3; -#endif +#include #ifdef USE_OTLP #include #include @@ -31,7 +30,14 @@ using namespace opentelemetry::exporter::otlp; const char* ALLOC_TAG = "S3_BENCHMARK"; std::unique_ptr S3Utils::makeClient(const std::map &cliDimensions) { - S3ClientConfiguration configuration; + Aws::Client::ClientConfigurationInitValues cfgInit; + cfgInit.shouldDisableIMDS = true; + S3ClientConfiguration configuration(cfgInit); + configuration.caFile = "C:/opt/curl-8.7.1_8-win64-mingw/bin/curl-ca-bundle.crt"; + // configuration.enableHttpClientTrace = true; + configuration.requestTimeoutMs = 10000; + configuration.connectTimeoutMs = 10000; + configuration.retryStrategy = Aws::MakeShared("RETRY", 0); if (cliDimensions.find("TelemetryHost") != cliDimensions.end() && cliDimensions.find("TelemetryPort") != cliDimensions.end()) { @@ -59,7 +65,7 @@ std::string Benchmark::S3Utils::getBucketPrefix(const std::map