Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid excessive reply buffer copy in WinHTTP #2954

Merged
merged 1 commit into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ namespace Aws
bool DoReceiveResponse(void* httpRequest) const override;
bool DoQueryHeaders(void* httpRequest, std::shared_ptr<Aws::Http::HttpResponse>& response, Aws::StringStream& ss, uint64_t& read) const override;
bool DoSendRequest(void* httpRequest) const override;
bool DoQueryDataAvailable(void* hHttpRequest, uint64_t& available) const override;
bool DoReadData(void* hHttpRequest, char* body, uint64_t size, uint64_t& read) const override;
void* GetClientModule() const override;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ namespace Aws
bool DoReceiveResponse(void* hHttpRequest) const override;
bool DoQueryHeaders(void* hHttpRequest, std::shared_ptr<Aws::Http::HttpResponse>& response, Aws::StringStream& ss, uint64_t& read) const override;
bool DoSendRequest(void* hHttpRequest) const override;
bool DoQueryDataAvailable(void* hHttpRequest, uint64_t& available) const override;
bool DoReadData(void* hHttpRequest, char* body, uint64_t size, uint64_t& read) const override;
void* GetClientModule() const override;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ namespace Aws
virtual bool DoReceiveResponse(void* hHttpRequest) const = 0;
virtual bool DoQueryHeaders(void* hHttpRequest, std::shared_ptr<Aws::Http::HttpResponse>& response, Aws::StringStream& ss, uint64_t& read) const = 0;
virtual bool DoSendRequest(void* hHttpRequest) const = 0;
virtual bool DoQueryDataAvailable(void* hHttpRequest, uint64_t& available) const = 0;
virtual bool DoReadData(void* hHttpRequest, char* body, uint64_t size, uint64_t& read) const = 0;
virtual void* GetClientModule() const = 0;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@

/**
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0.
*/

#pragma once

#include <aws/core/Core_EXPORTS.h>
#include <aws/core/utils/Array.h>
#include <streambuf>
#include <functional>

namespace Aws
{
namespace Utils
{
namespace Stream
{
/**
* This is a wrapper to perform a hack to write directly to the put area of the underlying streambuf
*/
class AWS_CORE_API StreamBufProtectedWriter : public std::streambuf
{
public:
StreamBufProtectedWriter() = delete;

using WriterFunc = std::function<bool(char* dst, uint64_t dstSz, uint64_t& read)>;

static uint64_t WriteToBuffer(Aws::IOStream& ioStream, const WriterFunc& writerFunc)
{
uint64_t totalRead = 0;

while (true)
{
StreamBufProtectedWriter* pBufferCasted = static_cast<StreamBufProtectedWriter*>(ioStream.rdbuf());
bool bufferPresent = pBufferCasted && pBufferCasted->pptr() && (pBufferCasted->pptr() < pBufferCasted->epptr());
uint64_t read = 0;
bool success = false;
if (bufferPresent)
{
// have access to underlying put ptr.
success = WriteDirectlyToPtr(pBufferCasted, writerFunc, read);
}
else
{
// can't access underlying buffer, stream buffer maybe be customized to not use put ptr.
// or underlying put buffer is simply not initialized yet.
success = WriteWithHelperBuffer(ioStream, writerFunc, read);
}
totalRead += read;
if (!success)
{
break;
}

if (pBufferCasted && pBufferCasted->pptr() && (pBufferCasted->pptr() >= pBufferCasted->epptr()))
{
if(!ForceOverflow(ioStream, writerFunc))
{
break;
} else {
totalRead++;
}
}
}
return totalRead;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

under what conditions are you getting to this line? only way out of while is to return. is this expected or while is missing a break somewhere?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as dimitry also if the line is "unreachable in theory but we need to return something" we should assert use AWS_UNREACHABLE as a invariant

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed to break + proper return

}
protected:
static bool ForceOverflow(Aws::IOStream& ioStream, const WriterFunc& writerFunc)
{
char dstChar;
uint64_t read = 0;
if (writerFunc(&dstChar, 1, read) && read > 0)
{
ioStream.write(&dstChar, 1);
if (ioStream.fail()) {
AWS_LOGSTREAM_ERROR("StreamBufProtectedWriter", "Failed to write 1 byte (eof: "
<< ioStream.eof() << ", bad: " << ioStream.bad() << ")");
return false;
}
return true;
}
return false;
}

static uint64_t WriteWithHelperBuffer(Aws::IOStream& ioStream, const WriterFunc& writerFunc, uint64_t& read)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: the return type should be bool

{
char tmpBuf[1024];
uint64_t tmpBufSz = sizeof(tmpBuf);

if(writerFunc(tmpBuf, tmpBufSz, read) && read > 0)
{
ioStream.write(tmpBuf, read);
if (ioStream.fail()) {
AWS_LOGSTREAM_ERROR("StreamBufProtectedWriter", "Failed to write " << tmpBufSz
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: shouldn't it be "Failed to write " << read?

<< " (eof: " << ioStream.eof() << ", bad: " << ioStream.bad() << ")");
return false;
}
return true;
}
return false;
}

static uint64_t WriteDirectlyToPtr(StreamBufProtectedWriter* pBuffer, const WriterFunc& writerFunc, uint64_t& read)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: the return type should be bool

{
auto dstBegin = pBuffer->pptr();
uint64_t dstSz = pBuffer->epptr() - dstBegin;
if(writerFunc(dstBegin, dstSz, read) && read > 0)
{
assert(read <= dstSz);
pBuffer->pbump((int) read);
return true;
}
return false;
}
};
}
}
}
4 changes: 2 additions & 2 deletions src/aws-cpp-sdk-core/source/http/curl/CurlHttpClient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ static size_t WriteData(char* ptr, size_t size, size_t nmemb, void* userdata)
return 0;
}

