Skip to content

Fix s3 crt checksum calc #3075

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 15, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -20,6 +20,7 @@
#include <aws/core/utils/DNS.h>
#include <aws/s3-crt/S3CrtServiceClientModel.h>
#include <aws/s3-crt/S3ExpressIdentityProvider.h>
#include <aws/s3-crt/S3CrtIdentityProviderAdapter.h>

struct aws_s3_client;
// TODO: temporary fix for naming conflicts on Windows.
@@ -6275,6 +6276,7 @@ namespace Aws
std::shared_ptr<Aws::Http::HttpRequest> request;
std::shared_ptr<Aws::Http::HttpResponse> response;
std::shared_ptr<Aws::Crt::Http::HttpRequest> crtHttpRequest;
Aws::UniquePtr<struct aws_s3_checksum_config> checksumConfig;
};

Aws::Client::XmlOutcome GenerateXmlOutcome(const std::shared_ptr<Http::HttpResponse>& response) const;
@@ -6316,6 +6318,7 @@ namespace Aws
std::shared_ptr<Aws::Crt::Auth::ICredentialsProvider> m_crtCredProvider;
std::shared_ptr<S3CrtEndpointProviderBase> m_endpointProvider;
std::shared_ptr<S3ExpressIdentityProvider> m_identityProvider;
S3CrtIdentityProviderUserData m_identityProviderUserData{m_identityProvider};
};

} // namespace S3Crt
Original file line number Diff line number Diff line change
@@ -10,6 +10,21 @@

