diff --git a/src/aws-cpp-sdk-identity-management/include/aws/identity-management/auth/STSProfileCredentialsProvider.h b/src/aws-cpp-sdk-identity-management/include/aws/identity-management/auth/STSProfileCredentialsProvider.h index 4b90bb01ec7..524554bf14a 100644 --- a/src/aws-cpp-sdk-identity-management/include/aws/identity-management/auth/STSProfileCredentialsProvider.h +++ b/src/aws-cpp-sdk-identity-management/include/aws/identity-management/auth/STSProfileCredentialsProvider.h @@ -52,8 +52,47 @@ namespace Aws */ STSProfileCredentialsProvider(const Aws::String& profileName, std::chrono::minutes duration = std::chrono::minutes(60)); + /** + * Use the provided profile name from the shared configuration file and a custom STS client. + * + * @param profileName The name of the profile in the shared configuration file. + * @param duration The duration, in minutes, of the role session, after which the credentials are expired. + * The value can range from 15 minutes up to the maximum session duration setting for the role. By default, + * the duration is set to 1 hour. + * Note: This credential provider refreshes the credentials 5 minutes before their expiration time. That + * ensures the credentials do not expire between the time they're checked and the time they're returned to + * the user. + * If the duration for the credentials is 5 minutes or less, the provider will refresh the credentials only + * when they expire. + * @param stsClientFactory A factory function that creates an STSClient with specific credentials. + * Using the overload where the function returns a shared_ptr is preferred. + * + */ STSProfileCredentialsProvider(const Aws::String& profileName, std::chrono::minutes duration, const std::function &stsClientFactory); + /** + * Use the provided profile name from the shared configuration file and a custom STS client. + * + * @param profileName The name of the profile in the shared configuration file. + * @param duration The duration, in minutes, of the role session, after which the credentials are expired. + * The value can range from 15 minutes up to the maximum session duration setting for the role. By default, + * the duration is set to 1 hour. + * Note: This credential provider refreshes the credentials 5 minutes before their expiration time. That + * ensures the credentials do not expire between the time they're checked and the time they're returned to + * the user. + * If the duration for the credentials is 5 minutes or less, the provider will refresh the credentials only + * when they expire. + * @param stsClientFactory A factory function that creates an STSClient with specific credentials. + * + */ + STSProfileCredentialsProvider(const Aws::String& profileName, std::chrono::minutes duration, const std::function(const AWSCredentials&)> &stsClientFactory); + + /** + * Compatibility constructor to assist with overload resolution when passing nullptr for the client factory. + * + */ + STSProfileCredentialsProvider(const Aws::String& profileName, std::chrono::minutes duration, std::nullptr_t); + /** * Fetches the credentials set from STS following the rules defined in the shared configuration file. */ @@ -74,7 +113,7 @@ namespace Aws AWSCredentials m_credentials; const std::chrono::minutes m_duration; const std::chrono::milliseconds m_reloadFrequency; - std::function m_stsClientFactory; + std::function(const AWSCredentials&)> m_stsClientFactory; }; } } diff --git a/src/aws-cpp-sdk-identity-management/source/auth/STSProfileCredentialsProvider.cpp b/src/aws-cpp-sdk-identity-management/source/auth/STSProfileCredentialsProvider.cpp index fd82b678fba..a362eccd541 100644 --- a/src/aws-cpp-sdk-identity-management/source/auth/STSProfileCredentialsProvider.cpp +++ b/src/aws-cpp-sdk-identity-management/source/auth/STSProfileCredentialsProvider.cpp @@ -17,6 +17,12 @@ using namespace Aws::Auth; constexpr char CLASS_TAG[] = "STSProfileCredentialsProvider"; +template +struct NoOpDeleter +{ + void operator()(T*) {} +}; + STSProfileCredentialsProvider::STSProfileCredentialsProvider() : STSProfileCredentialsProvider(GetConfigProfileName(), std::chrono::minutes(60)/*duration*/, nullptr/*stsClientFactory*/) { @@ -27,8 +33,24 @@ STSProfileCredentialsProvider::STSProfileCredentialsProvider(const Aws::String& { } +STSProfileCredentialsProvider::STSProfileCredentialsProvider(const Aws::String& profileName, std::chrono::minutes duration, std::nullptr_t) + : m_profileName(profileName), + m_duration(duration), + m_reloadFrequency(std::chrono::minutes(std::max(int64_t(5), static_cast(duration.count()))) - std::chrono::minutes(5)), + m_stsClientFactory(nullptr) +{ +} + STSProfileCredentialsProvider::STSProfileCredentialsProvider(const Aws::String& profileName, std::chrono::minutes duration, const std::function &stsClientFactory) : m_profileName(profileName), + m_duration(duration), + m_reloadFrequency(std::chrono::minutes(std::max(int64_t(5), static_cast(duration.count()))) - std::chrono::minutes(5)), + m_stsClientFactory([=](const auto& credentials) {return std::shared_ptr(stsClientFactory(credentials), NoOpDeleter()); }) +{ +} + +STSProfileCredentialsProvider::STSProfileCredentialsProvider(const Aws::String& profileName, std::chrono::minutes duration, const std::function (const AWSCredentials&)>& stsClientFactory) + : m_profileName(profileName), m_duration(duration), m_reloadFrequency(std::chrono::minutes(std::max(int64_t(5), static_cast(duration.count()))) - std::chrono::minutes(5)), m_stsClientFactory(stsClientFactory) @@ -337,7 +359,8 @@ AWSCredentials STSProfileCredentialsProvider::GetCredentialsFromSTS(const AWSCre { using namespace Aws::STS::Model; if (m_stsClientFactory) { - return GetCredentialsFromSTSInternal(roleArn, m_stsClientFactory(credentials)); + auto client = m_stsClientFactory(credentials); + return GetCredentialsFromSTSInternal(roleArn, client.get()); } Aws::STS::STSClient stsClient {credentials}; diff --git a/tests/aws-cpp-sdk-identity-management-tests/auth/STSProfileCredentialsProviderTest.cpp b/tests/aws-cpp-sdk-identity-management-tests/auth/STSProfileCredentialsProviderTest.cpp index 197535a6a2e..cf234107db0 100644 --- a/tests/aws-cpp-sdk-identity-management-tests/auth/STSProfileCredentialsProviderTest.cpp +++ b/tests/aws-cpp-sdk-identity-management-tests/auth/STSProfileCredentialsProviderTest.cpp @@ -313,7 +313,7 @@ TEST_F(STSProfileCredentialsProviderTest, AssumeRoleWithoutRoleARN) STSProfileCredentialsProvider credsProvider("default", roleSessionDuration, [](const AWSCredentials&) { ADD_FAILURE() << "STS Service client should not be used in this scenario."; - return nullptr; + return (STSClient*)nullptr; }); auto actualCredentials = credsProvider.GetAWSCredentials(); @@ -383,7 +383,7 @@ TEST_F(STSProfileCredentialsProviderTest, AssumeRoleWithoutSourceProfile) STSProfileCredentialsProvider credsProvider("default", roleSessionDuration, [](const AWSCredentials&) { ADD_FAILURE() << "STS Service client should not be used in this scenario."; - return nullptr; + return (STSClient*)nullptr; }); auto actualCredentials = credsProvider.GetAWSCredentials(); @@ -409,7 +409,7 @@ TEST_F(STSProfileCredentialsProviderTest, AssumeRoleWithNonExistentSourceProfile STSProfileCredentialsProvider credsProvider("default", roleSessionDuration, [](const AWSCredentials&) { ADD_FAILURE() << "STS Service client should not be used in this scenario."; - return nullptr; + return (STSClient*)nullptr; }); auto actualCredentials = credsProvider.GetAWSCredentials(); @@ -556,7 +556,7 @@ TEST_F(STSProfileCredentialsProviderTest, AssumeRoleSelfReferencingSourceProfile Model::AssumeRoleResult mockResult; mockResult.SetCredentials(stsCredentials); - Aws::UniquePtr stsClient; + std::shared_ptr stsClient; int stsCallCounter = 0; @@ -572,9 +572,9 @@ TEST_F(STSProfileCredentialsProviderTest, AssumeRoleSelfReferencingSourceProfile EXPECT_STREQ(ACCESS_KEY_ID_2, creds.GetAWSAccessKeyId().c_str()); EXPECT_STREQ(SECRET_ACCESS_KEY_ID_2, creds.GetAWSSecretKey().c_str()); } - stsClient = Aws::MakeUnique(CLASS_TAG, creds); + stsClient = Aws::MakeShared(CLASS_TAG, creds); stsClient->MockAssumeRole(mockResult); - return stsClient.get(); + return stsClient; }); auto actualCredentials = credsProvider.GetAWSCredentials(); @@ -614,7 +614,7 @@ TEST_F(STSProfileCredentialsProviderTest, AssumeRoleRecursivelyCircularReference STSProfileCredentialsProvider credsProvider("default", roleSessionDuration, [](const AWSCredentials&) { ADD_FAILURE() << "STS Service client should not be used in this scenario."; - return nullptr; + return (STSClient*)nullptr; }); auto actualCredentials = credsProvider.GetAWSCredentials();