size_t cur = response->GetResponseBody().tellp();
auto cur = response->GetResponseBody().tellp();
if (response->GetResponseBody().fail()) {
const auto& ref = response->GetResponseBody();
AWS_LOGSTREAM_ERROR(CURL_HTTP_CLIENT_TAG, "Unable to query response output position (eof: "
Expand Down Expand Up @@ -302,7 +302,7 @@ static size_t ReadBody(char* ptr, size_t size, size_t nmemb, void* userdata, boo
{
if (!ioStream->eof() && ioStream->peek() != EOF)
{
amountRead = ioStream->readsome(ptr, amountToRead);
amountRead = (size_t) ioStream->readsome(ptr, amountToRead);
}
if (amountRead == 0 && !ioStream->eof())
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,11 @@ bool WinHttpSyncHttpClient::DoSendRequest(void* hHttpRequest) const
return (WinHttpSendRequest(hHttpRequest, NULL, NULL, 0, 0, 0, NULL) != 0);
}

bool WinHttpSyncHttpClient::DoQueryDataAvailable(void* hHttpRequest, uint64_t& available) const
{
return (WinHttpQueryDataAvailable(hHttpRequest, (LPDWORD)&available) != 0);
}

bool WinHttpSyncHttpClient::DoReadData(void* hHttpRequest, char* body, uint64_t size, uint64_t& read) const
{
return (WinHttpReadData(hHttpRequest, body, (DWORD)size, (LPDWORD)&read) != 0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,11 @@ bool WinINetSyncHttpClient::DoSendRequest(void* hHttpRequest) const
return (HttpSendRequestEx(hHttpRequest, NULL, NULL, 0, 0) != 0);
}

bool WinINetSyncHttpClient::DoQueryDataAvailable(void* hHttpRequest, uint64_t& available) const
{
return (InternetQueryDataAvailable(hHttpRequest, (LPDWORD)&available, /*reserved*/ 0, /*reserved*/ 0) != 0);
}

bool WinINetSyncHttpClient::DoReadData(void* hHttpRequest, char* body, uint64_t size, uint64_t& read) const
{
return (InternetReadFile(hHttpRequest, body, (DWORD)size, (LPDWORD)&read) != 0);
Expand Down
116 changes: 83 additions & 33 deletions src/aws-cpp-sdk-core/source/http/windows/WinSyncHttpClient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <aws/core/utils/StringUtils.h>
#include <aws/core/utils/HashingUtils.h>
#include <aws/core/utils/logging/LogMacros.h>
#include <aws/core/utils/stream/StreamBufProtectedWriter.h>
#include <aws/core/client/ClientConfiguration.h>
#include <aws/core/http/windows/WinConnectionPoolMgr.h>
#include <aws/core/utils/memory/AWSMemory.h>
Expand Down Expand Up @@ -245,55 +246,104 @@ bool WinSyncHttpClient::BuildSuccessResponse(const std::shared_ptr<HttpRequest>&

if (request->GetMethod() != HttpMethod::HTTP_HEAD)
{
char body[1024];
uint64_t bodySize = sizeof(body);
int64_t numBytesResponseReceived = 0;
read = 0;
if(!ContinueRequest(*request) || !IsRequestProcessingEnabled())
{
response->SetClientErrorType(CoreErrors::USER_CANCELLED);
response->SetClientErrorMessage("Request processing disabled or continuation cancelled by user's continuation handler.");
response->SetResponseCode(Aws::Http::HttpResponseCode::NO_RESPONSE);
return false;
}

bool success = ContinueRequest(*request);
if (response->GetResponseBody().fail()) {
const auto& ref = response->GetResponseBody();
AWS_LOGSTREAM_ERROR(GetLogTag(), "Response output stream is in a bad state (eof: " << ref.eof() << ", bad: " << ref.bad() << ")");
response->SetClientErrorType(CoreErrors::NETWORK_CONNECTION);
response->SetClientErrorMessage("Response output stream is in a bad state.");
return false;
}

while (DoReadData(hHttpRequest, body, bodySize, read) && read > 0 && success)
{
response->GetResponseBody().write(body, read);
if (read > 0)
{
for (const auto& hashIterator : request->GetResponseValidationHashes())
{
hashIterator.second->Update(reinterpret_cast<unsigned char*>(body), static_cast<size_t>(read));
}
numBytesResponseReceived += read;
if (readLimiter != nullptr)
{
readLimiter->ApplyAndPayForCost(read);
}
auto& receivedHandler = request->GetDataReceivedEventHandler();
if (receivedHandler)
bool connectionOpen = true;
auto writerFunc =
[this, hHttpRequest, &request, readLimiter, &response, &connectionOpen](char* dst, uint64_t dstSz, uint64_t& read) -> bool
{
receivedHandler(request.get(), response.get(), (long long)read);
}
}
bool success = true;
uint64_t available = 0;
connectionOpen = DoQueryDataAvailable(hHttpRequest, available);

success = success && ContinueRequest(*request) && IsRequestProcessingEnabled();
if (connectionOpen && available)
{
dstSz = (std::min)(dstSz, available);
success = DoReadData(hHttpRequest, dst, dstSz, read);
if (success && read > 0)
{
for (const auto& hashIterator : request->GetResponseValidationHashes())
{
hashIterator.second->Update(reinterpret_cast<unsigned char*>(dst), static_cast<size_t>(read));
}
if (readLimiter != nullptr)
{
readLimiter->ApplyAndPayForCost(read);
}
auto& receivedHandler = request->GetDataReceivedEventHandler();
if (receivedHandler)
{
receivedHandler(request.get(), response.get(), (long long)read);
}
}
if (!ContinueRequest(*request) || !IsRequestProcessingEnabled())
{
return false;
}
}
return connectionOpen && success && ContinueRequest(*request) && IsRequestProcessingEnabled();
};
uint64_t numBytesResponseReceived = Aws::Utils::Stream::StreamBufProtectedWriter::WriteToBuffer(response->GetResponseBody(), writerFunc);

if(!ContinueRequest(*request) || !IsRequestProcessingEnabled())
{
response->SetClientErrorType(CoreErrors::USER_CANCELLED);
response->SetClientErrorMessage("Request processing disabled or continuation cancelled by user's continuation handler.");
response->SetResponseCode(Aws::Http::HttpResponseCode::NO_RESPONSE);
return false;
}

if (success && response->HasHeader(Aws::Http::CONTENT_LENGTH_HEADER))
if (response->GetResponseBody().fail()) {
Aws::OStringStream errorMsgStr;
errorMsgStr << "Failed to write received response (eof: "
<< response->GetResponseBody().eof() << ", bad: " << response->GetResponseBody().bad() << ")";

Aws::String errorMsg = errorMsgStr.str();
AWS_LOGSTREAM_ERROR(GetLogTag(), errorMsg);
response->SetClientErrorType(CoreErrors::NETWORK_CONNECTION);
response->SetClientErrorMessage(errorMsg);
return false;
}

if (request->IsEventStreamRequest() && !response->HasHeader(Aws::Http::X_AMZN_ERROR_TYPE))
{
response->GetResponseBody().flush();
if (response->GetResponseBody().fail()) {
const auto& ref = response->GetResponseBody();
AWS_LOGSTREAM_ERROR(GetLogTag(), "Failed to flush event response (eof: " << ref.eof() << ", bad: " << ref.bad() << ")");
response->SetClientErrorType(CoreErrors::NETWORK_CONNECTION);
response->SetClientErrorMessage("Failed to flush event stream event response");
return false;
}
}

if (response->HasHeader(Aws::Http::CONTENT_LENGTH_HEADER))
{
const Aws::String& contentLength = response->GetHeader(Aws::Http::CONTENT_LENGTH_HEADER);
AWS_LOGSTREAM_TRACE(GetLogTag(), "Response content-length header: " << contentLength);
AWS_LOGSTREAM_TRACE(GetLogTag(), "Response body length: " << numBytesResponseReceived);
if (StringUtils::ConvertToInt64(contentLength.c_str()) != numBytesResponseReceived)
if ((uint64_t) StringUtils::ConvertToInt64(contentLength.c_str()) != numBytesResponseReceived)
{
success = false;
response->SetClientErrorType(CoreErrors::NETWORK_CONNECTION);
response->SetClientErrorMessage("Response body length doesn't match the content-length header.");
AWS_LOGSTREAM_ERROR(GetLogTag(), "Response body length doesn't match the content-length header.");
return false;
}
}

if(!success)
{
return false;
}
}

//go ahead and flush the response body.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ namespace
auto cognitoClient = Aws::MakeShared<Aws::CognitoIdentity::CognitoIdentityClient>(ALLOCATION_TAG, config);
Aws::AccessManagement::AccessManagementClient accessManagementClient(iamClient, cognitoClient);
accountId = accessManagementClient.GetAccountId();
assert(!accountId.empty()); // AccountId must be set for this test
}
m_accountId = accountId;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ unsigned int LevenshteinDistance(Aws::String s1, Aws::String s2)
TEST_F(TranscribeStreamingTests, TranscribeStreamingCppSdkSample)
{
const Aws::Vector<Aws::String> EXPECTED_ALTERNATIVES = {"This is a C plus plus test sample", "This is a C++ test sample"};
for(size_t chunkDuration = 50; chunkDuration <= 200; chunkDuration += 25)
for(size_t chunkDuration = 50; chunkDuration <= 200; chunkDuration += 50)
{
m_testTraces.clear();
TestTrace(Aws::String("### Starting TranscribeStreamingCppSdkSample with chunks of ") + Aws::Utils::StringUtils::to_string(chunkDuration) + " ms ##");
Expand All @@ -582,7 +582,7 @@ TEST_F(TranscribeStreamingTests, TranscribeStreamingKantSample)
static const char expected[] = "Categorical imperative: Act only according to that maxim whereby you can at the same time will that it should become a universal law. "
"Two things fill the mind with ever-increasing wonder and awe, the more often and the more intensely the mind of thought is drawn to them: "
"the starry heavens above me and the moral law within me.";
for(size_t chunkDuration = 50; chunkDuration <= 200; chunkDuration += 25)
for(size_t chunkDuration = 50; chunkDuration <= 200; chunkDuration += 50)
{
m_testTraces.clear();
TestTrace(Aws::String("### Starting TranscribeStreamingKantSample with chunks of ") + Aws::Utils::StringUtils::to_string(chunkDuration) + " ms ##");
Expand Down
7 changes: 7 additions & 0 deletions tests/benchmark/benchmark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@
#include <metric/CloudWatchMetrics.h>

int main(int argc, char *argv[]) {
if (1 == argc ||
2 == argc && (std::string(argv[1]) == "-h" || std::string(argv[1]) == "--help" ))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NICE!

{
Benchmark::Configuration::PrintHelp();
return 0;
}

Aws::SDKOptions options;
Aws::InitAPI(options);
{
Expand Down
2 changes: 2 additions & 0 deletions tests/benchmark/include/Configuration.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@ namespace Benchmark {
std::string service;
std::string api;
long durationMillis;
size_t maxRepeats;
bool shouldReportToCloudWatch;
std::map<std::string, std::string> dimensions;
};

class Configuration {
public:
static void PrintHelp();
static Configuration FromArgs(int argc, char *argv[]);
inline RunConfiguration GetConfiguration() const { return this->runConfiguration; }
private:
Expand Down
Loading
Loading