Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix s3 crt checksum calc #3075

Merged
merged 2 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <aws/core/utils/DNS.h>
#include <aws/s3-crt/S3CrtServiceClientModel.h>
#include <aws/s3-crt/S3ExpressIdentityProvider.h>
#include <aws/s3-crt/S3CrtIdentityProviderAdapter.h>

struct aws_s3_client;
// TODO: temporary fix for naming conflicts on Windows.
Expand Down Expand Up @@ -6275,6 +6276,7 @@ namespace Aws
std::shared_ptr<Aws::Http::HttpRequest> request;
std::shared_ptr<Aws::Http::HttpResponse> response;
std::shared_ptr<Aws::Crt::Http::HttpRequest> crtHttpRequest;
Aws::UniquePtr<struct aws_s3_checksum_config> checksumConfig;
};

Aws::Client::XmlOutcome GenerateXmlOutcome(const std::shared_ptr<Http::HttpResponse>& response) const;
Expand Down Expand Up @@ -6316,6 +6318,7 @@ namespace Aws
std::shared_ptr<Aws::Crt::Auth::ICredentialsProvider> m_crtCredProvider;
std::shared_ptr<S3CrtEndpointProviderBase> m_endpointProvider;
std::shared_ptr<S3ExpressIdentityProvider> m_identityProvider;
S3CrtIdentityProviderUserData m_identityProviderUserData{m_identityProvider};
};

} // namespace S3Crt
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,21 @@

