Skip to content

Commit

Permalink
Avoid excessive reply buffer copy in WinHTTP
Browse files Browse the repository at this point in the history
  • Loading branch information
SergeyRyabinin authored and Sergey Ryabinin committed May 15, 2024
1 parent 97fb759 commit 3b1f7a8
Show file tree
Hide file tree
Showing 9 changed files with 224 additions and 37 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@

/**
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0.
*/

#pragma once

#include <aws/core/Core_EXPORTS.h>
#include <aws/core/utils/Array.h>
#include <streambuf>
#include <functional>

namespace Aws
{
namespace Utils
{
namespace Stream
{
/**
* This is a wrapper to perform a hack to write directly to the put area of the underlying streambuf
*/
class AWS_CORE_API StreamBufProtectedWriter : public std::streambuf
{
public:
StreamBufProtectedWriter() = delete;

using WriterFunc = std::function<bool(char* dst, uint64_t dstSz, uint64_t& read)>;

static uint64_t WriteToBuffer(std::streambuf* pBuffer, const WriterFunc& writerFunc)
{
StreamBufProtectedWriter* pBufferCasted = static_cast<StreamBufProtectedWriter*>(pBuffer);
uint64_t totalRead = 0;

while (true)
{
bool bufferPresent = pBufferCasted->pptr() && (pBufferCasted->pptr() < pBufferCasted->epptr());
uint64_t read = 0;
if (bufferPresent)
{
// have access to underlying put ptr.
read = WriteDirectlyToPtr(pBufferCasted, writerFunc);
}
else
{
// can't access underlying buffer, stream buffer maybe be customized to not use put ptr.
// or underlying put buffer is simply not initialized yet.
read = WriteWithHelperBuffer(pBufferCasted, writerFunc);
}
if (!read)
{
return totalRead;
}
totalRead += read;

if (pBufferCasted->pptr() && (pBufferCasted->pptr() >= pBufferCasted->epptr()))
{
if(!ForceOverflow(pBufferCasted, writerFunc))
{
return totalRead;
} else {
totalRead++;
}
}
}
return totalRead;
}
protected:
static size_t ForceOverflow(StreamBufProtectedWriter* pBuffer, const WriterFunc& writerFunc)
{
char dstChar;
uint64_t read = 0;
if (writerFunc(&dstChar, 1, read) && read > 0)
{
pBuffer->overflow(dstChar);
return 1;
}
return 0;
}

static uint64_t WriteWithHelperBuffer(StreamBufProtectedWriter* pBuffer, const WriterFunc& writerFunc)
{
uint64_t read = 0;
char tmpBuf[1024];
uint64_t tmpBufSz = sizeof(tmpBuf);

if(writerFunc(tmpBuf, tmpBufSz, read) && read > 0)
{
auto dstPptrBefore = pBuffer->pptr();
auto dstEptrBefore = pBuffer->epptr();
auto dstSzBefore = pBuffer->epptr() - pBuffer->pptr();
auto actuallyWritten = pBuffer->sputn(tmpBuf, read);
if ((uint64_t) actuallyWritten != read)
{
std::cout << "Failed to write " << read << "\n";
std::cout << "Only wrote " << actuallyWritten << "\n";
std::cout << "pptr() = " << pBuffer->pptr() << "\n";
std::cout << "epptr() = " << pBuffer->epptr() << "\n";
std::cout << "dst size = " << pBuffer->epptr() - pBuffer->pptr() << "\n";
std::cout << "pptr() before = " << dstPptrBefore << "\n";
std::cout << "epptr() before = " << dstEptrBefore << "\n";
std::cout << "dst size before = " << dstSzBefore << "\n";
}
assert((uint64_t) actuallyWritten == read);
}

return read;
}

static uint64_t WriteDirectlyToPtr(StreamBufProtectedWriter* pBuffer, const WriterFunc& writerFunc)
{
auto dstBegin = pBuffer->pptr();
uint64_t dstSz = pBuffer->epptr() - dstBegin;
uint64_t read = 0;
if(writerFunc(dstBegin, dstSz, read) && read > 0)
{
assert(read <= dstSz);
pBuffer->pbump((int) read);
}

return read;
}
};
}
}
}
4 changes: 2 additions & 2 deletions src/aws-cpp-sdk-core/source/http/curl/CurlHttpClient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ static size_t WriteData(char* ptr, size_t size, size_t nmemb, void* userdata)
return 0;
}

