Skip to content

Commit

Permalink
Avoid excessive reply buffer copy in WinHTTP using a hack
Browse files Browse the repository at this point in the history
  • Loading branch information
Sergey Ryabinin committed May 10, 2024
1 parent 2fedad7 commit bae5120
Show file tree
Hide file tree
Showing 8 changed files with 189 additions and 10 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@

/**
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0.
*/

#pragma once

#include <aws/core/Core_EXPORTS.h>
#include <aws/core/utils/Array.h>
#include <streambuf>
#include <functional>

namespace Aws
{
namespace Utils
{
namespace Stream
{
/**
* This is a wrapper to perform a hack to write directly to the put area of the underlying streambuf
*/
class AWS_CORE_API StreamBufProtectedWriter : public std::streambuf
{
public:
StreamBufProtectedWriter() = delete;

using WriterFunc = std::function<bool(char* dst, uint64_t dstSz, uint64_t& read)>;

static size_t ForceOverflow(StreamBufProtectedWriter* pBuffer, const WriterFunc& writerFunc)
{
char dstChar;
size_t read = 0;
if (writerFunc(&dstChar, 1, read) && read > 0)
{
pBuffer->overflow(dstChar);
return 1;
}
return 0;
}

static size_t WriteToBuffer(std::streambuf* pBuffer, const WriterFunc& writerFunc)
{
StreamBufProtectedWriter* pBufferCasted = static_cast<StreamBufProtectedWriter*>(pBuffer);
auto dstBegin = pBufferCasted->pptr();
auto dstSz = pBufferCasted->epptr() - dstBegin;
size_t totalRead = 0;
uint64_t read = 0;
if (dstSz == 0)
{
// prime the initial buffer
char tmpBuf[1024];
uint64_t tmpBufSz = sizeof(tmpBuf);
if (writerFunc(tmpBuf, tmpBufSz, read) && read > 0)
{
pBufferCasted->xsputn(tmpBuf, read);
totalRead += read;
if (pBufferCasted->pptr() >= pBufferCasted->epptr())
{
if(!ForceOverflow(pBufferCasted, writerFunc))
{
return totalRead;
} else {
totalRead++;
}
}
} else {
return 0;
}

dstBegin = pBufferCasted->pptr();
assert(pBufferCasted->epptr() > dstBegin);
dstSz = pBufferCasted->epptr() - dstBegin;
assert(dstSz);
}

read = 0;
while (writerFunc(dstBegin, dstSz, read) && read > 0)
{
totalRead += read;
pBufferCasted->pbump((int) read);
if (pBufferCasted->pptr() >= pBufferCasted->epptr())
{
if(!ForceOverflow(pBufferCasted, writerFunc))
{
return totalRead;
} else {
totalRead++;
}
}
dstBegin = pBufferCasted->pptr();
assert(pBufferCasted->epptr() > dstBegin);
dstSz = pBufferCasted->epptr() - dstBegin;
assert(dstSz);
}
return totalRead;
}
};
}
}
}
4 changes: 2 additions & 2 deletions src/aws-cpp-sdk-core/source/http/curl/CurlHttpClient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ static size_t WriteData(char* ptr, size_t size, size_t nmemb, void* userdata)
return 0;
}