namespace Aws {
namespace S3Crt {
/**
* Userdata class for ensuring lifetime of the identity provider and
* implementation pointers for crt.
*/
class S3CrtIdentityProviderUserData final {
public:
explicit S3CrtIdentityProviderUserData(std::shared_ptr<S3ExpressIdentityProvider> identity_provider);
std::shared_ptr<S3ExpressIdentityProvider> GetIdentityProvider() const { return m_identityProvider; }
std::shared_ptr<aws_s3express_credentials_provider_vtable> GetImpl() const { return m_impl; }

private:
std::shared_ptr<S3ExpressIdentityProvider> m_identityProvider;
std::shared_ptr<struct aws_s3express_credentials_provider_vtable> m_impl;
};

/**
* Factory for a CRT aws_s3express_credentials_provider. Cannot subclass or instantiate,
* only for building a crt provider to be used in crt configuration.
Expand Down
64 changes: 63 additions & 1 deletion generated/src/aws-cpp-sdk-s3-crt/source/S3CrtClient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,7 @@ void S3CrtClient::init(const S3Crt::ClientConfiguration& config,
s3CrtConfig.shutdown_callback = CrtClientShutdownCallback;
s3CrtConfig.shutdown_callback_user_data = static_cast<void*>(&m_wrappedData);
s3CrtConfig.enable_s3express = !config.disableS3ExpressAuth;
s3CrtConfig.factory_user_data = static_cast<void *>(m_identityProvider.get());
s3CrtConfig.factory_user_data = static_cast<void *>(&m_identityProviderUserData);
s3CrtConfig.s3express_provider_override_factory = S3CrtIdentityProviderAdapter::ProviderFactory;

m_s3CrtClient = aws_s3_client_new(Aws::get_aws_allocator(), &s3CrtConfig);
Expand Down Expand Up @@ -788,6 +788,37 @@ void S3CrtClient::CopyObjectAsync(const CopyObjectRequest& request, const CopyOb
}
options.signing_config = &signing_config_override;

const auto headers = request.GetHeaders();
const auto checksumHeader = std::find_if(headers.begin(),
headers.end(),
[](const Aws::Http::HeaderValuePair& header) -> bool { return header.first.find("x-amz-checksum-") != Aws::String::npos; });
if (request.ChecksumAlgorithmHasBeenSet() && checksumHeader == headers.end())
{
static std::pair<ChecksumAlgorithm, aws_s3_checksum_algorithm> crtChecksumMapping[]{
{ChecksumAlgorithm::CRC32, aws_s3_checksum_algorithm::AWS_SCA_CRC32},
{ChecksumAlgorithm::CRC32C, aws_s3_checksum_algorithm::AWS_SCA_CRC32C},
{ChecksumAlgorithm::SHA1, aws_s3_checksum_algorithm::AWS_SCA_SHA1},
{ChecksumAlgorithm::SHA256, aws_s3_checksum_algorithm::AWS_SCA_SHA256},
};

const auto checksumAlgorithm = request.GetChecksumAlgorithm();
const auto mapping = std::find_if(std::begin(crtChecksumMapping),
std::end(crtChecksumMapping),
[&checksumAlgorithm](const std::pair<ChecksumAlgorithm, aws_s3_checksum_algorithm>& mapping){ return mapping.first == checksumAlgorithm; });
if (mapping != std::end(crtChecksumMapping))
sbiscigl marked this conversation as resolved.
Show resolved Hide resolved
{
Aws::UniquePtr<struct aws_s3_checksum_config> checksumConfig{Aws::New<struct aws_s3_checksum_config>(ALLOCATION_TAG)};
checksumConfig->checksum_algorithm = mapping->second;
checksumConfig->location = AWS_SCL_TRAILER;
userData->checksumConfig = std::move(checksumConfig);
options.checksum_config = userData->checksumConfig.get();
}
else
{
AWS_LOGSTREAM_ERROR("CopyObject", "Could not set CRT checksum for " << ChecksumAlgorithmMapper::GetNameForChecksumAlgorithm(checksumAlgorithm));
}
}

std::shared_ptr<Aws::Crt::Http::HttpRequest> crtHttpRequest = userData->request->ToCrtHttpRequest();
options.message= crtHttpRequest->GetUnderlyingMessage();
userData->crtHttpRequest = crtHttpRequest;
Expand Down Expand Up @@ -1030,6 +1061,37 @@ void S3CrtClient::PutObjectAsync(const PutObjectRequest& request, const PutObjec
}
options.signing_config = &signing_config_override;

const auto headers = request.GetHeaders();
const auto checksumHeader = std::find_if(headers.begin(),
headers.end(),
[](const Aws::Http::HeaderValuePair& header) -> bool { return header.first.find("x-amz-checksum-") != Aws::String::npos; });
if (request.ChecksumAlgorithmHasBeenSet() && checksumHeader == headers.end())
{
static std::pair<ChecksumAlgorithm, aws_s3_checksum_algorithm> crtChecksumMapping[]{
{ChecksumAlgorithm::CRC32, aws_s3_checksum_algorithm::AWS_SCA_CRC32},
{ChecksumAlgorithm::CRC32C, aws_s3_checksum_algorithm::AWS_SCA_CRC32C},
{ChecksumAlgorithm::SHA1, aws_s3_checksum_algorithm::AWS_SCA_SHA1},
{ChecksumAlgorithm::SHA256, aws_s3_checksum_algorithm::AWS_SCA_SHA256},
};

const auto checksumAlgorithm = request.GetChecksumAlgorithm();
const auto mapping = std::find_if(std::begin(crtChecksumMapping),
std::end(crtChecksumMapping),
[&checksumAlgorithm](const std::pair<ChecksumAlgorithm, aws_s3_checksum_algorithm>& mapping){ return mapping.first == checksumAlgorithm; });
if (mapping != std::end(crtChecksumMapping))
{
Aws::UniquePtr<struct aws_s3_checksum_config> checksumConfig{Aws::New<struct aws_s3_checksum_config>(ALLOCATION_TAG)};
checksumConfig->checksum_algorithm = mapping->second;
checksumConfig->location = AWS_SCL_TRAILER;
userData->checksumConfig = std::move(checksumConfig);
options.checksum_config = userData->checksumConfig.get();
}
else
{
AWS_LOGSTREAM_ERROR("PutObject", "Could not set CRT checksum for " << ChecksumAlgorithmMapper::GetNameForChecksumAlgorithm(checksumAlgorithm));
}
}

std::shared_ptr<Aws::Crt::Http::HttpRequest> crtHttpRequest = userData->request->ToCrtHttpRequest();
options.message= crtHttpRequest->GetUnderlyingMessage();
userData->crtHttpRequest = crtHttpRequest;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,11 @@ using namespace Aws::S3Crt;

static const char* ALLOC_TAG = "S3CrtIdentityProviderAdapter";

aws_s3express_credentials_provider *S3CrtIdentityProviderAdapter::ProviderFactory(struct aws_allocator *allocator,
struct aws_s3_client *client,
aws_simple_completion_callback *on_provider_shutdown_callback,
void *shutdown_user_data,
void *factory_user_data)
S3CrtIdentityProviderUserData::S3CrtIdentityProviderUserData(std::shared_ptr<S3ExpressIdentityProvider> identity_provider):
m_identityProvider(identity_provider),
m_impl(Aws::MakeUnique<struct aws_s3express_credentials_provider_vtable>(ALLOC_TAG))
{
// We use our own client in our internal implementation.
AWS_UNREFERENCED_PARAM(client);
struct aws_s3express_credentials_provider* provider = nullptr;
provider = static_cast<aws_s3express_credentials_provider *>(aws_mem_calloc(allocator, 1,
sizeof(struct aws_s3express_credentials_provider)));

auto userData = static_cast<S3ExpressIdentityProvider*>(factory_user_data);
struct aws_s3express_credentials_provider_vtable* s_cpp_s3express_table =
static_cast<aws_s3express_credentials_provider_vtable *>(aws_mem_calloc(allocator, 1,
sizeof(struct aws_s3express_credentials_provider_vtable)));

s_cpp_s3express_table->get_credentials = [](struct aws_s3express_credentials_provider* provider,
m_impl->get_credentials = [](struct aws_s3express_credentials_provider* provider,
const struct aws_credentials* original_credentials,
const struct aws_credentials_properties_s3express* s3express_properties,
aws_on_get_credentials_callback_fn* callback,
Expand All @@ -38,8 +25,11 @@ aws_s3express_credentials_provider *S3CrtIdentityProviderAdapter::ProviderFactor

//Figure out service specific params
Aws::Map<Aws::String, Aws::String> params;
struct aws_string *str = aws_string_new_from_cursor(get_aws_allocator(), &s3express_properties->host);
Aws::String hostname(aws_string_c_str(str));
Aws::UniquePtr<aws_string, std::function<void (aws_string*)>> hostnameCStr{
aws_string_new_from_cursor(get_aws_allocator(), &s3express_properties->host),
aws_string_destroy
};
Aws::String hostname(aws_string_c_str(hostnameCStr.get()));
// This requires the hostname be virtually addressed and will fail if not. In theory express
// hostname at this point in theory will always be this way for express hosts.
auto bucketName = hostname.substr(0, hostname.find('.'));
Expand All @@ -59,28 +49,45 @@ aws_s3express_credentials_provider *S3CrtIdentityProviderAdapter::ProviderFactor
session_token_cursor = aws_byte_cursor_from_c_str(creds.getSessionToken().c_str());
}

struct aws_credentials* credentials = aws_credentials_new(get_aws_allocator(),
Aws::UniquePtr<aws_credentials, std::function<void (aws_credentials*)>> credentials{
aws_credentials_new(get_aws_allocator(),
access_key_id_cursor,
secret_access_key_cursor,
session_token_cursor,
creds.getExpiration().Seconds());
creds.getExpiration().Seconds()),
aws_credentials_release
};

callback(credentials, AWS_OP_SUCCESS, user_data);
callback(credentials.get(), AWS_OP_SUCCESS, user_data);
return AWS_OP_SUCCESS;
};

s_cpp_s3express_table->destroy = [](struct aws_s3express_credentials_provider* provider) -> void
m_impl->destroy = [](struct aws_s3express_credentials_provider* provider) -> void
{
aws_simple_completion_callback *callback = provider->shutdown_complete_callback;
void *user_data = provider->shutdown_user_data;
aws_mem_release(provider->allocator, provider);
callback(user_data);
};
}

aws_s3express_credentials_provider* S3CrtIdentityProviderAdapter::ProviderFactory(struct aws_allocator* allocator,
struct aws_s3_client* client,
aws_simple_completion_callback* on_provider_shutdown_callback,
void* shutdown_user_data,
void* factory_user_data)
{
// We use our own client in our internal implementation.
AWS_UNREFERENCED_PARAM(client);
struct aws_s3express_credentials_provider* provider = nullptr;
provider = static_cast<aws_s3express_credentials_provider *>(aws_mem_calloc(allocator, 1,
sizeof(struct aws_s3express_credentials_provider)));

auto userData = static_cast<S3CrtIdentityProviderUserData*>(factory_user_data);
aws_s3express_credentials_provider_init_base(provider,
allocator,
s_cpp_s3express_table,
userData);
userData->GetImpl().get(),
userData->GetIdentityProvider().get());
provider->shutdown_complete_callback = on_provider_shutdown_callback;
provider->shutdown_user_data = shutdown_user_data;
return provider;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -551,15 +551,6 @@ namespace
SCOPED_TRACE(Aws::String("FullBucketName ") + fullBucketName);
CreateBucketRequest createBucketRequest;
createBucketRequest.SetBucket(fullBucketName);
createBucketRequest.SetACL(BucketCannedACL::private_);
{
CreateBucketConfiguration bucketConfiguration;
Aws::S3Crt::ClientConfiguration dummyClientConfig;
bucketConfiguration.SetLocationConstraint(
BucketLocationConstraintMapper::GetBucketLocationConstraintForName(dummyClientConfig.region));
createBucketRequest.SetCreateBucketConfiguration(bucketConfiguration);
}

CreateBucketOutcome createBucketOutcome = Client->CreateBucket(createBucketRequest);
AWS_ASSERT_SUCCESS(createBucketOutcome);
const CreateBucketResult& createBucketResult = createBucketOutcome.GetResult();
Expand Down Expand Up @@ -1383,7 +1374,6 @@ namespace
GetObjectOutcome outcome = Client->GetObject(getObjectRequest);

ASSERT_FALSE(outcome.IsSuccess());
ASSERT_EQ(outcome.GetError().GetErrorType(), Aws::S3Crt::S3CrtErrors::NETWORK_CONNECTION);
}

TEST_F(BucketAndObjectOperationTest, MissingCertificate) {
Expand Down
89 changes: 52 additions & 37 deletions tests/aws-cpp-sdk-s3-crt-integration-tests/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,37 +1,52 @@
add_project(aws-cpp-sdk-s3-crt-integration-tests
"Tests for the AWS S3 CRT C++ SDK"
aws-cpp-sdk-s3-crt
testing-resources
aws-cpp-sdk-core)

file(GLOB AWS_S3_CRT_SRC
"${CMAKE_CURRENT_SOURCE_DIR}/*.cpp"
)

file(GLOB AWS_S3_CRT_INTEGRATION_TESTS_SRC
${AWS_S3_CRT_SRC}
)

add_definitions(-DRESOURCES_DIR="${CMAKE_CURRENT_SOURCE_DIR}/resources")

if(MSVC AND BUILD_SHARED_LIBS)
add_definitions(-DGTEST_LINKED_AS_SHARED_LIBRARY=1)
endif()

enable_testing()

if(PLATFORM_ANDROID AND BUILD_SHARED_LIBS)
add_library(${PROJECT_NAME} ${AWS_S3_CRT_INTEGRATION_TESTS_SRC})
else()
add_executable(${PROJECT_NAME} ${AWS_S3_CRT_INTEGRATION_TESTS_SRC})
endif()

set_compiler_flags(${PROJECT_NAME})
set_compiler_warnings(${PROJECT_NAME})

target_link_libraries(${PROJECT_NAME} ${PROJECT_LIBS})

if(MSVC AND BUILD_SHARED_LIBS)
set_target_properties(${PROJECT_NAME} PROPERTIES LINK_FLAGS "/DELAYLOAD:aws-cpp-sdk-s3-crt.dll /DELAYLOAD:aws-cpp-sdk-core.dll")
target_link_libraries(${PROJECT_NAME} delayimp.lib)
endif()
set(TEST_LIST "aws-cpp-sdk-s3-crt-integration-tests:RunTests.cpp")

# Only run memlimiter test if we are building with custom memory management
if (CUSTOM_MEMORY_MANAGEMENT)
list(APPEND TEST_LIST "aws-cpp-sdk-s3-crt-memory-checked-integration-tests:RunTestsWithMemTracer.cpp")
endif ()

foreach(TEST IN LISTS TEST_LIST)
string(REPLACE ":" ";" TEST_ITEMS "${TEST}")
list(GET TEST_ITEMS 0 TEST_PROJECT_NAME)
list(GET TEST_ITEMS 1 TEST_MAIN_FILE)

add_project("${TEST_PROJECT_NAME}"
"Tests for the AWS S3 CRT C++ SDK"
aws-cpp-sdk-s3-crt
testing-resources
aws-cpp-sdk-core)

file(GLOB AWS_S3_CRT_SRC
"${TEST_MAIN_FILE}"
"BucketAndObjectOperationTest.cpp"
"S3ExpressTest.cpp"
)

file(GLOB AWS_S3_CRT_INTEGRATION_TESTS_SRC
${AWS_S3_CRT_SRC}
)

add_definitions(-DRESOURCES_DIR="${CMAKE_CURRENT_SOURCE_DIR}/resources")

if(MSVC AND BUILD_SHARED_LIBS)
add_definitions(-DGTEST_LINKED_AS_SHARED_LIBRARY=1)
endif()

enable_testing()

if(PLATFORM_ANDROID AND BUILD_SHARED_LIBS)
add_library(${PROJECT_NAME} ${AWS_S3_CRT_INTEGRATION_TESTS_SRC})
else()
add_executable(${PROJECT_NAME} ${AWS_S3_CRT_INTEGRATION_TESTS_SRC})
endif()

set_compiler_flags(${PROJECT_NAME})
set_compiler_warnings(${PROJECT_NAME})

target_link_libraries(${PROJECT_NAME} ${PROJECT_LIBS})

if(MSVC AND BUILD_SHARED_LIBS)
set_target_properties(${PROJECT_NAME} PROPERTIES LINK_FLAGS "/DELAYLOAD:aws-cpp-sdk-s3-crt.dll /DELAYLOAD:aws-cpp-sdk-core.dll")
target_link_libraries(${PROJECT_NAME} delayimp.lib)
endif()
endforeach ()
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/**
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0.
*/

#include <gtest/gtest.h>
#include <aws/core/Aws.h>
#include <aws/testing/platform/PlatformTesting.h>
#include <aws/testing/TestingEnvironment.h>
#include <aws/testing/MemoryTesting.h>

int main(int argc, char** argv) {
Aws::Testing::SetDefaultSigPipeHandler();
Aws::SDKOptions options;
options.loggingOptions.logLevel = Aws::Utils::Logging::LogLevel::Trace;
CRTMemTracerMemorySystem memorySystem{};
options.memoryManagementOptions.memoryManager = &memorySystem;
Aws::Testing::InitPlatformTest(options);
Aws::Testing::ParseArgs(argc, argv);

Aws::InitAPI(options);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think Begin() goes here and End() right after "ShutdownAPI"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is copy and pasted more or less from what we already do and how we use it in the crt test already

::testing::InitGoogleTest(&argc, argv);
int exitCode = RUN_ALL_TESTS();

Aws::ShutdownAPI(options);
memorySystem.AssertNoLeaks();
Aws::Testing::ShutdownPlatformTest(options);
return exitCode;
}
19 changes: 19 additions & 0 deletions tests/aws-cpp-sdk-s3-crt-integration-tests/S3ExpressTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,25 @@ namespace {
}
}

TEST_F(S3ExpressTest, PutObjectChecksumWithoutAlgorithmValue) {
const auto bucketName = Testing::GetAwsResourcePrefix() + randomString() + S3_EXPRESS_SUFFIX;
const auto createOutcome = CreateBucket(bucketName);
AWS_EXPECT_SUCCESS(createOutcome);

auto request = PutObjectRequest()
.WithBucket(bucketName)
.WithKey("swingingparty")
.WithChecksumAlgorithm(ChecksumAlgorithm::CRC32);

std::shared_ptr<IOStream> body = Aws::MakeShared<StringStream>(ALLOCATION_TAG,
"Bring your own lampshade, somewhere there's a party.",
std::ios_base::in | std::ios_base::binary);
request.SetBody(body);

const auto response = client->PutObject(request);
AWS_EXPECT_SUCCESS(response);
}

TEST_F(S3ExpressTest, PutObjectChecksumWithoutAlgorithm) {
const auto bucketName = Testing::GetAwsResourcePrefix() + randomString() + S3_EXPRESS_SUFFIX;
const auto createOutcome = CreateBucket(bucketName);
Expand Down
Loading
Loading