diff --git a/README.md b/README.md index c54692e..4ac1bd2 100644 --- a/README.md +++ b/README.md @@ -82,8 +82,50 @@ $ aws-spiffe-workload-helper x509-credential-process \ | session-duration | No | The duration, in seconds, of the resulting session. Optional. Can range from 15 minutes (900) to 12 hours (43200). | `3600` | | workload-api-addr | No | Overrides the address of the Workload API endpoint that will be use to fetch the X509 SVID. If unspecified, the value from the SPIFFE_ENDPOINT_SOCKET environment variable will be used. | `unix:///opt/my/path/workload.sock` | +#### `x509-credential-file` -### Configuring AWS SDKs and CLIs +The `x509-credential-file` command starts a long-lived daemon which exchanges +an X509 SVID for a short-lived set of AWS credentials using the AWS Roles +Anywhere API. It writes the credentials to a specified file in the format +supported by AWS SDKs and CLIs as a "credential file". + +It repeats this exchange process when the AWS credentials are more than 50% of +the way through their lifetime, ensuring that a fresh set of credentials are +always available. + +Whilst the `x509-credentials-process` flow should be preferred as it does not +cause credentials to be written to the filesystem, the `x509-credentials-file` +flow may be useful in scenarios where you need to provide credentials to legacy +SDKs or CLIs that do not support the `credential_process` configuration. + +The command fetches the X509-SVID from the SPIFFE Workload API. The location of +the SPIFFE Workload API endpoint should be specified using the +`SPIFFE_ENDPOINT_SOCKET` environment variable or the `--workload-api-addr` flag. + +```sh +$ aws-spiffe-workload-helper x509-credential-file \ + --trust-anchor-arn arn:aws:rolesanywhere:us-east-1:123456789012:trust-anchor/0000000-0000-0000-0000-000000000000 \ + --profile-arn arn:aws:rolesanywhere:us-east-1:123456789012:profile/0000000-0000-0000-0000-000000000000 \ + --role-arn arn:aws:iam::123456789012:role/example-role \ + --workload-api-addr unix:///opt/workload-api.sock \ + --aws-credentials-file /opt/my-aws-credentials-file +``` + +###### Reference + +| Flag | Required | Description | Example | +|----------------------|----------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------| +| role-arn | Yes | The ARN of the role to assume. Required. | `arn:aws:iam::123456789012:role/example-role` | +| profile-arn | Yes | The ARN of the Roles Anywhere profile to use. Required. | `arn:aws:rolesanywhere:us-east-1:123456789012:profile/0000000-0000-0000-0000-00000000000` | +| trust-anchor-arn | Yes | The ARN of the Roles Anywhere trust anchor to use. Required. | `arn:aws:rolesanywhere:us-east-1:123456789012:trust-anchor/0000000-0000-0000-0000-000000000000` | +| region | No | Overrides AWS region to use when exchanging the SVID for AWS credentials. Optional. | `us-east-1` | +| session-duration | No | The duration, in seconds, of the resulting session. Optional. Can range from 15 minutes (900) to 12 hours (43200). | `3600` | +| workload-api-addr | No | Overrides the address of the Workload API endpoint that will be use to fetch the X509 SVID. If unspecified, the value from the SPIFFE_ENDPOINT_SOCKET environment variable will be used. | `unix:///opt/my/path/workload.sock` | +| aws-credentials-path | Yes | The path to the AWS credentials file to write. | `/opt/my-aws-credentials-file | +| force | No | If set, failures loading the existing AWS credentials file will be ignored and the contents overwritten. | | +| replace | No | If set, the AWS credentials file will be replaced if it exists. This will remove any profiles not written by this tool. | | + +## Configuring AWS SDKs and CLIs To configure AWS SDKs and CLIs to use Roles Anywhere and SPIFFE for authentication, you will modify the AWS configuration file. diff --git a/cmd/credential_file.go b/cmd/credential_file.go new file mode 100644 index 0000000..6788822 --- /dev/null +++ b/cmd/credential_file.go @@ -0,0 +1,251 @@ +package main + +import ( + "context" + "fmt" + "log/slog" + "time" + + "github.com/spf13/cobra" + "github.com/spiffe/aws-spiffe-workload-helper/internal" + "github.com/spiffe/go-spiffe/v2/workloadapi" +) + +func newX509CredentialFileOneshotCmd() (*cobra.Command, error) { + force := false + replace := false + awsCredentialsPath := "" + sf := &sharedFlags{} + cmd := &cobra.Command{ + Use: "x509-credential-file-oneshot", + Short: `Exchanges an X509 SVID for a short-lived set of AWS credentials using AWS Roles Anywhere. Writes the credentials to a file in the 'credential file' format expected by the AWS CLI and SDKs.`, + Long: `Exchanges an X509 SVID for a short-lived set of AWS credentials using AWS Roles Anywhere. Writes the credentials to a file in the 'credential file' format expected by the AWS CLI and SDKs.`, + RunE: func(cmd *cobra.Command, args []string) error { + return oneshotX509CredentialFile( + cmd.Context(), force, replace, awsCredentialsPath, sf, + ) + }, + } + if err := sf.addFlags(cmd); err != nil { + return nil, fmt.Errorf("adding shared flags: %w", err) + } + cmd.Flags().StringVar(&awsCredentialsPath, "aws-credentials-path", "", "The path to the AWS credentials file to write.") + if err := cmd.MarkFlagRequired("aws-credentials-path"); err != nil { + return nil, fmt.Errorf("marking aws-credentials-path flag as required: %w", err) + } + cmd.Flags().BoolVar(&force, "force", false, "If set, failures loading the existing AWS credentials file will be ignored and the contents overwritten.") + cmd.Flags().BoolVar(&replace, "replace", false, "If set, the AWS credentials file will be replaced if it exists. This will remove any profiles not written by this tool.") + + return cmd, nil +} + +func oneshotX509CredentialFile( + ctx context.Context, + force bool, + replace bool, + awsCredentialsPath string, + sf *sharedFlags, +) error { + client, err := workloadapi.New( + ctx, + workloadapi.WithAddr(sf.workloadAPIAddr), + ) + if err != nil { + return fmt.Errorf("creating workload api client: %w", err) + } + defer func() { + if err := client.Close(); err != nil { + slog.Warn("Failed to close workload API client", "error", err) + } + }() + + x509Ctx, err := client.FetchX509Context(ctx) + if err != nil { + return fmt.Errorf("fetching x509 context: %w", err) + } + svid := x509Ctx.DefaultSVID() + slog.Info( + "Fetched X509 SVID", + "svid", svidValue(svid), + ) + + credentials, err := exchangeX509SVIDForAWSCredentials(sf, svid) + if err != nil { + return fmt.Errorf("exchanging X509 SVID for AWS credentials: %w", err) + } + + expiresAt, err := time.Parse(time.RFC3339, credentials.Expiration) + if err != nil { + return fmt.Errorf("parsing expiration time: %w", err) + } + + // Now we write this to disk in the format that the AWS CLI/SDK + // expects for a credentials file. + err = internal.UpsertAWSCredentialsFileProfile( + slog.Default(), + internal.AWSCredentialsFileConfig{ + Path: awsCredentialsPath, + Force: force, + ReplaceFile: replace, + }, + internal.AWSCredentialsFileProfile{ + AWSAccessKeyID: credentials.AccessKeyId, + AWSSecretAccessKey: credentials.SecretAccessKey, + AWSSessionToken: credentials.SessionToken, + }, + ) + if err != nil { + return fmt.Errorf("writing credentials to file: %w", err) + } + slog.Info( + "Wrote AWS credential to file", + "path", awsCredentialsPath, + "aws_expires_at", expiresAt, + ) + return nil +} + +func newX509CredentialFileCmd() (*cobra.Command, error) { + force := false + replace := false + awsCredentialsPath := "" + sf := &sharedFlags{} + cmd := &cobra.Command{ + Use: "x509-credential-file", + Short: `On a regular basis, this daemon exchanges an X509 SVID for a short-lived set of AWS credentials using AWS Roles Anywhere. Writes the credentials to a file in the 'credential file' format expected by the AWS CLI and SDKs.`, + Long: `On a regular basis, this daemon exchanges an X509 SVID for a short-lived set of AWS credentials using AWS Roles Anywhere. Writes the credentials to a file in the 'credential file' format expected by the AWS CLI and SDKs.`, + RunE: func(cmd *cobra.Command, args []string) error { + return daemonX509CredentialFile( + cmd.Context(), force, replace, awsCredentialsPath, sf, + ) + }, + } + if err := sf.addFlags(cmd); err != nil { + return nil, fmt.Errorf("adding shared flags: %w", err) + } + cmd.Flags().StringVar(&awsCredentialsPath, "aws-credentials-path", "", "The path to the AWS credentials file to write.") + if err := cmd.MarkFlagRequired("aws-credentials-path"); err != nil { + return nil, fmt.Errorf("marking aws-credentials-path flag as required: %w", err) + } + cmd.Flags().BoolVar(&force, "force", false, "If set, failures loading the existing AWS credentials file will be ignored and the contents overwritten.") + cmd.Flags().BoolVar(&replace, "replace", false, "If set, the AWS credentials file will be replaced if it exists. This will remove any profiles not written by this tool.") + + return cmd, nil +} + +func daemonX509CredentialFile( + ctx context.Context, + force bool, + replace bool, + awsCredentialsPath string, + sf *sharedFlags, +) error { + slog.Info("Starting AWS credential file daemon") + client, err := workloadapi.New( + ctx, + workloadapi.WithAddr(sf.workloadAPIAddr), + ) + if err != nil { + return fmt.Errorf("creating workload api client: %w", err) + } + defer func() { + if err := client.Close(); err != nil { + slog.Warn("Failed to close workload API client", "error", err) + } + }() + + slog.Debug("Fetching initial X509 SVID") + x509Source, err := workloadapi.NewX509Source(ctx, workloadapi.WithClient(client)) + if err != nil { + return fmt.Errorf("creating x509 source: %w", err) + } + defer func() { + if err := x509Source.Close(); err != nil { + slog.Warn("Failed to close x509 source", "error", err) + } + }() + + svidUpdate := x509Source.Updated() + svid, err := x509Source.GetX509SVID() + if err != nil { + return fmt.Errorf("fetching initial X509 SVID: %w", err) + } + slog.Info("Fetched initial X509 SVID", "svid", svidValue(svid)) + + for { + slog.Debug( + "Exchanging X509 SVID for AWS credentials", + "svid", svidValue(svid), + ) + credentials, err := exchangeX509SVIDForAWSCredentials(sf, svid) + if err != nil { + return fmt.Errorf("exchanging X509 SVID for AWS credentials: %w", err) + } + slog.Info( + "Successfully exchanged X509 SVID for AWS credentials", + "svid", svidValue(svid), + ) + + expiresAt, err := time.Parse(time.RFC3339, credentials.Expiration) + if err != nil { + return fmt.Errorf("parsing expiration time: %w", err) + } + + slog.Debug("Writing AWS credentials to file", "path", awsCredentialsPath) + err = internal.UpsertAWSCredentialsFileProfile( + slog.Default(), + internal.AWSCredentialsFileConfig{ + Path: awsCredentialsPath, + Force: force, + ReplaceFile: replace, + }, + internal.AWSCredentialsFileProfile{ + AWSAccessKeyID: credentials.AccessKeyId, + AWSSecretAccessKey: credentials.SecretAccessKey, + AWSSessionToken: credentials.SessionToken, + }, + ) + if err != nil { + return fmt.Errorf("writing credentials to file: %w", err) + } + slog.Info("Wrote AWS credentials to file", "path", awsCredentialsPath) + + // Calculate next renewal time as 50% of the remaining time left on the + // AWS credentials. + // TODO(noah): This is a little crude, it may make more sense to just + // renew on a fixed basis (e.g every minute?). We'll go with this + // for now, and speak to consumers once it's in use to see if a + // different mechanism may be more suitable. + now := time.Now() + awsTTL := expiresAt.Sub(now) + renewIn := awsTTL / 2 + awsRenewAt := now.Add(renewIn) + + slog.Info( + "Sleeping until a new X509 SVID is received or the AWS credentials are close to expiry", + "aws_expires_at", expiresAt, + "aws_ttl", awsTTL, + "aws_renews_at", awsRenewAt, + "svid_expires_at", svid.Certificates[0].NotAfter, + "svid_ttl", svid.Certificates[0].NotAfter.Sub(now), + ) + + select { + case <-time.After(time.Until(awsRenewAt)): + slog.Info("Triggering renewal as AWS credentials are close to expiry") + case <-svidUpdate: + slog.Debug("Received potential X509 SVID update") + newSVID, err := x509Source.GetX509SVID() + if err != nil { + return fmt.Errorf("fetching updated X509 SVID: %w", err) + } + slog.Info( + "Received new X509 SVID from Workload API, will update AWS credentials", + "svid", svidValue(svid), + ) + svid = newSVID + case <-ctx.Done(): + return nil + } + } +} diff --git a/cmd/credential_process.go b/cmd/credential_process.go new file mode 100644 index 0000000..79e762d --- /dev/null +++ b/cmd/credential_process.go @@ -0,0 +1,64 @@ +package main + +import ( + "encoding/json" + "fmt" + "log/slog" + "os" + + "github.com/spf13/cobra" + "github.com/spiffe/go-spiffe/v2/workloadapi" +) + +func newX509CredentialProcessCmd() (*cobra.Command, error) { + sf := &sharedFlags{} + cmd := &cobra.Command{ + Use: "x509-credential-process", + Short: `Exchanges an X509 SVID for a short-lived set of AWS credentials using AWS Roles Anywhere. Compatible with the AWS credential process functionality.`, + Long: `Exchanges an X509 SVID for a short-lived set of AWS credentials using the AWS Roles Anywhere API. It returns the credentials to STDOUT, in the format expected by AWS SDKs and CLIs when invoking an external credential process.`, + RunE: func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + client, err := workloadapi.New( + ctx, + workloadapi.WithAddr(sf.workloadAPIAddr), + ) + if err != nil { + return fmt.Errorf("creating workload api client: %w", err) + } + defer func() { + if err := client.Close(); err != nil { + slog.Warn("Failed to close workload API client", "error", err) + } + }() + + x509Ctx, err := client.FetchX509Context(ctx) + if err != nil { + return fmt.Errorf("fetching x509 context: %w", err) + } + // TODO(strideynet): Implement SVID selection mechanism, for now, + // we'll just use the first returned SVID (a.k.a the default). + svid := x509Ctx.DefaultSVID() + slog.Debug("Fetched X509 SVID", "svid", svidValue(svid)) + + credentials, err := exchangeX509SVIDForAWSCredentials(sf, svid) + if err != nil { + return fmt.Errorf("exchanging X509 SVID for AWS credentials: %w", err) + } + + out, err := json.Marshal(credentials) + if err != nil { + return fmt.Errorf("marshalling credentials: %w", err) + } + _, err = os.Stdout.Write(out) + if err != nil { + return fmt.Errorf("writing credentials to stdout: %w", err) + } + return nil + }, + } + if err := sf.addFlags(cmd); err != nil { + return nil, fmt.Errorf("adding shared flags: %w", err) + } + + return cmd, nil +} diff --git a/cmd/main.go b/cmd/main.go index 36eb522..1e4fb83 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -1,7 +1,6 @@ package main import ( - "encoding/json" "fmt" "log/slog" "os" @@ -9,7 +8,7 @@ import ( "github.com/spf13/cobra" awsspiffe "github.com/spiffe/aws-spiffe-workload-helper" "github.com/spiffe/aws-spiffe-workload-helper/internal/vendoredaws" - "github.com/spiffe/go-spiffe/v2/workloadapi" + "github.com/spiffe/go-spiffe/v2/svid/x509svid" ) var ( @@ -50,97 +49,84 @@ func newRootCmd() (*cobra.Command, error) { } rootCmd.AddCommand(x509CredentialProcessCmd) - return rootCmd, nil -} + x509CredentialFileCmd, err := newX509CredentialFileCmd() + if err != nil { + return nil, fmt.Errorf("initializing x509-credential-file command: %w", err) + } + rootCmd.AddCommand(x509CredentialFileCmd) -func newX509CredentialProcessCmd() (*cobra.Command, error) { - var ( - roleARN string - region string - profileARN string - sessionDuration int - trustAnchorARN string - roleSessionName string - workloadAPIAddr string - ) - cmd := &cobra.Command{ - Use: "x509-credential-process", - Short: `Exchanges an X509 SVID for a short-lived set of AWS credentials using AWS Roles Anywhere. Compatible with the AWS credential process functionality.`, - Long: `Exchanges an X509 SVID for a short-lived set of AWS credentials using the AWS Roles Anywhere API. It returns the credentials to STDOUT, in the format expected by AWS SDKs and CLIs when invoking an external credential process.`, - RunE: func(cmd *cobra.Command, args []string) error { - ctx := cmd.Context() - client, err := workloadapi.New( - ctx, - workloadapi.WithAddr(workloadAPIAddr), - ) - if err != nil { - return fmt.Errorf("creating workload api client: %w", err) - } + x509CredentialFileOneshotCmd, err := newX509CredentialFileOneshotCmd() + if err != nil { + return nil, fmt.Errorf("initializing x509-credential-file-oneshot command: %w", err) + } + rootCmd.AddCommand(x509CredentialFileOneshotCmd) - x509Ctx, err := client.FetchX509Context(ctx) - if err != nil { - return fmt.Errorf("fetching x509 context: %w", err) - } - // TODO(strideynet): Implement SVID selection mechanism, for now, - // we'll just use the first returned SVID (a.k.a the default). - svid := x509Ctx.DefaultSVID() - slog.Debug( - "Fetched X509 SVID", - slog.Group("svid", - "spiffe_id", svid.ID, - "hint", svid.Hint, - ), - ) + return rootCmd, nil +} - signer := &awsspiffe.X509SVIDSigner{ - SVID: svid, - } - signatureAlgorithm, err := signer.SignatureAlgorithm() - if err != nil { - return fmt.Errorf("getting signature algorithm: %w", err) - } - credentials, err := vendoredaws.GenerateCredentials(&vendoredaws.CredentialsOpts{ - RoleArn: roleARN, - ProfileArnStr: profileARN, - Region: region, - RoleSessionName: roleSessionName, - TrustAnchorArnStr: trustAnchorARN, - SessionDuration: sessionDuration, - }, signer, signatureAlgorithm) - if err != nil { - return fmt.Errorf("generating credentials: %w", err) - } - slog.Debug( - "Generated AWS credentials", - "expiration", credentials.Expiration, - ) +type sharedFlags struct { + roleARN string + region string + profileARN string + sessionDuration int + trustAnchorARN string + roleSessionName string + workloadAPIAddr string +} - out, err := json.Marshal(credentials) - if err != nil { - return fmt.Errorf("marshalling credentials: %w", err) - } - _, err = os.Stdout.Write(out) - if err != nil { - return fmt.Errorf("writing credentials to stdout: %w", err) - } - return nil - }, - } - cmd.Flags().StringVar(&roleARN, "role-arn", "", "The ARN of the role to assume. Required.") +func (f *sharedFlags) addFlags(cmd *cobra.Command) error { + cmd.Flags().StringVar(&f.roleARN, "role-arn", "", "The ARN of the role to assume. Required.") if err := cmd.MarkFlagRequired("role-arn"); err != nil { - return nil, fmt.Errorf("marking role-arn flag as required: %w", err) + return fmt.Errorf("marking role-arn flag as required: %w", err) } - cmd.Flags().StringVar(®ion, "region", "", "Overrides AWS region to use when exchanging the SVID for AWS credentials. Optional.") - cmd.Flags().StringVar(&profileARN, "profile-arn", "", "The ARN of the Roles Anywhere profile to use. Required.") + cmd.Flags().StringVar(&f.region, "region", "", "Overrides AWS region to use when exchanging the SVID for AWS credentials. Optional.") + cmd.Flags().StringVar(&f.profileARN, "profile-arn", "", "The ARN of the Roles Anywhere profile to use. Required.") if err := cmd.MarkFlagRequired("profile-arn"); err != nil { - return nil, fmt.Errorf("marking profile-arn flag as required: %w", err) + return fmt.Errorf("marking profile-arn flag as required: %w", err) } - cmd.Flags().IntVar(&sessionDuration, "session-duration", 3600, "The duration, in seconds, of the resulting session. Optional. Can range from 15 minutes (900) to 12 hours (43200).") - cmd.Flags().StringVar(&trustAnchorARN, "trust-anchor-arn", "", "The ARN of the Roles Anywhere trust anchor to use. Required.") + cmd.Flags().IntVar(&f.sessionDuration, "session-duration", 3600, "The duration, in seconds, of the resulting session. Optional. Can range from 15 minutes (900) to 12 hours (43200).") + cmd.Flags().StringVar(&f.trustAnchorARN, "trust-anchor-arn", "", "The ARN of the Roles Anywhere trust anchor to use. Required.") if err := cmd.MarkFlagRequired("trust-anchor-arn"); err != nil { - return nil, fmt.Errorf("marking trust-anchor-arn flag as required: %w", err) + return fmt.Errorf("marking trust-anchor-arn flag as required: %w", err) } - cmd.Flags().StringVar(&roleSessionName, "role-session-name", "", "The identifier for the role session. Optional.") - cmd.Flags().StringVar(&workloadAPIAddr, "workload-api-addr", "", "Overrides the address of the Workload API endpoint that will be use to fetch the X509 SVID. If unspecified, the value from the SPIFFE_ENDPOINT_SOCKET environment variable will be used.") - return cmd, nil + cmd.Flags().StringVar(&f.roleSessionName, "role-session-name", "", "The identifier for the role session. Optional.") + cmd.Flags().StringVar(&f.workloadAPIAddr, "workload-api-addr", "", "Overrides the address of the Workload API endpoint that will be use to fetch the X509 SVID. If unspecified, the value from the SPIFFE_ENDPOINT_SOCKET environment variable will be used.") + return nil +} + +func exchangeX509SVIDForAWSCredentials( + sf *sharedFlags, + svid *x509svid.SVID, +) (vendoredaws.CredentialProcessOutput, error) { + signer := &awsspiffe.X509SVIDSigner{ + SVID: svid, + } + signatureAlgorithm, err := signer.SignatureAlgorithm() + if err != nil { + return vendoredaws.CredentialProcessOutput{}, fmt.Errorf("getting signature algorithm: %w", err) + } + credentials, err := vendoredaws.GenerateCredentials(&vendoredaws.CredentialsOpts{ + RoleArn: sf.roleARN, + ProfileArnStr: sf.profileARN, + Region: sf.region, + RoleSessionName: sf.roleSessionName, + TrustAnchorArnStr: sf.trustAnchorARN, + SessionDuration: sf.sessionDuration, + }, signer, signatureAlgorithm) + if err != nil { + return vendoredaws.CredentialProcessOutput{}, fmt.Errorf("generating credentials: %w", err) + } + slog.Debug( + "Generated AWS credentials", + "expiration", credentials.Expiration, + ) + return credentials, nil +} + +func svidValue(svid *x509svid.SVID) slog.Value { + return slog.GroupValue( + slog.String("id", svid.ID.String()), + slog.String("hint", svid.Hint), + slog.Time("expires_at", svid.Certificates[0].NotAfter), + ) } diff --git a/go.mod b/go.mod index 50f3d03..c1a96f7 100644 --- a/go.mod +++ b/go.mod @@ -11,12 +11,16 @@ require ( require ( github.com/Microsoft/go-winio v0.6.2 // indirect github.com/aws/aws-sdk-go v1.55.5 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/go-jose/go-jose/v4 v4.0.4 // indirect + github.com/google/go-cmp v0.6.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/miekg/pkcs11 v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6 // indirect + github.com/stretchr/testify v1.10.0 // indirect github.com/zeebo/errs v1.3.0 // indirect golang.org/x/crypto v0.28.0 // indirect golang.org/x/net v0.30.0 // indirect @@ -26,4 +30,6 @@ require ( google.golang.org/genproto/googleapis/rpc v0.0.0-20240814211410-ddb44dafa142 // indirect google.golang.org/grpc v1.67.1 // indirect google.golang.org/protobuf v1.34.2 // indirect + gopkg.in/ini.v1 v1.67.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 06adbdd..390510f 100644 --- a/go.sum +++ b/go.sum @@ -34,6 +34,8 @@ github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6/go.mod h github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/zeebo/errs v1.3.0 h1:hmiaKqgYZzcVgRL1Vkc1Mn2914BbzB0IBxs+ebeutGs= github.com/zeebo/errs v1.3.0/go.mod h1:sgbWHsvVuTPHcqJJGQ1WhI5KbWlHYz+2+2C/LSEtCw4= golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw= @@ -53,6 +55,8 @@ google.golang.org/grpc v1.67.1/go.mod h1:1gLDyUQU7CTLJI90u3nXZ9ekeghjeM7pTDZlqFN google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA= +gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= diff --git a/internal/aws_credentials_file.go b/internal/aws_credentials_file.go new file mode 100644 index 0000000..68eb471 --- /dev/null +++ b/internal/aws_credentials_file.go @@ -0,0 +1,103 @@ +package internal + +import ( + "fmt" + "log/slog" + "os" + "path/filepath" + + "gopkg.in/ini.v1" +) + +type AWSCredentialsFileConfig struct { + Path string + ProfileName string + Force bool + ReplaceFile bool +} + +type AWSCredentialsFileProfile struct { + AWSAccessKeyID string + AWSSecretAccessKey string + AWSSessionToken string +} + +func ensureDirectory(fpath string) error { + dpath := filepath.Dir(fpath) + if _, err := os.Stat(dpath); os.IsNotExist(err) { + if err := os.MkdirAll(dpath, 0700); err != nil { + return fmt.Errorf("creating directory (%s): %w", dpath, err) + } + } + return nil +} + +func loadAWSCredentialsFile( + log *slog.Logger, + cfg AWSCredentialsFileConfig, +) (*ini.File, error) { + if cfg.ReplaceFile { + return ini.Empty(), nil + } + + // Create parent directory for file, should it be needed. + if err := ensureDirectory(cfg.Path); err != nil { + return nil, fmt.Errorf("ensuring parent directory: %w", err) + } + + f, err := ini.Load(cfg.Path) + if err == nil { + return f, nil + } + + // If it doesn't exist, we can "create" it. + if os.IsNotExist(err) { + return ini.Empty(), nil + } + + // If force mode is enabled, ignore the error and return an empty file. + if cfg.Force { + log.Warn( + "When loading the existing AWS credentials file, an error occurred. As --force is set, the file will be overwritten.", + "error", err, + "path", cfg.Path, + ) + return ini.Empty(), nil + } + + // Otherwise, fail... + log.Error( + "When loading the existing AWS credentials file, an error occurred. Use --force to ignore errors and attempt to overwrite.", + "error", err, + "path", cfg.Path, + ) + return nil, fmt.Errorf("loading existing aws credentials file: %w", err) +} + +// UpsertAWSCredentialsFileProfile writes the provided AWS credentials profile to the AWS credentials file. +// See https://docs.aws.amazon.com/cli/v1/userguide/cli-configure-files.html +func UpsertAWSCredentialsFileProfile( + log *slog.Logger, + cfg AWSCredentialsFileConfig, + p AWSCredentialsFileProfile, +) error { + f, err := loadAWSCredentialsFile(log, cfg) + if err != nil { + return fmt.Errorf("loading existing aws credentials: %w", err) + } + + sectionName := "default" + if cfg.ProfileName != "" { + sectionName = cfg.ProfileName + } + sec := f.Section(sectionName) + + sec.Key("aws_secret_access_key").SetValue(p.AWSSecretAccessKey) + sec.Key("aws_access_key_id").SetValue(p.AWSAccessKeyID) + sec.Key("aws_session_token").SetValue(p.AWSSessionToken) + + if err := f.SaveTo(cfg.Path); err != nil { + return fmt.Errorf("saving aws credentials file: %w", err) + } + return nil +} diff --git a/internal/aws_credentials_file_test.go b/internal/aws_credentials_file_test.go new file mode 100644 index 0000000..940f5af --- /dev/null +++ b/internal/aws_credentials_file_test.go @@ -0,0 +1,148 @@ +package internal + +import ( + "log/slog" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestAWSCredentialsFile_Write(t *testing.T) { + log := slog.Default() + + defaultProfile := AWSCredentialsFileProfile{ + AWSAccessKeyID: "1234567890", + AWSSecretAccessKey: "abcdefgh", + AWSSessionToken: "ijklmnop", + } + + preExistingContents := []byte(`[pre-existing] +aws_secret_access_key = foo +aws_access_key_id = bar +aws_session_token = bizz +`) + + tests := []struct { + name string + existingFileContents []byte + config AWSCredentialsFileConfig + profile AWSCredentialsFileProfile + want []byte + wantErr string + }{ + { + name: "no pre-existing file - default profile", + config: AWSCredentialsFileConfig{}, + profile: defaultProfile, + want: []byte(`[default] +aws_secret_access_key = abcdefgh +aws_access_key_id = 1234567890 +aws_session_token = ijklmnop +`), + }, + { + name: "no pre-existing file - named profile", + config: AWSCredentialsFileConfig{ + ProfileName: "my-profile", + }, + profile: defaultProfile, + want: []byte(`[my-profile] +aws_secret_access_key = abcdefgh +aws_access_key_id = 1234567890 +aws_session_token = ijklmnop +`), + }, + { + name: "pre-existing file, no profile name clash - default profile", + config: AWSCredentialsFileConfig{}, + profile: defaultProfile, + existingFileContents: preExistingContents, + want: []byte(`[pre-existing] +aws_secret_access_key = foo +aws_access_key_id = bar +aws_session_token = bizz + +[default] +aws_secret_access_key = abcdefgh +aws_access_key_id = 1234567890 +aws_session_token = ijklmnop +`), + }, + { + name: "pre-existing file, no profile name clash - default profile with replace mode", + config: AWSCredentialsFileConfig{ + ReplaceFile: true, + }, + profile: defaultProfile, + existingFileContents: preExistingContents, + want: []byte(`[default] +aws_secret_access_key = abcdefgh +aws_access_key_id = 1234567890 +aws_session_token = ijklmnop +`), + }, + { + name: "pre-existing file, profile name clash", + config: AWSCredentialsFileConfig{ + ProfileName: "pre-existing", + }, + profile: defaultProfile, + existingFileContents: preExistingContents, + want: []byte(`[pre-existing] +aws_secret_access_key = abcdefgh +aws_access_key_id = 1234567890 +aws_session_token = ijklmnop +`), + }, + { + name: "pre-existing file with garbage", + config: AWSCredentialsFileConfig{}, + profile: defaultProfile, + existingFileContents: []byte(`dduhufd`), + wantErr: "key-value delimiter not found", + }, + { + name: "pre-existing file with garbage, --force", + config: AWSCredentialsFileConfig{ + Force: true, + }, + profile: defaultProfile, + existingFileContents: []byte(`dduhufd`), + want: []byte(`[default] +aws_secret_access_key = abcdefgh +aws_access_key_id = 1234567890 +aws_session_token = ijklmnop +`), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tmp := t.TempDir() + credentialPath := filepath.Join(tmp, "credentials") + cfg := tt.config + cfg.Path = credentialPath + + if tt.existingFileContents != nil { + require.NoError(t, os.WriteFile(credentialPath, tt.existingFileContents, 0600)) + } + + err := UpsertAWSCredentialsFileProfile( + log, + cfg, + tt.profile, + ) + if tt.wantErr != "" { + require.ErrorContains(t, err, tt.wantErr) + return + } + require.NoError(t, err) + + got, err := os.ReadFile(credentialPath) + require.NoError(t, err) + + require.Equal(t, string(tt.want), string(got)) + }) + } +}