-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.go
178 lines (157 loc) · 5.48 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
package main
import (
"context"
"fmt"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/sts"
"log"
"os"
"strings"
"syscall"
)
func main() {
// get user flags
userConfig.Get().CheckValidity()
// check for IAM credentials to authorize the request
if !userConfig.AreIAMCredentialsSet() {
log.Println("could not find IAM credentials in flags, checking env vars for credentials")
// check for credentials in env vars
if os.Getenv("AWS_ACCESS_KEY_ID") == "" && os.Getenv("AWS_SECRET_ACCESS_KEY") == "" {
log.Fatalln("could not find IAM credentials to authorize request")
}
*userConfig.accessKeyId.value = os.Getenv("AWS_ACCESS_KEY_ID")
*userConfig.secretAccessKey.value = os.Getenv("AWS_SECRET_ACCESS_KEY")
}
// setup aws config
ctx := context.Background()
cfg, err := config.LoadDefaultConfig(ctx,
config.WithRegion(*userConfig.region.value),
config.WithCredentialsProvider(aws.CredentialsProvider(credentials.NewStaticCredentialsProvider(
*userConfig.accessKeyId.value,
*userConfig.secretAccessKey.value,
""))),
)
if err != nil {
log.Fatalln("could not create aws config: ", err.Error())
}
// get serial number
if *userConfig.serialNumber.value == "virtual" {
idOut, err := sts.NewFromConfig(cfg).GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{})
if err != nil {
log.Println("could not get caller identity:", err.Error())
}
*userConfig.serialNumber.value = strings.Replace(*idOut.Arn, "user", "mfa", 1)
}
// get session token
crds, err := sts.NewFromConfig(cfg).GetSessionToken(ctx, &sts.GetSessionTokenInput{
SerialNumber: userConfig.serialNumber.value,
TokenCode: userConfig.mfaToken.value,
DurationSeconds: userConfig.SessionInt32(),
})
if err != nil {
log.Fatalln("could not get session token:", err.Error())
}
switch *userConfig.mode.value {
case "env":
if err := os.Setenv("AWS_ACCESS_KEY_ID", *crds.Credentials.AccessKeyId); err != nil {
log.Fatalln("could not set AWS_ACCESS_KEY_ID env var:", err.Error())
}
if err := os.Setenv("AWS_SECRET_ACCESS_KEY", *crds.Credentials.SecretAccessKey); err != nil {
log.Fatalln("could not set AWS_ACCESS_KEY_ID env var:", err.Error())
}
if err := os.Setenv("AWS_SESSION_TOKEN", *crds.Credentials.SessionToken); err != nil {
log.Fatalln("could not set AWS_ACCESS_KEY_ID env var:", err.Error())
}
if err := os.Setenv("AWS_REGION", *userConfig.region.value); err != nil {
log.Fatalln("could not set AWS_ACCESS_KEY_ID env var:", err.Error())
}
// set env vars to shell
if err := syscall.Exec(os.Getenv("SHELL"),
[]string{os.Getenv("SHELL")}, syscall.Environ()); err != nil {
log.Fatalln("could not set environment vars:", err.Error())
}
log.Println("credentials successfully stored into environment variables")
case "conf":
if err := storeToConfigFile(crds, userConfig); err != nil {
log.Fatalln("could not store credentials to local files:", err.Error())
}
log.Println("credentials successfully stored to aws credentials/config files")
default:
log.Fatalln("Mode not supported. Supported modes are: [env, mode]")
}
}
func storeToConfigFile(crds *sts.GetSessionTokenOutput, userConf availableUserFlags) error {
// config file content to write
configContent := fmt.Sprintf(`
[profile %s]
region=%s
`, *userConf.profile.value, *userConf.region.value)
// credentials file content to write
credContent := fmt.Sprintf(`
[%s]
aws_access_key_id=%s
aws_secret_access_key=%s
aws_session_token=%s
`, *userConf.profile.value, *crds.Credentials.AccessKeyId,
*crds.Credentials.SecretAccessKey, *crds.Credentials.SessionToken)
// read config file
currentConf, err := os.ReadFile(*userConf.confFile.value)
if err != nil {
return fmt.Errorf("could not read config file err=%w", err)
}
// read credentials file
currentCred, err := os.ReadFile(*userConf.credFile.value)
if err != nil {
return fmt.Errorf("coudl not read credentials file err=%w", err)
}
// remove old information from config
tempConf := strings.Split(string(currentConf), "\n")
for i, line := range tempConf {
// delete old profile and data if it contains region
if strings.TrimSpace(line) == fmt.Sprintf("[profile %s]", *userConf.profile.value) {
tempConf[i] = ""
if strings.Contains(tempConf[i+1], "region") {
tempConf[i+1] = ""
}
}
}
// remove old information from credentials
tempCred := strings.Split(string(currentCred), "\n")
for i, line := range tempCred {
if strings.TrimSpace(line) == fmt.Sprintf("[%s]", *userConf.profile.value) {
tempCred[i] = ""
if strings.Contains(tempCred[i+1], "aws_access_key_id") {
tempCred[i+1] = ""
}
if strings.Contains(tempCred[i+2], "aws_secret_access_key") {
tempCred[i+2] = ""
}
if strings.Contains(tempCred[i+3], "aws_session_token") {
tempCred[i+3] = ""
}
}
}
// append existing data
newCred := strings.Join(purgeEmpty(tempCred), "\n") + credContent
newConf := strings.Join(purgeEmpty(tempConf), "\n") + configContent
// write config file
if err = os.WriteFile(*userConf.credFile.value, []byte(newCred), 0600); err != nil {
return fmt.Errorf("could not write config file err=%w", err)
}
// append credentials file
if err = os.WriteFile(*userConf.confFile.value, []byte(newConf), 0600); err != nil {
return fmt.Errorf("could not write credentials file err=%w", err)
}
return nil
}
func purgeEmpty(s []string) []string {
var r []string
for i, v := range s {
if v != "" {
r = append(r, s[i])
}
}
return r
}