namespace Aws {
namespace S3Crt {
/**
* Userdata class for ensuring lifetime of the identity provider and
* implementation pointers for crt.
*/
class S3CrtIdentityProviderUserData final {
public:
explicit S3CrtIdentityProviderUserData(std::shared_ptr<S3ExpressIdentityProvider> identity_provider);
std::shared_ptr<S3ExpressIdentityProvider> GetIdentityProvider() const { return m_identityProvider; }
std::shared_ptr<aws_s3express_credentials_provider_vtable> GetImpl() const { return m_impl; }

private:
std::shared_ptr<S3ExpressIdentityProvider> m_identityProvider;
std::shared_ptr<struct aws_s3express_credentials_provider_vtable> m_impl;
};

/**
* Factory for a CRT aws_s3express_credentials_provider. Cannot subclass or instantiate,
* only for building a crt provider to be used in crt configuration.
64 changes: 63 additions & 1 deletion generated/src/aws-cpp-sdk-s3-crt/source/S3CrtClient.cpp
Original file line number Diff line number Diff line change
@@ -473,7 +473,7 @@ void S3CrtClient::init(const S3Crt::ClientConfiguration& config,
s3CrtConfig.shutdown_callback = CrtClientShutdownCallback;
s3CrtConfig.shutdown_callback_user_data = static_cast<void*>(&m_wrappedData);
s3CrtConfig.enable_s3express = !config.disableS3ExpressAuth;
s3CrtConfig.factory_user_data = static_cast<void *>(m_identityProvider.get());
s3CrtConfig.factory_user_data = static_cast<void *>(&m_identityProviderUserData);
s3CrtConfig.s3express_provider_override_factory = S3CrtIdentityProviderAdapter::ProviderFactory;

m_s3CrtClient = aws_s3_client_new(Aws::get_aws_allocator(), &s3CrtConfig);
@@ -788,6 +788,37 @@ void S3CrtClient::CopyObjectAsync(const CopyObjectRequest& request, const CopyOb
}
options.signing_config = &signing_config_override;

const auto headers = request.GetHeaders();
const auto checksumHeader = std::find_if(headers.begin(),
headers.end(),
[](const Aws::Http::HeaderValuePair& header) -> bool { return header.first.find("x-amz-checksum-") != Aws::String::npos; });
if (request.ChecksumAlgorithmHasBeenSet() && checksumHeader == headers.end())
{
static std::pair<ChecksumAlgorithm, aws_s3_checksum_algorithm> crtChecksumMapping[]{
{ChecksumAlgorithm::CRC32, aws_s3_checksum_algorithm::AWS_SCA_CRC32},
{ChecksumAlgorithm::CRC32C, aws_s3_checksum_algorithm::AWS_SCA_CRC32C},
{ChecksumAlgorithm::SHA1, aws_s3_checksum_algorithm::AWS_SCA_SHA1},
{ChecksumAlgorithm::SHA256, aws_s3_checksum_algorithm::AWS_SCA_SHA256},
};

const auto checksumAlgorithm = request.GetChecksumAlgorithm();
const auto mapping = std::find_if(std::begin(crtChecksumMapping),
std::end(crtChecksumMapping),
[&checksumAlgorithm](const std::pair<ChecksumAlgorithm, aws_s3_checksum_algorithm>& mapping){ return mapping.first == checksumAlgorithm; });
if (mapping != std::end(crtChecksumMapping))
{
Aws::UniquePtr<struct aws_s3_checksum_config> checksumConfig{Aws::New<struct aws_s3_checksum_config>(ALLOCATION_TAG)};
checksumConfig->checksum_algorithm = mapping->second;
checksumConfig->location = AWS_SCL_TRAILER;
userData->checksumConfig = std::move(checksumConfig);
options.checksum_config = userData->checksumConfig.get();
}
else
{
AWS_LOGSTREAM_ERROR("CopyObject", "Could not set CRT checksum for " << ChecksumAlgorithmMapper::GetNameForChecksumAlgorithm(checksumAlgorithm));
}
}

std::shared_ptr<Aws::Crt::Http::HttpRequest> crtHttpRequest = userData->request->ToCrtHttpRequest();
options.message= crtHttpRequest->GetUnderlyingMessage();
userData->crtHttpRequest = crtHttpRequest;
@@ -1030,6 +1061,37 @@ void S3CrtClient::PutObjectAsync(const PutObjectRequest& request, const PutObjec
}
options.signing_config = &signing_config_override;

const auto headers = request.GetHeaders();
const auto checksumHeader = std::find_if(headers.begin(),
headers.end(),
[](const Aws::Http::HeaderValuePair& header) -> bool { return header.first.find("x-amz-checksum-") != Aws::String::npos; });
if (request.ChecksumAlgorithmHasBeenSet() && checksumHeader == headers.end())
{
static std::pair<ChecksumAlgorithm, aws_s3_checksum_algorithm> crtChecksumMapping[]{
{ChecksumAlgorithm::CRC32, aws_s3_checksum_algorithm::AWS_SCA_CRC32},
{ChecksumAlgorithm::CRC32C, aws_s3_checksum_algorithm::AWS_SCA_CRC32C},
{ChecksumAlgorithm::SHA1, aws_s3_checksum_algorithm::AWS_SCA_SHA1},
{ChecksumAlgorithm::SHA256, aws_s3_checksum_algorithm::AWS_SCA_SHA256},
};

const auto checksumAlgorithm = request.GetChecksumAlgorithm();
const auto mapping = std::find_if(std::begin(crtChecksumMapping),
std::end(crtChecksumMapping),
[&checksumAlgorithm](const std::pair<ChecksumAlgorithm, aws_s3_checksum_algorithm>& mapping){ return mapping.first == checksumAlgorithm; });
if (mapping != std::end(crtChecksumMapping))
{
Aws::UniquePtr<struct aws_s3_checksum_config> checksumConfig{Aws::New<struct aws_s3_checksum_config>(ALLOCATION_TAG)};
checksumConfig->checksum_algorithm = mapping->second;
checksumConfig->location = AWS_SCL_TRAILER;
userData->checksumConfig = std::move(checksumConfig);
options.checksum_config = userData->checksumConfig.get();
}
else
{
AWS_LOGSTREAM_ERROR("PutObject", "Could not set CRT checksum for " << ChecksumAlgorithmMapper::GetNameForChecksumAlgorithm(checksumAlgorithm));
}
}

std::shared_ptr<Aws::Crt::Http::HttpRequest> crtHttpRequest = userData->request->ToCrtHttpRequest();
options.message= crtHttpRequest->GetUnderlyingMessage();
userData->crtHttpRequest = crtHttpRequest;
Original file line number Diff line number Diff line change
@@ -10,24 +10,11 @@ using namespace Aws::S3Crt;

static const char* ALLOC_TAG = "S3CrtIdentityProviderAdapter";

aws_s3express_credentials_provider *S3CrtIdentityProviderAdapter::ProviderFactory(struct aws_allocator *allocator,
struct aws_s3_client *client,
aws_simple_completion_callback *on_provider_shutdown_callback,
void *shutdown_user_data,
void *factory_user_data)
S3CrtIdentityProviderUserData::S3CrtIdentityProviderUserData(std::shared_ptr<S3ExpressIdentityProvider> identity_provider):
m_identityProvider(identity_provider),
m_impl(Aws::MakeUnique<struct aws_s3express_credentials_provider_vtable>(ALLOC_TAG))
{
// We use our own client in our internal implementation.
AWS_UNREFERENCED_PARAM(client);
struct aws_s3express_credentials_provider* provider = nullptr;
provider = static_cast<aws_s3express_credentials_provider *>(aws_mem_calloc(allocator, 1,
sizeof(struct aws_s3express_credentials_provider)));

auto userData = static_cast<S3ExpressIdentityProvider*>(factory_user_data);
struct aws_s3express_credentials_provider_vtable* s_cpp_s3express_table =
static_cast<aws_s3express_credentials_provider_vtable *>(aws_mem_calloc(allocator, 1,
sizeof(struct aws_s3express_credentials_provider_vtable)));

s_cpp_s3express_table->get_credentials = [](struct aws_s3express_credentials_provider* provider,
m_impl->get_credentials = [](struct aws_s3express_credentials_provider* provider,
const struct aws_credentials* original_credentials,
const struct aws_credentials_properties_s3express* s3express_properties,
aws_on_get_credentials_callback_fn* callback,
@@ -38,8 +25,11 @@ aws_s3express_credentials_provider *S3CrtIdentityProviderAdapter::ProviderFactor

//Figure out service specific params
Aws::Map<Aws::String, Aws::String> params;
struct aws_string *str = aws_string_new_from_cursor(get_aws_allocator(), &s3express_properties->host);
Aws::String hostname(aws_string_c_str(str));
Aws::UniquePtr<aws_string, std::function<void (aws_string*)>> hostnameCStr{
aws_string_new_from_cursor(get_aws_allocator(), &s3express_properties->host),
aws_string_destroy
};
Aws::String hostname(aws_string_c_str(hostnameCStr.get()));
// This requires the hostname be virtually addressed and will fail if not. In theory express
// hostname at this point in theory will always be this way for express hosts.
auto bucketName = hostname.substr(0, hostname.find('.'));
@@ -59,28 +49,45 @@ aws_s3express_credentials_provider *S3CrtIdentityProviderAdapter::ProviderFactor
session_token_cursor = aws_byte_cursor_from_c_str(creds.getSessionToken().c_str());
}

struct aws_credentials* credentials = aws_credentials_new(get_aws_allocator(),
Aws::UniquePtr<aws_credentials, std::function<void (aws_credentials*)>> credentials{
aws_credentials_new(get_aws_allocator(),
access_key_id_cursor,
secret_access_key_cursor,
session_token_cursor,
creds.getExpiration().Seconds());
creds.getExpiration().Seconds()),
aws_credentials_release
};

callback(credentials, AWS_OP_SUCCESS, user_data);
callback(credentials.get(), AWS_OP_SUCCESS, user_data);
return AWS_OP_SUCCESS;
};

s_cpp_s3express_table->destroy = [](struct aws_s3express_credentials_provider* provider) -> void
m_impl->destroy = [](struct aws_s3express_credentials_provider* provider) -> void
{
aws_simple_completion_callback *callback = provider->shutdown_complete_callback;
void *user_data = provider->shutdown_user_data;
aws_mem_release(provider->allocator, provider);
callback(user_data);
};
}

aws_s3express_credentials_provider* S3CrtIdentityProviderAdapter::ProviderFactory(struct aws_allocator* allocator,
struct aws_s3_client* client,
aws_simple_completion_callback* on_provider_shutdown_callback,
void* shutdown_user_data,
void* factory_user_data)
{
// We use our own client in our internal implementation.
AWS_UNREFERENCED_PARAM(client);
struct aws_s3express_credentials_provider* provider = nullptr;
provider = static_cast<aws_s3express_credentials_provider *>(aws_mem_calloc(allocator, 1,
sizeof(struct aws_s3express_credentials_provider)));

auto userData = static_cast<S3CrtIdentityProviderUserData*>(factory_user_data);
aws_s3express_credentials_provider_init_base(provider,
allocator,
s_cpp_s3express_table,
userData);
userData->GetImpl().get(),
userData->GetIdentityProvider().get());
provider->shutdown_complete_callback = on_provider_shutdown_callback;
provider->shutdown_user_data = shutdown_user_data;
return provider;
Original file line number Diff line number Diff line change
@@ -551,15 +551,6 @@ namespace
SCOPED_TRACE(Aws::String("FullBucketName ") + fullBucketName);
CreateBucketRequest createBucketRequest;
createBucketRequest.SetBucket(fullBucketName);
createBucketRequest.SetACL(BucketCannedACL::private_);
{
CreateBucketConfiguration bucketConfiguration;
Aws::S3Crt::ClientConfiguration dummyClientConfig;
bucketConfiguration.SetLocationConstraint(
BucketLocationConstraintMapper::GetBucketLocationConstraintForName(dummyClientConfig.region));
createBucketRequest.SetCreateBucketConfiguration(bucketConfiguration);
}

CreateBucketOutcome createBucketOutcome = Client->CreateBucket(createBucketRequest);
AWS_ASSERT_SUCCESS(createBucketOutcome);
const CreateBucketResult& createBucketResult = createBucketOutcome.GetResult();
@@ -1383,7 +1374,6 @@ namespace
GetObjectOutcome outcome = Client->GetObject(getObjectRequest);

ASSERT_FALSE(outcome.IsSuccess());
ASSERT_EQ(outcome.GetError().GetErrorType(), Aws::S3Crt::S3CrtErrors::NETWORK_CONNECTION);
}

TEST_F(BucketAndObjectOperationTest, MissingCertificate) {
89 changes: 52 additions & 37 deletions tests/aws-cpp-sdk-s3-crt-integration-tests/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,37 +1,52 @@
add_project(aws-cpp-sdk-s3-crt-integration-tests
"Tests for the AWS S3 CRT C++ SDK"
aws-cpp-sdk-s3-crt
testing-resources
aws-cpp-sdk-core)

file(GLOB AWS_S3_CRT_SRC
"${CMAKE_CURRENT_SOURCE_DIR}/*.cpp"
)

file(GLOB AWS_S3_CRT_INTEGRATION_TESTS_SRC
${AWS_S3_CRT_SRC}
)

add_definitions(-DRESOURCES_DIR="${CMAKE_CURRENT_SOURCE_DIR}/resources")

if(MSVC AND BUILD_SHARED_LIBS)
add_definitions(-DGTEST_LINKED_AS_SHARED_LIBRARY=1)
endif()

enable_testing()

if(PLATFORM_ANDROID AND BUILD_SHARED_LIBS)
add_library(${PROJECT_NAME} ${AWS_S3_CRT_INTEGRATION_TESTS_SRC})
else()
add_executable(${PROJECT_NAME} ${AWS_S3_CRT_INTEGRATION_TESTS_SRC})
endif()

set_compiler_flags(${PROJECT_NAME})
set_compiler_warnings(${PROJECT_NAME})

target_link_libraries(${PROJECT_NAME} ${PROJECT_LIBS})

if(MSVC AND BUILD_SHARED_LIBS)
set_target_properties(${PROJECT_NAME} PROPERTIES LINK_FLAGS "/DELAYLOAD:aws-cpp-sdk-s3-crt.dll /DELAYLOAD:aws-cpp-sdk-core.dll")
target_link_libraries(${PROJECT_NAME} delayimp.lib)
endif()
set(TEST_LIST "aws-cpp-sdk-s3-crt-integration-tests:RunTests.cpp")

# Only run memlimiter test if we are building with custom memory management
if (CUSTOM_MEMORY_MANAGEMENT)
list(APPEND TEST_LIST "aws-cpp-sdk-s3-crt-memory-checked-integration-tests:RunTestsWithMemTracer.cpp")
endif ()

foreach(TEST IN LISTS TEST_LIST)
string(REPLACE ":" ";" TEST_ITEMS "${TEST}")
list(GET TEST_ITEMS 0 TEST_PROJECT_NAME)
list(GET TEST_ITEMS 1 TEST_MAIN_FILE)

add_project("${TEST_PROJECT_NAME}"
"Tests for the AWS S3 CRT C++ SDK"
aws-cpp-sdk-s3-crt
testing-resources
aws-cpp-sdk-core)

file(GLOB AWS_S3_CRT_SRC
"${TEST_MAIN_FILE}"
"BucketAndObjectOperationTest.cpp"
"S3ExpressTest.cpp"
)

file(GLOB AWS_S3_CRT_INTEGRATION_TESTS_SRC
${AWS_S3_CRT_SRC}
)

add_definitions(-DRESOURCES_DIR="${CMAKE_CURRENT_SOURCE_DIR}/resources")

if(MSVC AND BUILD_SHARED_LIBS)
add_definitions(-DGTEST_LINKED_AS_SHARED_LIBRARY=1)
endif()

enable_testing()

if(PLATFORM_ANDROID AND BUILD_SHARED_LIBS)
add_library(${PROJECT_NAME} ${AWS_S3_CRT_INTEGRATION_TESTS_SRC})
else()
add_executable(${PROJECT_NAME} ${AWS_S3_CRT_INTEGRATION_TESTS_SRC})
endif()

set_compiler_flags(${PROJECT_NAME})
set_compiler_warnings(${PROJECT_NAME})

target_link_libraries(${PROJECT_NAME} ${PROJECT_LIBS})

if(MSVC AND BUILD_SHARED_LIBS)
set_target_properties(${PROJECT_NAME} PROPERTIES LINK_FLAGS "/DELAYLOAD:aws-cpp-sdk-s3-crt.dll /DELAYLOAD:aws-cpp-sdk-core.dll")
target_link_libraries(${PROJECT_NAME} delayimp.lib)
endif()
endforeach ()
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/**
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0.
*/

#include <gtest/gtest.h>
#include <aws/core/Aws.h>
#include <aws/testing/platform/PlatformTesting.h>
#include <aws/testing/TestingEnvironment.h>
#include <aws/testing/MemoryTesting.h>

int main(int argc, char** argv) {
Aws::Testing::SetDefaultSigPipeHandler();
Aws::SDKOptions options;
options.loggingOptions.logLevel = Aws::Utils::Logging::LogLevel::Trace;
CRTMemTracerMemorySystem memorySystem{};
options.memoryManagementOptions.memoryManager = &memorySystem;
Aws::Testing::InitPlatformTest(options);
Aws::Testing::ParseArgs(argc, argv);

Aws::InitAPI(options);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think Begin() goes here and End() right after "ShutdownAPI"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is copy and pasted more or less from what we already do and how we use it in the crt test already

::testing::InitGoogleTest(&argc, argv);
int exitCode = RUN_ALL_TESTS();

Aws::ShutdownAPI(options);
memorySystem.AssertNoLeaks();
Aws::Testing::ShutdownPlatformTest(options);
return exitCode;
}
19 changes: 19 additions & 0 deletions tests/aws-cpp-sdk-s3-crt-integration-tests/S3ExpressTest.cpp
Original file line number Diff line number Diff line change
@@ -479,6 +479,25 @@ namespace {
}
}

TEST_F(S3ExpressTest, PutObjectChecksumWithoutAlgorithmValue) {
const auto bucketName = Testing::GetAwsResourcePrefix() + randomString() + S3_EXPRESS_SUFFIX;
const auto createOutcome = CreateBucket(bucketName);
AWS_EXPECT_SUCCESS(createOutcome);

auto request = PutObjectRequest()
.WithBucket(bucketName)
.WithKey("swingingparty")
.WithChecksumAlgorithm(ChecksumAlgorithm::CRC32);

std::shared_ptr<IOStream> body = Aws::MakeShared<StringStream>(ALLOCATION_TAG,
"Bring your own lampshade, somewhere there's a party.",
std::ios_base::in | std::ios_base::binary);
request.SetBody(body);

const auto response = client->PutObject(request);
AWS_EXPECT_SUCCESS(response);
}

TEST_F(S3ExpressTest, PutObjectChecksumWithoutAlgorithm) {
const auto bucketName = Testing::GetAwsResourcePrefix() + randomString() + S3_EXPRESS_SUFFIX;
const auto createOutcome = CreateBucket(bucketName);
Loading
Loading