diff --git a/Makefile b/Makefile index 6f53225..114ed07 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ NAME := aws-cli-oidc -VERSION := v0.3.0 +VERSION := v0.4.0 REVISION := $(shell git rev-parse --short HEAD) SRCS := $(shell find . -type f -name '*.go') diff --git a/cmd/aws_saml.go b/cmd/aws_saml.go index 96664c3..276dac8 100755 --- a/cmd/aws_saml.go +++ b/cmd/aws_saml.go @@ -15,21 +15,22 @@ import ( "github.com/versent/saml2aws" ) -func GetCredentialsWithSAML(samlResponse string) (*AWSCredentials, error) { - role, err := selectAwsRole(samlResponse) +func GetCredentialsWithSAML(samlResponse string, durationSeconds int64, defaultIAMRoleArn string) (*AWSCredentials, error) { + role, err := selectAwsRole(samlResponse, defaultIAMRoleArn) if err != nil { return nil, errors.Wrap(err, "Failed to assume role, please check you are permitted to assume the given role for the AWS service") } Writeln("Selected role: %s", role.RoleARN) + Writeln("Max Session Duration: %d seconds", durationSeconds) - return loginToStsUsingRole(role, samlResponse) + return loginToStsUsingRole(role, samlResponse, durationSeconds) } -func selectAwsRole(samlResponse string) (*saml2aws.AWSRole, error) { +func selectAwsRole(samlResponse, defaultIAMRoleArn string) (*saml2aws.AWSRole, error) { roles, err := saml2aws.ExtractAwsRoles([]byte(samlResponse)) if err != nil { - return nil, errors.Wrap(err, "Failed to extract aws roles") + return nil, errors.Wrap(err, "Failed to extract aws roles from SAML Assertion") } if len(roles) == 0 { @@ -41,10 +42,10 @@ func selectAwsRole(samlResponse string) (*saml2aws.AWSRole, error) { return nil, errors.Wrap(err, "Failed to parse aws roles") } - return resolveRole(awsRoles, samlResponse) + return resolveRole(awsRoles, samlResponse, defaultIAMRoleArn) } -func resolveRole(awsRoles []*saml2aws.AWSRole, samlAssertion string) (*saml2aws.AWSRole, error) { +func resolveRole(awsRoles []*saml2aws.AWSRole, samlAssertion, defaultIAMRoleArn string) (*saml2aws.AWSRole, error) { var role = new(saml2aws.AWSRole) if len(awsRoles) == 1 { @@ -57,7 +58,7 @@ func resolveRole(awsRoles []*saml2aws.AWSRole, samlAssertion string) (*saml2aws. for { var err error - role, err = promptForAWSRoleSelection(awsRoles) + role, err = promptForAWSRoleSelection(awsRoles, defaultIAMRoleArn) if err == nil { break } @@ -67,14 +68,21 @@ func resolveRole(awsRoles []*saml2aws.AWSRole, samlAssertion string) (*saml2aws. return role, nil } -func promptForAWSRoleSelection(awsRoles []*saml2aws.AWSRole) (*saml2aws.AWSRole, error) { +func promptForAWSRoleSelection(awsRoles []*saml2aws.AWSRole, defaultIAMRoleArn string) (*saml2aws.AWSRole, error) { roles := map[string]*saml2aws.AWSRole{} var roleOptions []string for _, role := range awsRoles { - name := fmt.Sprintf("%s", role.RoleARN) - roles[name] = role - roleOptions = append(roleOptions, name) + if defaultIAMRoleArn == role.RoleARN { + Writeln("Selected default role: %s", defaultIAMRoleArn) + return role, nil + } + roles[role.RoleARN] = role + roleOptions = append(roleOptions, role.RoleARN) + } + + if defaultIAMRoleArn != "" { + Writeln("Warning: You don't have the default role: %s", defaultIAMRoleArn) } sort.Strings(roleOptions) @@ -104,7 +112,7 @@ func promptForAWSRoleSelection(awsRoles []*saml2aws.AWSRole) (*saml2aws.AWSRole, return roles[roleOptions[i-1]], nil } -func loginToStsUsingRole(role *saml2aws.AWSRole, samlResponse string) (*AWSCredentials, error) { +func loginToStsUsingRole(role *saml2aws.AWSRole, samlResponse string, durationSeconds int64) (*AWSCredentials, error) { sess, err := session.NewSession() if err != nil { return nil, errors.Wrap(err, "Failed to create session") @@ -123,7 +131,7 @@ func loginToStsUsingRole(role *saml2aws.AWSRole, samlResponse string) (*AWSCrede PrincipalArn: aws.String(role.PrincipalARN), // Required RoleArn: aws.String(role.RoleARN), // Required SAMLAssertion: aws.String(b), // Required - DurationSeconds: aws.Int64(int64(900)), + DurationSeconds: aws.Int64(durationSeconds), } Writeln("Requesting AWS credentials using SAML assertion") diff --git a/cmd/get_cred.go b/cmd/get_cred.go index 39520a1..877331f 100755 --- a/cmd/get_cred.go +++ b/cmd/get_cred.go @@ -5,6 +5,7 @@ import ( "fmt" "net" "net/http" + "strconv" "strings" "time" @@ -84,6 +85,12 @@ func getCred(cmd *cobra.Command, args []string) { Traceln("ID token: %s", tokenResponse.IDToken) awsFedType := client.config.GetString(AWS_FEDERATION_TYPE) + maxSessionDurationSecondsString := client.config.GetString(MAX_SESSION_DURATION_SECONDS) + maxSessionDurationSeconds, err := strconv.ParseInt(maxSessionDurationSecondsString, 10, 64) + if err != nil { + maxSessionDurationSeconds = 3600 + } + defaultIAMRoleArn := client.config.GetString(DEFAULT_IAM_ROLE_ARN) var awsCreds *AWSCredentials if awsFedType == AWS_FEDERATION_TYPE_OIDC { @@ -105,7 +112,7 @@ func getCred(cmd *cobra.Command, args []string) { Exit(err) } - awsCreds, err = GetCredentialsWithSAML(samlResponse) + awsCreds, err = GetCredentialsWithSAML(samlResponse, maxSessionDurationSeconds, defaultIAMRoleArn) if err != nil { Writeln("Failed to get aws credentials with SAML2") Exit(err) diff --git a/cmd/root.go b/cmd/root.go index a7401c0..54e43cd 100755 --- a/cmd/root.go +++ b/cmd/root.go @@ -32,6 +32,8 @@ const FAILURE_REDIRECT_URL = "failure_redirect_url" const CLIENT_ID = "client_id" const CLIENT_SECRET = "client_secret" const AWS_FEDERATION_TYPE = "aws_federation_type" +const MAX_SESSION_DURATION_SECONDS = "max_session_duration_seconds" +const DEFAULT_IAM_ROLE_ARN = "default_iam_role_arn" // OIDC config const AWS_FEDERATION_ROLE = "aws_federation_role" diff --git a/cmd/setup.go b/cmd/setup.go index 6c49791..ffdb66c 100755 --- a/cmd/setup.go +++ b/cmd/setup.go @@ -3,6 +3,8 @@ package cmd import ( "fmt" "os" + "strconv" + "strings" input "github.com/natsukagami/go-input" "github.com/pkg/errors" @@ -64,6 +66,35 @@ func runSetup() { return nil }, }) + maxSessionDurationSeconds, _ := ui.Ask("The max session duration, in seconds, of the role session [900-43200] (Default: 3600):", &input.Options{ + Default: "3600", + Required: true, + Loop: true, + ValidateFunc: func(s string) error { + i, err := strconv.ParseInt(s, 10, 64) + if err != nil || i < 900 || i > 43200 { + return errors.New(fmt.Sprintf("Input must be 900-43200")) + } + return nil + }, + }) + defaultIAMRoleArn, _ := ui.Ask("The default IAM Role ARN when you have multiple roles, as arn:aws:iam:::role/ (Default: none):", &input.Options{ + Default: "", + Required: false, + Loop: true, + ValidateFunc: func(s string) error { + if s == "" { + return nil + } + arn := strings.Split(s, ":") + if len(arn) == 6 { + if arn[0] == "arn" && arn[1] == "aws" && arn[2] == "iam" && arn[3] == "" && strings.HasPrefix(arn[5], "role/") { + return nil + } + } + return errors.New(fmt.Sprintf("Input must be IAM Role ARN")) + }, + }) config := map[string]string{} @@ -74,6 +105,8 @@ func runSetup() { config[CLIENT_ID] = clientID config[CLIENT_SECRET] = clientSecret config[AWS_FEDERATION_TYPE] = answerFedType + config[MAX_SESSION_DURATION_SECONDS] = maxSessionDurationSeconds + config[DEFAULT_IAM_ROLE_ARN] = defaultIAMRoleArn if answerFedType == AWS_FEDERATION_TYPE_OIDC { oidcSetup(config)