Skip to content

Commit

Permalink
Avoid excessive reply buffer copy in WinHTTP
Browse files Browse the repository at this point in the history
  • Loading branch information
SergeyRyabinin authored and Sergey Ryabinin committed May 21, 2024
1 parent 2c07e1b commit 911c421
Show file tree
Hide file tree
Showing 15 changed files with 292 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ namespace Aws
bool DoReceiveResponse(void* httpRequest) const override;
bool DoQueryHeaders(void* httpRequest, std::shared_ptr<Aws::Http::HttpResponse>& response, Aws::StringStream& ss, uint64_t& read) const override;
bool DoSendRequest(void* httpRequest) const override;
bool DoQueryDataAvailable(void* hHttpRequest, uint64_t& available) const override;
bool DoReadData(void* hHttpRequest, char* body, uint64_t size, uint64_t& read) const override;
void* GetClientModule() const override;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ namespace Aws
bool DoReceiveResponse(void* hHttpRequest) const override;
bool DoQueryHeaders(void* hHttpRequest, std::shared_ptr<Aws::Http::HttpResponse>& response, Aws::StringStream& ss, uint64_t& read) const override;
bool DoSendRequest(void* hHttpRequest) const override;
bool DoQueryDataAvailable(void* hHttpRequest, uint64_t& available) const override;
bool DoReadData(void* hHttpRequest, char* body, uint64_t size, uint64_t& read) const override;
void* GetClientModule() const override;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ namespace Aws
virtual bool DoReceiveResponse(void* hHttpRequest) const = 0;
virtual bool DoQueryHeaders(void* hHttpRequest, std::shared_ptr<Aws::Http::HttpResponse>& response, Aws::StringStream& ss, uint64_t& read) const = 0;
virtual bool DoSendRequest(void* hHttpRequest) const = 0;
virtual bool DoQueryDataAvailable(void* hHttpRequest, uint64_t& available) const = 0;
virtual bool DoReadData(void* hHttpRequest, char* body, uint64_t size, uint64_t& read) const = 0;
virtual void* GetClientModule() const = 0;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@

/**
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0.
*/

#pragma once

#include <aws/core/Core_EXPORTS.h>
#include <aws/core/utils/Array.h>
#include <streambuf>
#include <functional>

