From 57521e69d99cbe57ca2bf0760d5ab3420cbc5fa7 Mon Sep 17 00:00:00 2001 From: SergeyRyabinin Date: Mon, 13 May 2024 21:50:11 +0000 Subject: [PATCH] Avoid excessive reply buffer copy in WinHTTP --- .../core/http/windows/WinHttpSyncHttpClient.h | 1 + .../core/http/windows/WinINetSyncHttpClient.h | 1 + .../aws/core/http/windows/WinSyncHttpClient.h | 1 + .../utils/stream/StreamBufProtectedWriter.h | 120 ++++++++++++++++++ .../source/http/curl/CurlHttpClient.cpp | 4 +- .../http/windows/WinHttpSyncHttpClient.cpp | 5 + .../http/windows/WinINetSyncHttpClient.cpp | 5 + .../source/http/windows/WinSyncHttpClient.cpp | 116 ++++++++++++----- .../CloudWatchLogsTests.cpp | 1 + .../TranscribeTests.cpp | 4 +- tests/benchmark/benchmark.cpp | 7 + tests/benchmark/include/Configuration.h | 2 + tests/benchmark/src/Configuration.cpp | 12 +- tests/benchmark/src/service/S3GetObject.cpp | 26 +++- tests/benchmark/src/service/S3PutObject.cpp | 22 +++- 15 files changed, 283 insertions(+), 44 deletions(-) create mode 100644 src/aws-cpp-sdk-core/include/aws/core/utils/stream/StreamBufProtectedWriter.h diff --git a/src/aws-cpp-sdk-core/include/aws/core/http/windows/WinHttpSyncHttpClient.h b/src/aws-cpp-sdk-core/include/aws/core/http/windows/WinHttpSyncHttpClient.h index 1dcfa4313c8..fc80aca7132 100644 --- a/src/aws-cpp-sdk-core/include/aws/core/http/windows/WinHttpSyncHttpClient.h +++ b/src/aws-cpp-sdk-core/include/aws/core/http/windows/WinHttpSyncHttpClient.h @@ -51,6 +51,7 @@ namespace Aws bool DoReceiveResponse(void* httpRequest) const override; bool DoQueryHeaders(void* httpRequest, std::shared_ptr& response, Aws::StringStream& ss, uint64_t& read) const override; bool DoSendRequest(void* httpRequest) const override; + bool DoQueryDataAvailable(void* hHttpRequest, uint64_t& available) const override; bool DoReadData(void* hHttpRequest, char* body, uint64_t size, uint64_t& read) const override; void* GetClientModule() const override; diff --git a/src/aws-cpp-sdk-core/include/aws/core/http/windows/WinINetSyncHttpClient.h b/src/aws-cpp-sdk-core/include/aws/core/http/windows/WinINetSyncHttpClient.h index 6a49767e21f..8b35a827bad 100644 --- a/src/aws-cpp-sdk-core/include/aws/core/http/windows/WinINetSyncHttpClient.h +++ b/src/aws-cpp-sdk-core/include/aws/core/http/windows/WinINetSyncHttpClient.h @@ -49,6 +49,7 @@ namespace Aws bool DoReceiveResponse(void* hHttpRequest) const override; bool DoQueryHeaders(void* hHttpRequest, std::shared_ptr& response, Aws::StringStream& ss, uint64_t& read) const override; bool DoSendRequest(void* hHttpRequest) const override; + bool DoQueryDataAvailable(void* hHttpRequest, uint64_t& available) const override; bool DoReadData(void* hHttpRequest, char* body, uint64_t size, uint64_t& read) const override; void* GetClientModule() const override; diff --git a/src/aws-cpp-sdk-core/include/aws/core/http/windows/WinSyncHttpClient.h b/src/aws-cpp-sdk-core/include/aws/core/http/windows/WinSyncHttpClient.h index e3a29a169d9..76a4fabd8fc 100644 --- a/src/aws-cpp-sdk-core/include/aws/core/http/windows/WinSyncHttpClient.h +++ b/src/aws-cpp-sdk-core/include/aws/core/http/windows/WinSyncHttpClient.h @@ -90,6 +90,7 @@ namespace Aws virtual bool DoReceiveResponse(void* hHttpRequest) const = 0; virtual bool DoQueryHeaders(void* hHttpRequest, std::shared_ptr& response, Aws::StringStream& ss, uint64_t& read) const = 0; virtual bool DoSendRequest(void* hHttpRequest) const = 0; + virtual bool DoQueryDataAvailable(void* hHttpRequest, uint64_t& available) const = 0; virtual bool DoReadData(void* hHttpRequest, char* body, uint64_t size, uint64_t& read) const = 0; virtual void* GetClientModule() const = 0; diff --git a/src/aws-cpp-sdk-core/include/aws/core/utils/stream/StreamBufProtectedWriter.h b/src/aws-cpp-sdk-core/include/aws/core/utils/stream/StreamBufProtectedWriter.h new file mode 100644 index 00000000000..d4d8c3d9f3f --- /dev/null +++ b/src/aws-cpp-sdk-core/include/aws/core/utils/stream/StreamBufProtectedWriter.h @@ -0,0 +1,120 @@ + +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#pragma once + +#include +#include +#include +#include + +namespace Aws +{ + namespace Utils + { + namespace Stream + { + /** + * This is a wrapper to perform a hack to write directly to the put area of the underlying streambuf + */ + class AWS_CORE_API StreamBufProtectedWriter : public std::streambuf + { + public: + StreamBufProtectedWriter() = delete; + + using WriterFunc = std::function; + + static uint64_t WriteToBuffer(Aws::IOStream& ioStream, const WriterFunc& writerFunc) + { + uint64_t totalRead = 0; + + while (true) + { + StreamBufProtectedWriter* pBufferCasted = static_cast(ioStream.rdbuf()); + bool bufferPresent = pBufferCasted && pBufferCasted->pptr() && (pBufferCasted->pptr() < pBufferCasted->epptr()); + uint64_t read = 0; + bool success = false; + if (bufferPresent) + { + // have access to underlying put ptr. + success = WriteDirectlyToPtr(pBufferCasted, writerFunc, read); + } + else + { + // can't access underlying buffer, stream buffer maybe be customized to not use put ptr. + // or underlying put buffer is simply not initialized yet. + success = WriteWithHelperBuffer(ioStream, writerFunc, read); + } + totalRead += read; + if (!success) + { + break; + } + + if (pBufferCasted && pBufferCasted->pptr() && (pBufferCasted->pptr() >= pBufferCasted->epptr())) + { + if(!ForceOverflow(ioStream, writerFunc)) + { + break; + } else { + totalRead++; + } + } + } + return totalRead; + } + protected: + static bool ForceOverflow(Aws::IOStream& ioStream, const WriterFunc& writerFunc) + { + char dstChar; + uint64_t read = 0; + if (writerFunc(&dstChar, 1, read) && read > 0) + { + ioStream.write(&dstChar, 1); + if (ioStream.fail()) { + AWS_LOGSTREAM_ERROR("StreamBufProtectedWriter", "Failed to write 1 byte (eof: " + << ioStream.eof() << ", bad: " << ioStream.bad() << ")"); + return false; + } + return true; + } + return false; + } + + static uint64_t WriteWithHelperBuffer(Aws::IOStream& ioStream, const WriterFunc& writerFunc, uint64_t& read) + { + char tmpBuf[1024]; + uint64_t tmpBufSz = sizeof(tmpBuf); + + if(writerFunc(tmpBuf, tmpBufSz, read) && read > 0) + { + ioStream.write(tmpBuf, read); + if (ioStream.fail()) { + AWS_LOGSTREAM_ERROR("StreamBufProtectedWriter", "Failed to write " << tmpBufSz + << " (eof: " << ioStream.eof() << ", bad: " << ioStream.bad() << ")"); + return false; + } + return true; + } + return false; + } + + static uint64_t WriteDirectlyToPtr(StreamBufProtectedWriter* pBuffer, const WriterFunc& writerFunc, uint64_t& read) + { + auto dstBegin = pBuffer->pptr(); + uint64_t dstSz = pBuffer->epptr() - dstBegin; + if(writerFunc(dstBegin, dstSz, read) && read > 0) + { + assert(read <= dstSz); + pBuffer->pbump((int) read); + return true; + } + return false; + } + }; + } + } +} diff --git a/src/aws-cpp-sdk-core/source/http/curl/CurlHttpClient.cpp b/src/aws-cpp-sdk-core/source/http/curl/CurlHttpClient.cpp index fb7b702650a..2b1527207e3 100644 --- a/src/aws-cpp-sdk-core/source/http/curl/CurlHttpClient.cpp +++ b/src/aws-cpp-sdk-core/source/http/curl/CurlHttpClient.cpp @@ -209,7 +209,7 @@ static size_t WriteData(char* ptr, size_t size, size_t nmemb, void* userdata) return 0; } - size_t cur = response->GetResponseBody().tellp(); + auto cur = response->GetResponseBody().tellp(); if (response->GetResponseBody().fail()) { const auto& ref = response->GetResponseBody(); AWS_LOGSTREAM_ERROR(CURL_HTTP_CLIENT_TAG, "Unable to query response output position (eof: " @@ -302,7 +302,7 @@ static size_t ReadBody(char* ptr, size_t size, size_t nmemb, void* userdata, boo { if (!ioStream->eof() && ioStream->peek() != EOF) { - amountRead = ioStream->readsome(ptr, amountToRead); + amountRead = (size_t) ioStream->readsome(ptr, amountToRead); } if (amountRead == 0 && !ioStream->eof()) { diff --git a/src/aws-cpp-sdk-core/source/http/windows/WinHttpSyncHttpClient.cpp b/src/aws-cpp-sdk-core/source/http/windows/WinHttpSyncHttpClient.cpp index 7e47dbfebd1..09b45744960 100644 --- a/src/aws-cpp-sdk-core/source/http/windows/WinHttpSyncHttpClient.cpp +++ b/src/aws-cpp-sdk-core/source/http/windows/WinHttpSyncHttpClient.cpp @@ -578,6 +578,11 @@ bool WinHttpSyncHttpClient::DoSendRequest(void* hHttpRequest) const return (WinHttpSendRequest(hHttpRequest, NULL, NULL, 0, 0, 0, NULL) != 0); } +bool WinHttpSyncHttpClient::DoQueryDataAvailable(void* hHttpRequest, uint64_t& available) const +{ + return (WinHttpQueryDataAvailable(hHttpRequest, (LPDWORD)&available) != 0); +} + bool WinHttpSyncHttpClient::DoReadData(void* hHttpRequest, char* body, uint64_t size, uint64_t& read) const { return (WinHttpReadData(hHttpRequest, body, (DWORD)size, (LPDWORD)&read) != 0); diff --git a/src/aws-cpp-sdk-core/source/http/windows/WinINetSyncHttpClient.cpp b/src/aws-cpp-sdk-core/source/http/windows/WinINetSyncHttpClient.cpp index 0c192974a5d..e030ba11e8f 100644 --- a/src/aws-cpp-sdk-core/source/http/windows/WinINetSyncHttpClient.cpp +++ b/src/aws-cpp-sdk-core/source/http/windows/WinINetSyncHttpClient.cpp @@ -245,6 +245,11 @@ bool WinINetSyncHttpClient::DoSendRequest(void* hHttpRequest) const return (HttpSendRequestEx(hHttpRequest, NULL, NULL, 0, 0) != 0); } +bool WinINetSyncHttpClient::DoQueryDataAvailable(void* hHttpRequest, uint64_t& available) const +{ + return (InternetQueryDataAvailable(hHttpRequest, (LPDWORD)&available, /*reserved*/ 0, /*reserved*/ 0) != 0); +} + bool WinINetSyncHttpClient::DoReadData(void* hHttpRequest, char* body, uint64_t size, uint64_t& read) const { return (InternetReadFile(hHttpRequest, body, (DWORD)size, (LPDWORD)&read) != 0); diff --git a/src/aws-cpp-sdk-core/source/http/windows/WinSyncHttpClient.cpp b/src/aws-cpp-sdk-core/source/http/windows/WinSyncHttpClient.cpp index 1cec38b129b..70f32c2f043 100644 --- a/src/aws-cpp-sdk-core/source/http/windows/WinSyncHttpClient.cpp +++ b/src/aws-cpp-sdk-core/source/http/windows/WinSyncHttpClient.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -245,55 +246,104 @@ bool WinSyncHttpClient::BuildSuccessResponse(const std::shared_ptr& if (request->GetMethod() != HttpMethod::HTTP_HEAD) { - char body[1024]; - uint64_t bodySize = sizeof(body); - int64_t numBytesResponseReceived = 0; - read = 0; + if(!ContinueRequest(*request) || !IsRequestProcessingEnabled()) + { + response->SetClientErrorType(CoreErrors::USER_CANCELLED); + response->SetClientErrorMessage("Request processing disabled or continuation cancelled by user's continuation handler."); + response->SetResponseCode(Aws::Http::HttpResponseCode::NO_RESPONSE); + return false; + } - bool success = ContinueRequest(*request); + if (response->GetResponseBody().fail()) { + const auto& ref = response->GetResponseBody(); + AWS_LOGSTREAM_ERROR(GetLogTag(), "Response output stream is in a bad state (eof: " << ref.eof() << ", bad: " << ref.bad() << ")"); + response->SetClientErrorType(CoreErrors::NETWORK_CONNECTION); + response->SetClientErrorMessage("Response output stream is in a bad state."); + return false; + } - while (DoReadData(hHttpRequest, body, bodySize, read) && read > 0 && success) - { - response->GetResponseBody().write(body, read); - if (read > 0) - { - for (const auto& hashIterator : request->GetResponseValidationHashes()) - { - hashIterator.second->Update(reinterpret_cast(body), static_cast(read)); - } - numBytesResponseReceived += read; - if (readLimiter != nullptr) - { - readLimiter->ApplyAndPayForCost(read); - } - auto& receivedHandler = request->GetDataReceivedEventHandler(); - if (receivedHandler) + bool connectionOpen = true; + auto writerFunc = + [this, hHttpRequest, &request, readLimiter, &response, &connectionOpen](char* dst, uint64_t dstSz, uint64_t& read) -> bool { - receivedHandler(request.get(), response.get(), (long long)read); - } - } + bool success = true; + uint64_t available = 0; + connectionOpen = DoQueryDataAvailable(hHttpRequest, available); - success = success && ContinueRequest(*request) && IsRequestProcessingEnabled(); + if (connectionOpen && available) + { + dstSz = (std::min)(dstSz, available); + success = DoReadData(hHttpRequest, dst, dstSz, read); + if (success && read > 0) + { + for (const auto& hashIterator : request->GetResponseValidationHashes()) + { + hashIterator.second->Update(reinterpret_cast(dst), static_cast(read)); + } + if (readLimiter != nullptr) + { + readLimiter->ApplyAndPayForCost(read); + } + auto& receivedHandler = request->GetDataReceivedEventHandler(); + if (receivedHandler) + { + receivedHandler(request.get(), response.get(), (long long)read); + } + } + if (!ContinueRequest(*request) || !IsRequestProcessingEnabled()) + { + return false; + } + } + return connectionOpen && success && ContinueRequest(*request) && IsRequestProcessingEnabled(); + }; + uint64_t numBytesResponseReceived = Aws::Utils::Stream::StreamBufProtectedWriter::WriteToBuffer(response->GetResponseBody(), writerFunc); + + if(!ContinueRequest(*request) || !IsRequestProcessingEnabled()) + { + response->SetClientErrorType(CoreErrors::USER_CANCELLED); + response->SetClientErrorMessage("Request processing disabled or continuation cancelled by user's continuation handler."); + response->SetResponseCode(Aws::Http::HttpResponseCode::NO_RESPONSE); + return false; } - if (success && response->HasHeader(Aws::Http::CONTENT_LENGTH_HEADER)) + if (response->GetResponseBody().fail()) { + Aws::OStringStream errorMsgStr; + errorMsgStr << "Failed to write received response (eof: " + << response->GetResponseBody().eof() << ", bad: " << response->GetResponseBody().bad() << ")"; + + Aws::String errorMsg = errorMsgStr.str(); + AWS_LOGSTREAM_ERROR(GetLogTag(), errorMsg); + response->SetClientErrorType(CoreErrors::NETWORK_CONNECTION); + response->SetClientErrorMessage(errorMsg); + return false; + } + + if (request->IsEventStreamRequest() && !response->HasHeader(Aws::Http::X_AMZN_ERROR_TYPE)) + { + response->GetResponseBody().flush(); + if (response->GetResponseBody().fail()) { + const auto& ref = response->GetResponseBody(); + AWS_LOGSTREAM_ERROR(GetLogTag(), "Failed to flush event response (eof: " << ref.eof() << ", bad: " << ref.bad() << ")"); + response->SetClientErrorType(CoreErrors::NETWORK_CONNECTION); + response->SetClientErrorMessage("Failed to flush event stream event response"); + return false; + } + } + + if (response->HasHeader(Aws::Http::CONTENT_LENGTH_HEADER)) { const Aws::String& contentLength = response->GetHeader(Aws::Http::CONTENT_LENGTH_HEADER); AWS_LOGSTREAM_TRACE(GetLogTag(), "Response content-length header: " << contentLength); AWS_LOGSTREAM_TRACE(GetLogTag(), "Response body length: " << numBytesResponseReceived); - if (StringUtils::ConvertToInt64(contentLength.c_str()) != numBytesResponseReceived) + if ((uint64_t) StringUtils::ConvertToInt64(contentLength.c_str()) != numBytesResponseReceived) { - success = false; response->SetClientErrorType(CoreErrors::NETWORK_CONNECTION); response->SetClientErrorMessage("Response body length doesn't match the content-length header."); AWS_LOGSTREAM_ERROR(GetLogTag(), "Response body length doesn't match the content-length header."); + return false; } } - - if(!success) - { - return false; - } } //go ahead and flush the response body. diff --git a/tests/aws-cpp-sdk-logs-integration-tests/CloudWatchLogsTests.cpp b/tests/aws-cpp-sdk-logs-integration-tests/CloudWatchLogsTests.cpp index b830a6f1a04..aead872d6c6 100644 --- a/tests/aws-cpp-sdk-logs-integration-tests/CloudWatchLogsTests.cpp +++ b/tests/aws-cpp-sdk-logs-integration-tests/CloudWatchLogsTests.cpp @@ -69,6 +69,7 @@ namespace auto cognitoClient = Aws::MakeShared(ALLOCATION_TAG, config); Aws::AccessManagement::AccessManagementClient accessManagementClient(iamClient, cognitoClient); accountId = accessManagementClient.GetAccountId(); + assert(!accountId.empty()); // AccountId must be set for this test } m_accountId = accountId; } diff --git a/tests/aws-cpp-sdk-transcribestreaming-integ-tests/TranscribeTests.cpp b/tests/aws-cpp-sdk-transcribestreaming-integ-tests/TranscribeTests.cpp index 349704864a9..bad6b6b40a6 100644 --- a/tests/aws-cpp-sdk-transcribestreaming-integ-tests/TranscribeTests.cpp +++ b/tests/aws-cpp-sdk-transcribestreaming-integ-tests/TranscribeTests.cpp @@ -557,7 +557,7 @@ unsigned int LevenshteinDistance(Aws::String s1, Aws::String s2) TEST_F(TranscribeStreamingTests, TranscribeStreamingCppSdkSample) { const Aws::Vector EXPECTED_ALTERNATIVES = {"This is a C plus plus test sample", "This is a C++ test sample"}; - for(size_t chunkDuration = 50; chunkDuration <= 200; chunkDuration += 25) + for(size_t chunkDuration = 50; chunkDuration <= 200; chunkDuration += 50) { m_testTraces.clear(); TestTrace(Aws::String("### Starting TranscribeStreamingCppSdkSample with chunks of ") + Aws::Utils::StringUtils::to_string(chunkDuration) + " ms ##"); @@ -582,7 +582,7 @@ TEST_F(TranscribeStreamingTests, TranscribeStreamingKantSample) static const char expected[] = "Categorical imperative: Act only according to that maxim whereby you can at the same time will that it should become a universal law. " "Two things fill the mind with ever-increasing wonder and awe, the more often and the more intensely the mind of thought is drawn to them: " "the starry heavens above me and the moral law within me."; - for(size_t chunkDuration = 50; chunkDuration <= 200; chunkDuration += 25) + for(size_t chunkDuration = 50; chunkDuration <= 200; chunkDuration += 50) { m_testTraces.clear(); TestTrace(Aws::String("### Starting TranscribeStreamingKantSample with chunks of ") + Aws::Utils::StringUtils::to_string(chunkDuration) + " ms ##"); diff --git a/tests/benchmark/benchmark.cpp b/tests/benchmark/benchmark.cpp index f7dcffc97aa..cb929bfcbdd 100644 --- a/tests/benchmark/benchmark.cpp +++ b/tests/benchmark/benchmark.cpp @@ -10,6 +10,13 @@ #include int main(int argc, char *argv[]) { + if (1 == argc || + 2 == argc && (std::string(argv[1]) == "-h" || std::string(argv[1]) == "--help" )) + { + Benchmark::Configuration::PrintHelp(); + return 0; + } + Aws::SDKOptions options; Aws::InitAPI(options); { diff --git a/tests/benchmark/include/Configuration.h b/tests/benchmark/include/Configuration.h index d7c8c1f92f5..d47e5eb09ce 100644 --- a/tests/benchmark/include/Configuration.h +++ b/tests/benchmark/include/Configuration.h @@ -11,12 +11,14 @@ namespace Benchmark { std::string service; std::string api; long durationMillis; + size_t maxRepeats; bool shouldReportToCloudWatch; std::map dimensions; }; class Configuration { public: + static void PrintHelp(); static Configuration FromArgs(int argc, char *argv[]); inline RunConfiguration GetConfiguration() const { return this->runConfiguration; } private: diff --git a/tests/benchmark/src/Configuration.cpp b/tests/benchmark/src/Configuration.cpp index 19cf8dc098f..2b17001f7a4 100644 --- a/tests/benchmark/src/Configuration.cpp +++ b/tests/benchmark/src/Configuration.cpp @@ -7,15 +7,25 @@ #include #include #include +#include Benchmark::Configuration::Configuration(Benchmark::RunConfiguration runConfiguration) : runConfiguration(std::move(runConfiguration)) {} +void Benchmark::Configuration::PrintHelp() { + + std::cout << "usage: benchmark [--service SERVICE] [--api API] [--durationMillis DURATION_MS] [--withMetrics KEY1:VAL1,KEY2:VAL2]\n"; + std::cout << "\n"; + std::cout << "example: benchmark --service s3 --api PutObject --durationMillis 1000\n"; + std::cout << "example: benchmark --service s3 --api PutObject --durationMillis 1000 --dimensions BucketType:S3Express\n"; +} + Benchmark::Configuration Benchmark::Configuration::FromArgs(int argc, char *argv[]) { return Benchmark::Configuration({ Benchmark::Configuration::GetCmdOption(argv, argv + argc, "--service"), Benchmark::Configuration::GetCmdOption(argv, argv + argc, "--api"), std::stol(Benchmark::Configuration::GetCmdOption(argv, argv + argc, "--durationMillis")), + std::stoul(Benchmark::Configuration::GetCmdOption(argv, argv + argc, "--maxRepeats")), Benchmark::Configuration::CmdOptionExists(argv, argv + argc, "--withMetrics"), Benchmark::Configuration::GetCmdOptions(argv, argv + argc, "--dimensions") }); @@ -31,7 +41,7 @@ char *Benchmark::Configuration::GetCmdOption(char **begin, char **end, const std std::map Benchmark::Configuration::GetCmdOptions(char **begin, char **end, const std::string &option) { char **itr = std::find(begin, end, option); - auto value = ++itr; + auto value = itr != end ? ++itr : itr; if (itr != end && value != end) { //check to make sure the next entry is not another arg std::string nextArg(*value); diff --git a/tests/benchmark/src/service/S3GetObject.cpp b/tests/benchmark/src/service/S3GetObject.cpp index ed3d059d29b..cd59b69b265 100644 --- a/tests/benchmark/src/service/S3GetObject.cpp +++ b/tests/benchmark/src/service/S3GetObject.cpp @@ -40,7 +40,19 @@ Benchmark::TestFunction Benchmark::S3GetObject::CreateTestFunction() { metricsEmitter->EmitMetricForOp("CreateBucket", S3Utils::getMetricDimensions(dimensions, {{"Service", "S3"}, {"Operation", "CreateBucket"}}), [&]() -> bool { - auto response = s3->CreateBucket(CreateBucketRequest().WithBucket(bucketName)); + auto request = CreateBucketRequest() + .WithBucket(bucketName); + if (dimensions.find("BucketType") != dimensions.end() && dimensions.at("BucketType") == "S3Express") { + request.WithCreateBucketConfiguration(CreateBucketConfiguration() + .WithLocation(LocationInfo() + .WithType(LocationType::AvailabilityZone) + .WithName("use1-az2")) + .WithBucket(BucketInfo() + .WithType(BucketType::Directory) + .WithDataRedundancy(DataRedundancy::SingleAvailabilityZone))); + } + + auto response = s3->CreateBucket(request); if (!response.IsSuccess()) { std::cout << "Create Bucket Failed With: " << response.GetError().GetMessage() @@ -64,7 +76,7 @@ Benchmark::TestFunction Benchmark::S3GetObject::CreateTestFunction() { if (!response.IsSuccess()) { std::cout << "Put Object Failed With: " << response.GetError().GetMessage() - << "\n";; + << "\n"; } return response.IsSuccess(); }); @@ -72,8 +84,10 @@ Benchmark::TestFunction Benchmark::S3GetObject::CreateTestFunction() { // Run GetObject requests const auto timeToEnd = duration_cast(steady_clock::now().time_since_epoch()).count() + configuration.GetConfiguration().durationMillis; + size_t counter = 0; + size_t maxRepeats = configuration.GetConfiguration().maxRepeats; auto getObjectRequest = GetObjectRequest().WithBucket(bucketName).WithKey(testObjectKey); - while (duration_cast(steady_clock::now().time_since_epoch()).count() < timeToEnd) { ; + while (duration_cast(steady_clock::now().time_since_epoch()).count() < timeToEnd) { metricsEmitter->EmitMetricForOp( "GetObject", S3Utils::getMetricDimensions(dimensions, {{"Service", "S3"}, {"Operation", "GetObject"}}), @@ -82,10 +96,14 @@ Benchmark::TestFunction Benchmark::S3GetObject::CreateTestFunction() { if (!response.IsSuccess()) { std::cout << "Get Object Failed With: " << response.GetError().GetMessage() - << "\n";; + << "\n"; } return response.IsSuccess(); }); + counter++; + if (maxRepeats && counter == maxRepeats) { + break; + } } // Clean up diff --git a/tests/benchmark/src/service/S3PutObject.cpp b/tests/benchmark/src/service/S3PutObject.cpp index 12d3a21c257..b48df4219e3 100644 --- a/tests/benchmark/src/service/S3PutObject.cpp +++ b/tests/benchmark/src/service/S3PutObject.cpp @@ -42,7 +42,19 @@ Benchmark::TestFunction Benchmark::S3PutObject::CreateTestFunction() { metricsEmitter->EmitMetricForOp("CreateBucket", S3Utils::getMetricDimensions(dimensions, {{"Service", "S3"}, {"Operation", "CreateBucket"}}), [&]() -> bool { - auto response = s3->CreateBucket(CreateBucketRequest().WithBucket(bucketName)); + auto request = CreateBucketRequest() + .WithBucket(bucketName); + if (dimensions.find("BucketType") != dimensions.end() && dimensions.at("BucketType") == "S3Express") { + request.WithCreateBucketConfiguration(CreateBucketConfiguration() + .WithLocation(LocationInfo() + .WithType(LocationType::AvailabilityZone) + .WithName("use1-az2")) + .WithBucket(BucketInfo() + .WithType(BucketType::Directory) + .WithDataRedundancy(DataRedundancy::SingleAvailabilityZone))); + } + + auto response = s3->CreateBucket(request); if (!response.IsSuccess()) { std::cout << "Create Bucket Failed With: " << response.GetError().GetMessage() @@ -56,6 +68,8 @@ Benchmark::TestFunction Benchmark::S3PutObject::CreateTestFunction() { std::vector keysToDelete; const auto timeToEnd = duration_cast(steady_clock::now().time_since_epoch()).count() + configuration.GetConfiguration().durationMillis; + size_t counter = 0; + size_t maxRepeats = configuration.GetConfiguration().maxRepeats; while (duration_cast(steady_clock::now().time_since_epoch()).count() < timeToEnd) { auto key = RandomString(8); keysToDelete.push_back(key); @@ -71,10 +85,14 @@ Benchmark::TestFunction Benchmark::S3PutObject::CreateTestFunction() { if (!response.IsSuccess()) { std::cout << "Put Failed With: " << response.GetError().GetMessage() - << "\n";; + << "\n"; } return response.IsSuccess(); }); + counter++; + if (maxRepeats && counter == maxRepeats) { + break; + } } // Clean up