diff --git a/lib/cloud/aws/errors.go b/lib/cloud/aws/errors.go index f13e1cf36c836..63a9ffa75ca95 100644 --- a/lib/cloud/aws/errors.go +++ b/lib/cloud/aws/errors.go @@ -24,6 +24,7 @@ import ( "strings" awshttp "github.com/aws/aws-sdk-go-v2/aws/transport/http" + ecstypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" iamtypes "github.com/aws/aws-sdk-go-v2/service/iam/types" "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/service/iam" @@ -31,18 +32,34 @@ import ( "github.com/gravitational/trace" ) -// ConvertRequestFailureError converts `error` into AWS RequestFailure errors -// to trace errors. If the provided error is not an `RequestFailure` it returns -// the error without modifying it. +// ConvertRequestFailureError converts `err` into AWS errors to trace errors. +// If the provided error is not a [awserr.RequestFailure] it delegates +// error conversion to [ConvertRequestFailureErrorV2] for SDK v2 compatibility. +// Prefer using [ConvertRequestFailureErrorV2] directly for AWS SDK v2 client +// errors. func ConvertRequestFailureError(err error) error { var requestErr awserr.RequestFailure - if !errors.As(err, &requestErr) { - return err + if errors.As(err, &requestErr) { + return convertRequestFailureErrorFromStatusCode(requestErr.StatusCode(), requestErr) } + return ConvertRequestFailureErrorV2(err) +} - return convertRequestFailureErrorFromStatusCode(requestErr.StatusCode(), requestErr) +// ConvertRequestFailureErrorV2 converts AWS SDK v2 errors to trace errors. +// If the provided error is not a [awshttp.ResponseError] it returns the error +// without modifying it. +func ConvertRequestFailureErrorV2(err error) error { + var re *awshttp.ResponseError + if errors.As(err, &re) { + return convertRequestFailureErrorFromStatusCode(re.HTTPStatusCode(), re.Err) + } + return err } +var ( + ecsClusterNotFoundException *ecstypes.ClusterNotFoundException +) + func convertRequestFailureErrorFromStatusCode(statusCode int, requestErr error) error { switch statusCode { case http.StatusForbidden: @@ -57,6 +74,10 @@ func convertRequestFailureErrorFromStatusCode(statusCode int, requestErr error) if strings.Contains(requestErr.Error(), redshiftserverless.ErrCodeAccessDeniedException) { return trace.AccessDenied(requestErr.Error()) } + + if strings.Contains(requestErr.Error(), ecsClusterNotFoundException.ErrorCode()) { + return trace.NotFound(requestErr.Error()) + } } return requestErr // Return unmodified. diff --git a/lib/cloud/aws/errors_test.go b/lib/cloud/aws/errors_test.go index 448c2ef9a6e24..7f0c3c26b0307 100644 --- a/lib/cloud/aws/errors_test.go +++ b/lib/cloud/aws/errors_test.go @@ -73,6 +73,30 @@ func TestConvertRequestFailureError(t *testing.T) { inputError: errors.New("not-aws-error"), wantUnmodified: true, }, + { + name: "v2 sdk error", + inputError: &awshttp.ResponseError{ + ResponseError: &smithyhttp.ResponseError{ + Response: &smithyhttp.Response{Response: &http.Response{ + StatusCode: http.StatusNotFound, + }}, + Err: trace.Errorf(""), + }, + }, + wantIsError: trace.IsNotFound, + }, + { + name: "v2 sdk error for ecs ClusterNotFoundException", + inputError: &awshttp.ResponseError{ + ResponseError: &smithyhttp.ResponseError{ + Response: &smithyhttp.Response{Response: &http.Response{ + StatusCode: http.StatusBadRequest, + }}, + Err: trace.Errorf("ClusterNotFoundException"), + }, + }, + wantIsError: trace.IsNotFound, + }, } for _, test := range tests { diff --git a/lib/integrations/awsoidc/listdeployeddatabaseservice.go b/lib/integrations/awsoidc/listdeployeddatabaseservice.go index c2894902f78fe..ad5bb9606faf4 100644 --- a/lib/integrations/awsoidc/listdeployeddatabaseservice.go +++ b/lib/integrations/awsoidc/listdeployeddatabaseservice.go @@ -27,6 +27,7 @@ import ( ecstypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" "github.com/gravitational/trace" + awslib "github.com/gravitational/teleport/lib/cloud/aws" "github.com/gravitational/teleport/lib/integrations/awsoidc/tags" ) @@ -139,6 +140,11 @@ func ListDeployedDatabaseServices(ctx context.Context, clt ListDeployedDatabaseS listServicesOutput, err := clt.ListServices(ctx, listServicesInput) if err != nil { + convertedError := awslib.ConvertRequestFailureErrorV2(err) + if trace.IsNotFound(convertedError) { + return &ListDeployedDatabaseServicesResponse{}, nil + } + return nil, trace.Wrap(err) } diff --git a/lib/integrations/awsoidc/listdeployeddatabaseservice_test.go b/lib/integrations/awsoidc/listdeployeddatabaseservice_test.go index 67f332d495c2b..84b163d519465 100644 --- a/lib/integrations/awsoidc/listdeployeddatabaseservice_test.go +++ b/lib/integrations/awsoidc/listdeployeddatabaseservice_test.go @@ -110,11 +110,11 @@ type mockListECSClient struct { } func (m *mockListECSClient) ListServices(ctx context.Context, params *ecs.ListServicesInput, optFns ...func(*ecs.Options)) (*ecs.ListServicesOutput, error) { - ret := &ecs.ListServicesOutput{} - if aws.ToString(params.Cluster) != m.clusterName { - return ret, nil + if aws.ToString(params.Cluster) != m.clusterName || len(m.services) == 0 { + return nil, trace.NotFound("ECS Cluster not found") } + ret := &ecs.ListServicesOutput{} requestedPage := 1 totalEndpoints := len(m.services) @@ -348,6 +348,25 @@ func TestListDeployedDatabaseServices(t *testing.T) { }, errCheck: require.NoError, }, + { + name: "returns empty list when the ECS Cluster does not exist", + req: ListDeployedDatabaseServicesRequest{ + Integration: "my-integration", + TeleportClusterName: "my-cluster", + Region: "us-east-1", + }, + mockClient: func() *mockListECSClient { + ret := &mockListECSClient{ + pageSize: 10, + } + return ret + }, + respCheck: func(t *testing.T, resp *ListDeployedDatabaseServicesResponse) { + require.Empty(t, resp.DeployedDatabaseServices, "expected 0 services") + require.Empty(t, resp.NextToken, "expected an empty NextToken") + }, + errCheck: require.NoError, + }, } { t.Run(tt.name, func(t *testing.T) { resp, err := ListDeployedDatabaseServices(ctx, tt.mockClient(), tt.req)