namespace Aws
{
namespace Utils
{
namespace Stream
{
/**
* This is a wrapper to perform a hack to write directly to the put area of the underlying streambuf
*/
class AWS_CORE_API StreamBufProtectedWriter : public std::streambuf
{
public:
StreamBufProtectedWriter() = delete;

using WriterFunc = std::function<bool(char* dst, uint64_t dstSz, uint64_t& read)>;

static uint64_t WriteToBuffer(Aws::IOStream& ioStream, const WriterFunc& writerFunc)
{
uint64_t totalRead = 0;

while (true)
{
StreamBufProtectedWriter* pBufferCasted = static_cast<StreamBufProtectedWriter*>(ioStream.rdbuf());
bool bufferPresent = pBufferCasted && pBufferCasted->pptr() && (pBufferCasted->pptr() < pBufferCasted->epptr());
uint64_t read = 0;
bool success = false;
if (bufferPresent)
{
// have access to underlying put ptr.
success = WriteDirectlyToPtr(pBufferCasted, writerFunc, read);
}
else
{
// can't access underlying buffer, stream buffer maybe be customized to not use put ptr.
// or underlying put buffer is simply not initialized yet.
success = WriteWithHelperBuffer(ioStream, writerFunc, read);
}
totalRead += read;
if (!success)
{
return totalRead;
}

if (pBufferCasted && pBufferCasted->pptr() && (pBufferCasted->pptr() >= pBufferCasted->epptr()))
{
if(!ForceOverflow(ioStream, writerFunc))
{
return totalRead;
} else {
totalRead++;
}
}
}
return totalRead;
}
protected:
static bool ForceOverflow(Aws::IOStream& ioStream, const WriterFunc& writerFunc)
{
char dstChar;
uint64_t read = 0;
if (writerFunc(&dstChar, 1, read) && read > 0)
{
ioStream.write(&dstChar, 1);
if (ioStream.fail()) {
AWS_LOGSTREAM_ERROR("StreamBufProtectedWriter", "Failed to write 1 byte (eof: "
<< ioStream.eof() << ", bad: " << ioStream.bad() << ")");
return 0;
}
return 1;
}
return false;
}

static uint64_t WriteWithHelperBuffer(Aws::IOStream& ioStream, const WriterFunc& writerFunc, uint64_t& read)
{
char tmpBuf[1024];
uint64_t tmpBufSz = sizeof(tmpBuf);

if(writerFunc(tmpBuf, tmpBufSz, read) && read > 0)
{
ioStream.write(tmpBuf, read);
if (ioStream.fail()) {
AWS_LOGSTREAM_ERROR("StreamBufProtectedWriter", "Failed to write " << tmpBufSz
<< " (eof: " << ioStream.eof() << ", bad: " << ioStream.bad() << ")");
return false;
}
return true;
}
return false;
}

static uint64_t WriteDirectlyToPtr(StreamBufProtectedWriter* pBuffer, const WriterFunc& writerFunc, uint64_t& read)
{
auto dstBegin = pBuffer->pptr();
uint64_t dstSz = pBuffer->epptr() - dstBegin;
std::cout << "dst size = " << dstSz << "\n";
if(writerFunc(dstBegin, dstSz, read) && read > 0)
{
assert(read <= dstSz);
pBuffer->pbump((int) read);
return true;
}
return false;
}
};
}
}
}
4 changes: 2 additions & 2 deletions src/aws-cpp-sdk-core/source/http/curl/CurlHttpClient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ static size_t WriteData(char* ptr, size_t size, size_t nmemb, void* userdata)
return 0;
}

