Skip to content

Commit

Permalink
added IAM credentials request options and default serial number
Browse files Browse the repository at this point in the history
  • Loading branch information
ZeljkoBenovic committed Nov 24, 2022
1 parent edd4d30 commit 0831364
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 2 deletions.
23 changes: 22 additions & 1 deletion flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ var userConfig = availableUserFlags{
serialNumber: userFlagStr{
value: new(string),
name: "serial",
defaultValue: "",
defaultValue: "virtual",
description: "The identification number of the MFA device",
},
mfaToken: userFlagStr{
Expand Down Expand Up @@ -57,6 +57,18 @@ var userConfig = availableUserFlags{
defaultValue: getHomeDir() + "/.aws/credentials",
description: "AWS credentials file location",
},
accessKeyId: userFlagStr{
value: new(string),
name: "access-key-id",
defaultValue: "",
description: "IAM access key id to authenticate the request",
},
secretAccessKey: userFlagStr{
value: new(string),
name: "secret-access-key",
defaultValue: "",
description: "IAM secret access key to authenticate the request",
},
}

type availableUserFlags struct {
Expand All @@ -68,6 +80,9 @@ type availableUserFlags struct {
sessionDuration userFlagInt
confFile userFlagStr
credFile userFlagStr

accessKeyId userFlagStr
secretAccessKey userFlagStr
}

type userFlagStr struct {
Expand All @@ -92,6 +107,8 @@ func (f *availableUserFlags) Get() *availableUserFlags {
flag.StringVar(f.mode.value, f.mode.name, f.mode.defaultValue, f.mode.description)
flag.StringVar(f.confFile.value, f.confFile.name, f.confFile.defaultValue, f.confFile.description)
flag.StringVar(f.credFile.value, f.credFile.name, f.credFile.defaultValue, f.credFile.description)
flag.StringVar(f.accessKeyId.value, f.accessKeyId.name, f.accessKeyId.defaultValue, f.accessKeyId.description)
flag.StringVar(f.secretAccessKey.value, f.secretAccessKey.name, f.secretAccessKey.defaultValue, f.secretAccessKey.description)
flag.IntVar(f.sessionDuration.value, f.sessionDuration.name, f.sessionDuration.defaultValue, f.sessionDuration.description)

flag.Parse()
Expand All @@ -108,6 +125,10 @@ func (f *availableUserFlags) CheckValidity() {
}
}

func (f *availableUserFlags) AreIAMCredentialsSet() bool {
return *f.accessKeyId.value != "" && *f.secretAccessKey.value != ""
}

func (f *availableUserFlags) SessionInt32() *int32 {
i := int32(*f.sessionDuration.value)

Expand Down
31 changes: 30 additions & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ package main
import (
"context"
"fmt"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/sts"
"log"
"os"
Expand All @@ -15,14 +17,41 @@ func main() {
// get user flags
userConfig.Get().CheckValidity()

// check for IAM credentials to authorize the request
if !userConfig.AreIAMCredentialsSet() {
log.Println("could not find IAM credentials in flags, checking env vars for credentials")

// check for credentials in env vars
if os.Getenv("AWS_ACCESS_KEY_ID") == "" && os.Getenv("AWS_SECRET_ACCESS_KEY") == "" {
log.Fatalln("could not find IAM credentials to authorize request")
}

*userConfig.accessKeyId.value = os.Getenv("AWS_ACCESS_KEY_ID")
*userConfig.secretAccessKey.value = os.Getenv("AWS_SECRET_ACCESS_KEY")
}

// setup aws config
ctx := context.Background()
cfg, err := config.LoadDefaultConfig(ctx,
config.WithRegion(*userConfig.region.value),
config.WithSharedConfigProfile(*userConfig.profile.value))
config.WithCredentialsProvider(aws.CredentialsProvider(credentials.NewStaticCredentialsProvider(
*userConfig.accessKeyId.value,
*userConfig.secretAccessKey.value,
""))),
)
if err != nil {
log.Fatalln("could not create aws config: ", err.Error())
}
// get serial number
if *userConfig.serialNumber.value == "virtual" {
idOut, err := sts.NewFromConfig(cfg).GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{})
if err != nil {
log.Println("could not get caller identity:", err.Error())
}

*userConfig.serialNumber.value = strings.Replace(*idOut.Arn, "user", "mfa", 1)
}

// get session token
crds, err := sts.NewFromConfig(cfg).GetSessionToken(ctx, &sts.GetSessionTokenInput{
SerialNumber: userConfig.serialNumber.value,
Expand Down

0 comments on commit 0831364

Please sign in to comment.