size_t cur = response->GetResponseBody().tellp();
auto cur = response->GetResponseBody().tellp();
if (response->GetResponseBody().fail()) {
const auto& ref = response->GetResponseBody();
AWS_LOGSTREAM_ERROR(CURL_HTTP_CLIENT_TAG, "Unable to query response output position (eof: "
Expand Down Expand Up @@ -302,7 +302,7 @@ static size_t ReadBody(char* ptr, size_t size, size_t nmemb, void* userdata, boo
{
if (!ioStream->eof() && ioStream->peek() != EOF)
{
amountRead = ioStream->readsome(ptr, amountToRead);
amountRead = (size_t) ioStream->readsome(ptr, amountToRead);
}
if (amountRead == 0 && !ioStream->eof())
{
Expand Down
48 changes: 24 additions & 24 deletions src/aws-cpp-sdk-core/source/http/windows/WinSyncHttpClient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <aws/core/utils/StringUtils.h>
#include <aws/core/utils/HashingUtils.h>
#include <aws/core/utils/logging/LogMacros.h>
#include <aws/core/utils/stream/StreamBufProtectedWriter.h>
#include <aws/core/client/ClientConfiguration.h>
#include <aws/core/http/windows/WinConnectionPoolMgr.h>
#include <aws/core/utils/memory/AWSMemory.h>
Expand Down Expand Up @@ -245,43 +246,42 @@ bool WinSyncHttpClient::BuildSuccessResponse(const std::shared_ptr<HttpRequest>&

if (request->GetMethod() != HttpMethod::HTTP_HEAD)
{
char body[1024];
uint64_t bodySize = sizeof(body);
int64_t numBytesResponseReceived = 0;
read = 0;

bool success = ContinueRequest(*request);

while (DoReadData(hHttpRequest, body, bodySize, read) && read > 0 && success)
{
response->GetResponseBody().write(body, read);
if (read > 0)
auto writerFunc =
[this, hHttpRequest, &request, readLimiter, &response](char* dst, uint64_t dstSz, uint64_t& read) -> bool{
bool success = DoReadData(hHttpRequest, dst, dstSz, read);
if (success)
{
for (const auto& hashIterator : request->GetResponseValidationHashes())
{
hashIterator.second->Update(reinterpret_cast<unsigned char*>(body), static_cast<size_t>(read));
}
numBytesResponseReceived += read;
if (readLimiter != nullptr)
if (read > 0)
{
readLimiter->ApplyAndPayForCost(read);
}
auto& receivedHandler = request->GetDataReceivedEventHandler();
if (receivedHandler)
{
receivedHandler(request.get(), response.get(), (long long)read);
for (const auto& hashIterator : request->GetResponseValidationHashes())
{
hashIterator.second->Update(reinterpret_cast<unsigned char*>(dst), static_cast<size_t>(read));
}
if (readLimiter != nullptr)
{
readLimiter->ApplyAndPayForCost(read);
}
auto& receivedHandler = request->GetDataReceivedEventHandler();
if (receivedHandler)
{
receivedHandler(request.get(), response.get(), (long long)read);
}
}
}
return success;
};
uint64_t numBytesResponseReceived = Aws::Utils::Stream::StreamBufProtectedWriter::WriteToBuffer(response->GetResponseBody().rdbuf(), writerFunc);

success = success && ContinueRequest(*request) && IsRequestProcessingEnabled();
}
success = success && ContinueRequest(*request) && IsRequestProcessingEnabled();

if (success && response->HasHeader(Aws::Http::CONTENT_LENGTH_HEADER))
{
const Aws::String& contentLength = response->GetHeader(Aws::Http::CONTENT_LENGTH_HEADER);
AWS_LOGSTREAM_TRACE(GetLogTag(), "Response content-length header: " << contentLength);
AWS_LOGSTREAM_TRACE(GetLogTag(), "Response body length: " << numBytesResponseReceived);
if (StringUtils::ConvertToInt64(contentLength.c_str()) != numBytesResponseReceived)
if ((uint64_t) StringUtils::ConvertToInt64(contentLength.c_str()) != numBytesResponseReceived)
{
success = false;
response->SetClientErrorType(CoreErrors::NETWORK_CONNECTION);
Expand Down
7 changes: 7 additions & 0 deletions tests/benchmark/benchmark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@
#include <metric/CloudWatchMetrics.h>

int main(int argc, char *argv[]) {
if (1 == argc ||
2 == argc && (std::string(argv[1]) == "-h" || std::string(argv[1]) == "--help" ))
{
Benchmark::Configuration::PrintHelp();
return 0;
}

Aws::SDKOptions options;
Aws::InitAPI(options);
{
Expand Down
2 changes: 2 additions & 0 deletions tests/benchmark/include/Configuration.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@ namespace Benchmark {
std::string service;
std::string api;
long durationMillis;
size_t maxRepeats;
bool shouldReportToCloudWatch;
std::map<std::string, std::string> dimensions;
};

class Configuration {
public:
static void PrintHelp();
static Configuration FromArgs(int argc, char *argv[]);
inline RunConfiguration GetConfiguration() const { return this->runConfiguration; }
private:
Expand Down
12 changes: 11 additions & 1 deletion tests/benchmark/src/Configuration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,25 @@
#include <algorithm>
#include <utility>
#include <vector>
#include <iostream>

Benchmark::Configuration::Configuration(Benchmark::RunConfiguration runConfiguration) :
runConfiguration(std::move(runConfiguration)) {}

void Benchmark::Configuration::PrintHelp() {

std::cout << "usage: benchmark [--service SERVICE] [--api API] [--durationMillis DURATION_MS] [--withMetrics KEY1:VAL1,KEY2:VAL2]\n";
std::cout << "\n";
std::cout << "example: benchmark --service s3 --api PutObject --durationMillis 1000\n";
std::cout << "example: benchmark --service s3 --api PutObject --durationMillis 1000 --dimensions BucketType:S3Express\n";
}

Benchmark::Configuration Benchmark::Configuration::FromArgs(int argc, char *argv[]) {
return Benchmark::Configuration({
Benchmark::Configuration::GetCmdOption(argv, argv + argc, "--service"),
Benchmark::Configuration::GetCmdOption(argv, argv + argc, "--api"),
std::stol(Benchmark::Configuration::GetCmdOption(argv, argv + argc, "--durationMillis")),
std::stoul(Benchmark::Configuration::GetCmdOption(argv, argv + argc, "--maxRepeats")),
Benchmark::Configuration::CmdOptionExists(argv, argv + argc, "--withMetrics"),
Benchmark::Configuration::GetCmdOptions(argv, argv + argc, "--dimensions")
});
Expand All @@ -31,7 +41,7 @@ char *Benchmark::Configuration::GetCmdOption(char **begin, char **end, const std

std::map<std::string, std::string> Benchmark::Configuration::GetCmdOptions(char **begin, char **end, const std::string &option) {
char **itr = std::find(begin, end, option);
auto value = ++itr;
auto value = itr != end ? ++itr : itr;
if (itr != end && value != end) {
//check to make sure the next entry is not another arg
std::string nextArg(*value);
Expand Down
26 changes: 22 additions & 4 deletions tests/benchmark/src/service/S3GetObject.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,19 @@ Benchmark::TestFunction Benchmark::S3GetObject::CreateTestFunction() {
metricsEmitter->EmitMetricForOp("CreateBucket",
S3Utils::getMetricDimensions(dimensions, {{"Service", "S3"}, {"Operation", "CreateBucket"}}),
[&]() -> bool {
auto response = s3->CreateBucket(CreateBucketRequest().WithBucket(bucketName));
auto request = CreateBucketRequest()
.WithBucket(bucketName);
if (dimensions.find("BucketType") != dimensions.end() && dimensions.at("BucketType") == "S3Express") {
request.WithCreateBucketConfiguration(CreateBucketConfiguration()
.WithLocation(LocationInfo()
.WithType(LocationType::AvailabilityZone)
.WithName("use1-az2"))
.WithBucket(BucketInfo()
.WithType(BucketType::Directory)
.WithDataRedundancy(DataRedundancy::SingleAvailabilityZone)));
}

auto response = s3->CreateBucket(request);
if (!response.IsSuccess()) {
std::cout << "Create Bucket Failed With: "
<< response.GetError().GetMessage()
Expand All @@ -64,16 +76,18 @@ Benchmark::TestFunction Benchmark::S3GetObject::CreateTestFunction() {
if (!response.IsSuccess()) {
std::cout << "Put Object Failed With: "
<< response.GetError().GetMessage()
<< "\n";;
<< "\n";
}
return response.IsSuccess();
});

// Run GetObject requests
const auto timeToEnd = duration_cast<milliseconds>(steady_clock::now().time_since_epoch()).count() +
configuration.GetConfiguration().durationMillis;
size_t counter = 0;
size_t maxRepeats = configuration.GetConfiguration().maxRepeats;
auto getObjectRequest = GetObjectRequest().WithBucket(bucketName).WithKey(testObjectKey);
while (duration_cast<milliseconds>(steady_clock::now().time_since_epoch()).count() < timeToEnd) { ;
while (duration_cast<milliseconds>(steady_clock::now().time_since_epoch()).count() < timeToEnd) {
metricsEmitter->EmitMetricForOp(
"GetObject",
S3Utils::getMetricDimensions(dimensions, {{"Service", "S3"}, {"Operation", "GetObject"}}),
Expand All @@ -82,10 +96,14 @@ Benchmark::TestFunction Benchmark::S3GetObject::CreateTestFunction() {
if (!response.IsSuccess()) {
std::cout << "Get Object Failed With: "
<< response.GetError().GetMessage()
<< "\n";;
<< "\n";
}
return response.IsSuccess();
});
counter++;
if (maxRepeats && counter == maxRepeats) {
break;
}
}

// Clean up
Expand Down
22 changes: 20 additions & 2 deletions tests/benchmark/src/service/S3PutObject.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,19 @@ Benchmark::TestFunction Benchmark::S3PutObject::CreateTestFunction() {
metricsEmitter->EmitMetricForOp("CreateBucket",
S3Utils::getMetricDimensions(dimensions, {{"Service", "S3"}, {"Operation", "CreateBucket"}}),
[&]() -> bool {
auto response = s3->CreateBucket(CreateBucketRequest().WithBucket(bucketName));
auto request = CreateBucketRequest()
.WithBucket(bucketName);
if (dimensions.find("BucketType") != dimensions.end() && dimensions.at("BucketType") == "S3Express") {
request.WithCreateBucketConfiguration(CreateBucketConfiguration()
.WithLocation(LocationInfo()
.WithType(LocationType::AvailabilityZone)
.WithName("use1-az2"))
.WithBucket(BucketInfo()
.WithType(BucketType::Directory)
.WithDataRedundancy(DataRedundancy::SingleAvailabilityZone)));
}

auto response = s3->CreateBucket(request);
if (!response.IsSuccess()) {
std::cout << "Create Bucket Failed With: "
<< response.GetError().GetMessage()
Expand All @@ -56,6 +68,8 @@ Benchmark::TestFunction Benchmark::S3PutObject::CreateTestFunction() {
std::vector<std::string> keysToDelete;
const auto timeToEnd = duration_cast<milliseconds>(steady_clock::now().time_since_epoch()).count() +
configuration.GetConfiguration().durationMillis;
size_t counter = 0;
size_t maxRepeats = configuration.GetConfiguration().maxRepeats;
while (duration_cast<milliseconds>(steady_clock::now().time_since_epoch()).count() < timeToEnd) {
auto key = RandomString(8);
keysToDelete.push_back(key);
Expand All @@ -71,10 +85,14 @@ Benchmark::TestFunction Benchmark::S3PutObject::CreateTestFunction() {
if (!response.IsSuccess()) {
std::cout << "Put Failed With: "
<< response.GetError().GetMessage()
<< "\n";;
<< "\n";
}
return response.IsSuccess();
});
counter++;
if (maxRepeats && counter == maxRepeats) {
break;
}
}

// Clean up
Expand Down
Loading

0 comments on commit 3b1f7a8

Please sign in to comment.