size_t cur = response->GetResponseBody().tellp();
auto cur = response->GetResponseBody().tellp();
if (response->GetResponseBody().fail()) {
const auto& ref = response->GetResponseBody();
AWS_LOGSTREAM_ERROR(CURL_HTTP_CLIENT_TAG, "Unable to query response output position (eof: "
Expand Down Expand Up @@ -302,7 +302,7 @@ static size_t ReadBody(char* ptr, size_t size, size_t nmemb, void* userdata, boo
{
if (!ioStream->eof() && ioStream->peek() != EOF)
{
amountRead = ioStream->readsome(ptr, amountToRead);
amountRead = (size_t) ioStream->readsome(ptr, amountToRead);
}
if (amountRead == 0 && !ioStream->eof())
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,11 @@ bool WinHttpSyncHttpClient::DoSendRequest(void* hHttpRequest) const
return (WinHttpSendRequest(hHttpRequest, NULL, NULL, 0, 0, 0, NULL) != 0);
}

bool WinHttpSyncHttpClient::DoQueryDataAvailable(void* hHttpRequest, uint64_t& available) const
{
return (WinHttpQueryDataAvailable(hHttpRequest, (LPDWORD)&available) != 0);
}

bool WinHttpSyncHttpClient::DoReadData(void* hHttpRequest, char* body, uint64_t size, uint64_t& read) const
{
return (WinHttpReadData(hHttpRequest, body, (DWORD)size, (LPDWORD)&read) != 0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,11 @@ bool WinINetSyncHttpClient::DoSendRequest(void* hHttpRequest) const
return (HttpSendRequestEx(hHttpRequest, NULL, NULL, 0, 0) != 0);
}

bool WinINetSyncHttpClient::DoQueryDataAvailable(void* hHttpRequest, uint64_t& available) const
{
return (InternetQueryDataAvailable(hHttpRequest, (LPDWORD)&available, /*reserved*/ 0, /*reserved*/ 0) != 0);
}

bool WinINetSyncHttpClient::DoReadData(void* hHttpRequest, char* body, uint64_t size, uint64_t& read) const
{
return (InternetReadFile(hHttpRequest, body, (DWORD)size, (LPDWORD)&read) != 0);
Expand Down
116 changes: 83 additions & 33 deletions src/aws-cpp-sdk-core/source/http/windows/WinSyncHttpClient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <aws/core/utils/StringUtils.h>
#include <aws/core/utils/HashingUtils.h>
#include <aws/core/utils/logging/LogMacros.h>
#include <aws/core/utils/stream/StreamBufProtectedWriter.h>
#include <aws/core/client/ClientConfiguration.h>
#include <aws/core/http/windows/WinConnectionPoolMgr.h>
#include <aws/core/utils/memory/AWSMemory.h>
Expand Down Expand Up @@ -245,55 +246,104 @@ bool WinSyncHttpClient::BuildSuccessResponse(const std::shared_ptr<HttpRequest>&

if (request->GetMethod() != HttpMethod::HTTP_HEAD)
{
char body[1024];
uint64_t bodySize = sizeof(body);
int64_t numBytesResponseReceived = 0;
read = 0;
if(!ContinueRequest(*request) || !IsRequestProcessingEnabled())
{
response->SetClientErrorType(CoreErrors::USER_CANCELLED);
response->SetClientErrorMessage("Request processing disabled or continuation cancelled by user's continuation handler.");
response->SetResponseCode(Aws::Http::HttpResponseCode::NO_RESPONSE);
return false;
}

bool success = ContinueRequest(*request);
if (response->GetResponseBody().fail()) {
const auto& ref = response->GetResponseBody();
AWS_LOGSTREAM_ERROR(GetLogTag(), "Response output stream is in a bad state (eof: " << ref.eof() << ", bad: " << ref.bad() << ")");
response->SetClientErrorType(CoreErrors::NETWORK_CONNECTION);
response->SetClientErrorMessage("Response output stream is in a bad state.");
return false;
}

while (DoReadData(hHttpRequest, body, bodySize, read) && read > 0 && success)
{
response->GetResponseBody().write(body, read);
if (read > 0)
{
for (const auto& hashIterator : request->GetResponseValidationHashes())
{
hashIterator.second->Update(reinterpret_cast<unsigned char*>(body), static_cast<size_t>(read));
}
numBytesResponseReceived += read;
if (readLimiter != nullptr)
{
readLimiter->ApplyAndPayForCost(read);
}
auto& receivedHandler = request->GetDataReceivedEventHandler();
if (receivedHandler)
bool connectionOpen = true;
auto writerFunc =
[this, hHttpRequest, &request, readLimiter, &response, &connectionOpen](char* dst, uint64_t dstSz, uint64_t& read) -> bool
{
receivedHandler(request.get(), response.get(), (long long)read);
}
}
bool success = true;
uint64_t available = 0;
connectionOpen = DoQueryDataAvailable(hHttpRequest, available);

success = success && ContinueRequest(*request) && IsRequestProcessingEnabled();
if (connectionOpen && available)
{
dstSz = (std::min)(dstSz, available);
success = DoReadData(hHttpRequest, dst, dstSz, read);
if (success && read > 0)
{
for (const auto& hashIterator : request->GetResponseValidationHashes())
{
hashIterator.second->Update(reinterpret_cast<unsigned char*>(dst), static_cast<size_t>(read));
}
if (readLimiter != nullptr)
{
readLimiter->ApplyAndPayForCost(read);
}
auto& receivedHandler = request->GetDataReceivedEventHandler();
if (receivedHandler)
{
receivedHandler(request.get(), response.get(), (long long)read);
}
}
if (!ContinueRequest(*request) || !IsRequestProcessingEnabled())
{
return false;
}
}
return connectionOpen && success && ContinueRequest(*request) && IsRequestProcessingEnabled();
};
uint64_t numBytesResponseReceived = Aws::Utils::Stream::StreamBufProtectedWriter::WriteToBuffer(response->GetResponseBody(), writerFunc);

if(!ContinueRequest(*request) || !IsRequestProcessingEnabled())
{
response->SetClientErrorType(CoreErrors::USER_CANCELLED);
response->SetClientErrorMessage("Request processing disabled or continuation cancelled by user's continuation handler.");
response->SetResponseCode(Aws::Http::HttpResponseCode::NO_RESPONSE);
return false;
}

if (success && response->HasHeader(Aws::Http::CONTENT_LENGTH_HEADER))
if (response->GetResponseBody().fail()) {
Aws::OStringStream errorMsgStr;
errorMsgStr << "Failed to write received response (eof: "
<< response->GetResponseBody().eof() << ", bad: " << response->GetResponseBody().bad() << ")";

Aws::String errorMsg = errorMsgStr.str();
AWS_LOGSTREAM_ERROR(GetLogTag(), errorMsg);
response->SetClientErrorType(CoreErrors::NETWORK_CONNECTION);
response->SetClientErrorMessage(errorMsg);
return false;
}

if (request->IsEventStreamRequest() && !response->HasHeader(Aws::Http::X_AMZN_ERROR_TYPE))
{
response->GetResponseBody().flush();
if (response->GetResponseBody().fail()) {
const auto& ref = response->GetResponseBody();
AWS_LOGSTREAM_ERROR(GetLogTag(), "Failed to flush event response (eof: " << ref.eof() << ", bad: " << ref.bad() << ")");
response->SetClientErrorType(CoreErrors::NETWORK_CONNECTION);
response->SetClientErrorMessage("Failed to flush event stream event response");
return false;
}
}

if (response->HasHeader(Aws::Http::CONTENT_LENGTH_HEADER))
{
const Aws::String& contentLength = response->GetHeader(Aws::Http::CONTENT_LENGTH_HEADER);
AWS_LOGSTREAM_TRACE(GetLogTag(), "Response content-length header: " << contentLength);
AWS_LOGSTREAM_TRACE(GetLogTag(), "Response body length: " << numBytesResponseReceived);
if (StringUtils::ConvertToInt64(contentLength.c_str()) != numBytesResponseReceived)
if ((uint64_t) StringUtils::ConvertToInt64(contentLength.c_str()) != numBytesResponseReceived)
{
success = false;
response->SetClientErrorType(CoreErrors::NETWORK_CONNECTION);
response->SetClientErrorMessage("Response body length doesn't match the content-length header.");
AWS_LOGSTREAM_ERROR(GetLogTag(), "Response body length doesn't match the content-length header.");
return false;
}
}

if(!success)
{
return false;
}
}

//go ahead and flush the response body.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ namespace
auto cognitoClient = Aws::MakeShared<Aws::CognitoIdentity::CognitoIdentityClient>(ALLOCATION_TAG, config);
Aws::AccessManagement::AccessManagementClient accessManagementClient(iamClient, cognitoClient);
accountId = accessManagementClient.GetAccountId();
assert(!accountId.empty()); // AccountId must be set for this test
}
m_accountId = accountId;
}
Expand Down
7 changes: 7 additions & 0 deletions tests/benchmark/benchmark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@
#include <metric/CloudWatchMetrics.h>

int main(int argc, char *argv[]) {
if (1 == argc ||
2 == argc && (std::string(argv[1]) == "-h" || std::string(argv[1]) == "--help" ))
{
Benchmark::Configuration::PrintHelp();
return 0;
}

Aws::SDKOptions options;
Aws::InitAPI(options);
{
Expand Down
2 changes: 2 additions & 0 deletions tests/benchmark/include/Configuration.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@ namespace Benchmark {
std::string service;
std::string api;
long durationMillis;
size_t maxRepeats;
bool shouldReportToCloudWatch;
std::map<std::string, std::string> dimensions;
};

class Configuration {
public:
static void PrintHelp();
static Configuration FromArgs(int argc, char *argv[]);
inline RunConfiguration GetConfiguration() const { return this->runConfiguration; }
private:
Expand Down
Loading

0 comments on commit 911c421

Please sign in to comment.