size_t cur = response->GetResponseBody().tellp();
auto cur = response->GetResponseBody().tellp();
if (response->GetResponseBody().fail()) {
const auto& ref = response->GetResponseBody();
AWS_LOGSTREAM_ERROR(CURL_HTTP_CLIENT_TAG, "Unable to query response output position (eof: "
Expand Down Expand Up @@ -302,7 +302,7 @@ static size_t ReadBody(char* ptr, size_t size, size_t nmemb, void* userdata, boo
{
if (!ioStream->eof() && ioStream->peek() != EOF)
{
amountRead = ioStream->readsome(ptr, amountToRead);
amountRead = (size_t) ioStream->readsome(ptr, amountToRead);
}
if (amountRead == 0 && !ioStream->eof())
{
Expand Down
36 changes: 34 additions & 2 deletions src/aws-cpp-sdk-core/source/http/windows/WinSyncHttpClient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <aws/core/utils/StringUtils.h>
#include <aws/core/utils/HashingUtils.h>
#include <aws/core/utils/logging/LogMacros.h>
#include <aws/core/utils/stream/StreamBufProtectedWriter.h>
#include <aws/core/client/ClientConfiguration.h>
#include <aws/core/http/windows/WinConnectionPoolMgr.h>
#include <aws/core/utils/memory/AWSMemory.h>
Expand Down Expand Up @@ -245,13 +246,42 @@ bool WinSyncHttpClient::BuildSuccessResponse(const std::shared_ptr<HttpRequest>&

if (request->GetMethod() != HttpMethod::HTTP_HEAD)
{
char body[1024];
uint64_t bodySize = sizeof(body);
// char body[1024];
// uint64_t bodySize = sizeof(body);
int64_t numBytesResponseReceived = 0;
read = 0;

bool success = ContinueRequest(*request);

auto writerFunc =
[this, hHttpRequest, &request, readLimiter, &response](char* dst, uint64_t dstSz, uint64_t& read) -> bool{
bool success = DoReadData(hHttpRequest, dst, dstSz, read);
if (success)
{
if (read > 0)
{
for (const auto& hashIterator : request->GetResponseValidationHashes())
{
hashIterator.second->Update(reinterpret_cast<unsigned char*>(dst), static_cast<size_t>(read));
}
if (readLimiter != nullptr)
{
readLimiter->ApplyAndPayForCost(read);
}
auto& receivedHandler = request->GetDataReceivedEventHandler();
if (receivedHandler)
{
receivedHandler(request.get(), response.get(), (long long)read);
}
}
}
return success;
};
numBytesResponseReceived = Aws::Utils::Stream::StreamBufProtectedWriter::WriteToBuffer(response->GetResponseBody().rdbuf(), writerFunc);

success = success && ContinueRequest(*request) && IsRequestProcessingEnabled();

#if 0
while (DoReadData(hHttpRequest, body, bodySize, read) && read > 0 && success)
{
response->GetResponseBody().write(body, read);
Expand All @@ -275,6 +305,8 @@ bool WinSyncHttpClient::BuildSuccessResponse(const std::shared_ptr<HttpRequest>&

success = success && ContinueRequest(*request) && IsRequestProcessingEnabled();
}
#endif


if (success && response->HasHeader(Aws::Http::CONTENT_LENGTH_HEADER))
{
Expand Down
7 changes: 7 additions & 0 deletions tests/benchmark/benchmark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@
#include <metric/CloudWatchMetrics.h>

int main(int argc, char *argv[]) {
if (1 == argc ||
2 == argc && (std::string(argv[1]) == "-h" || std::string(argv[1]) == "--help" ))
{
Benchmark::Configuration::PrintHelp();
return 0;
}

Aws::SDKOptions options;
Aws::InitAPI(options);
{
Expand Down
2 changes: 2 additions & 0 deletions tests/benchmark/include/Configuration.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@ namespace Benchmark {
std::string service;
std::string api;
long durationMillis;
size_t maxRepeats;
bool shouldReportToCloudWatch;
std::map<std::string, std::string> dimensions;
};

class Configuration {
public:
static void PrintHelp();
static Configuration FromArgs(int argc, char *argv[]);
inline RunConfiguration GetConfiguration() const { return this->runConfiguration; }
private:
Expand Down
12 changes: 11 additions & 1 deletion tests/benchmark/src/Configuration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,25 @@
#include <algorithm>
#include <utility>
#include <vector>
#include <iostream>

Benchmark::Configuration::Configuration(Benchmark::RunConfiguration runConfiguration) :
runConfiguration(std::move(runConfiguration)) {}

void Benchmark::Configuration::PrintHelp() {

std::cout << "usage: benchmark [--service SERVICE] [--api API] [--durationMillis DURATION_MS] [--withMetrics KEY1:VAL1,KEY2:VAL2]\n";
std::cout << "\n";
std::cout << "example: benchmark --service s3 --api PutObject --durationMillis 1000\n";
std::cout << "example: benchmark --service s3 --api PutObject --durationMillis 1000 --dimensions BucketType:S3Express\n";
}

Benchmark::Configuration Benchmark::Configuration::FromArgs(int argc, char *argv[]) {
return Benchmark::Configuration({
Benchmark::Configuration::GetCmdOption(argv, argv + argc, "--service"),
Benchmark::Configuration::GetCmdOption(argv, argv + argc, "--api"),
std::stol(Benchmark::Configuration::GetCmdOption(argv, argv + argc, "--durationMillis")),
std::stoul(Benchmark::Configuration::GetCmdOption(argv, argv + argc, "--maxRepeats")),
Benchmark::Configuration::CmdOptionExists(argv, argv + argc, "--withMetrics"),
Benchmark::Configuration::GetCmdOptions(argv, argv + argc, "--dimensions")
});
Expand All @@ -31,7 +41,7 @@ char *Benchmark::Configuration::GetCmdOption(char **begin, char **end, const std

std::map<std::string, std::string> Benchmark::Configuration::GetCmdOptions(char **begin, char **end, const std::string &option) {
char **itr = std::find(begin, end, option);
auto value = ++itr;
auto value = itr != end ? ++itr : itr;
if (itr != end && value != end) {
//check to make sure the next entry is not another arg
std::string nextArg(*value);
Expand Down
29 changes: 25 additions & 4 deletions tests/benchmark/src/service/S3GetObject.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,20 @@ Benchmark::TestFunction Benchmark::S3GetObject::CreateTestFunction() {
metricsEmitter->EmitMetricForOp("CreateBucket",
S3Utils::getMetricDimensions(dimensions, {{"Service", "S3"}, {"Operation", "CreateBucket"}}),
[&]() -> bool {
auto response = s3->CreateBucket(CreateBucketRequest().WithBucket(bucketName));
auto request = CreateBucketRequest()
.WithBucket(bucketName);
if (dimensions.find("BucketType") != dimensions.end() && dimensions.at("BucketType") == "S3Express") {
request.WithCreateBucketConfiguration(CreateBucketConfiguration()
.WithLocation(LocationInfo()
.WithType(LocationType::AvailabilityZone)
.WithName("use1-az6"))
.WithBucket(BucketInfo()
.WithType(BucketType::Directory)
.WithDataRedundancy(DataRedundancy::SingleAvailabilityZone)));
}


auto response = s3->CreateBucket(request);
if (!response.IsSuccess()) {
std::cout << "Create Bucket Failed With: "
<< response.GetError().GetMessage()
Expand All @@ -49,6 +62,8 @@ Benchmark::TestFunction Benchmark::S3GetObject::CreateTestFunction() {
return response.IsSuccess();
});

std::this_thread::sleep_for(std::chrono::seconds(2));

// Setup object to get
const auto testObjectKey = "BenchmarkTestObjectKey";
const auto randomBody64K = RandomString(64000);
Expand All @@ -64,16 +79,18 @@ Benchmark::TestFunction Benchmark::S3GetObject::CreateTestFunction() {
if (!response.IsSuccess()) {
std::cout << "Put Object Failed With: "
<< response.GetError().GetMessage()
<< "\n";;
<< "\n";
}
return response.IsSuccess();
});

// Run GetObject requests
const auto timeToEnd = duration_cast<milliseconds>(steady_clock::now().time_since_epoch()).count() +
configuration.GetConfiguration().durationMillis;
size_t counter = 0;
size_t maxRepeats = configuration.GetConfiguration().maxRepeats;
auto getObjectRequest = GetObjectRequest().WithBucket(bucketName).WithKey(testObjectKey);
while (duration_cast<milliseconds>(steady_clock::now().time_since_epoch()).count() < timeToEnd) { ;
while (duration_cast<milliseconds>(steady_clock::now().time_since_epoch()).count() < timeToEnd) {
metricsEmitter->EmitMetricForOp(
"GetObject",
S3Utils::getMetricDimensions(dimensions, {{"Service", "S3"}, {"Operation", "GetObject"}}),
Expand All @@ -82,10 +99,14 @@ Benchmark::TestFunction Benchmark::S3GetObject::CreateTestFunction() {
if (!response.IsSuccess()) {
std::cout << "Get Object Failed With: "
<< response.GetError().GetMessage()
<< "\n";;
<< "\n";
}
return response.IsSuccess();
});
counter++;
if (maxRepeats && counter == maxRepeats) {
break;
}
}

// Clean up
Expand Down
8 changes: 7 additions & 1 deletion tests/benchmark/src/service/S3PutObject.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ Benchmark::TestFunction Benchmark::S3PutObject::CreateTestFunction() {
std::vector<std::string> keysToDelete;
const auto timeToEnd = duration_cast<milliseconds>(steady_clock::now().time_since_epoch()).count() +
configuration.GetConfiguration().durationMillis;
size_t counter = 0;
size_t maxRepeats = configuration.GetConfiguration().maxRepeats;
while (duration_cast<milliseconds>(steady_clock::now().time_since_epoch()).count() < timeToEnd) {
auto key = RandomString(8);
keysToDelete.push_back(key);
Expand All @@ -71,10 +73,14 @@ Benchmark::TestFunction Benchmark::S3PutObject::CreateTestFunction() {
if (!response.IsSuccess()) {
std::cout << "Put Failed With: "
<< response.GetError().GetMessage()
<< "\n";;
<< "\n";
}
return response.IsSuccess();
});
counter++;
if (maxRepeats && counter == maxRepeats) {
break;
}
}

// Clean up
Expand Down

0 comments on commit bae5120

Please sign in to comment.