Skip to content

Commit

Permalink
Simplify awsconfig loading (#50809)
Browse files Browse the repository at this point in the history
This replaces awsconfig.WithIntegrationCredentialProvider option with the
awsconfig.WithOIDCIntegrationClient option.
This solves a chicken/egg problem with AWS config loading - callers no
longer need to load AWS config (to create a credential provider) to load
AWS config.
The OIDCIntegrationClient interface is also much simpler to implement.

This also adds default option overrides when creating an awsconfig.Cache.
For now, this is used to add an OIDCIntegrationClient when creating the
cache so that dependent callers don't have to.
  • Loading branch information
GavinFrazar authored Jan 8, 2025
1 parent 62dbb2c commit 20828a2
Show file tree
Hide file tree
Showing 11 changed files with 349 additions and 172 deletions.
129 changes: 97 additions & 32 deletions lib/cloud/awsconfig/awsconfig.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"github.com/gravitational/trace"
"go.opentelemetry.io/otel"

"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/modules"
)

Expand All @@ -43,12 +44,25 @@ const (
credentialsSourceIntegration
)

// IntegrationSessionProviderFunc defines a function that creates a credential provider from a region and an integration.
// This is used to generate aws configs for clients that must use an integration instead of ambient credentials.
type IntegrationCredentialProviderFunc func(ctx context.Context, region, integration string) (aws.CredentialsProvider, error)
// OIDCIntegrationClient is an interface that indicates which APIs are
// required to generate an AWS OIDC integration token.
type OIDCIntegrationClient interface {
// GetIntegration returns the specified integration resource.
GetIntegration(ctx context.Context, name string) (types.Integration, error)

// AssumeRoleClientProviderFunc provides an AWS STS assume role API client.
type AssumeRoleClientProviderFunc func(aws.Config) stscreds.AssumeRoleAPIClient
// GenerateAWSOIDCToken generates a token to be used to execute an AWS OIDC
// Integration action.
GenerateAWSOIDCToken(ctx context.Context, integrationName string) (string, error)
}

// STSClient is a subset of the AWS STS API.
type STSClient interface {
stscreds.AssumeRoleAPIClient
stscreds.AssumeRoleWithWebIdentityAPIClient
}

// STSClientProviderFunc provides an AWS STS assume role API client.
type STSClientProviderFunc func(aws.Config) STSClient

