diff --git a/README.md b/README.md index c33cc11..8c35af9 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,7 @@ GET /v1/ecr/{account}/repositories/{group}/{name}/images GET /v1/ecr/{account}/repositories/{group}/{name}/users POST /v1/ecr/{account}/repositories/{group}/{name}/users GET /v1/ecr/{account}/repositories/{group}/{name}/users/{user} +PUT /v1/ecr/{account}/repositories/{group}/{name}/users/{user} DELETE /v1/ecr/{account}/repositories/{group}/{name}/users/{user} ``` @@ -417,17 +418,90 @@ GET /v1/ecr/{account}/repositories/{group}/{name}/users/{user} "Key": "application", "Value": "myapps" }, + { + "Key": "ResourceName", + "Value": "spindev-00001-myAwesomeRepository-user1" + }, + { + "Key": "Name", + "Value": "spindev-00001/myAwesomeRepository" + }, + { + "Key": "spinup:org", + "Value": "spindev" + }, + { + "Key": "spinup:spaceid", + "Value": "spindev-00001" + } + ] +} +``` + +#### Update a user + +A user's tags and/or its access key can be updated. The operations are independent, +and can occur in the same request, or individually. If the access key is reset, +a new access key will be returned with the response. + +PUT /v1/ecr/{account}/repositories/{group}/{name}/users/{user} + +| Response Code | Definition | +| ----------------------------- | --------------------------------| +| **200 OK** | updated the user | +| **400 Bad Request** | badly formed request | +| **404 Not Found** | account not found | +| **500 Internal Server Error** | a server error occurred | + +##### Example update user request body + +```json +{ + "resetkey": true, + "tags": [ + { + "key": "application", + "value": "myapp123" + } + ], +} +``` + +##### Example update user response + +```json +{ + "UserName": "user1", + "AccessKey": { + "AccessKeyId": "AAAAABBBBBCCCCCDDDDDEEEEEFFFFF", + "CreateDate": "2021-02-03T22:37:30Z", + "SecretAccessKey": "gxyz1234567890abcdefghijklmnop", + "Status": "Active", + "UserName": "spincool-00001-testrepo1-user1" + }, + "DeletedAccessKeys": [ + "QQQQQRRRRRSSSSSTTTTTUUUUVVVV" + ], + "Tags": [ + { + "Key": "application", + "Value": "myapp123" + }, + { + "Key": "ResourceName", + "Value": "spindev-00001-myAwesomeRepository-user1" + }, { "Key": "Name", - "Value": "spincool-00002/camdenstestrepo01" + "Value": "spindev-00001/myAwesomeRepository" }, { "Key": "spinup:org", - "Value": "spincool" + "Value": "spindev" }, { "Key": "spinup:spaceid", - "Value": "spincool-00002" + "Value": "spindev-00001" } ] } diff --git a/api/handlers_repositories.go b/api/handlers_repositories.go index 8485dc1..cafda68 100644 --- a/api/handlers_repositories.go +++ b/api/handlers_repositories.go @@ -17,11 +17,6 @@ import ( log "github.com/sirupsen/logrus" ) -type ecrOrchestrator struct { - client ecr.ECR - org string -} - // RepositoriesCreateHandler is the http handler for creating a repository func (s *server) RepositoriesCreateHandler(w http.ResponseWriter, r *http.Request) { w = LogWriter{w} @@ -51,12 +46,10 @@ func (s *server) RepositoriesCreateHandler(w http.ResponseWriter, r *http.Reques return } - orch := &ecrOrchestrator{ - client: ecr.New( - ecr.WithSession(session.Session), - ), - org: s.org, - } + orch := newEcrOrchestrator( + ecr.New(ecr.WithSession(session.Session)), + s.org, + ) resp, err := orch.repositoryCreate(r.Context(), account, group, &req) if err != nil { @@ -239,12 +232,10 @@ func (s *server) RepositoriesUpdateHandler(w http.ResponseWriter, r *http.Reques handleError(w, apierror.New(apierror.ErrForbidden, msg, nil)) } - orch := &ecrOrchestrator{ - client: ecr.New( - ecr.WithSession(session.Session), - ), - org: s.org, - } + orch := newEcrOrchestrator( + ecr.New(ecr.WithSession(session.Session)), + s.org, + ) resp, err := orch.repositoryUpdate(r.Context(), account, group, name, &req) if err != nil { @@ -285,11 +276,10 @@ func (s *server) RepositoriesDeleteHandler(w http.ResponseWriter, r *http.Reques handleError(w, apierror.New(apierror.ErrForbidden, msg, nil)) } - orch := &ecrOrchestrator{ - client: ecr.New( - ecr.WithSession(session.Session), - ), - } + orch := newEcrOrchestrator( + ecr.New(ecr.WithSession(session.Session)), + s.org, + ) resp, err := orch.repositoryDelete(r.Context(), account, group, name) if err != nil { diff --git a/api/handlers_users.go b/api/handlers_users.go index f5b6fe4..6e278e1 100644 --- a/api/handlers_users.go +++ b/api/handlers_users.go @@ -8,6 +8,7 @@ import ( "github.com/YaleSpinup/apierror" "github.com/YaleSpinup/ecr-api/iam" "github.com/gorilla/mux" + "github.com/pkg/errors" log "github.com/sirupsen/logrus" ) @@ -52,12 +53,10 @@ func (s *server) UsersCreateHandler(w http.ResponseWriter, r *http.Request) { return } - orch := &iamOrchestrator{ - client: iam.New( - iam.WithSession(session.Session), - ), - org: s.org, - } + orch := newIamOrchestrator( + iam.New(iam.WithSession(session.Session)), + s.org, + ) groupName, err := orch.prepareAccount(r.Context()) if err != nil { @@ -65,7 +64,7 @@ func (s *server) UsersCreateHandler(w http.ResponseWriter, r *http.Request) { return } - out, err := orch.repositoryUserCreate(r.Context(), name, group, groupName, req) + out, err := orch.repositoryUserCreate(r.Context(), name, group, groupName, &req) if err != nil { handleError(w, err) return @@ -107,12 +106,10 @@ func (s *server) UsersListHandler(w http.ResponseWriter, r *http.Request) { return } - orch := &iamOrchestrator{ - client: iam.New( - iam.WithSession(session.Session), - ), - org: s.org, - } + orch := newIamOrchestrator( + iam.New(iam.WithSession(session.Session)), + s.org, + ) output, err := orch.listRepositoryUsers(r.Context(), group, name) if err != nil { @@ -156,12 +153,10 @@ func (s *server) UsersShowHandler(w http.ResponseWriter, r *http.Request) { return } - orch := &iamOrchestrator{ - client: iam.New( - iam.WithSession(session.Session), - ), - org: s.org, - } + orch := newIamOrchestrator( + iam.New(iam.WithSession(session.Session)), + s.org, + ) output, err := orch.getRepositoryUser(r.Context(), group, name, user) if err != nil { @@ -183,8 +178,55 @@ func (s *server) UsersShowHandler(w http.ResponseWriter, r *http.Request) { // UsersUpdateHandler updates a repository user func (s *server) UsersUpdateHandler(w http.ResponseWriter, r *http.Request) { + w = LogWriter{w} + vars := mux.Vars(r) + account := vars["account"] + group := vars["group"] + name := vars["name"] + userName := vars["user"] + + req := RepositoryUserUpdateRequest{} + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + msg := fmt.Sprintf("cannot decode body into update repository user input: %s", err) + handleError(w, apierror.New(apierror.ErrBadRequest, msg, err)) + return + } + + role := fmt.Sprintf("arn:aws:iam::%s:role/%s", account, s.session.RoleName) + + // IAM doesn't support resource tags, so we can't pass the s.orgPolicy here + session, err := s.assumeRole( + r.Context(), + s.session.ExternalID, + role, + "", + ) + if err != nil { + msg := fmt.Sprintf("failed to assume role in account: %s", account) + handleError(w, apierror.New(apierror.ErrForbidden, msg, nil)) + return + } + + orch := newIamOrchestrator( + iam.New(iam.WithSession(session.Session)), + s.org, + ) + + resp, err := orch.repositoryUserUpdate(r.Context(), name, group, userName, &req) + if err != nil { + handleError(w, errors.Wrap(err, "failed to update repository user")) + return + } + + j, err := json.Marshal(resp) + if err != nil { + handleError(w, errors.Wrap(err, "unable to marshal response")) + return + } + w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusNotImplemented) + w.WriteHeader(http.StatusOK) + w.Write(j) } // UsersDeleteHandler deletes a repository user @@ -211,12 +253,10 @@ func (s *server) UsersDeleteHandler(w http.ResponseWriter, r *http.Request) { return } - orch := &iamOrchestrator{ - client: iam.New( - iam.WithSession(session.Session), - ), - org: s.org, - } + orch := newIamOrchestrator( + iam.New(iam.WithSession(session.Session)), + s.org, + ) if err := orch.repositoryUserDelete(r.Context(), name, group, userName); err != nil { handleError(w, err) diff --git a/api/orchestration_users.go b/api/orchestration_users.go index cb8f37e..d2a58f5 100644 --- a/api/orchestration_users.go +++ b/api/orchestration_users.go @@ -58,11 +58,6 @@ var EcrAdminPolicy = iam.PolicyDocument{ }, } -type iamOrchestrator struct { - client iam.IAM - org string -} - // listRepositoryUsers lists users in a repository func (o *iamOrchestrator) listRepositoryUsers(ctx context.Context, group, name string) ([]string, error) { path := fmt.Sprintf("/spinup/%s/%s/%s", o.org, group, name) @@ -250,7 +245,7 @@ func (o *iamOrchestrator) userCreateGroupIfMissing(ctx context.Context, name, pa return nil } -func (o *iamOrchestrator) repositoryUserCreate(ctx context.Context, name, group, groupName string, req RepositoryUserCreateRequest) (*RepositoryUserResponse, error) { +func (o *iamOrchestrator) repositoryUserCreate(ctx context.Context, name, group, groupName string, req *RepositoryUserCreateRequest) (*RepositoryUserResponse, error) { log.Infof("creating repository %s user %s in group %s in iam group %s", name, req.UserName, group, groupName) path := fmt.Sprintf("/spinup/%s/%s/%s/", o.org, group, name) @@ -280,3 +275,50 @@ func (o *iamOrchestrator) repositoryUserCreate(ctx context.Context, name, group, return repositoryUserResponseFromIAM(o.org, user, nil, []string{groupName}), nil } + +func (o *iamOrchestrator) repositoryUserUpdate(ctx context.Context, name, group, userName string, req *RepositoryUserUpdateRequest) (*RepositoryUserResponse, error) { + log.Infof("updating repository %s user %s in group %s", name, userName, group) + + uname := fmt.Sprintf("%s-%s-%s", group, name, userName) + repository := fmt.Sprintf("%s/%s", group, name) + + response := &RepositoryUserResponse{ + UserName: userName, + } + + if req.Tags != nil { + req.Tags = normalizeUserTags(o.org, group, repository, uname, req.Tags) + if err := o.client.TagUser(ctx, uname, toIAMTags(req.Tags)); err != nil { + return nil, err + } + response.Tags = req.Tags + } + + if req.ResetKey { + // get a list of users access keys + keys, err := o.client.ListAccessKeys(ctx, uname) + if err != nil { + return nil, err + } + + newKeyOut, err := o.client.CreateAccessKey(ctx, uname) + if err != nil { + return nil, err + } + response.AccessKey = newKeyOut + + deletedKeyIds := make([]string, 0, len(keys)) + // delete the old access keys + for _, k := range keys { + err = o.client.DeleteAccessKey(ctx, uname, aws.StringValue(k.AccessKeyId)) + if err != nil { + return response, err + } + deletedKeyIds = append(deletedKeyIds, aws.StringValue(k.AccessKeyId)) + } + + response.DeletedAccessKeys = deletedKeyIds + } + + return response, nil +} diff --git a/api/orchestrators.go b/api/orchestrators.go new file mode 100644 index 0000000..0762696 --- /dev/null +++ b/api/orchestrators.go @@ -0,0 +1,30 @@ +package api + +import ( + "github.com/YaleSpinup/ecr-api/ecr" + "github.com/YaleSpinup/ecr-api/iam" +) + +type ecrOrchestrator struct { + client ecr.ECR + org string +} + +func newEcrOrchestrator(client ecr.ECR, org string) *ecrOrchestrator { + return &ecrOrchestrator{ + client: client, + org: org, + } +} + +type iamOrchestrator struct { + client iam.IAM + org string +} + +func newIamOrchestrator(client iam.IAM, org string) *iamOrchestrator { + return &iamOrchestrator{ + client: client, + org: org, + } +} diff --git a/api/server_test.go b/api/server_test.go index cf5ea9e..2a88180 100644 --- a/api/server_test.go +++ b/api/server_test.go @@ -80,3 +80,15 @@ func TestRetry(t *testing.T) { t.Errorf("unexpected error for successful retry, got %s", err) } } + +func TestOrgTagAccessPolicy(t *testing.T) { + expected := `{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Action":["*"],"Resource":"*","Condition":{"StringEquals":{"aws:ResourceTag/spinup:org":"testOrg"}}}]}` + out, err := orgTagAccessPolicy("testOrg") + if err != nil { + t.Errorf("expected nil error, got %s", err) + } + + if string(out) != expected { + t.Errorf("expected %s, got %s", expected, out) + } +} diff --git a/api/types.go b/api/types.go index cdc8f08..2599dc0 100644 --- a/api/types.go +++ b/api/types.go @@ -73,10 +73,18 @@ type RepositoryUserCreateRequest struct { // RepositoryUserResponse is the response payload for user operations type RepositoryUserResponse struct { - UserName string - AccessKeys []*iam.AccessKeyMetadata - Groups []string - Tags []*Tag + UserName string + AccessKeys []*iam.AccessKeyMetadata `json:",omitempty"` + AccessKey *iam.AccessKey `json:",omitempty"` + DeletedAccessKeys []string `json:",omitempty"` + Groups []string `json:",omitempty"` + Tags []*Tag +} + +// RepositoryUserUpdateRequest is the request payload for updating a user +type RepositoryUserUpdateRequest struct { + ResetKey bool + Tags []*Tag } // Tag is our AWS compatible tag struct that can be converted to specific tag types diff --git a/common/config.go b/common/config.go index 84aab11..e4b974e 100644 --- a/common/config.go +++ b/common/config.go @@ -46,16 +46,15 @@ type Account struct { // Version carries around the API version information type Version struct { - Version string - VersionPrerelease string - BuildStamp string - GitHash string + Version string + BuildStamp string + GitHash string } // ReadConfig decodes the configuration from an io Reader func ReadConfig(r io.Reader) (Config, error) { var c Config - log.Infoln("Reading configuration") + log.Infoln("decoding configuration...") if err := json.NewDecoder(r).Decode(&c); err != nil { return c, errors.Wrap(err, "unable to decode JSON message") } diff --git a/docker/Dockerfile b/docker/Dockerfile index 71a078b..b8ed4a1 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -16,7 +16,7 @@ RUN go mod download COPY . . RUN go version RUN go test ./... -cover -RUN go build -o /app/api.out -ldflags="-X main.Version=$version -X main.VersionPrerelease=$prerelease -X main.githash=$githash -X main.buildstamp=$buildstamp" *.go +RUN go build -o /app/api.out -ldflags="-X main.Version=$version -X main.githash=$githash -X main.buildstamp=$buildstamp" *.go RUN /app/api.out -version # final stage diff --git a/go.mod b/go.mod index e8baa9c..598c2c8 100644 --- a/go.mod +++ b/go.mod @@ -1,19 +1,20 @@ module github.com/YaleSpinup/ecr-api -go 1.14 +go 1.15 require ( github.com/YaleSpinup/apierror v0.1.0 - github.com/aws/aws-sdk-go v1.35.35 + github.com/aws/aws-sdk-go v1.37.8 github.com/google/go-cmp v0.5.1 // indirect - github.com/google/uuid v1.1.2 - github.com/gorilla/handlers v1.5.0 + github.com/google/uuid v1.2.0 + github.com/gorilla/handlers v1.5.1 github.com/gorilla/mux v1.8.0 github.com/pkg/errors v0.9.1 - github.com/prometheus/client_golang v1.7.1 - github.com/prometheus/common v0.13.0 // indirect - github.com/sirupsen/logrus v1.6.0 + github.com/prometheus/client_golang v1.9.0 + github.com/prometheus/procfs v0.5.0 // indirect + github.com/sirupsen/logrus v1.7.0 github.com/stretchr/testify v1.6.1 // indirect - golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a + golang.org/x/crypto v0.0.0-20201221181555-eec23a3978ad + golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c // indirect google.golang.org/protobuf v1.25.0 // indirect ) diff --git a/go.sum b/go.sum index eda5873..42f086b 100644 --- a/go.sum +++ b/go.sum @@ -24,6 +24,8 @@ github.com/aws/aws-lambda-go v1.13.3/go.mod h1:4UKl9IzQMoD+QF79YdCuzCwp8VbmG4VAQ github.com/aws/aws-sdk-go v1.27.0/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo= github.com/aws/aws-sdk-go v1.35.35 h1:o/EbgEcIPWga7GWhJhb3tiaxqk4/goTdo5YEMdnVxgE= github.com/aws/aws-sdk-go v1.35.35/go.mod h1:hcU610XS61/+aQV88ixoOzUoG7v3b31pl2zKMmprdro= +github.com/aws/aws-sdk-go v1.37.8 h1:9kywcbuz6vQuTf+FD+U7FshafrHzmqUCjgAEiLuIJ8U= +github.com/aws/aws-sdk-go v1.37.8/go.mod h1:hcU610XS61/+aQV88ixoOzUoG7v3b31pl2zKMmprdro= github.com/aws/aws-sdk-go-v2 v0.18.0/go.mod h1:JWVYvqSMppoMJC0x5wdwiImzgXTI9FuZwxzkQq9wy+g= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= @@ -91,6 +93,8 @@ github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvq github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= github.com/golang/protobuf v1.4.2 h1:+Z5KGCizgyZCbGh1KZqA0fcLLkwbsjIzS4aV2v7wJX0= github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/golang/protobuf v1.4.3 h1:JjCZWpVbqXDqFVmTfYWEVTMIYrL/NPdPSCHPJ0T/raM= +github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= @@ -106,10 +110,14 @@ github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm4 github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.1.2 h1:EVhdT+1Kseyi1/pUmXKaFxYsDNy9RQYkMWRH68J/W7Y= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.2.0 h1:qJYtXnJRWmpe7m/3XlyhrsLrEURqHRM2kxzoxXqyUDs= +github.com/google/uuid v1.2.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg= github.com/gorilla/handlers v1.5.0 h1:4wjo3sf9azi99c8hTmyaxp9y5S+pFszsy3pP0rAw/lw= github.com/gorilla/handlers v1.5.0/go.mod h1:t8XrUpc4KVXb7HGyJ4/cEnwQiaxrX/hz1Zv/4g96P1Q= +github.com/gorilla/handlers v1.5.1 h1:9lRY6j8DEeeBT10CvO9hGW0gmky0BprnvDI5vfhUHH4= +github.com/gorilla/handlers v1.5.1/go.mod h1:t8XrUpc4KVXb7HGyJ4/cEnwQiaxrX/hz1Zv/4g96P1Q= github.com/gorilla/mux v1.6.2/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= github.com/gorilla/mux v1.7.3/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= @@ -232,6 +240,8 @@ github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5Fsn github.com/prometheus/client_golang v1.3.0/go.mod h1:hJaj2vgQTGQmVCsAACORcieXFeDPbaTKGT+JTgUa3og= github.com/prometheus/client_golang v1.7.1 h1:NTGy1Ja9pByO+xAeH/qiWnLrKtr3hJPNjaVUwnjpdpA= github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M= +github.com/prometheus/client_golang v1.9.0 h1:Rrch9mh17XcxvEu9D9DEpb4isxjGBtcevQjKvxPRQIU= +github.com/prometheus/client_golang v1.9.0/go.mod h1:FqZLKOZnGdFAhOK4nqGHa7D66IdsO+O441Eve7ptJDU= github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= github.com/prometheus/client_model v0.0.0-20190115171406-56726106282f/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= @@ -246,12 +256,17 @@ github.com/prometheus/common v0.10.0 h1:RyRA7RzGXQZiW+tGMr7sxa85G1z0yOpM1qq5c8lN github.com/prometheus/common v0.10.0/go.mod h1:Tlit/dnDKsSWFlCLTWaA1cyBgKHSMdTB80sz/V91rCo= github.com/prometheus/common v0.13.0 h1:vJlpe9wPgDRM1Z+7Wj3zUUjY1nr6/1jNKyl7llliccg= github.com/prometheus/common v0.13.0/go.mod h1:U+gB1OBLb1lF3O42bTCL+FK18tX9Oar16Clt/msog/s= +github.com/prometheus/common v0.15.0 h1:4fgOnadei3EZvgRwxJ7RMpG1k1pOZth5Pc13tyspaKM= +github.com/prometheus/common v0.15.0/go.mod h1:U+gB1OBLb1lF3O42bTCL+FK18tX9Oar16Clt/msog/s= github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/prometheus/procfs v0.0.0-20190117184657-bf6a532e95b1/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= github.com/prometheus/procfs v0.0.8/go.mod h1:7Qr8sr6344vo1JqZ6HhLceV9o3AJ1Ff+GxbHq6oeK9A= github.com/prometheus/procfs v0.1.3 h1:F0+tqvhOksq22sc6iCHF5WGlWjdwj92p0udFh1VFBS8= github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU= +github.com/prometheus/procfs v0.2.0/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU= +github.com/prometheus/procfs v0.5.0 h1:ICtgn8CchRgPjUV2P2qwqAAPVDd5CFZsFOpkBRc1vS0= +github.com/prometheus/procfs v0.5.0/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU= github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= @@ -264,6 +279,8 @@ github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPx github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.6.0 h1:UBcNElsrwanuuMsnGSlYmtmgbb23qDR5dG+6X6Oo89I= github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88= +github.com/sirupsen/logrus v1.7.0 h1:ShrD1U9pZB12TX0cVy0DtePoCH97K8EtX+mg7ZARUtM= +github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= github.com/soheilhy/cmux v0.1.4/go.mod h1:IM3LyeVVIOuxMH7sFAkER9+bJ4dT7Ms6E4xg4kGIyLM= @@ -306,6 +323,8 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a h1:vclmkQCjlDX5OydZ9wv8rBCcS0QyQY66Mpf/7BZbInM= golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20201221181555-eec23a3978ad h1:DN0cp81fZ3njFcrLCytUHRSUkqBjfTo4Tx9RJTWs0EY= +golang.org/x/crypto v0.0.0-20201221181555-eec23a3978ad/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= @@ -356,6 +375,7 @@ golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20190502145724-3ef323f4f1fd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190726091711-fc99dfbffb4e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190826190057-c7b8b68b1456/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191220142924-d4481acd189f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200106162015-b016eb3dc98e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -364,6 +384,10 @@ golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f h1:+Nyd8tzPX9R7BWHguqsrbFdRx3WQ/1ib8I44HXV5yTA= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201214210602-f9fddec55a1e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c h1:VwygUrnw9jn88c4u8GD3rZQbqrP/tgas88tPUbBxQrk= +golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k= @@ -391,6 +415,7 @@ google.golang.org/api v0.3.1 h1:oJra/lMfmtm13/rgY/8i3MzjFWYXvQIAKjQ3HqofMk8= google.golang.org/api v0.3.1/go.mod h1:6wY9I6uQWHQ8EM57III9mq/AjF+i8G65rmVagqKMtkk= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.2.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/appengine v1.4.0 h1:/wp5JvzpHIxhs/dumFmF7BXTf3Z+dd4uXta4kVyO508= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/genproto v0.0.0-20190307195333-5fe7a883aa19/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= diff --git a/iam/iam_test.go b/iam/iam_test.go index 3cd1f00..eff44d0 100644 --- a/iam/iam_test.go +++ b/iam/iam_test.go @@ -1,6 +1,7 @@ package iam import ( + "math/rand" "reflect" "testing" "time" @@ -9,6 +10,7 @@ import ( ) var testTime = time.Now() +var testPastTime = time.Unix(rand.Int63n(time.Now().Unix()), 0) // mockIAMClient is a fake IAM client type mockIAMClient struct { @@ -23,6 +25,7 @@ func newMockIAMClient(t *testing.T, err error) iamiface.IAMAPI { err: err, } } + func TestNewSession(t *testing.T) { client := New() to := reflect.TypeOf(client).String() diff --git a/iam/users.go b/iam/users.go index 7718ffe..ae92b6b 100644 --- a/iam/users.go +++ b/iam/users.go @@ -64,6 +64,45 @@ func (i *IAM) GetUserWithPath(ctx context.Context, path, name string) (*iam.User return out.User, nil } +func (i *IAM) CreateAccessKey(ctx context.Context, name string) (*iam.AccessKey, error) { + if name == "" { + return nil, apierror.New(apierror.ErrBadRequest, "invalid input", nil) + } + + log.Infof("creating access key for %s", name) + + out, err := i.Service.CreateAccessKeyWithContext(ctx, &iam.CreateAccessKeyInput{ + UserName: aws.String(name), + }) + + if err != nil { + return nil, ErrCode("failed to create access keys", err) + } + + log.Debugf("got output from create access keys: %+v", out) + + return out.AccessKey, nil +} + +func (i *IAM) DeleteAccessKey(ctx context.Context, name, keyId string) error { + if name == "" || keyId == "" { + return apierror.New(apierror.ErrBadRequest, "invalid input", nil) + } + + log.Infof("deleting access key %s for %s", keyId, name) + + _, err := i.Service.DeleteAccessKeyWithContext(ctx, &iam.DeleteAccessKeyInput{ + AccessKeyId: aws.String(keyId), + UserName: aws.String(name), + }) + + if err != nil { + return ErrCode("failed to delete access keys", err) + } + + return nil +} + func (i *IAM) ListAccessKeys(ctx context.Context, name string) ([]*iam.AccessKeyMetadata, error) { if name == "" { return nil, apierror.New(apierror.ErrBadRequest, "invalid input", nil) @@ -164,3 +203,22 @@ func (i *IAM) ListGroupsForUser(ctx context.Context, name string) ([]string, err return groups, nil } + +func (i *IAM) TagUser(ctx context.Context, name string, tags []*iam.Tag) error { + if name == "" || tags == nil { + return apierror.New(apierror.ErrBadRequest, "invalid input", nil) + } + + log.Infof("tagging user %s with tags %+v", name, tags) + + _, err := i.Service.TagUserWithContext(ctx, &iam.TagUserInput{ + UserName: aws.String(name), + Tags: tags, + }) + + if err != nil { + return ErrCode("failed to tag user", err) + } + + return nil +} diff --git a/iam/users_test.go b/iam/users_test.go index 7d22257..08c5626 100644 --- a/iam/users_test.go +++ b/iam/users_test.go @@ -5,9 +5,290 @@ import ( "reflect" "testing" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/service/iam" ) +var rootuser1 = &iam.User{ + Arn: aws.String("arn:aws:iam::0123456789:user/rootuser1"), + CreateDate: aws.Time(testTime), + Path: aws.String("/"), + UserId: aws.String("ABCDEFGROOTUSER1"), + UserName: aws.String("rootuser1"), +} + +var user1 = &iam.User{ + Arn: aws.String("arn:aws:iam::0123456789:user/path1/user1"), + CreateDate: aws.Time(testTime), + Path: aws.String("/path1/"), + UserId: aws.String("ABCDEFGUSER1"), + UserName: aws.String("user1"), +} + +var user2 = &iam.User{ + Arn: aws.String("arn:aws:iam::0123456789:user/path1/user2"), + CreateDate: aws.Time(testTime), + Path: aws.String("/path1/"), + UserId: aws.String("ABCDEFGUSER2"), + UserName: aws.String("user2"), +} + +var user3 = &iam.User{ + Arn: aws.String("arn:aws:iam::0123456789:user/path1/user3"), + CreateDate: aws.Time(testTime), + Path: aws.String("/path2/"), + UserId: aws.String("ABCDEFGUSER3"), + UserName: aws.String("user3"), +} + +var testUsers = []*iam.User{ + rootuser1, + user1, + user2, + user3, +} + +var testAccessKeys = map[string][]*iam.AccessKeyMetadata{ + "rootuser1": {}, + "user1": { + { + AccessKeyId: aws.String("USER1XXXXXXXXX01"), + CreateDate: aws.Time(testPastTime), + Status: aws.String("Active"), + UserName: aws.String("user1"), + }, + { + AccessKeyId: aws.String("USER1XXXXXXXXX02"), + CreateDate: aws.Time(testPastTime), + Status: aws.String("Inactive"), + UserName: aws.String("user1"), + }, + }, + "user2": { + { + AccessKeyId: aws.String("USER2XXXXXXXXX01"), + CreateDate: aws.Time(testPastTime), + Status: aws.String("Active"), + UserName: aws.String("user2"), + }, + { + AccessKeyId: aws.String("USER2XXXXXXXXX02"), + CreateDate: aws.Time(testPastTime), + Status: aws.String("Inactive"), + UserName: aws.String("user2"), + }, + }, + "user3": { + { + AccessKeyId: aws.String("USER3XXXXXXXXX01"), + CreateDate: aws.Time(testPastTime), + Status: aws.String("InActive"), + UserName: aws.String("user3"), + }, + }, +} + +var testUserGroups = map[string][]*iam.Group{ + "rootuser1": { + { + Arn: aws.String(""), + CreateDate: aws.Time(testTime), + GroupId: aws.String(""), + GroupName: aws.String("rootGroup1"), + Path: aws.String("/"), + }, + { + Arn: aws.String(""), + CreateDate: aws.Time(testTime), + GroupId: aws.String(""), + GroupName: aws.String("rootGroup2"), + Path: aws.String("/"), + }, + }, + "user1": { + { + Arn: aws.String(""), + CreateDate: aws.Time(testTime), + GroupId: aws.String(""), + GroupName: aws.String("userGroup1"), + Path: aws.String("/path1/"), + }, + }, + "user2": { + { + Arn: aws.String(""), + CreateDate: aws.Time(testTime), + GroupId: aws.String(""), + GroupName: aws.String("userGroup1"), + Path: aws.String("/path1/"), + }, + }, + "user3": { + { + Arn: aws.String(""), + CreateDate: aws.Time(testTime), + GroupId: aws.String(""), + GroupName: aws.String("userGroup3"), + Path: aws.String("/path2/"), + }, + }, +} + +func (m *mockIAMClient) ListUsersWithContext(ctx context.Context, input *iam.ListUsersInput, opts ...request.Option) (*iam.ListUsersOutput, error) { + if m.err != nil { + return nil, m.err + } + + var users []*iam.User + for _, u := range testUsers { + if aws.StringValue(input.PathPrefix) == aws.StringValue(u.Path) { + users = append(users, u) + } + } + + return &iam.ListUsersOutput{Users: users}, nil +} + +func (m *mockIAMClient) GetUserWithContext(ctx context.Context, input *iam.GetUserInput, opts ...request.Option) (*iam.GetUserOutput, error) { + if m.err != nil { + return nil, m.err + } + + for _, u := range testUsers { + if aws.StringValue(input.UserName) == aws.StringValue(u.UserName) { + return &iam.GetUserOutput{User: u}, nil + } + } + + return nil, awserr.New(iam.ErrCodeNoSuchEntityException, "Not Found", nil) +} + +func (m *mockIAMClient) ListAccessKeysWithContext(ctx context.Context, input *iam.ListAccessKeysInput, opts ...request.Option) (*iam.ListAccessKeysOutput, error) { + if m.err != nil { + return nil, m.err + } + + for userName, keys := range testAccessKeys { + if aws.StringValue(input.UserName) == userName { + return &iam.ListAccessKeysOutput{AccessKeyMetadata: keys}, nil + } + } + + return nil, awserr.New(iam.ErrCodeNoSuchEntityException, "Not Found", nil) +} + +func (m *mockIAMClient) CreateUserWithContext(ctx context.Context, input *iam.CreateUserInput, opts ...request.Option) (*iam.CreateUserOutput, error) { + if m.err != nil { + return nil, m.err + } + + for _, u := range testUsers { + iu := aws.StringValue(input.UserName) + ou := aws.StringValue(u.UserName) + ip := aws.StringValue(input.Path) + op := aws.StringValue(u.Path) + if (iu == ou) && (ip == op) { + return &iam.CreateUserOutput{ + User: &iam.User{ + Arn: u.Arn, + CreateDate: u.CreateDate, + Path: u.Path, + Tags: input.Tags, + UserId: u.UserId, + UserName: u.UserName, + }, + }, nil + } + } + + return &iam.CreateUserOutput{}, nil +} + +func (m *mockIAMClient) DeleteUserWithContext(ctx context.Context, input *iam.DeleteUserInput, opts ...request.Option) (*iam.DeleteUserOutput, error) { + if m.err != nil { + return nil, m.err + } + + for _, u := range testUsers { + if aws.StringValue(input.UserName) == aws.StringValue(u.UserName) { + return &iam.DeleteUserOutput{}, nil + } + } + + return nil, awserr.New(iam.ErrCodeNoSuchEntityException, "Not Found", nil) +} + +func (m *mockIAMClient) ListGroupsForUserWithContext(ctx context.Context, input *iam.ListGroupsForUserInput, opts ...request.Option) (*iam.ListGroupsForUserOutput, error) { + if m.err != nil { + return nil, m.err + } + + for userName, groups := range testUserGroups { + if aws.StringValue(input.UserName) == userName { + return &iam.ListGroupsForUserOutput{Groups: groups}, nil + } + } + + return nil, awserr.New(iam.ErrCodeNoSuchEntityException, "Not Found", nil) +} + +func (m *mockIAMClient) DeleteAccessKeyWithContext(ctx context.Context, input *iam.DeleteAccessKeyInput, opts ...request.Option) (*iam.DeleteAccessKeyOutput, error) { + if m.err != nil { + return nil, m.err + } + + for userName, keys := range testAccessKeys { + if aws.StringValue(input.UserName) == userName { + for _, k := range keys { + if aws.StringValue(k.AccessKeyId) == aws.StringValue(input.AccessKeyId) { + if aws.StringValue(k.Status) != "Inactive" { + return nil, awserr.New(iam.ErrCodeDeleteConflictException, "access key must be inactive", nil) + } + return &iam.DeleteAccessKeyOutput{}, nil + } + } + } + } + + return nil, awserr.New(iam.ErrCodeNoSuchEntityException, "Not Found", nil) +} + +func (m *mockIAMClient) CreateAccessKeyWithContext(ctx context.Context, input *iam.CreateAccessKeyInput, opts ...request.Option) (*iam.CreateAccessKeyOutput, error) { + if m.err != nil { + return nil, m.err + } + + for _, u := range testUsers { + if aws.StringValue(input.UserName) == aws.StringValue(u.UserName) { + return &iam.CreateAccessKeyOutput{ + AccessKey: &iam.AccessKey{ + CreateDate: aws.Time(testTime), + UserName: u.UserName, + Status: aws.String("Active"), + }, + }, nil + } + } + + return nil, awserr.New(iam.ErrCodeNoSuchEntityException, "Not Found", nil) +} + +func (m *mockIAMClient) TagUserWithContext(ctx context.Context, input *iam.TagUserInput, opts ...request.Option) (*iam.TagUserOutput, error) { + if m.err != nil { + return nil, m.err + } + + for _, u := range testUsers { + if aws.StringValue(input.UserName) == aws.StringValue(u.UserName) { + return &iam.TagUserOutput{}, nil + } + } + + return nil, awserr.New(iam.ErrCodeNoSuchEntityException, "Not Found", nil) +} + func TestIAM_ListUsers(t *testing.T) { type args struct { ctx context.Context @@ -20,7 +301,47 @@ func TestIAM_ListUsers(t *testing.T) { err error wantErr bool }{ - // TODO: Add test cases. + { + name: "empty path", + args: args{ + ctx: context.TODO(), + path: "", + }, + want: []string{"rootuser1"}, + }, + { + name: "root path", + args: args{ + ctx: context.TODO(), + path: "/", + }, + want: []string{"rootuser1"}, + }, + { + name: "path1", + args: args{ + ctx: context.TODO(), + path: "/path1/", + }, + want: []string{"user1", "user2"}, + }, + { + name: "path2", + args: args{ + ctx: context.TODO(), + path: "/path2/", + }, + want: []string{"user3"}, + }, + { + name: "aws error", + args: args{ + ctx: context.TODO(), + path: "/", + }, + err: awserr.New(iam.ErrCodeLimitExceededException, "limit exceeded", nil), + wantErr: true, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -50,7 +371,70 @@ func TestIAM_GetUserWithPath(t *testing.T) { err error wantErr bool }{ - // TODO: Add test cases. + { + name: "empty path and name", + args: args{ + ctx: context.TODO(), + path: "", + name: "", + }, + wantErr: true, + }, + { + name: "empty name", + args: args{ + ctx: context.TODO(), + path: "/", + name: "", + }, + wantErr: true, + }, + { + name: "empty path, root user", + args: args{ + ctx: context.TODO(), + path: "", + name: "rootuser1", + }, + want: rootuser1, + }, + { + name: "path1, user1", + args: args{ + ctx: context.TODO(), + path: "/path1/", + name: "user1", + }, + want: user1, + }, + { + name: "path1, rootuser1", + args: args{ + ctx: context.TODO(), + path: "/path1/", + name: "rootuser1", + }, + wantErr: true, + }, + { + name: "path2, user3", + args: args{ + ctx: context.TODO(), + path: "/path2/", + name: "user3", + }, + want: user3, + }, + { + name: "aws error", + args: args{ + ctx: context.TODO(), + path: "/", + name: "rootuser1", + }, + err: awserr.New(iam.ErrCodeLimitExceededException, "limit exceeded", nil), + wantErr: true, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -79,7 +463,63 @@ func TestIAM_ListAccessKeys(t *testing.T) { err error wantErr bool }{ - // TODO: Add test cases. + { + name: "empty name", + args: args{ + ctx: context.TODO(), + name: "", + }, + wantErr: true, + }, + { + name: "rootuser", + args: args{ + ctx: context.TODO(), + name: "rootuser1", + }, + want: testAccessKeys["rootuser1"], + }, + { + name: "user1", + args: args{ + ctx: context.TODO(), + name: "user1", + }, + want: testAccessKeys["user1"], + }, + { + name: "user2", + args: args{ + ctx: context.TODO(), + name: "user2", + }, + want: testAccessKeys["user2"], + }, + { + name: "user3", + args: args{ + ctx: context.TODO(), + name: "user3", + }, + want: testAccessKeys["user3"], + }, + { + name: "unknown user", + args: args{ + ctx: context.TODO(), + name: "someotheruser", + }, + wantErr: true, + }, + { + name: "aws error", + args: args{ + ctx: context.TODO(), + name: "rootuser1", + }, + err: awserr.New(iam.ErrCodeLimitExceededException, "limit exceeded", nil), + wantErr: true, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -110,7 +550,123 @@ func TestIAM_CreateUser(t *testing.T) { err error wantErr bool }{ - // TODO: Add test cases. + { + name: "empty name, path and tags", + args: args{ + name: "", + path: "", + tags: nil, + }, + wantErr: true, + }, + { + name: "empty name", + args: args{ + name: "", + path: "/path1/", + tags: []*iam.Tag{ + { + Key: aws.String("foo"), + Value: aws.String("bar"), + }, + }, + }, + wantErr: true, + }, + { + name: "empty path", + args: args{ + name: "rootuser1", + path: "", + tags: []*iam.Tag{ + { + Key: aws.String("foo"), + Value: aws.String("bar"), + }, + }, + }, + want: rootuser1, + }, + { + name: "empty tags", + args: args{ + name: "rootuser1", + path: "/", + tags: nil, + }, + want: rootuser1, + }, + { + name: "rootuser1 in /", + args: args{ + name: "rootuser1", + path: "/", + tags: []*iam.Tag{ + { + Key: aws.String("foo"), + Value: aws.String("bar"), + }, + }, + }, + want: rootuser1, + }, + { + name: "user1 in /path1/", + args: args{ + name: "user1", + path: "/path1/", + tags: []*iam.Tag{ + { + Key: aws.String("foo"), + Value: aws.String("bar"), + }, + }, + }, + want: user1, + }, + { + name: "user2 in /path1/", + args: args{ + name: "user2", + path: "/path1/", + tags: []*iam.Tag{ + { + Key: aws.String("foo"), + Value: aws.String("bar"), + }, + }, + }, + want: user2, + }, + { + name: "user3 in /path2/", + args: args{ + name: "user3", + path: "/path2/", + tags: []*iam.Tag{ + { + Key: aws.String("foo"), + Value: aws.String("bar"), + }, + }, + }, + want: user3, + }, + { + name: "aws error", + args: args{ + name: "rootuser1", + path: "/", + tags: []*iam.Tag{ + { + Key: aws.String("foo"), + Value: aws.String("bar"), + }, + }, + }, + err: awserr.New(iam.ErrCodeLimitExceededException, "limit exceeded", nil), + wantErr: true, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -120,8 +676,19 @@ func TestIAM_CreateUser(t *testing.T) { t.Errorf("IAM.CreateUser() error = %v, wantErr %v", err, tt.wantErr) return } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("IAM.CreateUser() = %v, want %v", got, tt.want) + + // apply the tags passed with the args to the output (struct not pointer) + var want *iam.User + if tt.want != nil { + w := *tt.want + if tt.args.tags != nil { + w.Tags = tt.args.tags + } + want = &w + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("IAM.CreateUser() = %v, want %v", got, want) } }) } @@ -138,7 +705,59 @@ func TestIAM_DeleteUser(t *testing.T) { err error wantErr bool }{ - // TODO: Add test cases. + { + name: "empty name", + args: args{ + ctx: context.TODO(), + name: "", + }, + wantErr: true, + }, + { + name: "rootuser1", + args: args{ + ctx: context.TODO(), + name: "rootuser1", + }, + }, + { + name: "user1", + args: args{ + ctx: context.TODO(), + name: "user1", + }, + }, + { + name: "user2", + args: args{ + ctx: context.TODO(), + name: "user2", + }, + }, + { + name: "user3", + args: args{ + ctx: context.TODO(), + name: "user3", + }, + }, + { + name: "unknown user", + args: args{ + ctx: context.TODO(), + name: "otheruser", + }, + wantErr: true, + }, + { + name: "aws error", + args: args{ + ctx: context.TODO(), + name: "rootuser1", + }, + err: awserr.New(iam.ErrCodeLimitExceededException, "limit exceeded", nil), + wantErr: true, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -185,7 +804,63 @@ func TestIAM_ListGroupsForUser(t *testing.T) { err error wantErr bool }{ - // TODO: Add test cases. + { + name: "empty name", + args: args{ + ctx: context.TODO(), + name: "", + }, + wantErr: true, + }, + { + name: "rootuser1", + args: args{ + ctx: context.TODO(), + name: "rootuser1", + }, + want: []string{"rootGroup1", "rootGroup2"}, + }, + { + name: "user1", + args: args{ + ctx: context.TODO(), + name: "user1", + }, + want: []string{"userGroup1"}, + }, + { + name: "user2", + args: args{ + ctx: context.TODO(), + name: "user2", + }, + want: []string{"userGroup1"}, + }, + { + name: "user3", + args: args{ + ctx: context.TODO(), + name: "user3", + }, + want: []string{"userGroup3"}, + }, + { + name: "unkown user", + args: args{ + ctx: context.TODO(), + name: "someotheruser", + }, + wantErr: true, + }, + { + name: "aws error", + args: args{ + ctx: context.TODO(), + name: "rootuser1", + }, + err: awserr.New(iam.ErrCodeLimitExceededException, "limit exceeded", nil), + wantErr: true, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -201,3 +876,361 @@ func TestIAM_ListGroupsForUser(t *testing.T) { }) } } + +func TestIAM_CreateAccessKey(t *testing.T) { + type args struct { + ctx context.Context + name string + } + tests := []struct { + name string + args args + err error + want *iam.AccessKey + wantErr bool + }{ + { + name: "empty name", + args: args{ + ctx: context.TODO(), + name: "", + }, + wantErr: true, + }, + { + name: "rootuser1", + args: args{ + ctx: context.TODO(), + name: "rootuser1", + }, + want: &iam.AccessKey{ + CreateDate: aws.Time(testTime), + UserName: aws.String("rootuser1"), + Status: aws.String("Active"), + }, + }, + { + name: "user1", + args: args{ + ctx: context.TODO(), + name: "user1", + }, + want: &iam.AccessKey{ + CreateDate: aws.Time(testTime), + UserName: aws.String("user1"), + Status: aws.String("Active"), + }, + }, + { + name: "user2", + args: args{ + ctx: context.TODO(), + name: "user2", + }, + want: &iam.AccessKey{ + CreateDate: aws.Time(testTime), + UserName: aws.String("user2"), + Status: aws.String("Active"), + }, + }, + { + name: "user3", + args: args{ + ctx: context.TODO(), + name: "user3", + }, + want: &iam.AccessKey{ + CreateDate: aws.Time(testTime), + UserName: aws.String("user3"), + Status: aws.String("Active"), + }, + }, + { + name: "unknown user", + args: args{ + ctx: context.TODO(), + name: "someotheruser", + }, + wantErr: true, + }, + { + name: "aws error", + args: args{ + ctx: context.TODO(), + name: "rootuser1", + }, + err: awserr.New(iam.ErrCodeLimitExceededException, "limit exceeded", nil), + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + i := &IAM{Service: newMockIAMClient(t, tt.err)} + got, err := i.CreateAccessKey(tt.args.ctx, tt.args.name) + if (err != nil) != tt.wantErr { + t.Errorf("IAM.CreateAccessKey() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("IAM.CreateAccessKey() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestIAM_DeleteAccessKey(t *testing.T) { + type args struct { + ctx context.Context + name string + keyId string + } + tests := []struct { + name string + args args + err error + wantErr bool + }{ + { + name: "empy name and key id", + args: args{ + ctx: context.TODO(), + name: "", + keyId: "", + }, + wantErr: true, + }, + { + name: "empy name", + args: args{ + ctx: context.TODO(), + name: "", + keyId: "USER1XXXXXXXXX01", + }, + wantErr: true, + }, + { + name: "empy key id", + args: args{ + ctx: context.TODO(), + name: "user1", + keyId: "", + }, + wantErr: true, + }, + { + name: "user1 active key USER1XXXXXXXXX01", + args: args{ + ctx: context.TODO(), + name: "user1", + keyId: "USER1XXXXXXXXX01", + }, + wantErr: true, + }, + { + name: "user1 inactive key USER1XXXXXXXXX02", + args: args{ + ctx: context.TODO(), + name: "user1", + keyId: "USER1XXXXXXXXX02", + }, + }, + { + name: "user2 active key USER2XXXXXXXXX01", + args: args{ + ctx: context.TODO(), + name: "user2", + keyId: "USER2XXXXXXXXX01", + }, + wantErr: true, + }, + { + name: "user2 inactive key USER2XXXXXXXXX02", + args: args{ + ctx: context.TODO(), + name: "user2", + keyId: "USER2XXXXXXXXX02", + }, + }, + { + name: "user3 active key USER3XXXXXXXXX01", + args: args{ + ctx: context.TODO(), + name: "user3", + keyId: "USER3XXXXXXXXX01", + }, + wantErr: true, + }, + { + name: "unknown user", + args: args{ + ctx: context.TODO(), + name: "someotheruser", + keyId: "USER1XXXXXXXXX01", + }, + wantErr: true, + }, + { + name: "unknown key", + args: args{ + ctx: context.TODO(), + name: "user1", + keyId: "xxxxxmissingxxxxx", + }, + wantErr: true, + }, + { + name: "aws error", + args: args{ + ctx: context.TODO(), + name: "user1", + keyId: "USER1XXXXXXXXX02", + }, + err: awserr.New(iam.ErrCodeLimitExceededException, "limit exceeded", nil), + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + i := &IAM{Service: newMockIAMClient(t, tt.err)} + if err := i.DeleteAccessKey(tt.args.ctx, tt.args.name, tt.args.keyId); (err != nil) != tt.wantErr { + t.Errorf("IAM.DeleteAccessKey() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestIAM_TagUser(t *testing.T) { + type args struct { + ctx context.Context + name string + tags []*iam.Tag + } + tests := []struct { + name string + args args + err error + wantErr bool + }{ + { + name: "empty name and tags", + args: args{ + ctx: context.TODO(), + name: "", + tags: nil, + }, + wantErr: true, + }, + { + name: "empty name", + args: args{ + ctx: context.TODO(), + name: "", + tags: []*iam.Tag{ + { + Key: aws.String("foo"), + Value: aws.String("bar"), + }, + }, + }, + wantErr: true, + }, + { + name: "empty tags", + args: args{ + ctx: context.TODO(), + name: "rootuser1", + tags: nil, + }, + wantErr: true, + }, + { + name: "rootuser1", + args: args{ + ctx: context.TODO(), + name: "rootuser1", + tags: []*iam.Tag{ + { + Key: aws.String("foo"), + Value: aws.String("bar"), + }, + }, + }, + }, + { + name: "user1", + args: args{ + ctx: context.TODO(), + name: "user1", + tags: []*iam.Tag{ + { + Key: aws.String("foo"), + Value: aws.String("bar"), + }, + }, + }, + }, + { + name: "user2", + args: args{ + ctx: context.TODO(), + name: "user2", + tags: []*iam.Tag{ + { + Key: aws.String("foo"), + Value: aws.String("bar"), + }, + }, + }, + }, + { + name: "user3", + args: args{ + ctx: context.TODO(), + name: "user3", + tags: []*iam.Tag{ + { + Key: aws.String("foo"), + Value: aws.String("bar"), + }, + }, + }, + }, + { + name: "unknown user", + args: args{ + ctx: context.TODO(), + name: "someotheruser", + tags: []*iam.Tag{ + { + Key: aws.String("foo"), + Value: aws.String("bar"), + }, + }, + }, + wantErr: true, + }, + { + name: "aws error", + args: args{ + ctx: context.TODO(), + name: "user1", + tags: []*iam.Tag{ + { + Key: aws.String("foo"), + Value: aws.String("bar"), + }, + }, + }, + err: awserr.New(iam.ErrCodeLimitExceededException, "limit exceeded", nil), + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + i := &IAM{Service: newMockIAMClient(t, tt.err)} + if err := i.TagUser(tt.args.ctx, tt.args.name, tt.args.tags); (err != nil) != tt.wantErr { + t.Errorf("IAM.TagUser() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/k8s/Dockerfile b/k8s/Dockerfile deleted file mode 100644 index 68cdd59..0000000 --- a/k8s/Dockerfile +++ /dev/null @@ -1,26 +0,0 @@ -# build stage -FROM golang:alpine AS build-env - -ARG version="0.0.0" -ARG prerelease="" -ARG githash="" -ARG buildstamp="" - -RUN apk add --no-cache git openssh-client gcc musl-dev -RUN mkdir /app -WORKDIR /app -RUN go version -COPY go.mod . -COPY go.sum . -RUN go mod download -COPY . . -RUN go build -o /app/api.out -ldflags="-X main.Version=$version -X main.VersionPrerelease=$prerelease -X main.githash=$githash -X main.buildstamp=$buildstamp" *.go -RUN /app/api.out -version - -# final stage -FROM alpine -RUN apk add --no-cache ca-certificates -WORKDIR /app -COPY --from=build-env /app/api.out /app/api -EXPOSE 80 -ENTRYPOINT ["./api"] diff --git a/k8s/README.md b/k8s/README.md index 7705c8a..faf681f 100644 --- a/k8s/README.md +++ b/k8s/README.md @@ -1,6 +1,6 @@ # k8s development Readme -The application ships with a basic k8s config in the `k8s/` directory. There, you will find a `Dockerfile` an `api` helm chart and a `values.yaml` to deploy the pod, service and ingress. By default, `skaffold` will reference the configuration in `config/config.json`. +The application ships with a basic k8s config in the `k8s/` directory. There, you will find an `api` helm chart and a `values.yaml` to deploy the pod, service and ingress. By default, `skaffold` will use the [paketo buildpacks](https://paketo.io/) and will reference the configuration in `config/config.json`. ## install docker desktop and enable kubernetes diff --git a/k8s/api/templates/deployment.yaml b/k8s/api/templates/deployment.yaml index dda5117..e8d9114 100644 --- a/k8s/api/templates/deployment.yaml +++ b/k8s/api/templates/deployment.yaml @@ -24,10 +24,12 @@ spec: containers: - name: {{ .Chart.Name }} image: {{ .Values.image }} - volumeMounts: - - name: {{ include "api.fullname" . }}-config - mountPath: "/app/config" - readOnly: true + env: + - name: API_CONFIG + valueFrom: + secretKeyRef: + name: {{ include "api.fullname" . }}-config-json + key: config.json ports: - name: http containerPort: 8080 @@ -42,10 +44,6 @@ spec: port: http resources: {{- toYaml .Values.resources | nindent 12 }} - volumes: - - name: {{ include "api.fullname" . }}-config - secret: - secretName: {{ include "api.fullname" . }}-config-json {{- with .Values.nodeSelector }} nodeSelector: {{- toYaml . | nindent 8 }} diff --git a/main.go b/main.go index 68786b5..aa9c065 100644 --- a/main.go +++ b/main.go @@ -17,9 +17,12 @@ along with this program. If not, see . package main import ( - "bufio" + "bytes" + "encoding/base64" "flag" "fmt" + "io" + "io/ioutil" "net/http" "os" @@ -33,9 +36,6 @@ var ( // Version is the main version number Version = "0.0.0" - // VersionPrerelease is a prerelease marker - VersionPrerelease = "" - // Buildstamp is the timestamp the binary was built, it should be set at buildtime with ldflags Buildstamp = "No BuildStamp Provided" @@ -56,24 +56,17 @@ func main() { if err != nil { log.Fatal("unable to get working directory") } - log.Infof("Starting ecr-api version %s%s (%s)", Version, VersionPrerelease, cwd) - - configFile, err := os.Open(*configFileName) - if err != nil { - log.Fatalln("Unable to open config file", err) - } + log.Infof("Starting ecr-api version %s (%s)", Version, cwd) - r := bufio.NewReader(configFile) - config, err := common.ReadConfig(r) + config, err := common.ReadConfig(configReader()) if err != nil { - log.Fatalf("Unable to read configuration from %s. %+v", *configFileName, err) + log.Fatalf("Unable to read configuration from: %+v", err) } config.Version = common.Version{ - Version: Version, - VersionPrerelease: VersionPrerelease, - BuildStamp: Buildstamp, - GitHash: Githash, + Version: Version, + BuildStamp: Buildstamp, + GitHash: Githash, } // Set the loglevel, info if it's unset @@ -92,14 +85,42 @@ func main() { log.Debug("Starting profiler on 127.0.0.1:6080") go http.ListenAndServe("127.0.0.1:6080", nil) } - log.Debugf("Read config: %+v", config) + log.Debugf("loaded configuration: %+v", config) if err := api.NewServer(config); err != nil { log.Fatal(err) } } +func configReader() io.Reader { + if configEnv := os.Getenv("API_CONFIG"); configEnv != "" { + log.Infof("reading configuration from API_CONFIG environment") + + c, err := base64.StdEncoding.DecodeString(configEnv) + if err != nil { + log.Infof("API_CONFIG is not base64 encoded") + c = []byte(configEnv) + } + + return bytes.NewReader(c) + } + + log.Infof("reading configuration from %s", *configFileName) + + configFile, err := os.Open(*configFileName) + if err != nil { + log.Fatalln("unable to open config file", err) + } + + c, err := ioutil.ReadAll(configFile) + if err != nil { + log.Fatalln("unable to read config file", err) + } + + return bytes.NewReader(c) +} + func vers() { - fmt.Printf("ecr-api Version: %s%s\n", Version, VersionPrerelease) + fmt.Printf("ecr-api Version: %s\n", Version) os.Exit(0) } diff --git a/resourcegroupstaggingapi/resourcegroupstaggingapi_test.go b/resourcegroupstaggingapi/resourcegroupstaggingapi_test.go index 80130c5..b2a1a40 100644 --- a/resourcegroupstaggingapi/resourcegroupstaggingapi_test.go +++ b/resourcegroupstaggingapi/resourcegroupstaggingapi_test.go @@ -5,10 +5,13 @@ import ( "reflect" "testing" + "github.com/YaleSpinup/apierror" "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/service/resourcegroupstaggingapi" "github.com/aws/aws-sdk-go/service/resourcegroupstaggingapi/resourcegroupstaggingapiiface" + "github.com/pkg/errors" ) // mockResourceGroupsTaggingAPIClient is a fake resourcegroupstaggingapi client @@ -113,24 +116,19 @@ func (m *mockResourceGroupsTaggingAPIClient) GetResourcesWithContext(ctx context matches := true for _, filter := range input.TagFilters { innerMatch := func() bool { - m.t.Logf("processing tagfilter %+v", filter) for _, rt := range r.tags { if aws.StringValue(filter.Key) == rt.key { - m.t.Logf("tag keys match for %s (%s = %s)", r.arn, rt.key, aws.StringValue(filter.Key)) if len(filter.Values) == 0 { - m.t.Logf("appending %s to the list, keys match (%s = %s) and no value specified", r.arn, rt.key, aws.StringValue(filter.Key)) return true } for _, value := range aws.StringValueSlice(filter.Values) { if value == rt.value { - m.t.Logf("appending %s to the list, keys match (%s = %s) and value matches (%s = %s)", r.arn, rt.key, aws.StringValue(filter.Key), value, rt.value) return true } } } } - m.t.Logf("returning false for %s", r.arn) return false }() @@ -140,7 +138,6 @@ func (m *mockResourceGroupsTaggingAPIClient) GetResourcesWithContext(ctx context } if matches { - m.t.Logf("resource %s matches", r.arn) resourceList = append(resourceList, &resourcegroupstaggingapi.ResourceTagMapping{ ResourceARN: aws.String(r.arn), }) @@ -195,4 +192,43 @@ func TestGetResourcesWithTags(t *testing.T) { if !reflect.DeepEqual(expected, out) { t.Errorf("expected %+v, got %+v", expected, out) } + + if _, err := r.GetResourcesWithTags(context.TODO(), []string{}, nil); err != nil { + if aerr, ok := err.(apierror.Error); ok { + if aerr.Code != apierror.ErrBadRequest { + t.Errorf("expected error code %s, got: %s", apierror.ErrInternalError, aerr.Code) + } + } else { + t.Errorf("expected apierror.Error") + } + } else { + t.Error("expected error for empty filter list, got nil") + } + + r.Service.(*mockResourceGroupsTaggingAPIClient).err = awserr.New(resourcegroupstaggingapi.ErrCodeThrottledException, "throttled", nil) + if _, err := r.GetResourcesWithTags(context.TODO(), []string{}, filters); err != nil { + if aerr, ok := err.(apierror.Error); ok { + if aerr.Code != apierror.ErrConflict { + t.Errorf("expected error code %s, got: %s", apierror.ErrConflict, aerr.Code) + } + } else { + t.Errorf("expected apierror.Error") + } + } else { + t.Error("expected error for empty filter list, got nil") + } + + // test non-aws error + r.Service.(*mockResourceGroupsTaggingAPIClient).err = errors.New("things blowing up!") + if _, err := r.GetResourcesWithTags(context.TODO(), []string{}, filters); err != nil { + if aerr, ok := err.(apierror.Error); ok { + if aerr.Code != apierror.ErrInternalError { + t.Errorf("expected error code %s, got: %s", apierror.ErrInternalError, aerr.Code) + } + } else { + t.Errorf("expected apierror.Error") + } + } else { + t.Error("expected error for empty filter list, got nil") + } } diff --git a/skaffold.yaml b/skaffold.yaml index 402cc5a..4da2c0d 100644 --- a/skaffold.yaml +++ b/skaffold.yaml @@ -3,14 +3,14 @@ kind: Config build: artifacts: - image: hub.docker.com/yaleits/ecr-api - docker: - dockerfile: k8s/Dockerfile - buildArgs: - version: 0.0.0 - prerelease: dev + buildpacks: + builder: paketobuildpacks/builder:tiny local: useBuildkit: true deploy: + kubectl: + manifests: + - k8s/k8s-* helm: releases: - name: ecrapi diff --git a/sts/sts.go b/sts/sts.go index 6722770..6100932 100644 --- a/sts/sts.go +++ b/sts/sts.go @@ -53,7 +53,7 @@ func WithDefaultSessionDuration(t int64) STSOption { // AssumeRole assumes the passed role with the given input // NB: the combined size of the inlinePolicy and the policy within the policyArns passed is 2048 characters. func (s *STS) AssumeRole(ctx context.Context, input *sts.AssumeRoleInput) (*sts.AssumeRoleOutput, error) { - if input.RoleArn == nil { + if input == nil || aws.StringValue(input.RoleArn) == "" { return nil, apierror.New(apierror.ErrBadRequest, "invalid input", nil) } diff --git a/sts/sts_test.go b/sts/sts_test.go index cb5049a..1f06f4a 100644 --- a/sts/sts_test.go +++ b/sts/sts_test.go @@ -1,12 +1,20 @@ package sts import ( + "context" "reflect" "testing" + "time" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/service/sts" "github.com/aws/aws-sdk-go/service/sts/stsiface" ) +var testTime = time.Now() + // mockSTSClient is a fake sts client type mockSTSClient struct { stsiface.STSAPI @@ -21,10 +29,99 @@ func newMockSTSClient(t *testing.T, err error) stsiface.STSAPI { } } -func TestNewSession(t *testing.T) { +var testAssumeRoleOutput = &sts.AssumeRoleOutput{ + AssumedRoleUser: &sts.AssumedRoleUser{ + Arn: aws.String("arn:aws:sts::0123456789:assumed-role/UnitTestXAManagementRole/spinup-unit-ecr-api-000000-11111-2222-3333-444444"), + AssumedRoleId: aws.String("AABBCCDDEEFFGGHHIIJJ12345:spinup-unit-ecr-api-000000-11111-2222-3333-444444"), + }, + Credentials: &sts.Credentials{ + AccessKeyId: aws.String(""), + Expiration: aws.Time(testTime), + SecretAccessKey: aws.String(""), + SessionToken: aws.String(""), + }, +} + +func (m *mockSTSClient) AssumeRoleWithContext(ctx context.Context, input *sts.AssumeRoleInput, opts ...request.Option) (*sts.AssumeRoleOutput, error) { + if m.err != nil { + return nil, m.err + } + + return testAssumeRoleOutput, nil +} + +func TestNew(t *testing.T) { client := New() to := reflect.TypeOf(client).String() if to != "sts.STS" { t.Errorf("expected type to be 'sts.STS', got %s", to) } } + +func TestSTS_AssumeRole(t *testing.T) { + type args struct { + ctx context.Context + input *sts.AssumeRoleInput + } + tests := []struct { + name string + args args + want *sts.AssumeRoleOutput + err error + wantErr bool + }{ + { + name: "nil input", + args: args{ + ctx: context.TODO(), + input: nil, + }, + wantErr: true, + }, + { + name: "empty role arn", + args: args{ + ctx: context.TODO(), + input: &sts.AssumeRoleInput{}, + }, + wantErr: true, + }, + { + name: "valid role arn", + args: args{ + ctx: context.TODO(), + input: &sts.AssumeRoleInput{ + RoleArn: aws.String("arn:aws:iam::516855177326:role/UnitTestXAManagementRole"), + }, + }, + want: testAssumeRoleOutput, + }, + { + name: "aws error", + args: args{ + ctx: context.TODO(), + input: &sts.AssumeRoleInput{ + RoleArn: aws.String("arn:aws:iam::516855177326:role/UnitTestXAManagementRole"), + }, + }, + err: awserr.New(sts.ErrCodeMalformedPolicyDocumentException, "bad policy yo", nil), + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &STS{ + Service: newMockSTSClient(t, tt.err), + Org: "unit", + } + got, err := s.AssumeRole(tt.args.ctx, tt.args.input) + if (err != nil) != tt.wantErr { + t.Errorf("STS.AssumeRole() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("STS.AssumeRole() = %v, want %v", got, tt.want) + } + }) + } +}