// AssumeRole is an AWS role to assume, optionally with an external ID.
type AssumeRole struct {
Expand All @@ -68,14 +82,16 @@ type options struct {
credentialsSource credentialsSource
// integration is the name of the integration to be used to fetch the credentials.
integration string
// integrationCredentialsProvider is the integration credential provider to use.
integrationCredentialsProvider IntegrationCredentialProviderFunc
// oidcIntegrationClient provides APIs to generate AWS OIDC tokens, which
// can then be exchanged for IAM credentials.
// Required if integration credentials are requested.
oidcIntegrationClient OIDCIntegrationClient
// customRetryer is a custom retryer to use for the config.
customRetryer func() aws.Retryer
// maxRetries is the maximum number of retries to use for the config.
maxRetries *int
// assumeRoleClientProvider sets the STS assume role client provider func.
assumeRoleClientProvider AssumeRoleClientProviderFunc
// stsClientProvider sets the STS assume role client provider func.
stsClientProvider STSClientProviderFunc
}

func buildOptions(optFns ...OptionsFn) (*options, error) {
Expand All @@ -99,15 +115,18 @@ func (o *options) checkAndSetDefaults() error {
if o.integration == "" {
return trace.BadParameter("missing integration name")
}
if o.oidcIntegrationClient == nil {
return trace.BadParameter("missing AWS OIDC integration client")
}
default:
return trace.BadParameter("missing credentials source (ambient or integration)")
}
if len(o.assumeRoles) > 2 {
return trace.BadParameter("role chain contains more than 2 roles")
}

if o.assumeRoleClientProvider == nil {
o.assumeRoleClientProvider = func(cfg aws.Config) stscreds.AssumeRoleAPIClient {
if o.stsClientProvider == nil {
o.stsClientProvider = func(cfg aws.Config) STSClient {
return sts.NewFromConfig(cfg, func(o *sts.Options) {
o.TracerProvider = smithyoteltracing.Adapt(otel.GetTracerProvider())
})
Expand Down Expand Up @@ -175,18 +194,17 @@ func WithAmbientCredentials() OptionsFn {
}
}

// WithIntegrationCredentialProvider sets the integration credential provider.
func WithIntegrationCredentialProvider(cred IntegrationCredentialProviderFunc) OptionsFn {
// WithSTSClientProvider sets the STS API client factory func.
func WithSTSClientProvider(fn STSClientProviderFunc) OptionsFn {
return func(options *options) {
options.integrationCredentialsProvider = cred
options.stsClientProvider = fn
}
}

// WithAssumeRoleClientProviderFunc sets the STS API client factory func used to
// assume roles.
func WithAssumeRoleClientProviderFunc(fn AssumeRoleClientProviderFunc) OptionsFn {
// WithOIDCIntegrationClient sets the OIDC integration client.
func WithOIDCIntegrationClient(c OIDCIntegrationClient) OptionsFn {
return func(options *options) {
options.assumeRoleClientProvider = fn
options.oidcIntegrationClient = c
}
}

Expand All @@ -202,7 +220,7 @@ func GetConfig(ctx context.Context, region string, optFns ...OptionsFn) (aws.Con
if err != nil {
return aws.Config{}, trace.Wrap(err)
}
return getConfigForRoleChain(ctx, cfg, opts.assumeRoles, opts.assumeRoleClientProvider)
return getConfigForRoleChain(ctx, cfg, opts.assumeRoles, opts.stsClientProvider)
}

// loadDefaultConfig loads a new config.
Expand All @@ -217,6 +235,7 @@ func buildConfigOptions(region string, cred aws.CredentialsProvider, opts *optio
config.WithDefaultRegion(defaultRegion),
config.WithRegion(region),
config.WithCredentialsProvider(cred),
config.WithCredentialsCacheOptions(awsCredentialsCacheOptions),
}
if modules.GetModules().IsBoringBinary() {
configOpts = append(configOpts, config.WithUseFIPSEndpoint(aws.FIPSEndpointStateEnabled))
Expand All @@ -232,27 +251,35 @@ func buildConfigOptions(region string, cred aws.CredentialsProvider, opts *optio

// getBaseConfig returns an AWS config without assuming any roles.
func getBaseConfig(ctx context.Context, region string, opts *options) (aws.Config, error) {
var cred aws.CredentialsProvider
slog.DebugContext(ctx, "Initializing AWS config from default credential chain",
"region", region,
)
cfg, err := loadDefaultConfig(ctx, region, nil, opts)
if err != nil {
return aws.Config{}, trace.Wrap(err)
}

if opts.credentialsSource == credentialsSourceIntegration {
if opts.integrationCredentialsProvider == nil {
return aws.Config{}, trace.BadParameter("missing aws integration credential provider")
slog.DebugContext(ctx, "Initializing AWS config with OIDC integration credentials",
"region", region,
"integration", opts.integration,
)
provider := &integrationCredentialsProvider{
OIDCIntegrationClient: opts.oidcIntegrationClient,
stsClt: opts.stsClientProvider(cfg),
integrationName: opts.integration,
}

slog.DebugContext(ctx, "Initializing AWS config with integration", "region", region, "integration", opts.integration)
var err error
cred, err = opts.integrationCredentialsProvider(ctx, region, opts.integration)
cc := aws.NewCredentialsCache(provider, awsCredentialsCacheOptions)
_, err := cc.Retrieve(ctx)
if err != nil {
return aws.Config{}, trace.Wrap(err)
}
} else {
slog.DebugContext(ctx, "Initializing AWS config from default credential chain", "region", region)
cfg.Credentials = cc
}

cfg, err := loadDefaultConfig(ctx, region, cred, opts)
return cfg, trace.Wrap(err)
return cfg, nil
}

func getConfigForRoleChain(ctx context.Context, cfg aws.Config, roles []AssumeRole, newCltFn AssumeRoleClientProviderFunc) (aws.Config, error) {
func getConfigForRoleChain(ctx context.Context, cfg aws.Config, roles []AssumeRole, newCltFn STSClientProviderFunc) (aws.Config, error) {
for _, r := range roles {
cfg.Credentials = getAssumeRoleProvider(ctx, newCltFn(cfg), r)
}
Expand All @@ -277,3 +304,41 @@ func getAssumeRoleProvider(ctx context.Context, clt stscreds.AssumeRoleAPIClient
}
})
}

// staticIdentityToken provides itself as a JWT []byte token to implement
// [stscreds.IdentityTokenRetriever].
type staticIdentityToken string

// GetIdentityToken retrieves the JWT token.
func (t staticIdentityToken) GetIdentityToken() ([]byte, error) {
return []byte(t), nil
}

// integrationCredentialsProvider provides AWS OIDC integration credentials.
type integrationCredentialsProvider struct {
OIDCIntegrationClient
stsClt STSClient
integrationName string
}

// Retrieve provides [aws.Credentials] for an AWS OIDC integration.
func (p *integrationCredentialsProvider) Retrieve(ctx context.Context) (aws.Credentials, error) {
integration, err := p.GetIntegration(ctx, p.integrationName)
if err != nil {
return aws.Credentials{}, trace.Wrap(err)
}
spec := integration.GetAWSOIDCIntegrationSpec()
if spec == nil {
return aws.Credentials{}, trace.BadParameter("invalid integration subkind, expected awsoidc, got %s", integration.GetSubKind())
}
token, err := p.GenerateAWSOIDCToken(ctx, p.integrationName)
if err != nil {
return aws.Credentials{}, trace.Wrap(err)
}
cred, err := stscreds.NewWebIdentityRoleProvider(
p.stsClt,
spec.RoleARN,
staticIdentityToken(token),
).Retrieve(ctx)
return cred, trace.Wrap(err)
}
Loading

0 comments on commit 20828a2

Please sign in to comment.