diff --git a/go.mod b/go.mod index 9b7f9a944bc36..6c4acb9d9a68b 100644 --- a/go.mod +++ b/go.mod @@ -42,12 +42,13 @@ require ( github.com/aquasecurity/libbpfgo v0.5.1-libbpf-1.2 github.com/armon/go-radix v1.0.0 github.com/aws/aws-sdk-go v1.55.5 - github.com/aws/aws-sdk-go-v2 v1.32.6 + github.com/aws/aws-sdk-go-v2 v1.32.7 github.com/aws/aws-sdk-go-v2/config v1.28.6 github.com/aws/aws-sdk-go-v2/credentials v1.17.47 github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue v1.15.20 github.com/aws/aws-sdk-go-v2/feature/dynamodbstreams/attributevalue v1.14.55 github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.21 + github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.5.2 github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.43 github.com/aws/aws-sdk-go-v2/service/applicationautoscaling v1.34.1 github.com/aws/aws-sdk-go-v2/service/athena v1.49.0 diff --git a/go.sum b/go.sum index 1795f10be5367..74ebf6bad1a27 100644 --- a/go.sum +++ b/go.sum @@ -849,8 +849,8 @@ github.com/aws/aws-sdk-go v1.49.12/go.mod h1:LF8svs817+Nz+DmiMQKTO3ubZ/6IaTpq3Tj github.com/aws/aws-sdk-go v1.55.5 h1:KKUZBfBoyqy5d3swXyiC7Q76ic40rYcbqH7qjh59kzU= github.com/aws/aws-sdk-go v1.55.5/go.mod h1:eRwEWoyTWFMVYVQzKMNHWP5/RV4xIUGMQfXQHfHkpNU= github.com/aws/aws-sdk-go-v2 v1.18.0/go.mod h1:uzbQtefpm44goOPmdKyAlXSNcwlRgF3ePWVW6EtJvvw= -github.com/aws/aws-sdk-go-v2 v1.32.6 h1:7BokKRgRPuGmKkFMhEg/jSul+tB9VvXhcViILtfG8b4= -github.com/aws/aws-sdk-go-v2 v1.32.6/go.mod h1:P5WJBrYqqbWVaOxgH0X/FYYD47/nooaPOZPlQdmiN2U= +github.com/aws/aws-sdk-go-v2 v1.32.7 h1:ky5o35oENWi0JYWUZkB7WYvVPP+bcRF5/Iq7JWSb5Rw= +github.com/aws/aws-sdk-go-v2 v1.32.7/go.mod h1:P5WJBrYqqbWVaOxgH0X/FYYD47/nooaPOZPlQdmiN2U= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.7 h1:lL7IfaFzngfx0ZwUGOZdsFFnQ5uLvR0hWqqhyE7Q9M8= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.7/go.mod h1:QraP0UcVlQJsmHfioCrveWOC1nbiWUl3ej08h4mXWoc= github.com/aws/aws-sdk-go-v2/config v1.18.25/go.mod h1:dZnYpD5wTW/dQF0rRNLVypB396zWCcPiBIvdvSWHEg4= @@ -866,6 +866,8 @@ github.com/aws/aws-sdk-go-v2/feature/dynamodbstreams/attributevalue v1.14.55/go. github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.13.3/go.mod h1:4Q0UFP0YJf0NrsEuEYHpM9fTSEVnD16Z3uyEF7J9JGM= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.21 h1:AmoU1pziydclFT/xRV+xXE/Vb8fttJCLRPv8oAkprc0= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.21/go.mod h1:AjUdLYe4Tgs6kpH4Bv7uMZo7pottoyHMn4eTcIcneaY= +github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.5.2 h1:fo+GuZNME9oGDc7VY+EBT+oCrco6RjRgUp1bKTcaHrU= +github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.5.2/go.mod h1:fnqb94UO6YCjBIic4WaqDYkNVAEFWOWiReVHitBBWW0= github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.43 h1:iLdpkYZ4cXIQMO7ud+cqMWR1xK5ESbt1rvN77tRi1BY= github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.43/go.mod h1:OgbsKPAswXDd5kxnR4vZov69p3oYjbvUyIRBAAV0y9o= github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.33/go.mod h1:7i0PF1ME/2eUPFcjkVIwq+DOygHEoK92t5cDqNgYbIw= diff --git a/integrations/event-handler/go.mod b/integrations/event-handler/go.mod index 85aa9286df663..24e4a018523e6 100644 --- a/integrations/event-handler/go.mod +++ b/integrations/event-handler/go.mod @@ -62,7 +62,7 @@ require ( github.com/armon/go-radix v1.0.0 // indirect github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 // indirect github.com/aws/aws-sdk-go v1.55.5 // indirect - github.com/aws/aws-sdk-go-v2 v1.32.6 // indirect + github.com/aws/aws-sdk-go-v2 v1.32.7 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.7 // indirect github.com/aws/aws-sdk-go-v2/config v1.28.6 // indirect github.com/aws/aws-sdk-go-v2/credentials v1.17.47 // indirect diff --git a/integrations/event-handler/go.sum b/integrations/event-handler/go.sum index 55b2c7576dae4..639828a2ce2a4 100644 --- a/integrations/event-handler/go.sum +++ b/integrations/event-handler/go.sum @@ -727,8 +727,8 @@ github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 h1:DklsrG3d github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= github.com/aws/aws-sdk-go v1.55.5 h1:KKUZBfBoyqy5d3swXyiC7Q76ic40rYcbqH7qjh59kzU= github.com/aws/aws-sdk-go v1.55.5/go.mod h1:eRwEWoyTWFMVYVQzKMNHWP5/RV4xIUGMQfXQHfHkpNU= -github.com/aws/aws-sdk-go-v2 v1.32.6 h1:7BokKRgRPuGmKkFMhEg/jSul+tB9VvXhcViILtfG8b4= -github.com/aws/aws-sdk-go-v2 v1.32.6/go.mod h1:P5WJBrYqqbWVaOxgH0X/FYYD47/nooaPOZPlQdmiN2U= +github.com/aws/aws-sdk-go-v2 v1.32.7 h1:ky5o35oENWi0JYWUZkB7WYvVPP+bcRF5/Iq7JWSb5Rw= +github.com/aws/aws-sdk-go-v2 v1.32.7/go.mod h1:P5WJBrYqqbWVaOxgH0X/FYYD47/nooaPOZPlQdmiN2U= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.7 h1:lL7IfaFzngfx0ZwUGOZdsFFnQ5uLvR0hWqqhyE7Q9M8= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.7/go.mod h1:QraP0UcVlQJsmHfioCrveWOC1nbiWUl3ej08h4mXWoc= github.com/aws/aws-sdk-go-v2/config v1.28.6 h1:D89IKtGrs/I3QXOLNTH93NJYtDhm8SYa9Q5CsPShmyo= diff --git a/integrations/terraform/go.mod b/integrations/terraform/go.mod index ecc2f0d21e193..91407a3c15a90 100644 --- a/integrations/terraform/go.mod +++ b/integrations/terraform/go.mod @@ -75,7 +75,7 @@ require ( github.com/armon/go-radix v1.0.0 // indirect github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 // indirect github.com/aws/aws-sdk-go v1.55.5 // indirect - github.com/aws/aws-sdk-go-v2 v1.32.6 // indirect + github.com/aws/aws-sdk-go-v2 v1.32.7 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.7 // indirect github.com/aws/aws-sdk-go-v2/config v1.28.6 // indirect github.com/aws/aws-sdk-go-v2/credentials v1.17.47 // indirect diff --git a/integrations/terraform/go.sum b/integrations/terraform/go.sum index ba2f9d7df83f7..22673f7d74d8f 100644 --- a/integrations/terraform/go.sum +++ b/integrations/terraform/go.sum @@ -784,8 +784,8 @@ github.com/aws/aws-sdk-go v1.15.78/go.mod h1:E3/ieXAlvM0XWO57iftYVDLLvQ824smPP3A github.com/aws/aws-sdk-go v1.25.3/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo= github.com/aws/aws-sdk-go v1.55.5 h1:KKUZBfBoyqy5d3swXyiC7Q76ic40rYcbqH7qjh59kzU= github.com/aws/aws-sdk-go v1.55.5/go.mod h1:eRwEWoyTWFMVYVQzKMNHWP5/RV4xIUGMQfXQHfHkpNU= -github.com/aws/aws-sdk-go-v2 v1.32.6 h1:7BokKRgRPuGmKkFMhEg/jSul+tB9VvXhcViILtfG8b4= -github.com/aws/aws-sdk-go-v2 v1.32.6/go.mod h1:P5WJBrYqqbWVaOxgH0X/FYYD47/nooaPOZPlQdmiN2U= +github.com/aws/aws-sdk-go-v2 v1.32.7 h1:ky5o35oENWi0JYWUZkB7WYvVPP+bcRF5/Iq7JWSb5Rw= +github.com/aws/aws-sdk-go-v2 v1.32.7/go.mod h1:P5WJBrYqqbWVaOxgH0X/FYYD47/nooaPOZPlQdmiN2U= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.7 h1:lL7IfaFzngfx0ZwUGOZdsFFnQ5uLvR0hWqqhyE7Q9M8= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.7/go.mod h1:QraP0UcVlQJsmHfioCrveWOC1nbiWUl3ej08h4mXWoc= github.com/aws/aws-sdk-go-v2/config v1.28.6 h1:D89IKtGrs/I3QXOLNTH93NJYtDhm8SYa9Q5CsPShmyo= @@ -798,6 +798,8 @@ github.com/aws/aws-sdk-go-v2/feature/dynamodbstreams/attributevalue v1.14.55 h1: github.com/aws/aws-sdk-go-v2/feature/dynamodbstreams/attributevalue v1.14.55/go.mod h1:mJ7tAfWUIVja+y4kVGXr/SucTEErhYx5nAZ39WV7W6o= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.21 h1:AmoU1pziydclFT/xRV+xXE/Vb8fttJCLRPv8oAkprc0= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.21/go.mod h1:AjUdLYe4Tgs6kpH4Bv7uMZo7pottoyHMn4eTcIcneaY= +github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.5.2 h1:fo+GuZNME9oGDc7VY+EBT+oCrco6RjRgUp1bKTcaHrU= +github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.5.2/go.mod h1:fnqb94UO6YCjBIic4WaqDYkNVAEFWOWiReVHitBBWW0= github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.43 h1:iLdpkYZ4cXIQMO7ud+cqMWR1xK5ESbt1rvN77tRi1BY= github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.43/go.mod h1:OgbsKPAswXDd5kxnR4vZov69p3oYjbvUyIRBAAV0y9o= github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.25 h1:s/fF4+yDQDoElYhfIVvSNyeCydfbuTKzhxSXDXCPasU= diff --git a/lib/cloud/aws/aws.go b/lib/cloud/aws/aws.go index 27ea56321b7df..fded866456e7f 100644 --- a/lib/cloud/aws/aws.go +++ b/lib/cloud/aws/aws.go @@ -22,12 +22,12 @@ import ( "slices" "strings" + rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types" redshifttypes "github.com/aws/aws-sdk-go-v2/service/redshift/types" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/elasticache" "github.com/aws/aws-sdk-go/service/memorydb" "github.com/aws/aws-sdk-go/service/opensearchservice" - "github.com/aws/aws-sdk-go/service/rds" "github.com/coreos/go-semver/semver" "github.com/gravitational/teleport/lib/services" @@ -74,18 +74,51 @@ func IsOpenSearchDomainAvailable(domain *opensearchservice.DomainStatus) bool { } // IsRDSProxyAvailable checks if the RDS Proxy is available. -func IsRDSProxyAvailable(dbProxy *rds.DBProxy) bool { - return IsResourceAvailable(dbProxy, dbProxy.Status) +func IsRDSProxyAvailable(dbProxy *rdstypes.DBProxy) bool { + switch dbProxy.Status { + case + rdstypes.DBProxyStatusAvailable, + rdstypes.DBProxyStatusModifying, + rdstypes.DBProxyStatusReactivating: + return true + case + rdstypes.DBProxyStatusCreating, + rdstypes.DBProxyStatusDeleting, + rdstypes.DBProxyStatusIncompatibleNetwork, + rdstypes.DBProxyStatusInsufficientResourceLimits, + rdstypes.DBProxyStatusSuspended, + rdstypes.DBProxyStatusSuspending: + return false + } + slog.WarnContext(context.Background(), "Assuming RDS Proxy with unknown status is available", + "status", dbProxy.Status, + ) + return true } // IsRDSProxyCustomEndpointAvailable checks if the RDS Proxy custom endpoint is available. -func IsRDSProxyCustomEndpointAvailable(customEndpoint *rds.DBProxyEndpoint) bool { - return IsResourceAvailable(customEndpoint, customEndpoint.Status) +func IsRDSProxyCustomEndpointAvailable(customEndpoint *rdstypes.DBProxyEndpoint) bool { + switch customEndpoint.Status { + case + rdstypes.DBProxyEndpointStatusAvailable, + rdstypes.DBProxyEndpointStatusModifying: + return true + case + rdstypes.DBProxyEndpointStatusCreating, + rdstypes.DBProxyEndpointStatusDeleting, + rdstypes.DBProxyEndpointStatusIncompatibleNetwork, + rdstypes.DBProxyEndpointStatusInsufficientResourceLimits: + return true + } + slog.WarnContext(context.Background(), "Assuming RDS Proxy custom endpoint with unknown status is available", + "status", customEndpoint.Status, + ) + return true } // IsRDSInstanceSupported returns true if database supports IAM authentication. // Currently, only MariaDB is being checked. -func IsRDSInstanceSupported(instance *rds.DBInstance) bool { +func IsRDSInstanceSupported(instance *rdstypes.DBInstance) bool { // TODO(jakule): Check other engines. if aws.StringValue(instance.Engine) != services.RDSEngineMariaDB { return true @@ -105,7 +138,7 @@ func IsRDSInstanceSupported(instance *rds.DBInstance) bool { } // IsRDSClusterSupported checks whether the Aurora cluster is supported. -func IsRDSClusterSupported(cluster *rds.DBCluster) bool { +func IsRDSClusterSupported(cluster *rdstypes.DBCluster) bool { switch aws.StringValue(cluster.EngineMode) { // Aurora Serverless v1 does NOT support IAM authentication. // https://docs.aws.amazon.com/AmazonRDS/latest/AuroraUserGuide/aurora-serverless.html#aurora-serverless.limitations @@ -129,7 +162,7 @@ func IsRDSClusterSupported(cluster *rds.DBCluster) bool { } // AuroraMySQLVersion extracts aurora mysql version from engine version -func AuroraMySQLVersion(cluster *rds.DBCluster) string { +func AuroraMySQLVersion(cluster *rdstypes.DBCluster) string { // version guide: https://docs.aws.amazon.com/AmazonRDS/latest/AuroraUserGuide/AuroraMySQL.Updates.Versions.html // a list of all the available versions: https://docs.aws.amazon.com/cli/latest/reference/rds/describe-db-engine-versions.html // @@ -154,7 +187,7 @@ func AuroraMySQLVersion(cluster *rds.DBCluster) string { // for this DocumentDB cluster. // // https://docs.aws.amazon.com/documentdb/latest/developerguide/iam-identity-auth.html -func IsDocumentDBClusterSupported(cluster *rds.DBCluster) bool { +func IsDocumentDBClusterSupported(cluster *rdstypes.DBCluster) bool { ver, err := semver.NewVersion(aws.StringValue(cluster.EngineVersion)) if err != nil { slog.ErrorContext(context.Background(), "Failed to parse DocumentDB engine version", "version", aws.StringValue(cluster.EngineVersion)) diff --git a/lib/cloud/aws/tags_helpers.go b/lib/cloud/aws/tags_helpers.go index 3e61bd6fc1a42..43f6ba48f61ca 100644 --- a/lib/cloud/aws/tags_helpers.go +++ b/lib/cloud/aws/tags_helpers.go @@ -24,14 +24,13 @@ import ( "slices" ec2TypesV2 "github.com/aws/aws-sdk-go-v2/service/ec2/types" - rdsTypesV2 "github.com/aws/aws-sdk-go-v2/service/rds/types" + rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types" redshifttypes "github.com/aws/aws-sdk-go-v2/service/redshift/types" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/ec2" "github.com/aws/aws-sdk-go/service/elasticache" "github.com/aws/aws-sdk-go/service/memorydb" "github.com/aws/aws-sdk-go/service/opensearchservice" - "github.com/aws/aws-sdk-go/service/rds" "github.com/aws/aws-sdk-go/service/redshiftserverless" "github.com/aws/aws-sdk-go/service/secretsmanager" "golang.org/x/exp/maps" @@ -43,11 +42,10 @@ import ( type ResourceTag interface { // TODO Go generic does not allow access common fields yet. List all types // here and use a type switch for now. - rdsTypesV2.Tag | + rdstypes.Tag | ec2TypesV2.Tag | redshifttypes.Tag | *ec2.Tag | - *rds.Tag | *elasticache.Tag | *memorydb.Tag | *redshiftserverless.Tag | @@ -76,8 +74,6 @@ func TagsToLabels[Tag ResourceTag](tags []Tag) map[string]string { func resourceTagToKeyValue[Tag ResourceTag](tag Tag) (string, string) { switch v := any(tag).(type) { - case *rds.Tag: - return aws.StringValue(v.Key), aws.StringValue(v.Value) case *ec2.Tag: return aws.StringValue(v.Key), aws.StringValue(v.Value) case *elasticache.Tag: @@ -86,7 +82,7 @@ func resourceTagToKeyValue[Tag ResourceTag](tag Tag) (string, string) { return aws.StringValue(v.Key), aws.StringValue(v.Value) case *redshiftserverless.Tag: return aws.StringValue(v.Key), aws.StringValue(v.Value) - case rdsTypesV2.Tag: + case rdstypes.Tag: return aws.StringValue(v.Key), aws.StringValue(v.Value) case ec2TypesV2.Tag: return aws.StringValue(v.Key), aws.StringValue(v.Value) @@ -123,22 +119,3 @@ func LabelsToTags[T any, PT SettableTag[T]](labels map[string]string) (tags []*T } return } - -// LabelsToRDSV2Tags converts labels into [rdsTypesV2.Tag] list. -func LabelsToRDSV2Tags(labels map[string]string) []rdsTypesV2.Tag { - keys := maps.Keys(labels) - slices.Sort(keys) - - ret := make([]rdsTypesV2.Tag, 0, len(keys)) - for _, key := range keys { - key := key - value := labels[key] - - ret = append(ret, rdsTypesV2.Tag{ - Key: &key, - Value: &value, - }) - } - - return ret -} diff --git a/lib/cloud/aws/tags_helpers_test.go b/lib/cloud/aws/tags_helpers_test.go index 228c477a316cb..d014b7dd0999d 100644 --- a/lib/cloud/aws/tags_helpers_test.go +++ b/lib/cloud/aws/tags_helpers_test.go @@ -22,18 +22,20 @@ import ( "testing" rdsTypesV2 "github.com/aws/aws-sdk-go-v2/service/rds/types" + rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/ec2" "github.com/aws/aws-sdk-go/service/elasticache" - "github.com/aws/aws-sdk-go/service/rds" "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/lib/cloud/awstesthelpers" ) func TestTagsToLabels(t *testing.T) { t.Parallel() t.Run("rds", func(t *testing.T) { - inputTags := []*rds.Tag{ + inputTags := []rdstypes.Tag{ { Key: aws.String("Env"), Value: aws.String("dev"), @@ -153,7 +155,7 @@ func TestLabelsToTags(t *testing.T) { }, } - actualTags := LabelsToRDSV2Tags(inputLabels) + actualTags := awstesthelpers.LabelsToRDSTags(inputLabels) require.EqualValues(t, expectTags, actualTags) }) } diff --git a/lib/cloud/awstesthelpers/tags.go b/lib/cloud/awstesthelpers/tags.go index 5e1f4aa0e0738..28bed6b973f0b 100644 --- a/lib/cloud/awstesthelpers/tags.go +++ b/lib/cloud/awstesthelpers/tags.go @@ -22,6 +22,7 @@ import ( "maps" "slices" + rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types" redshifttypes "github.com/aws/aws-sdk-go-v2/service/redshift/types" ) @@ -43,3 +44,22 @@ func LabelsToRedshiftTags(labels map[string]string) []redshifttypes.Tag { return ret } + +// LabelsToRDSTags converts labels into a [rdstypes.Tag] list. +func LabelsToRDSTags(labels map[string]string) []rdstypes.Tag { + keys := slices.Collect(maps.Keys(labels)) + slices.Sort(keys) + + ret := make([]rdstypes.Tag, 0, len(keys)) + for _, key := range keys { + key := key + value := labels[key] + + ret = append(ret, rdstypes.Tag{ + Key: &key, + Value: &value, + }) + } + + return ret +} diff --git a/lib/cloud/clients.go b/lib/cloud/clients.go index 99c2deb4001f0..80c4f9bc06ee0 100644 --- a/lib/cloud/clients.go +++ b/lib/cloud/clients.go @@ -51,8 +51,6 @@ import ( "github.com/aws/aws-sdk-go/service/memorydb/memorydbiface" "github.com/aws/aws-sdk-go/service/opensearchservice" "github.com/aws/aws-sdk-go/service/opensearchservice/opensearchserviceiface" - "github.com/aws/aws-sdk-go/service/rds" - "github.com/aws/aws-sdk-go/service/rds/rdsiface" "github.com/aws/aws-sdk-go/service/redshiftserverless" "github.com/aws/aws-sdk-go/service/redshiftserverless/redshiftserverlessiface" "github.com/aws/aws-sdk-go/service/s3" @@ -111,8 +109,6 @@ type GCPClients interface { type AWSClients interface { // GetAWSSession returns AWS session for the specified region and any role(s). GetAWSSession(ctx context.Context, region string, opts ...AWSOptionsFn) (*awssession.Session, error) - // GetAWSRDSClient returns AWS RDS client for the specified region. - GetAWSRDSClient(ctx context.Context, region string, opts ...AWSOptionsFn) (rdsiface.RDSAPI, error) // GetAWSRedshiftServerlessClient returns AWS Redshift Serverless client for the specified region. GetAWSRedshiftServerlessClient(ctx context.Context, region string, opts ...AWSOptionsFn) (redshiftserverlessiface.RedshiftServerlessAPI, error) // GetAWSElastiCacheClient returns AWS ElastiCache client for the specified region. @@ -504,15 +500,6 @@ func (c *cloudClients) GetAWSSession(ctx context.Context, region string, opts .. return c.getAWSSessionForRole(ctx, region, options) } -// GetAWSRDSClient returns AWS RDS client for the specified region. -func (c *cloudClients) GetAWSRDSClient(ctx context.Context, region string, opts ...AWSOptionsFn) (rdsiface.RDSAPI, error) { - session, err := c.GetAWSSession(ctx, region, opts...) - if err != nil { - return nil, trace.Wrap(err) - } - return rds.New(session), nil -} - // GetAWSRedshiftServerlessClient returns AWS Redshift Serverless client for the specified region. func (c *cloudClients) GetAWSRedshiftServerlessClient(ctx context.Context, region string, opts ...AWSOptionsFn) (redshiftserverlessiface.RedshiftServerlessAPI, error) { session, err := c.GetAWSSession(ctx, region, opts...) @@ -1018,8 +1005,6 @@ var _ Clients = (*TestCloudClients)(nil) // TestCloudClients are used in tests. type TestCloudClients struct { - RDS rdsiface.RDSAPI - RDSPerRegion map[string]rdsiface.RDSAPI RedshiftServerless redshiftserverlessiface.RedshiftServerlessAPI ElastiCache elasticacheiface.ElastiCacheAPI OpenSearch opensearchserviceiface.OpenSearchServiceAPI @@ -1089,18 +1074,6 @@ func (c *TestCloudClients) getAWSSessionForRegion(region string) (*awssession.Se }) } -// GetAWSRDSClient returns AWS RDS client for the specified region. -func (c *TestCloudClients) GetAWSRDSClient(ctx context.Context, region string, opts ...AWSOptionsFn) (rdsiface.RDSAPI, error) { - _, err := c.GetAWSSession(ctx, region, opts...) - if err != nil { - return nil, trace.Wrap(err) - } - if len(c.RDSPerRegion) != 0 { - return c.RDSPerRegion[region], nil - } - return c.RDS, nil -} - // GetAWSRedshiftServerlessClient returns AWS Redshift Serverless client for the specified region. func (c *TestCloudClients) GetAWSRedshiftServerlessClient(ctx context.Context, region string, opts ...AWSOptionsFn) (redshiftserverlessiface.RedshiftServerlessAPI, error) { _, err := c.GetAWSSession(ctx, region, opts...) diff --git a/lib/cloud/mocks/aws_config.go b/lib/cloud/mocks/aws_config.go index b52dfbd36d74a..6d46be65d0191 100644 --- a/lib/cloud/mocks/aws_config.go +++ b/lib/cloud/mocks/aws_config.go @@ -29,11 +29,16 @@ import ( ) type AWSConfigProvider struct { + Err error STSClient *STSClient OIDCIntegrationClient awsconfig.OIDCIntegrationClient } func (f *AWSConfigProvider) GetConfig(ctx context.Context, region string, optFns ...awsconfig.OptionsFn) (aws.Config, error) { + if f.Err != nil { + return aws.Config{}, f.Err + } + stsClt := f.STSClient if stsClt == nil { stsClt = &STSClient{} diff --git a/lib/cloud/mocks/aws_rds.go b/lib/cloud/mocks/aws_rds.go index 50130d668f5c0..c7fc7331d17e5 100644 --- a/lib/cloud/mocks/aws_rds.go +++ b/lib/cloud/mocks/aws_rds.go @@ -19,159 +19,120 @@ package mocks import ( + "context" "fmt" + "github.com/aws/aws-sdk-go-v2/service/rds" + rdsv2 "github.com/aws/aws-sdk-go-v2/service/rds" + rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/arn" - "github.com/aws/aws-sdk-go/aws/request" - "github.com/aws/aws-sdk-go/service/rds" - "github.com/aws/aws-sdk-go/service/rds/rdsiface" "github.com/google/uuid" "github.com/gravitational/trace" - libcloudaws "github.com/gravitational/teleport/lib/cloud/aws" + "github.com/gravitational/teleport/lib/cloud/awstesthelpers" ) -// RDSMock mocks AWS RDS API. -type RDSMock struct { - rdsiface.RDSAPI - DBInstances []*rds.DBInstance - DBClusters []*rds.DBCluster - DBProxies []*rds.DBProxy - DBProxyEndpoints []*rds.DBProxyEndpoint - DBEngineVersions []*rds.DBEngineVersion -} +type RDSClient struct { + Unauth bool -func (m *RDSMock) DescribeDBInstancesWithContext(ctx aws.Context, input *rds.DescribeDBInstancesInput, options ...request.Option) (*rds.DescribeDBInstancesOutput, error) { - if err := checkEngineFilters(input.Filters, m.DBEngineVersions); err != nil { - return nil, trace.Wrap(err) - } - instances, err := applyInstanceFilters(m.DBInstances, input.Filters) - if err != nil { - return nil, trace.Wrap(err) - } - if aws.StringValue(input.DBInstanceIdentifier) == "" { - return &rds.DescribeDBInstancesOutput{ - DBInstances: instances, - }, nil - } - for _, instance := range instances { - if aws.StringValue(instance.DBInstanceIdentifier) == aws.StringValue(input.DBInstanceIdentifier) { - return &rds.DescribeDBInstancesOutput{ - DBInstances: []*rds.DBInstance{instance}, - }, nil - } - } - return nil, trace.NotFound("instance %v not found", aws.StringValue(input.DBInstanceIdentifier)) + DBInstances []rdstypes.DBInstance + DBClusters []rdstypes.DBCluster + DBProxies []rdstypes.DBProxy + DBProxyEndpoints []rdstypes.DBProxyEndpoint + DBEngineVersions []rdstypes.DBEngineVersion } -func (m *RDSMock) DescribeDBInstancesPagesWithContext(ctx aws.Context, input *rds.DescribeDBInstancesInput, fn func(*rds.DescribeDBInstancesOutput, bool) bool, options ...request.Option) error { - if err := checkEngineFilters(input.Filters, m.DBEngineVersions); err != nil { - return trace.Wrap(err) +func (c *RDSClient) DescribeDBClusters(_ context.Context, input *rdsv2.DescribeDBClustersInput, _ ...func(*rdsv2.Options)) (*rdsv2.DescribeDBClustersOutput, error) { + if c.Unauth { + return nil, trace.AccessDenied("unauthorized") } - instances, err := applyInstanceFilters(m.DBInstances, input.Filters) - if err != nil { - return trace.Wrap(err) - } - fn(&rds.DescribeDBInstancesOutput{ - DBInstances: instances, - }, true) - return nil -} -func (m *RDSMock) DescribeDBClustersWithContext(ctx aws.Context, input *rds.DescribeDBClustersInput, options ...request.Option) (*rds.DescribeDBClustersOutput, error) { - if err := checkEngineFilters(input.Filters, m.DBEngineVersions); err != nil { + if err := checkEngineFilters(input.Filters, c.DBEngineVersions); err != nil { return nil, trace.Wrap(err) } - clusters, err := applyClusterFilters(m.DBClusters, input.Filters) + clusters, err := applyClusterFilters(c.DBClusters, input.Filters) if err != nil { return nil, trace.Wrap(err) } if aws.StringValue(input.DBClusterIdentifier) == "" { - return &rds.DescribeDBClustersOutput{ + return &rdsv2.DescribeDBClustersOutput{ DBClusters: clusters, }, nil } for _, cluster := range clusters { if aws.StringValue(cluster.DBClusterIdentifier) == aws.StringValue(input.DBClusterIdentifier) { - return &rds.DescribeDBClustersOutput{ - DBClusters: []*rds.DBCluster{cluster}, + return &rdsv2.DescribeDBClustersOutput{ + DBClusters: []rdstypes.DBCluster{cluster}, }, nil } } return nil, trace.NotFound("cluster %v not found", aws.StringValue(input.DBClusterIdentifier)) } -func (m *RDSMock) DescribeDBClustersPagesWithContext(aws aws.Context, input *rds.DescribeDBClustersInput, fn func(*rds.DescribeDBClustersOutput, bool) bool, options ...request.Option) error { - if err := checkEngineFilters(input.Filters, m.DBEngineVersions); err != nil { - return trace.Wrap(err) +func (c *RDSClient) DescribeDBInstances(_ context.Context, input *rdsv2.DescribeDBInstancesInput, _ ...func(*rdsv2.Options)) (*rdsv2.DescribeDBInstancesOutput, error) { + if c.Unauth { + return nil, trace.AccessDenied("unauthorized") } - clusters, err := applyClusterFilters(m.DBClusters, input.Filters) + + if err := checkEngineFilters(input.Filters, c.DBEngineVersions); err != nil { + return nil, trace.Wrap(err) + } + instances, err := applyInstanceFilters(c.DBInstances, input.Filters) if err != nil { - return trace.Wrap(err) + return nil, trace.Wrap(err) } - fn(&rds.DescribeDBClustersOutput{ - DBClusters: clusters, - }, true) - return nil -} - -func (m *RDSMock) ModifyDBInstanceWithContext(ctx aws.Context, input *rds.ModifyDBInstanceInput, options ...request.Option) (*rds.ModifyDBInstanceOutput, error) { - for i, instance := range m.DBInstances { + if aws.StringValue(input.DBInstanceIdentifier) == "" { + return &rdsv2.DescribeDBInstancesOutput{ + DBInstances: instances, + }, nil + } + for _, instance := range instances { if aws.StringValue(instance.DBInstanceIdentifier) == aws.StringValue(input.DBInstanceIdentifier) { - if aws.BoolValue(input.EnableIAMDatabaseAuthentication) { - m.DBInstances[i].IAMDatabaseAuthenticationEnabled = aws.Bool(true) - } - return &rds.ModifyDBInstanceOutput{ - DBInstance: m.DBInstances[i], + return &rdsv2.DescribeDBInstancesOutput{ + DBInstances: []rdstypes.DBInstance{instance}, }, nil } } return nil, trace.NotFound("instance %v not found", aws.StringValue(input.DBInstanceIdentifier)) } -func (m *RDSMock) ModifyDBClusterWithContext(ctx aws.Context, input *rds.ModifyDBClusterInput, options ...request.Option) (*rds.ModifyDBClusterOutput, error) { - for i, cluster := range m.DBClusters { - if aws.StringValue(cluster.DBClusterIdentifier) == aws.StringValue(input.DBClusterIdentifier) { - if aws.BoolValue(input.EnableIAMDatabaseAuthentication) { - m.DBClusters[i].IAMDatabaseAuthenticationEnabled = aws.Bool(true) - } - return &rds.ModifyDBClusterOutput{ - DBCluster: m.DBClusters[i], - }, nil - } +func (c *RDSClient) DescribeDBProxies(_ context.Context, input *rdsv2.DescribeDBProxiesInput, _ ...func(*rdsv2.Options)) (*rdsv2.DescribeDBProxiesOutput, error) { + if c.Unauth { + return nil, trace.AccessDenied("unauthorized") } - return nil, trace.NotFound("cluster %v not found", aws.StringValue(input.DBClusterIdentifier)) -} -func (m *RDSMock) DescribeDBProxiesWithContext(ctx aws.Context, input *rds.DescribeDBProxiesInput, options ...request.Option) (*rds.DescribeDBProxiesOutput, error) { if aws.StringValue(input.DBProxyName) == "" { - return &rds.DescribeDBProxiesOutput{ - DBProxies: m.DBProxies, + return &rdsv2.DescribeDBProxiesOutput{ + DBProxies: c.DBProxies, }, nil } - for _, dbProxy := range m.DBProxies { + for _, dbProxy := range c.DBProxies { if aws.StringValue(dbProxy.DBProxyName) == aws.StringValue(input.DBProxyName) { - return &rds.DescribeDBProxiesOutput{ - DBProxies: []*rds.DBProxy{dbProxy}, + return &rdsv2.DescribeDBProxiesOutput{ + DBProxies: []rdstypes.DBProxy{dbProxy}, }, nil } } return nil, trace.NotFound("proxy %v not found", aws.StringValue(input.DBProxyName)) } -func (m *RDSMock) DescribeDBProxyEndpointsWithContext(ctx aws.Context, input *rds.DescribeDBProxyEndpointsInput, options ...request.Option) (*rds.DescribeDBProxyEndpointsOutput, error) { +func (c *RDSClient) DescribeDBProxyEndpoints(_ context.Context, input *rdsv2.DescribeDBProxyEndpointsInput, _ ...func(*rdsv2.Options)) (*rdsv2.DescribeDBProxyEndpointsOutput, error) { + if c.Unauth { + return nil, trace.AccessDenied("unauthorized") + } + inputProxyName := aws.StringValue(input.DBProxyName) inputProxyEndpointName := aws.StringValue(input.DBProxyEndpointName) if inputProxyName == "" && inputProxyEndpointName == "" { - return &rds.DescribeDBProxyEndpointsOutput{ - DBProxyEndpoints: m.DBProxyEndpoints, + return &rdsv2.DescribeDBProxyEndpointsOutput{ + DBProxyEndpoints: c.DBProxyEndpoints, }, nil } - var endpoints []*rds.DBProxyEndpoint - for _, dbProxyEndpoiont := range m.DBProxyEndpoints { + var endpoints []rdstypes.DBProxyEndpoint + for _, dbProxyEndpoiont := range c.DBProxyEndpoints { if inputProxyEndpointName != "" && inputProxyEndpointName != aws.StringValue(dbProxyEndpoiont.DBProxyEndpointName) { continue @@ -187,114 +148,51 @@ func (m *RDSMock) DescribeDBProxyEndpointsWithContext(ctx aws.Context, input *rd if len(endpoints) == 0 { return nil, trace.NotFound("proxy endpoint %v not found", aws.StringValue(input.DBProxyEndpointName)) } - return &rds.DescribeDBProxyEndpointsOutput{DBProxyEndpoints: endpoints}, nil -} - -func (m *RDSMock) DescribeDBProxiesPagesWithContext(ctx aws.Context, input *rds.DescribeDBProxiesInput, fn func(*rds.DescribeDBProxiesOutput, bool) bool, options ...request.Option) error { - fn(&rds.DescribeDBProxiesOutput{ - DBProxies: m.DBProxies, - }, true) - return nil -} - -func (m *RDSMock) DescribeDBProxyEndpointsPagesWithContext(ctx aws.Context, input *rds.DescribeDBProxyEndpointsInput, fn func(*rds.DescribeDBProxyEndpointsOutput, bool) bool, options ...request.Option) error { - fn(&rds.DescribeDBProxyEndpointsOutput{ - DBProxyEndpoints: m.DBProxyEndpoints, - }, true) - return nil -} - -func (m *RDSMock) ListTagsForResourceWithContext(ctx aws.Context, input *rds.ListTagsForResourceInput, options ...request.Option) (*rds.ListTagsForResourceOutput, error) { - return &rds.ListTagsForResourceOutput{}, nil -} - -// RDSMockUnauth is a mock RDS client that returns access denied to each call. -type RDSMockUnauth struct { - rdsiface.RDSAPI -} - -func (m *RDSMockUnauth) DescribeDBInstancesWithContext(ctx aws.Context, input *rds.DescribeDBInstancesInput, options ...request.Option) (*rds.DescribeDBInstancesOutput, error) { - return nil, trace.AccessDenied("unauthorized") -} - -func (m *RDSMockUnauth) DescribeDBClustersWithContext(ctx aws.Context, input *rds.DescribeDBClustersInput, options ...request.Option) (*rds.DescribeDBClustersOutput, error) { - return nil, trace.AccessDenied("unauthorized") -} - -func (m *RDSMockUnauth) ModifyDBInstanceWithContext(ctx aws.Context, input *rds.ModifyDBInstanceInput, options ...request.Option) (*rds.ModifyDBInstanceOutput, error) { - return nil, trace.AccessDenied("unauthorized") -} - -func (m *RDSMockUnauth) ModifyDBClusterWithContext(ctx aws.Context, input *rds.ModifyDBClusterInput, options ...request.Option) (*rds.ModifyDBClusterOutput, error) { - return nil, trace.AccessDenied("unauthorized") -} - -func (m *RDSMockUnauth) DescribeDBInstancesPagesWithContext(ctx aws.Context, input *rds.DescribeDBInstancesInput, fn func(*rds.DescribeDBInstancesOutput, bool) bool, options ...request.Option) error { - return trace.AccessDenied("unauthorized") + return &rdsv2.DescribeDBProxyEndpointsOutput{DBProxyEndpoints: endpoints}, nil } -func (m *RDSMockUnauth) DescribeDBClustersPagesWithContext(aws aws.Context, input *rds.DescribeDBClustersInput, fn func(*rds.DescribeDBClustersOutput, bool) bool, options ...request.Option) error { - return trace.AccessDenied("unauthorized") -} - -func (m *RDSMockUnauth) DescribeDBProxiesWithContext(ctx aws.Context, input *rds.DescribeDBProxiesInput, options ...request.Option) (*rds.DescribeDBProxiesOutput, error) { - return nil, trace.AccessDenied("unauthorized") -} - -func (m *RDSMockUnauth) DescribeDBProxyEndpointsWithContext(ctx aws.Context, input *rds.DescribeDBProxyEndpointsInput, options ...request.Option) (*rds.DescribeDBProxyEndpointsOutput, error) { - return nil, trace.AccessDenied("unauthorized") -} - -func (m *RDSMockUnauth) DescribeDBProxiesPagesWithContext(ctx aws.Context, input *rds.DescribeDBProxiesInput, fn func(*rds.DescribeDBProxiesOutput, bool) bool, options ...request.Option) error { - return trace.AccessDenied("unauthorized") -} - -// RDSMockByDBType is a mock RDS client that mocks API calls by DB type -type RDSMockByDBType struct { - rdsiface.RDSAPI - DBInstances rdsiface.RDSAPI - DBClusters rdsiface.RDSAPI - DBProxies rdsiface.RDSAPI -} - -func (m *RDSMockByDBType) DescribeDBInstancesWithContext(ctx aws.Context, input *rds.DescribeDBInstancesInput, options ...request.Option) (*rds.DescribeDBInstancesOutput, error) { - return m.DBInstances.DescribeDBInstancesWithContext(ctx, input, options...) -} - -func (m *RDSMockByDBType) ModifyDBInstanceWithContext(ctx aws.Context, input *rds.ModifyDBInstanceInput, options ...request.Option) (*rds.ModifyDBInstanceOutput, error) { - return m.DBInstances.ModifyDBInstanceWithContext(ctx, input, options...) -} - -func (m *RDSMockByDBType) DescribeDBInstancesPagesWithContext(ctx aws.Context, input *rds.DescribeDBInstancesInput, fn func(*rds.DescribeDBInstancesOutput, bool) bool, options ...request.Option) error { - return m.DBInstances.DescribeDBInstancesPagesWithContext(ctx, input, fn, options...) -} - -func (m *RDSMockByDBType) DescribeDBClustersWithContext(ctx aws.Context, input *rds.DescribeDBClustersInput, options ...request.Option) (*rds.DescribeDBClustersOutput, error) { - return m.DBClusters.DescribeDBClustersWithContext(ctx, input, options...) -} - -func (m *RDSMockByDBType) ModifyDBClusterWithContext(ctx aws.Context, input *rds.ModifyDBClusterInput, options ...request.Option) (*rds.ModifyDBClusterOutput, error) { - return m.DBClusters.ModifyDBClusterWithContext(ctx, input, options...) -} +func (c *RDSClient) ModifyDBCluster(ctx context.Context, input *rdsv2.ModifyDBClusterInput, optFns ...func(*rdsv2.Options)) (*rdsv2.ModifyDBClusterOutput, error) { + if c.Unauth { + return nil, trace.AccessDenied("unauthorized") + } -func (m *RDSMockByDBType) DescribeDBClustersPagesWithContext(aws aws.Context, input *rds.DescribeDBClustersInput, fn func(*rds.DescribeDBClustersOutput, bool) bool, options ...request.Option) error { - return m.DBClusters.DescribeDBClustersPagesWithContext(aws, input, fn, options...) + for i, cluster := range c.DBClusters { + if aws.StringValue(cluster.DBClusterIdentifier) == aws.StringValue(input.DBClusterIdentifier) { + if aws.BoolValue(input.EnableIAMDatabaseAuthentication) { + c.DBClusters[i].IAMDatabaseAuthenticationEnabled = aws.Bool(true) + } + return &rdsv2.ModifyDBClusterOutput{ + DBCluster: &c.DBClusters[i], + }, nil + } + } + return nil, trace.NotFound("cluster %v not found", aws.StringValue(input.DBClusterIdentifier)) } -func (m *RDSMockByDBType) DescribeDBProxiesWithContext(ctx aws.Context, input *rds.DescribeDBProxiesInput, options ...request.Option) (*rds.DescribeDBProxiesOutput, error) { - return m.DBProxies.DescribeDBProxiesWithContext(ctx, input, options...) -} +func (c *RDSClient) ModifyDBInstance(ctx context.Context, input *rdsv2.ModifyDBInstanceInput, optFns ...func(*rdsv2.Options)) (*rdsv2.ModifyDBInstanceOutput, error) { + if c.Unauth { + return nil, trace.AccessDenied("unauthorized") + } -func (m *RDSMockByDBType) DescribeDBProxyEndpointsWithContext(ctx aws.Context, input *rds.DescribeDBProxyEndpointsInput, options ...request.Option) (*rds.DescribeDBProxyEndpointsOutput, error) { - return m.DBProxies.DescribeDBProxyEndpointsWithContext(ctx, input, options...) + for i, instance := range c.DBInstances { + if aws.StringValue(instance.DBInstanceIdentifier) == aws.StringValue(input.DBInstanceIdentifier) { + if aws.BoolValue(input.EnableIAMDatabaseAuthentication) { + c.DBInstances[i].IAMDatabaseAuthenticationEnabled = aws.Bool(true) + } + return &rdsv2.ModifyDBInstanceOutput{ + DBInstance: &c.DBInstances[i], + }, nil + } + } + return nil, trace.NotFound("instance %v not found", aws.StringValue(input.DBInstanceIdentifier)) } -func (m *RDSMockByDBType) DescribeDBProxiesPagesWithContext(ctx aws.Context, input *rds.DescribeDBProxiesInput, fn func(*rds.DescribeDBProxiesOutput, bool) bool, options ...request.Option) error { - return m.DBProxies.DescribeDBProxiesPagesWithContext(ctx, input, fn, options...) +func (c *RDSClient) ListTagsForResource(context.Context, *rds.ListTagsForResourceInput, ...func(*rds.Options)) (*rds.ListTagsForResourceOutput, error) { + return &rds.ListTagsForResourceOutput{}, nil } // checkEngineFilters checks RDS filters to detect unrecognized engine filters. -func checkEngineFilters(filters []*rds.Filter, engineVersions []*rds.DBEngineVersion) error { +func checkEngineFilters(filters []rdstypes.Filter, engineVersions []rdstypes.DBEngineVersion) error { if len(filters) == 0 { return nil } @@ -307,8 +205,8 @@ func checkEngineFilters(filters []*rds.Filter, engineVersions []*rds.DBEngineVer continue } for _, v := range f.Values { - if _, ok := recognizedEngines[aws.StringValue(v)]; !ok { - return trace.Errorf("unrecognized engine name %q", aws.StringValue(v)) + if _, ok := recognizedEngines[v]; !ok { + return trace.Errorf("unrecognized engine name %q", v) } } } @@ -316,11 +214,11 @@ func checkEngineFilters(filters []*rds.Filter, engineVersions []*rds.DBEngineVer } // applyInstanceFilters filters RDS DBInstances using the provided RDS filters. -func applyInstanceFilters(in []*rds.DBInstance, filters []*rds.Filter) ([]*rds.DBInstance, error) { +func applyInstanceFilters(in []rdstypes.DBInstance, filters []rdstypes.Filter) ([]rdstypes.DBInstance, error) { if len(filters) == 0 { return in, nil } - var out []*rds.DBInstance + var out []rdstypes.DBInstance efs := engineFilterSet(filters) clusterIDs := clusterIdentifierFilterSet(filters) for _, instance := range in { @@ -336,11 +234,11 @@ func applyInstanceFilters(in []*rds.DBInstance, filters []*rds.Filter) ([]*rds.D } // applyClusterFilters filters RDS DBClusters using the provided RDS filters. -func applyClusterFilters(in []*rds.DBCluster, filters []*rds.Filter) ([]*rds.DBCluster, error) { +func applyClusterFilters(in []rdstypes.DBCluster, filters []rdstypes.Filter) ([]rdstypes.DBCluster, error) { if len(filters) == 0 { return in, nil } - var out []*rds.DBCluster + var out []rdstypes.DBCluster efs := engineFilterSet(filters) for _, cluster := range in { if clusterEngineMatches(cluster, efs) { @@ -351,59 +249,59 @@ func applyClusterFilters(in []*rds.DBCluster, filters []*rds.Filter) ([]*rds.DBC } // engineFilterSet builds a string set of engine names from a list of RDS filters. -func engineFilterSet(filters []*rds.Filter) map[string]struct{} { +func engineFilterSet(filters []rdstypes.Filter) map[string]struct{} { return filterValues(filters, "engine") } // clusterIdentifierFilterSet builds a string set of ClusterIDs from a list of RDS filters. -func clusterIdentifierFilterSet(filters []*rds.Filter) map[string]struct{} { +func clusterIdentifierFilterSet(filters []rdstypes.Filter) map[string]struct{} { return filterValues(filters, "db-cluster-id") } -func filterValues(filters []*rds.Filter, filterKey string) map[string]struct{} { +func filterValues(filters []rdstypes.Filter, filterKey string) map[string]struct{} { out := make(map[string]struct{}) for _, f := range filters { if aws.StringValue(f.Name) != filterKey { continue } for _, v := range f.Values { - out[aws.StringValue(v)] = struct{}{} + out[v] = struct{}{} } } return out } // instanceEngineMatches returns whether an RDS DBInstance engine matches any engine name in a filter set. -func instanceEngineMatches(instance *rds.DBInstance, filterSet map[string]struct{}) bool { +func instanceEngineMatches(instance rdstypes.DBInstance, filterSet map[string]struct{}) bool { _, ok := filterSet[aws.StringValue(instance.Engine)] return ok } // instanceClusterIDMatches returns whether an RDS DBInstance ClusterID matches any ClusterID in a filter set. -func instanceClusterIDMatches(instance *rds.DBInstance, filterSet map[string]struct{}) bool { +func instanceClusterIDMatches(instance rdstypes.DBInstance, filterSet map[string]struct{}) bool { _, ok := filterSet[aws.StringValue(instance.DBClusterIdentifier)] return ok } // clusterEngineMatches returns whether an RDS DBCluster engine matches any engine name in a filter set. -func clusterEngineMatches(cluster *rds.DBCluster, filterSet map[string]struct{}) bool { +func clusterEngineMatches(cluster rdstypes.DBCluster, filterSet map[string]struct{}) bool { _, ok := filterSet[aws.StringValue(cluster.Engine)] return ok } -// RDSInstance returns a sample rds.DBInstance. -func RDSInstance(name, region string, labels map[string]string, opts ...func(*rds.DBInstance)) *rds.DBInstance { - instance := &rds.DBInstance{ +// RDSInstance returns a sample rdstypes.DBInstance. +func RDSInstance(name, region string, labels map[string]string, opts ...func(*rdstypes.DBInstance)) *rdstypes.DBInstance { + instance := &rdstypes.DBInstance{ DBInstanceArn: aws.String(fmt.Sprintf("arn:aws:rds:%v:123456789012:db:%v", region, name)), DBInstanceIdentifier: aws.String(name), DbiResourceId: aws.String(uuid.New().String()), Engine: aws.String("postgres"), DBInstanceStatus: aws.String("available"), - Endpoint: &rds.Endpoint{ + Endpoint: &rdstypes.Endpoint{ Address: aws.String(fmt.Sprintf("%v.aabbccdd.%v.rds.amazonaws.com", name, region)), - Port: aws.Int64(5432), + Port: aws.Int32(5432), }, - TagList: libcloudaws.LabelsToTags[rds.Tag](labels), + TagList: awstesthelpers.LabelsToRDSTags(labels), } for _, opt := range opts { opt(instance) @@ -411,9 +309,9 @@ func RDSInstance(name, region string, labels map[string]string, opts ...func(*rd return instance } -// RDSCluster returns a sample rds.DBCluster. -func RDSCluster(name, region string, labels map[string]string, opts ...func(*rds.DBCluster)) *rds.DBCluster { - cluster := &rds.DBCluster{ +// RDSCluster returns a sample rdstypes.DBCluster. +func RDSCluster(name, region string, labels map[string]string, opts ...func(*rdstypes.DBCluster)) *rdstypes.DBCluster { + cluster := &rdstypes.DBCluster{ DBClusterArn: aws.String(fmt.Sprintf("arn:aws:rds:%v:123456789012:cluster:%v", region, name)), DBClusterIdentifier: aws.String(name), DbClusterResourceId: aws.String(uuid.New().String()), @@ -422,9 +320,9 @@ func RDSCluster(name, region string, labels map[string]string, opts ...func(*rds Status: aws.String("available"), Endpoint: aws.String(fmt.Sprintf("%v.cluster-aabbccdd.%v.rds.amazonaws.com", name, region)), ReaderEndpoint: aws.String(fmt.Sprintf("%v.cluster-ro-aabbccdd.%v.rds.amazonaws.com", name, region)), - Port: aws.Int64(3306), - TagList: libcloudaws.LabelsToTags[rds.Tag](labels), - DBClusterMembers: []*rds.DBClusterMember{{ + Port: aws.Int32(3306), + TagList: awstesthelpers.LabelsToRDSTags(labels), + DBClusterMembers: []rdstypes.DBClusterMember{{ IsClusterWriter: aws.Bool(true), // One writer by default. }}, } @@ -434,49 +332,49 @@ func RDSCluster(name, region string, labels map[string]string, opts ...func(*rds return cluster } -func WithRDSClusterReader(cluster *rds.DBCluster) { - cluster.DBClusterMembers = append(cluster.DBClusterMembers, &rds.DBClusterMember{ +func WithRDSClusterReader(cluster *rdstypes.DBCluster) { + cluster.DBClusterMembers = append(cluster.DBClusterMembers, rdstypes.DBClusterMember{ IsClusterWriter: aws.Bool(false), // Add reader. }) } -func WithRDSClusterCustomEndpoint(name string) func(*rds.DBCluster) { - return func(cluster *rds.DBCluster) { +func WithRDSClusterCustomEndpoint(name string) func(*rdstypes.DBCluster) { + return func(cluster *rdstypes.DBCluster) { parsed, _ := arn.Parse(aws.StringValue(cluster.DBClusterArn)) - cluster.CustomEndpoints = append(cluster.CustomEndpoints, aws.String( + cluster.CustomEndpoints = append(cluster.CustomEndpoints, fmt.Sprintf("%v.cluster-custom-aabbccdd.%v.rds.amazonaws.com", name, parsed.Region), - )) + ) } } -// RDSProxy returns a sample rds.DBProxy. -func RDSProxy(name, region, vpcID string) *rds.DBProxy { - return &rds.DBProxy{ +// RDSProxy returns a sample rdstypes.DBProxy. +func RDSProxy(name, region, vpcID string) *rdstypes.DBProxy { + return &rdstypes.DBProxy{ DBProxyArn: aws.String(fmt.Sprintf("arn:aws:rds:%s:123456789012:db-proxy:prx-%s", region, name)), DBProxyName: aws.String(name), - EngineFamily: aws.String(rds.EngineFamilyMysql), + EngineFamily: aws.String(string(rdstypes.EngineFamilyMysql)), Endpoint: aws.String(fmt.Sprintf("%s.proxy-aabbccdd.%s.rds.amazonaws.com", name, region)), VpcId: aws.String(vpcID), RequireTLS: aws.Bool(true), - Status: aws.String("available"), + Status: "available", } } -// RDSProxyCustomEndpoint returns a sample rds.DBProxyEndpoint. -func RDSProxyCustomEndpoint(rdsProxy *rds.DBProxy, name, region string) *rds.DBProxyEndpoint { - return &rds.DBProxyEndpoint{ +// RDSProxyCustomEndpoint returns a sample rdstypes.DBProxyEndpoint. +func RDSProxyCustomEndpoint(rdsProxy *rdstypes.DBProxy, name, region string) *rdstypes.DBProxyEndpoint { + return &rdstypes.DBProxyEndpoint{ Endpoint: aws.String(fmt.Sprintf("%s.endpoint.proxy-aabbccdd.%s.rds.amazonaws.com", name, region)), DBProxyEndpointName: aws.String(name), DBProxyName: rdsProxy.DBProxyName, DBProxyEndpointArn: aws.String(fmt.Sprintf("arn:aws:rds:%v:123456789012:db-proxy-endpoint:prx-endpoint-%v", region, name)), - TargetRole: aws.String(rds.DBProxyEndpointTargetRoleReadOnly), - Status: aws.String("available"), + TargetRole: rdstypes.DBProxyEndpointTargetRoleReadOnly, + Status: "available", } } -// DocumentDBCluster returns a sample rds.DBCluster for DocumentDB. -func DocumentDBCluster(name, region string, labels map[string]string, opts ...func(*rds.DBCluster)) *rds.DBCluster { - cluster := &rds.DBCluster{ +// DocumentDBCluster returns a sample rdstypes.DBCluster for DocumentDB. +func DocumentDBCluster(name, region string, labels map[string]string, opts ...func(*rdstypes.DBCluster)) *rdstypes.DBCluster { + cluster := &rdstypes.DBCluster{ DBClusterArn: aws.String(fmt.Sprintf("arn:aws:rds:%v:123456789012:cluster:%v", region, name)), DBClusterIdentifier: aws.String(name), DbClusterResourceId: aws.String(uuid.New().String()), @@ -485,9 +383,9 @@ func DocumentDBCluster(name, region string, labels map[string]string, opts ...fu Status: aws.String("available"), Endpoint: aws.String(fmt.Sprintf("%v.cluster-aabbccdd.%v.docdb.amazonaws.com", name, region)), ReaderEndpoint: aws.String(fmt.Sprintf("%v.cluster-ro-aabbccdd.%v.docdb.amazonaws.com", name, region)), - Port: aws.Int64(27017), - TagList: libcloudaws.LabelsToTags[rds.Tag](labels), - DBClusterMembers: []*rds.DBClusterMember{{ + Port: aws.Int32(27017), + TagList: awstesthelpers.LabelsToRDSTags(labels), + DBClusterMembers: []rdstypes.DBClusterMember{{ IsClusterWriter: aws.Bool(true), // One writer by default. }}, } @@ -497,6 +395,6 @@ func DocumentDBCluster(name, region string, labels map[string]string, opts ...fu return cluster } -func WithDocumentDBClusterReader(cluster *rds.DBCluster) { +func WithDocumentDBClusterReader(cluster *rdstypes.DBCluster) { WithRDSClusterReader(cluster) } diff --git a/lib/cloud/mocks/aws_sts.go b/lib/cloud/mocks/aws_sts.go index 178a1259669a4..4c55712bc82fb 100644 --- a/lib/cloud/mocks/aws_sts.go +++ b/lib/cloud/mocks/aws_sts.go @@ -67,7 +67,7 @@ func (m *STSClient) AssumeRoleWithWebIdentity(ctx context.Context, in *sts.Assum }, nil } -func (m *STSClient) AssumeRole(ctx context.Context, in *sts.AssumeRoleInput, _ ...func(*sts.Options)) (*sts.AssumeRoleOutput, error) { +func (m *STSClient) AssumeRole(ctx context.Context, in *sts.AssumeRoleInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleOutput, error) { // Retrieve credentials if we have a credential provider, so that all // assume-role providers in a role chain are triggered to call AssumeRole. if m.credentialProvider != nil { @@ -89,6 +89,12 @@ func (m *STSClient) AssumeRole(ctx context.Context, in *sts.AssumeRoleInput, _ . }, nil } +func (m *STSClient) GetCallerIdentity(ctx context.Context, params *sts.GetCallerIdentityInput, optFns ...func(*sts.Options)) (*sts.GetCallerIdentityOutput, error) { + return &sts.GetCallerIdentityOutput{ + Arn: aws.String(m.ARN), + }, nil +} + // record is a helper function that records the role ARN and external ID for an // assumed role. // It delegates to the configured recordFn, if it has one, so that all assumed diff --git a/lib/srv/db/access_test.go b/lib/srv/db/access_test.go index 46c6ca1a19f53..906acfd06c7cd 100644 --- a/lib/srv/db/access_test.go +++ b/lib/srv/db/access_test.go @@ -2491,7 +2491,6 @@ func (p *agentParams) setDefaults(c *testContext) { if p.CloudClients == nil { p.CloudClients = &clients.TestCloudClients{ STS: &mocks.STSClientV1{}, - RDS: &mocks.RDSMock{}, RedshiftServerless: &mocks.RedshiftServerlessMock{}, ElastiCache: p.ElastiCache, MemoryDB: p.MemoryDB, @@ -2501,7 +2500,7 @@ func (p *agentParams) setDefaults(c *testContext) { } } if p.AWSConfigProvider == nil { - p.AWSConfigProvider = &mocks.AWSConfigProvider{} + p.AWSConfigProvider = &mocks.AWSConfigProvider{Err: trace.AccessDenied("AWS SDK clients are disabled for tests by default")} } if p.DiscoveryResourceChecker == nil { diff --git a/lib/srv/db/cloud/aws.go b/lib/srv/db/cloud/aws.go index 8222599c318a7..cd19dafee7496 100644 --- a/lib/srv/db/cloud/aws.go +++ b/lib/srv/db/cloud/aws.go @@ -23,21 +23,24 @@ import ( "encoding/json" "log/slog" + "github.com/aws/aws-sdk-go-v2/service/rds" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/iam" "github.com/aws/aws-sdk-go/service/iam/iamiface" - "github.com/aws/aws-sdk-go/service/rds" "github.com/gravitational/trace" "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/cloud" awslib "github.com/gravitational/teleport/lib/cloud/aws" + "github.com/gravitational/teleport/lib/cloud/awsconfig" dbiam "github.com/gravitational/teleport/lib/srv/db/common/iam" ) // awsConfig is the config for the client that configures IAM for AWS databases. type awsConfig struct { + // awsConfigProvider provides [aws.Config] for AWS SDK service clients. + awsConfigProvider awsconfig.Provider // clients is an interface for creating AWS clients. clients cloud.Clients // identity is AWS identity this database agent is running as. @@ -46,6 +49,9 @@ type awsConfig struct { database types.Database // policyName is the name of the inline policy for the identity. policyName string + // rdsClientProviderFn is an internal-only [rdsClient] provider + // func that is only set in tests. + rdsClientProviderFn rdsClientProviderFunc } // Check validates the config. @@ -62,6 +68,12 @@ func (c *awsConfig) Check() error { if c.policyName == "" { return trace.BadParameter("missing parameter policy name") } + if c.awsConfigProvider == nil { + return trace.BadParameter("missing parameter awsConfigProvider") + } + if c.rdsClientProviderFn == nil { + return trace.BadParameter("missing parameter rdsClientProviderFn") + } return nil } @@ -75,7 +87,7 @@ func newAWS(ctx context.Context, config awsConfig) (*awsClient, error) { teleport.ComponentKey, "aws", "db", config.database.GetName(), ) - dbConfigurator, err := getDBConfigurator(logger, config.clients, config.database) + dbConfigurator, err := getDBConfigurator(logger, config) if err != nil { return nil, trace.Wrap(err) } @@ -102,10 +114,14 @@ type dbIAMAuthConfigurator interface { } // getDBConfigurator returns a database IAM Auth configurator. -func getDBConfigurator(logger *slog.Logger, clients cloud.Clients, db types.Database) (dbIAMAuthConfigurator, error) { - if db.IsRDS() { +func getDBConfigurator(logger *slog.Logger, cfg awsConfig) (dbIAMAuthConfigurator, error) { + if cfg.database.IsRDS() { // Only setting for RDS instances and Aurora clusters. - return &rdsDBConfigurator{clients: clients, logger: logger}, nil + return &rdsDBConfigurator{ + awsConfigProvider: cfg.awsConfigProvider, + logger: logger, + rdsClientProviderFn: cfg.rdsClientProviderFn, + }, nil } // IAM Auth for Redshift, ElastiCache, and RDS Proxy is always enabled. return &nopDBConfigurator{}, nil @@ -303,8 +319,9 @@ func (r *awsClient) detachIAMPolicy(ctx context.Context) error { } type rdsDBConfigurator struct { - clients cloud.Clients - logger *slog.Logger + awsConfigProvider awsconfig.Provider + logger *slog.Logger + rdsClientProviderFn rdsClientProviderFunc } // ensureIAMAuth enables RDS instance IAM auth if it isn't already enabled. @@ -323,15 +340,16 @@ func (r *rdsDBConfigurator) ensureIAMAuth(ctx context.Context, db types.Database func (r *rdsDBConfigurator) enableIAMAuth(ctx context.Context, db types.Database) error { r.logger.DebugContext(ctx, "Enabling IAM auth for RDS") meta := db.GetAWS() - rdsClt, err := r.clients.GetAWSRDSClient(ctx, meta.Region, - cloud.WithAssumeRoleFromAWSMeta(meta), - cloud.WithAmbientCredentials(), + awsCfg, err := r.awsConfigProvider.GetConfig(ctx, meta.Region, + awsconfig.WithAssumeRole(meta.AssumeRoleARN, meta.ExternalID), + awsconfig.WithAmbientCredentials(), ) if err != nil { return trace.Wrap(err) } + clt := r.rdsClientProviderFn(awsCfg) if meta.RDS.ClusterID != "" { - _, err = rdsClt.ModifyDBClusterWithContext(ctx, &rds.ModifyDBClusterInput{ + _, err = clt.ModifyDBCluster(ctx, &rds.ModifyDBClusterInput{ DBClusterIdentifier: aws.String(meta.RDS.ClusterID), EnableIAMDatabaseAuthentication: aws.Bool(true), ApplyImmediately: aws.Bool(true), @@ -339,7 +357,7 @@ func (r *rdsDBConfigurator) enableIAMAuth(ctx context.Context, db types.Database return awslib.ConvertIAMError(err) } if meta.RDS.InstanceID != "" { - _, err = rdsClt.ModifyDBInstanceWithContext(ctx, &rds.ModifyDBInstanceInput{ + _, err = clt.ModifyDBInstance(ctx, &rds.ModifyDBInstanceInput{ DBInstanceIdentifier: aws.String(meta.RDS.InstanceID), EnableIAMDatabaseAuthentication: aws.Bool(true), ApplyImmediately: aws.Bool(true), diff --git a/lib/srv/db/cloud/iam.go b/lib/srv/db/cloud/iam.go index aa1629157d78f..1165e3f589058 100644 --- a/lib/srv/db/cloud/iam.go +++ b/lib/srv/db/cloud/iam.go @@ -25,6 +25,8 @@ import ( "sync" "time" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/rds" "github.com/aws/aws-sdk-go/aws/credentials" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" @@ -35,6 +37,7 @@ import ( "github.com/gravitational/teleport/lib/auth/authclient" "github.com/gravitational/teleport/lib/cloud" awslib "github.com/gravitational/teleport/lib/cloud/aws" + "github.com/gravitational/teleport/lib/cloud/awsconfig" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/srv/db/common/iam" ) @@ -45,6 +48,8 @@ type IAMConfig struct { Clock clockwork.Clock // AccessPoint is a caching client connected to the Auth Server. AccessPoint authclient.DatabaseAccessPoint + // AWSConfigProvider provides [aws.Config] for AWS SDK service clients. + AWSConfigProvider awsconfig.Provider // Clients is an interface for retrieving cloud clients. Clients cloud.Clients // HostID is the host identified where this agent is running. @@ -52,6 +57,9 @@ type IAMConfig struct { HostID string // onProcessedTask is called after a task is processed. onProcessedTask func(processedTask iamTask, processError error) + // rdsClientProviderFn is an internal-only [rdsClient] provider + // func that is only set in tests. + rdsClientProviderFn rdsClientProviderFunc } // Check validates the IAM configurator config. @@ -62,6 +70,9 @@ func (c *IAMConfig) Check() error { if c.AccessPoint == nil { return trace.BadParameter("missing AccessPoint") } + if c.AWSConfigProvider == nil { + return trace.BadParameter("missing AWSConfigProvider") + } if c.Clients == nil { cloudClients, err := cloud.NewClients() if err != nil { @@ -72,6 +83,11 @@ func (c *IAMConfig) Check() error { if c.HostID == "" { return trace.BadParameter("missing HostID") } + if c.rdsClientProviderFn == nil { + c.rdsClientProviderFn = func(cfg aws.Config, optFns ...func(*rds.Options)) rdsClient { + return rds.NewFromConfig(cfg, optFns...) + } + } return nil } @@ -233,10 +249,12 @@ func (c *IAM) getAWSConfigurator(ctx context.Context, database types.Database) ( return nil, trace.Wrap(err) } return newAWS(ctx, awsConfig{ - clients: c.cfg.Clients, - policyName: policyName, - identity: identity, - database: database, + awsConfigProvider: c.cfg.AWSConfigProvider, + clients: c.cfg.Clients, + database: database, + identity: identity, + policyName: policyName, + rdsClientProviderFn: c.cfg.rdsClientProviderFn, }) } diff --git a/lib/srv/db/cloud/iam_test.go b/lib/srv/db/cloud/iam_test.go index d13d1fc74b86c..6ad6b4dd71f09 100644 --- a/lib/srv/db/cloud/iam_test.go +++ b/lib/srv/db/cloud/iam_test.go @@ -24,10 +24,10 @@ import ( "testing" "time" + rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/service/iam" - "github.com/aws/aws-sdk-go/service/rds" "github.com/google/uuid" "github.com/gravitational/trace" "github.com/stretchr/testify/require" @@ -46,26 +46,28 @@ func TestAWSIAM(t *testing.T) { t.Cleanup(cancel) // Setup AWS database objects. - rdsInstance := &rds.DBInstance{ + rdsInstance := &rdstypes.DBInstance{ DBInstanceArn: aws.String("arn:aws:rds:us-west-1:123456789012:db:postgres-rds"), DBInstanceIdentifier: aws.String("postgres-rds"), DbiResourceId: aws.String("db-xyz"), } - auroraCluster := &rds.DBCluster{ + auroraCluster := &rdstypes.DBCluster{ DBClusterArn: aws.String("arn:aws:rds:us-east-1:123456789012:cluster:postgres-aurora"), DBClusterIdentifier: aws.String("postgres-aurora"), DbClusterResourceId: aws.String("cluster-xyz"), } // Configure mocks. - stsClient := &mocks.STSClientV1{ - ARN: "arn:aws:iam::123456789012:role/test-role", + stsClient := &mocks.STSClient{ + STSClientV1: mocks.STSClientV1{ + ARN: "arn:aws:iam::123456789012:role/test-role", + }, } - rdsClient := &mocks.RDSMock{ - DBInstances: []*rds.DBInstance{rdsInstance}, - DBClusters: []*rds.DBCluster{auroraCluster}, + clt := &mocks.RDSClient{ + DBInstances: []rdstypes.DBInstance{*rdsInstance}, + DBClusters: []rdstypes.DBCluster{*auroraCluster}, } iamClient := &mocks.IAMMock{} @@ -152,15 +154,18 @@ func TestAWSIAM(t *testing.T) { } configurator, err := NewIAM(ctx, IAMConfig{ AccessPoint: &mockAccessPoint{}, + AWSConfigProvider: &mocks.AWSConfigProvider{ + STSClient: stsClient, + }, Clients: &clients.TestCloudClients{ - RDS: rdsClient, - STS: stsClient, + STS: &stsClient.STSClientV1, IAM: iamClient, }, HostID: "host-id", onProcessedTask: func(iamTask, error) { taskChan <- struct{}{} }, + rdsClientProviderFn: newFakeRDSClientProvider(clt), }) require.NoError(t, err) require.NoError(t, configurator.Start(ctx)) @@ -177,6 +182,7 @@ func TestAWSIAM(t *testing.T) { database: rdsDatabase, wantPolicyContains: rdsDatabase.GetAWS().RDS.ResourceID, getIAMAuthEnabled: func() bool { + rdsInstance := &clt.DBInstances[0] out := aws.BoolValue(rdsInstance.IAMDatabaseAuthenticationEnabled) // reset it rdsInstance.IAMDatabaseAuthenticationEnabled = aws.Bool(false) @@ -187,6 +193,7 @@ func TestAWSIAM(t *testing.T) { database: auroraDatabase, wantPolicyContains: auroraDatabase.GetAWS().RDS.ResourceID, getIAMAuthEnabled: func() bool { + auroraCluster := &clt.DBClusters[0] out := aws.BoolValue(auroraCluster.IAMDatabaseAuthenticationEnabled) // reset it auroraCluster.IAMDatabaseAuthenticationEnabled = aws.Bool(false) @@ -283,14 +290,20 @@ func TestAWSIAMNoPermissions(t *testing.T) { t.Cleanup(cancel) // Create unauthorized mocks for AWS services. - stsClient := &mocks.STSClientV1{ - ARN: "arn:aws:iam::123456789012:role/test-role", + stsClient := &mocks.STSClient{ + STSClientV1: mocks.STSClientV1{ + ARN: "arn:aws:iam::123456789012:role/test-role", + }, } // Make configurator. configurator, err := NewIAM(ctx, IAMConfig{ AccessPoint: &mockAccessPoint{}, Clients: &clients.TestCloudClients{}, // placeholder, HostID: "host-id", + AWSConfigProvider: &mocks.AWSConfigProvider{ + STSClient: stsClient, + }, + rdsClientProviderFn: newFakeRDSClientProvider(&mocks.RDSClient{Unauth: true}), }) require.NoError(t, err) @@ -303,33 +316,30 @@ func TestAWSIAMNoPermissions(t *testing.T) { name: "RDS database", meta: types.AWS{Region: "localhost", AccountID: "123456789012", RDS: types.RDS{InstanceID: "postgres-rds", ResourceID: "postgres-rds-resource-id"}}, clients: &clients.TestCloudClients{ - RDS: &mocks.RDSMockUnauth{}, IAM: &mocks.IAMErrorMock{ Error: trace.AccessDenied("unauthorized"), }, - STS: stsClient, + STS: &stsClient.STSClientV1, }, }, { name: "Aurora cluster", meta: types.AWS{Region: "localhost", AccountID: "123456789012", RDS: types.RDS{ClusterID: "postgres-aurora", ResourceID: "postgres-aurora-resource-id"}}, clients: &clients.TestCloudClients{ - RDS: &mocks.RDSMockUnauth{}, IAM: &mocks.IAMErrorMock{ Error: trace.AccessDenied("unauthorized"), }, - STS: stsClient, + STS: &stsClient.STSClientV1, }, }, { name: "RDS database missing metadata", meta: types.AWS{Region: "localhost", RDS: types.RDS{ClusterID: "postgres-aurora"}}, clients: &clients.TestCloudClients{ - RDS: &mocks.RDSMockUnauth{}, IAM: &mocks.IAMErrorMock{ Error: trace.AccessDenied("unauthorized"), }, - STS: stsClient, + STS: &stsClient.STSClientV1, }, }, { @@ -339,7 +349,7 @@ func TestAWSIAMNoPermissions(t *testing.T) { IAM: &mocks.IAMErrorMock{ Error: trace.AccessDenied("unauthorized"), }, - STS: stsClient, + STS: &stsClient.STSClientV1, }, }, { @@ -352,7 +362,7 @@ func TestAWSIAMNoPermissions(t *testing.T) { IAM: &mocks.IAMErrorMock{ Error: trace.AccessDenied("unauthorized"), }, - STS: stsClient, + STS: &stsClient.STSClientV1, }, }, { @@ -362,7 +372,7 @@ func TestAWSIAMNoPermissions(t *testing.T) { IAM: &mocks.IAMErrorMock{ Error: awserr.New(iam.ErrCodeUnmodifiableEntityException, "unauthorized", fmt.Errorf("unauthorized")), }, - STS: stsClient, + STS: &stsClient.STSClientV1, }, }, } diff --git a/lib/srv/db/cloud/meta.go b/lib/srv/db/cloud/meta.go index 031f9fb9dae4c..1eb163d889c99 100644 --- a/lib/srv/db/cloud/meta.go +++ b/lib/srv/db/cloud/meta.go @@ -24,14 +24,14 @@ import ( "strings" "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/rds" + rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types" "github.com/aws/aws-sdk-go-v2/service/redshift" redshifttypes "github.com/aws/aws-sdk-go-v2/service/redshift/types" "github.com/aws/aws-sdk-go/service/elasticache" "github.com/aws/aws-sdk-go/service/elasticache/elasticacheiface" "github.com/aws/aws-sdk-go/service/memorydb" "github.com/aws/aws-sdk-go/service/memorydb/memorydbiface" - "github.com/aws/aws-sdk-go/service/rds" - "github.com/aws/aws-sdk-go/service/rds/rdsiface" "github.com/aws/aws-sdk-go/service/redshiftserverless" "github.com/aws/aws-sdk-go/service/redshiftserverless/redshiftserverlessiface" "github.com/gravitational/trace" @@ -53,6 +53,19 @@ type redshiftClient interface { // redshiftClientProviderFunc provides a [redshiftClient]. type redshiftClientProviderFunc func(cfg aws.Config, optFns ...func(*redshift.Options)) redshiftClient +// rdsClient defines a subset of the AWS RDS client API. +type rdsClient interface { + rds.DescribeDBClustersAPIClient + rds.DescribeDBInstancesAPIClient + rds.DescribeDBProxiesAPIClient + rds.DescribeDBProxyEndpointsAPIClient + ModifyDBCluster(ctx context.Context, params *rds.ModifyDBClusterInput, optFns ...func(*rds.Options)) (*rds.ModifyDBClusterOutput, error) + ModifyDBInstance(ctx context.Context, params *rds.ModifyDBInstanceInput, optFns ...func(*rds.Options)) (*rds.ModifyDBInstanceOutput, error) +} + +// rdsClientProviderFunc provides a [rdsClient]. +type rdsClientProviderFunc func(cfg aws.Config, optFns ...func(*rds.Options)) rdsClient + // MetadataConfig is the cloud metadata service config. type MetadataConfig struct { // Clients is an interface for retrieving cloud clients. @@ -63,6 +76,9 @@ type MetadataConfig struct { // redshiftClientProviderFn is an internal-only [redshiftClient] provider // func that is only set in tests. redshiftClientProviderFn redshiftClientProviderFunc + // rdsClientProviderFn is an internal-only [rdsClient] provider + // func that is only set in tests. + rdsClientProviderFn rdsClientProviderFunc } // Check validates the metadata service config. @@ -83,6 +99,11 @@ func (c *MetadataConfig) Check() error { return redshift.NewFromConfig(cfg, optFns...) } } + if c.rdsClientProviderFn == nil { + c.rdsClientProviderFn = func(cfg aws.Config, optFns ...func(*rds.Options)) rdsClient { + return rds.NewFromConfig(cfg, optFns...) + } + } return nil } @@ -147,20 +168,21 @@ func (m *Metadata) updateAWS(ctx context.Context, database types.Database, fetch // fetchRDSMetadata fetches metadata for the provided RDS or Aurora database. func (m *Metadata) fetchRDSMetadata(ctx context.Context, database types.Database) (*types.AWS, error) { meta := database.GetAWS() - rds, err := m.cfg.Clients.GetAWSRDSClient(ctx, meta.Region, - cloud.WithAssumeRoleFromAWSMeta(meta), - cloud.WithAmbientCredentials(), + awsCfg, err := m.cfg.AWSConfigProvider.GetConfig(ctx, meta.Region, + awsconfig.WithAssumeRole(meta.AssumeRoleARN, meta.ExternalID), + awsconfig.WithAmbientCredentials(), ) if err != nil { return nil, trace.Wrap(err) } + clt := m.cfg.rdsClientProviderFn(awsCfg) if meta.RDS.ClusterID != "" { - return fetchRDSClusterMetadata(ctx, rds, meta.RDS.ClusterID) + return fetchRDSClusterMetadata(ctx, clt, meta.RDS.ClusterID) } // Try to fetch the RDS instance fetchedMeta. - fetchedMeta, err := fetchRDSInstanceMetadata(ctx, rds, meta.RDS.InstanceID) + fetchedMeta, err := fetchRDSInstanceMetadata(ctx, clt, meta.RDS.InstanceID) if err != nil && !trace.IsNotFound(err) && !trace.IsAccessDenied(err) { return nil, trace.Wrap(err) } @@ -172,11 +194,11 @@ func (m *Metadata) fetchRDSMetadata(ctx context.Context, database types.Database if clusterID == "" { clusterID = meta.RDS.InstanceID } - return fetchRDSClusterMetadata(ctx, rds, clusterID) + return fetchRDSClusterMetadata(ctx, clt, clusterID) } // If instance was found, it may be a part of an Aurora cluster. if fetchedMeta.RDS.ClusterID != "" { - return fetchRDSClusterMetadata(ctx, rds, fetchedMeta.RDS.ClusterID) + return fetchRDSClusterMetadata(ctx, clt, fetchedMeta.RDS.ClusterID) } return fetchedMeta, nil } @@ -184,18 +206,19 @@ func (m *Metadata) fetchRDSMetadata(ctx context.Context, database types.Database // fetchRDSProxyMetadata fetches metadata for the provided RDS Proxy database. func (m *Metadata) fetchRDSProxyMetadata(ctx context.Context, database types.Database) (*types.AWS, error) { meta := database.GetAWS() - rds, err := m.cfg.Clients.GetAWSRDSClient(ctx, meta.Region, - cloud.WithAssumeRoleFromAWSMeta(meta), - cloud.WithAmbientCredentials(), + awsCfg, err := m.cfg.AWSConfigProvider.GetConfig(ctx, meta.Region, + awsconfig.WithAssumeRole(meta.AssumeRoleARN, meta.ExternalID), + awsconfig.WithAmbientCredentials(), ) if err != nil { return nil, trace.Wrap(err) } + clt := m.cfg.rdsClientProviderFn(awsCfg) if meta.RDSProxy.CustomEndpointName != "" { - return fetchRDSProxyCustomEndpointMetadata(ctx, rds, meta.RDSProxy.CustomEndpointName, database.GetURI()) + return fetchRDSProxyCustomEndpointMetadata(ctx, clt, meta.RDSProxy.CustomEndpointName, database.GetURI()) } - return fetchRDSProxyMetadata(ctx, rds, meta.RDSProxy.Name) + return fetchRDSProxyMetadata(ctx, clt, meta.RDSProxy.Name) } // fetchRedshiftMetadata fetches metadata for the provided Redshift database. @@ -275,8 +298,8 @@ func (m *Metadata) fetchMemoryDBMetadata(ctx context.Context, database types.Dat } // fetchRDSInstanceMetadata fetches metadata about specified RDS instance. -func fetchRDSInstanceMetadata(ctx context.Context, rdsClient rdsiface.RDSAPI, instanceID string) (*types.AWS, error) { - rdsInstance, err := describeRDSInstance(ctx, rdsClient, instanceID) +func fetchRDSInstanceMetadata(ctx context.Context, clt rdsClient, instanceID string) (*types.AWS, error) { + rdsInstance, err := describeRDSInstance(ctx, clt, instanceID) if err != nil { return nil, trace.Wrap(err) } @@ -284,22 +307,22 @@ func fetchRDSInstanceMetadata(ctx context.Context, rdsClient rdsiface.RDSAPI, in } // describeRDSInstance returns AWS RDS instance for the specified ID. -func describeRDSInstance(ctx context.Context, rdsClient rdsiface.RDSAPI, instanceID string) (*rds.DBInstance, error) { - out, err := rdsClient.DescribeDBInstancesWithContext(ctx, &rds.DescribeDBInstancesInput{ +func describeRDSInstance(ctx context.Context, clt rdsClient, instanceID string) (*rdstypes.DBInstance, error) { + out, err := clt.DescribeDBInstances(ctx, &rds.DescribeDBInstancesInput{ DBInstanceIdentifier: aws.String(instanceID), }) if err != nil { return nil, common.ConvertError(err) } if len(out.DBInstances) != 1 { - return nil, trace.BadParameter("expected 1 RDS instance for %v, got %+v", instanceID, out.DBInstances) + return nil, trace.BadParameter("expected 1 RDS instance for %v, got %d", instanceID, len(out.DBInstances)) } - return out.DBInstances[0], nil + return &out.DBInstances[0], nil } // fetchRDSClusterMetadata fetches metadata about specified Aurora cluster. -func fetchRDSClusterMetadata(ctx context.Context, rdsClient rdsiface.RDSAPI, clusterID string) (*types.AWS, error) { - rdsCluster, err := describeRDSCluster(ctx, rdsClient, clusterID) +func fetchRDSClusterMetadata(ctx context.Context, clt rdsClient, clusterID string) (*types.AWS, error) { + rdsCluster, err := describeRDSCluster(ctx, clt, clusterID) if err != nil { return nil, trace.Wrap(err) } @@ -307,8 +330,8 @@ func fetchRDSClusterMetadata(ctx context.Context, rdsClient rdsiface.RDSAPI, clu } // describeRDSCluster returns AWS Aurora cluster for the specified ID. -func describeRDSCluster(ctx context.Context, rdsClient rdsiface.RDSAPI, clusterID string) (*rds.DBCluster, error) { - out, err := rdsClient.DescribeDBClustersWithContext(ctx, &rds.DescribeDBClustersInput{ +func describeRDSCluster(ctx context.Context, clt rdsClient, clusterID string) (*rdstypes.DBCluster, error) { + out, err := clt.DescribeDBClusters(ctx, &rds.DescribeDBClustersInput{ DBClusterIdentifier: aws.String(clusterID), }) if err != nil { @@ -317,7 +340,7 @@ func describeRDSCluster(ctx context.Context, rdsClient rdsiface.RDSAPI, clusterI if len(out.DBClusters) != 1 { return nil, trace.BadParameter("expected 1 RDS cluster for %v, got %+v", clusterID, out.DBClusters) } - return out.DBClusters[0], nil + return &out.DBClusters[0], nil } // describeRedshiftCluster returns AWS Redshift cluster for the specified ID. @@ -364,8 +387,8 @@ func describeMemoryDBCluster(ctx context.Context, client memorydbiface.MemoryDBA } // fetchRDSProxyMetadata fetches metadata about specified RDS Proxy name. -func fetchRDSProxyMetadata(ctx context.Context, rdsClient rdsiface.RDSAPI, proxyName string) (*types.AWS, error) { - rdsProxy, err := describeRDSProxy(ctx, rdsClient, proxyName) +func fetchRDSProxyMetadata(ctx context.Context, clt rdsClient, proxyName string) (*types.AWS, error) { + rdsProxy, err := describeRDSProxy(ctx, clt, proxyName) if err != nil { return nil, trace.Wrap(err) } @@ -373,28 +396,28 @@ func fetchRDSProxyMetadata(ctx context.Context, rdsClient rdsiface.RDSAPI, proxy } // describeRDSProxy returns AWS RDS Proxy for the specified RDS Proxy name. -func describeRDSProxy(ctx context.Context, rdsClient rdsiface.RDSAPI, proxyName string) (*rds.DBProxy, error) { - out, err := rdsClient.DescribeDBProxiesWithContext(ctx, &rds.DescribeDBProxiesInput{ +func describeRDSProxy(ctx context.Context, clt rdsClient, proxyName string) (*rdstypes.DBProxy, error) { + out, err := clt.DescribeDBProxies(ctx, &rds.DescribeDBProxiesInput{ DBProxyName: aws.String(proxyName), }) if err != nil { return nil, common.ConvertError(err) } if len(out.DBProxies) != 1 { - return nil, trace.BadParameter("expected 1 RDS Proxy for %v, got %s", proxyName, out.DBProxies) + return nil, trace.BadParameter("expected 1 RDS Proxy for %v, got %d", proxyName, len(out.DBProxies)) } - return out.DBProxies[0], nil + return &out.DBProxies[0], nil } // fetchRDSProxyCustomEndpointMetadata fetches metadata about specified RDS // proxy custom endpoint. -func fetchRDSProxyCustomEndpointMetadata(ctx context.Context, rdsClient rdsiface.RDSAPI, proxyEndpointName, uri string) (*types.AWS, error) { - rdsProxyEndpoint, err := describeRDSProxyCustomEndpointAndFindURI(ctx, rdsClient, proxyEndpointName, uri) +func fetchRDSProxyCustomEndpointMetadata(ctx context.Context, clt rdsClient, proxyEndpointName, uri string) (*types.AWS, error) { + rdsProxyEndpoint, err := describeRDSProxyCustomEndpointAndFindURI(ctx, clt, proxyEndpointName, uri) if err != nil { return nil, trace.Wrap(err) } - rdsProxy, err := describeRDSProxy(ctx, rdsClient, aws.ToString(rdsProxyEndpoint.DBProxyName)) + rdsProxy, err := describeRDSProxy(ctx, clt, aws.ToString(rdsProxyEndpoint.DBProxyName)) if err != nil { return nil, trace.Wrap(err) } @@ -404,21 +427,27 @@ func fetchRDSProxyCustomEndpointMetadata(ctx context.Context, rdsClient rdsiface // describeRDSProxyCustomEndpointAndFindURI returns AWS RDS Proxy endpoint for // the specified RDS Proxy custom endpoint. -func describeRDSProxyCustomEndpointAndFindURI(ctx context.Context, rdsClient rdsiface.RDSAPI, proxyEndpointName, uri string) (*rds.DBProxyEndpoint, error) { - out, err := rdsClient.DescribeDBProxyEndpointsWithContext(ctx, &rds.DescribeDBProxyEndpointsInput{ +func describeRDSProxyCustomEndpointAndFindURI(ctx context.Context, clt rdsClient, proxyEndpointName, uri string) (*rdstypes.DBProxyEndpoint, error) { + out, err := clt.DescribeDBProxyEndpoints(ctx, &rds.DescribeDBProxyEndpointsInput{ DBProxyEndpointName: aws.String(proxyEndpointName), }) if err != nil { return nil, common.ConvertError(err) } - for _, customEndpoint := range out.DBProxyEndpoints { + var endpoints []string + for _, e := range out.DBProxyEndpoints { + endpoint := aws.ToString(e.Endpoint) + if endpoint == "" { + continue + } // Double check if it has the same URI in case multiple custom // endpoints have the same name. - if strings.Contains(uri, aws.ToString(customEndpoint.Endpoint)) { - return customEndpoint, nil + if strings.Contains(uri, endpoint) { + return &e, nil } + endpoints = append(endpoints, endpoint) } - return nil, trace.BadParameter("could not find RDS Proxy custom endpoint %v with URI %v, got %s", proxyEndpointName, uri, out.DBProxyEndpoints) + return nil, trace.BadParameter("could not find RDS Proxy custom endpoint %v with URI %v, got %s", proxyEndpointName, uri, endpoints) } func fetchRedshiftServerlessWorkgroupMetadata(ctx context.Context, client redshiftserverlessiface.RedshiftServerlessAPI, workgroupName string) (*types.AWS, error) { diff --git a/lib/srv/db/cloud/meta_test.go b/lib/srv/db/cloud/meta_test.go index 9e66a416a2ebb..71d3cbedea783 100644 --- a/lib/srv/db/cloud/meta_test.go +++ b/lib/srv/db/cloud/meta_test.go @@ -23,11 +23,12 @@ import ( "testing" "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/rds" + rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types" "github.com/aws/aws-sdk-go-v2/service/redshift" redshifttypes "github.com/aws/aws-sdk-go-v2/service/redshift/types" "github.com/aws/aws-sdk-go/service/elasticache" "github.com/aws/aws-sdk-go/service/memorydb" - "github.com/aws/aws-sdk-go/service/rds" "github.com/aws/aws-sdk-go/service/redshiftserverless" "github.com/stretchr/testify/require" @@ -40,8 +41,8 @@ import ( // TestAWSMetadata tests fetching AWS metadata for RDS and Redshift databases. func TestAWSMetadata(t *testing.T) { // Configure RDS API mock. - rds := &mocks.RDSMock{ - DBInstances: []*rds.DBInstance{ + rdsClt := &mocks.RDSClient{ + DBInstances: []rdstypes.DBInstance{ // Standalone RDS instance. { DBInstanceArn: aws.String("arn:aws:rds:us-west-1:123456789012:db:postgres-rds"), @@ -56,7 +57,7 @@ func TestAWSMetadata(t *testing.T) { DBClusterIdentifier: aws.String("postgres-aurora"), }, }, - DBClusters: []*rds.DBCluster{ + DBClusters: []rdstypes.DBCluster{ // Aurora cluster. { DBClusterArn: aws.String("arn:aws:rds:us-east-1:123456789012:cluster:postgres-aurora"), @@ -64,16 +65,17 @@ func TestAWSMetadata(t *testing.T) { DbClusterResourceId: aws.String("cluster-xyz"), }, }, - DBProxies: []*rds.DBProxy{ + DBProxies: []rdstypes.DBProxy{ { DBProxyArn: aws.String("arn:aws:rds:us-east-1:123456789012:db-proxy:prx-resource-id"), DBProxyName: aws.String("rds-proxy"), }, }, - DBProxyEndpoints: []*rds.DBProxyEndpoint{ + DBProxyEndpoints: []rdstypes.DBProxyEndpoint{ { DBProxyEndpointName: aws.String("rds-proxy-endpoint"), DBProxyName: aws.String("rds-proxy"), + Endpoint: aws.String("localhost"), }, }, } @@ -130,7 +132,6 @@ func TestAWSMetadata(t *testing.T) { // Create metadata fetcher. metadata, err := NewMetadata(MetadataConfig{ Clients: &cloud.TestCloudClients{ - RDS: rds, ElastiCache: elasticache, MemoryDB: memorydb, RedshiftServerless: redshiftServerless, @@ -140,6 +141,7 @@ func TestAWSMetadata(t *testing.T) { STSClient: fakeSTS, }, redshiftClientProviderFn: newFakeRedshiftClientProvider(redshiftClt), + rdsClientProviderFn: newFakeRDSClientProvider(rdsClt), }) require.NoError(t, err) @@ -407,7 +409,7 @@ func TestAWSMetadata(t *testing.T) { // cause an error. func TestAWSMetadataNoPermissions(t *testing.T) { // Create unauthorized mocks. - rds := &mocks.RDSMockUnauth{} + rdsClt := &mocks.RDSClient{Unauth: true} redshiftClt := &mocks.RedshiftClient{Unauth: true} fakeSTS := &mocks.STSClient{} @@ -415,13 +417,13 @@ func TestAWSMetadataNoPermissions(t *testing.T) { // Create metadata fetcher. metadata, err := NewMetadata(MetadataConfig{ Clients: &cloud.TestCloudClients{ - RDS: rds, STS: &fakeSTS.STSClientV1, }, AWSConfigProvider: &mocks.AWSConfigProvider{ STSClient: fakeSTS, }, redshiftClientProviderFn: newFakeRedshiftClientProvider(redshiftClt), + rdsClientProviderFn: newFakeRDSClientProvider(rdsClt), }) require.NoError(t, err) @@ -499,3 +501,9 @@ func newFakeRedshiftClientProvider(c redshiftClient) redshiftClientProviderFunc return c } } + +func newFakeRDSClientProvider(c rdsClient) rdsClientProviderFunc { + return func(cfg aws.Config, optFns ...func(*rds.Options)) rdsClient { + return c + } +} diff --git a/lib/srv/db/cloud/resource_checker_url.go b/lib/srv/db/cloud/resource_checker_url.go index fdc4efdb65fe9..c5485c23ed724 100644 --- a/lib/srv/db/cloud/resource_checker_url.go +++ b/lib/srv/db/cloud/resource_checker_url.go @@ -28,6 +28,7 @@ import ( "sync" "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/rds" "github.com/aws/aws-sdk-go-v2/service/redshift" "github.com/gravitational/trace" @@ -45,6 +46,9 @@ type urlChecker struct { // redshiftClientProviderFn is an internal-only [redshiftClient] provider // func that is only set in tests. redshiftClientProviderFn redshiftClientProviderFunc + // rdsClientProviderFn is an internal-only [rdsClient] provider + // func that is only set in tests. + rdsClientProviderFn rdsClientProviderFunc clients cloud.Clients logger *slog.Logger @@ -64,6 +68,9 @@ func newURLChecker(cfg DiscoveryResourceCheckerConfig) *urlChecker { redshiftClientProviderFn: func(cfg aws.Config, optFns ...func(*redshift.Options)) redshiftClient { return redshift.NewFromConfig(cfg, optFns...) }, + rdsClientProviderFn: func(cfg aws.Config, optFns ...func(*rds.Options)) rdsClient { + return rds.NewFromConfig(cfg, optFns...) + }, clients: cfg.Clients, logger: cfg.Logger, warnOnError: getWarnOnError(), diff --git a/lib/srv/db/cloud/resource_checker_url_aws.go b/lib/srv/db/cloud/resource_checker_url_aws.go index 336ee197815fb..cd3dff6488b70 100644 --- a/lib/srv/db/cloud/resource_checker_url_aws.go +++ b/lib/srv/db/cloud/resource_checker_url_aws.go @@ -21,10 +21,9 @@ package cloud import ( "context" + rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/opensearchservice" - "github.com/aws/aws-sdk-go/service/rds" - "github.com/aws/aws-sdk-go/service/rds/rdsiface" "github.com/aws/aws-sdk-go/service/redshiftserverless/redshiftserverlessiface" "github.com/gravitational/trace" @@ -82,22 +81,23 @@ func (c *urlChecker) logAWSAccessDeniedError(ctx context.Context, database types func (c *urlChecker) checkRDS(ctx context.Context, database types.Database) error { meta := database.GetAWS() - rdsClient, err := c.clients.GetAWSRDSClient(ctx, meta.Region, - cloud.WithAssumeRoleFromAWSMeta(meta), - cloud.WithAmbientCredentials(), + awsCfg, err := c.awsConfigProvider.GetConfig(ctx, meta.Region, + awsconfig.WithAssumeRole(meta.AssumeRoleARN, meta.ExternalID), + awsconfig.WithAmbientCredentials(), ) if err != nil { return trace.Wrap(err) } + clt := c.rdsClientProviderFn(awsCfg) if meta.RDS.ClusterID != "" { - return trace.Wrap(c.checkRDSCluster(ctx, database, rdsClient, meta.RDS.ClusterID)) + return trace.Wrap(c.checkRDSCluster(ctx, database, clt, meta.RDS.ClusterID)) } - return trace.Wrap(c.checkRDSInstance(ctx, database, rdsClient, meta.RDS.InstanceID)) + return trace.Wrap(c.checkRDSInstance(ctx, database, clt, meta.RDS.InstanceID)) } -func (c *urlChecker) checkRDSInstance(ctx context.Context, database types.Database, rdsClient rdsiface.RDSAPI, instanceID string) error { - rdsInstance, err := describeRDSInstance(ctx, rdsClient, instanceID) +func (c *urlChecker) checkRDSInstance(ctx context.Context, database types.Database, clt rdsClient, instanceID string) error { + rdsInstance, err := describeRDSInstance(ctx, clt, instanceID) if err != nil { return trace.Wrap(err) } @@ -107,12 +107,12 @@ func (c *urlChecker) checkRDSInstance(ctx context.Context, database types.Databa return trace.Wrap(requireDatabaseAddressPort(database, rdsInstance.Endpoint.Address, rdsInstance.Endpoint.Port)) } -func (c *urlChecker) checkRDSCluster(ctx context.Context, database types.Database, rdsClient rdsiface.RDSAPI, clusterID string) error { - rdsCluster, err := describeRDSCluster(ctx, rdsClient, clusterID) +func (c *urlChecker) checkRDSCluster(ctx context.Context, database types.Database, clt rdsClient, clusterID string) error { + rdsCluster, err := describeRDSCluster(ctx, clt, clusterID) if err != nil { return trace.Wrap(err) } - databases, err := common.NewDatabasesFromRDSCluster(rdsCluster, []*rds.DBInstance{}) + databases, err := common.NewDatabasesFromRDSCluster(rdsCluster, []rdstypes.DBInstance{}) if err != nil { c.logger.WarnContext(ctx, "Could not convert RDS cluster to database resources", "cluster", aws.StringValue(rdsCluster.DBClusterIdentifier), @@ -130,21 +130,22 @@ func (c *urlChecker) checkRDSCluster(ctx context.Context, database types.Databas func (c *urlChecker) checkRDSProxy(ctx context.Context, database types.Database) error { meta := database.GetAWS() - rdsClient, err := c.clients.GetAWSRDSClient(ctx, meta.Region, - cloud.WithAssumeRoleFromAWSMeta(meta), - cloud.WithAmbientCredentials(), + awsCfg, err := c.awsConfigProvider.GetConfig(ctx, meta.Region, + awsconfig.WithAssumeRole(meta.AssumeRoleARN, meta.ExternalID), + awsconfig.WithAmbientCredentials(), ) if err != nil { return trace.Wrap(err) } + clt := c.rdsClientProviderFn(awsCfg) if meta.RDSProxy.CustomEndpointName != "" { - return trace.Wrap(c.checkRDSProxyCustomEndpoint(ctx, database, rdsClient, meta.RDSProxy.CustomEndpointName)) + return trace.Wrap(c.checkRDSProxyCustomEndpoint(ctx, database, clt, meta.RDSProxy.CustomEndpointName)) } - return trace.Wrap(c.checkRDSProxyPrimaryEndpoint(ctx, database, rdsClient, meta.RDSProxy.Name)) + return trace.Wrap(c.checkRDSProxyPrimaryEndpoint(ctx, database, clt, meta.RDSProxy.Name)) } -func (c *urlChecker) checkRDSProxyPrimaryEndpoint(ctx context.Context, database types.Database, rdsClient rdsiface.RDSAPI, proxyName string) error { - rdsProxy, err := describeRDSProxy(ctx, rdsClient, proxyName) +func (c *urlChecker) checkRDSProxyPrimaryEndpoint(ctx context.Context, database types.Database, clt rdsClient, proxyName string) error { + rdsProxy, err := describeRDSProxy(ctx, clt, proxyName) if err != nil { return trace.Wrap(err) } @@ -153,8 +154,8 @@ func (c *urlChecker) checkRDSProxyPrimaryEndpoint(ctx context.Context, database return requireDatabaseHost(database, aws.StringValue(rdsProxy.Endpoint)) } -func (c *urlChecker) checkRDSProxyCustomEndpoint(ctx context.Context, database types.Database, rdsClient rdsiface.RDSAPI, proxyEndpointName string) error { - _, err := describeRDSProxyCustomEndpointAndFindURI(ctx, rdsClient, proxyEndpointName, database.GetURI()) +func (c *urlChecker) checkRDSProxyCustomEndpoint(ctx context.Context, database types.Database, clt rdsClient, proxyEndpointName string) error { + _, err := describeRDSProxyCustomEndpointAndFindURI(ctx, clt, proxyEndpointName, database.GetURI()) return trace.Wrap(err) } @@ -290,15 +291,16 @@ func (c *urlChecker) checkOpenSearchEndpoint(ctx context.Context, database types func (c *urlChecker) checkDocumentDB(ctx context.Context, database types.Database) error { meta := database.GetAWS() - rdsClient, err := c.clients.GetAWSRDSClient(ctx, meta.Region, - cloud.WithAssumeRoleFromAWSMeta(meta), - cloud.WithAmbientCredentials(), + awsCfg, err := c.awsConfigProvider.GetConfig(ctx, meta.Region, + awsconfig.WithAssumeRole(meta.AssumeRoleARN, meta.ExternalID), + awsconfig.WithAmbientCredentials(), ) if err != nil { return trace.Wrap(err) } + clt := c.rdsClientProviderFn(awsCfg) - cluster, err := describeRDSCluster(ctx, rdsClient, meta.DocumentDB.ClusterID) + cluster, err := describeRDSCluster(ctx, clt, meta.DocumentDB.ClusterID) if err != nil { return trace.Wrap(err) } diff --git a/lib/srv/db/cloud/resource_checker_url_aws_test.go b/lib/srv/db/cloud/resource_checker_url_aws_test.go index e8ba24f624c16..b7c1d373ee19e 100644 --- a/lib/srv/db/cloud/resource_checker_url_aws_test.go +++ b/lib/srv/db/cloud/resource_checker_url_aws_test.go @@ -22,11 +22,11 @@ import ( "context" "testing" + rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types" redshifttypes "github.com/aws/aws-sdk-go-v2/service/redshift/types" "github.com/aws/aws-sdk-go/service/elasticache" "github.com/aws/aws-sdk-go/service/memorydb" "github.com/aws/aws-sdk-go/service/opensearchservice" - "github.com/aws/aws-sdk-go/service/rds" "github.com/aws/aws-sdk-go/service/redshiftserverless" "github.com/stretchr/testify/require" @@ -54,7 +54,7 @@ func TestURLChecker_AWS(t *testing.T) { mocks.WithRDSClusterReader, mocks.WithRDSClusterCustomEndpoint("my-custom"), ) - rdsClusterDBs, err := common.NewDatabasesFromRDSCluster(rdsCluster, []*rds.DBInstance{}) + rdsClusterDBs, err := common.NewDatabasesFromRDSCluster(rdsCluster, []rdstypes.DBInstance{}) require.NoError(t, err) require.Len(t, rdsClusterDBs, 3) // Primary, reader, custom. testCases = append(testCases, append(rdsClusterDBs, rdsInstanceDB)...) @@ -121,12 +121,6 @@ func TestURLChecker_AWS(t *testing.T) { // Mock cloud clients. mockClients := &cloud.TestCloudClients{ - RDS: &mocks.RDSMock{ - DBInstances: []*rds.DBInstance{rdsInstance}, - DBClusters: []*rds.DBCluster{rdsCluster, docdbCluster}, - DBProxies: []*rds.DBProxy{rdsProxy}, - DBProxyEndpoints: []*rds.DBProxyEndpoint{rdsProxyCustomEndpoint}, - }, RedshiftServerless: &mocks.RedshiftServerlessMock{ Workgroups: []*redshiftserverless.Workgroup{redshiftServerlessWorkgroup}, Endpoints: []*redshiftserverless.EndpointAccess{redshiftServerlessVPCEndpoint}, @@ -143,7 +137,6 @@ func TestURLChecker_AWS(t *testing.T) { STS: &mocks.STSClientV1{}, } mockClientsUnauth := &cloud.TestCloudClients{ - RDS: &mocks.RDSMockUnauth{}, RedshiftServerless: &mocks.RedshiftServerlessMock{Unauth: true}, ElastiCache: &mocks.ElastiCacheMock{Unauth: true}, MemoryDB: &mocks.MemoryDBMock{Unauth: true}, @@ -158,12 +151,19 @@ func TestURLChecker_AWS(t *testing.T) { name string clients cloud.Clients awsConfigProvider awsconfig.Provider + rdsClient rdsClient redshiftClient redshiftClient }{ { name: "API check", clients: mockClients, awsConfigProvider: &mocks.AWSConfigProvider{}, + rdsClient: &mocks.RDSClient{ + DBInstances: []rdstypes.DBInstance{*rdsInstance}, + DBClusters: []rdstypes.DBCluster{*rdsCluster, *docdbCluster}, + DBProxies: []rdstypes.DBProxy{*rdsProxy}, + DBProxyEndpoints: []rdstypes.DBProxyEndpoint{*rdsProxyCustomEndpoint}, + }, redshiftClient: &mocks.RedshiftClient{ Clusters: []redshifttypes.Cluster{redshiftCluster}, }, @@ -172,6 +172,7 @@ func TestURLChecker_AWS(t *testing.T) { name: "basic endpoint check", clients: mockClientsUnauth, awsConfigProvider: &mocks.AWSConfigProvider{}, + rdsClient: &mocks.RDSClient{Unauth: true}, redshiftClient: &mocks.RedshiftClient{Unauth: true}, }, } @@ -184,6 +185,7 @@ func TestURLChecker_AWS(t *testing.T) { Logger: utils.NewSlogLoggerForTests(), }) c.redshiftClientProviderFn = newFakeRedshiftClientProvider(method.redshiftClient) + c.rdsClientProviderFn = newFakeRDSClientProvider(method.rdsClient) for _, database := range testCases { t.Run(database.GetName(), func(t *testing.T) { diff --git a/lib/srv/db/common/auth.go b/lib/srv/db/common/auth.go index e567d82d402e0..b94a10c5c32fb 100644 --- a/lib/srv/db/common/auth.go +++ b/lib/srv/db/common/auth.go @@ -35,12 +35,12 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/aws/aws-sdk-go-v2/aws" + rdsauth "github.com/aws/aws-sdk-go-v2/feature/rds/auth" "github.com/aws/aws-sdk-go-v2/service/redshift" "github.com/aws/aws-sdk-go/aws/credentials" v4 "github.com/aws/aws-sdk-go/aws/signer/v4" "github.com/aws/aws-sdk-go/service/elasticache" "github.com/aws/aws-sdk-go/service/memorydb" - "github.com/aws/aws-sdk-go/service/rds/rdsutils" "github.com/aws/aws-sdk-go/service/redshiftserverless" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" @@ -243,9 +243,9 @@ func (a *dbAuth) WithLogger(getUpdatedLogger func(*slog.Logger) *slog.Logger) Au // when connecting to RDS and Aurora databases. func (a *dbAuth) GetRDSAuthToken(ctx context.Context, database types.Database, databaseUser string) (string, error) { meta := database.GetAWS() - awsSession, err := a.cfg.Clients.GetAWSSession(ctx, meta.Region, - cloud.WithAssumeRoleFromAWSMeta(meta), - cloud.WithAmbientCredentials(), + awsCfg, err := a.cfg.AWSConfigProvider.GetConfig(ctx, meta.Region, + awsconfig.WithAssumeRole(meta.AssumeRoleARN, meta.ExternalID), + awsconfig.WithAmbientCredentials(), ) if err != nil { return "", trace.Wrap(err) @@ -254,11 +254,13 @@ func (a *dbAuth) GetRDSAuthToken(ctx context.Context, database types.Database, d "database", database, "database_user", databaseUser, ) - token, err := rdsutils.BuildAuthToken( + token, err := rdsauth.BuildAuthToken( + ctx, database.GetURI(), meta.Region, databaseUser, - awsSession.Config.Credentials) + awsCfg.Credentials, + ) if err != nil { policy, getPolicyErr := dbiam.GetReadableAWSPolicyDocument(database) if getPolicyErr != nil { diff --git a/lib/srv/db/common/auth_test.go b/lib/srv/db/common/auth_test.go index ae136b4d53c46..4982aa660d603 100644 --- a/lib/srv/db/common/auth_test.go +++ b/lib/srv/db/common/auth_test.go @@ -609,7 +609,6 @@ func TestAuthGetAWSTokenWithAssumedRole(t *testing.T) { AccessPoint: new(accessPointMock), Clients: &cloud.TestCloudClients{ STS: &fakeSTS.STSClientV1, - RDS: &mocks.RDSMock{}, RedshiftServerless: &mocks.RedshiftServerlessMock{ GetCredentialsOutput: mocks.RedshiftServerlessGetCredentialsOutput("IAM:some-user", "some-password", clock), }, diff --git a/lib/srv/db/server.go b/lib/srv/db/server.go index 28fcc486bf4db..dfb1a4b164192 100644 --- a/lib/srv/db/server.go +++ b/lib/srv/db/server.go @@ -259,9 +259,10 @@ func (c *Config) CheckAndSetDefaults(ctx context.Context) (err error) { } if c.CloudIAM == nil { c.CloudIAM, err = cloud.NewIAM(ctx, cloud.IAMConfig{ - AccessPoint: c.AccessPoint, - Clients: c.CloudClients, - HostID: c.HostID, + AccessPoint: c.AccessPoint, + AWSConfigProvider: c.AWSConfigProvider, + Clients: c.CloudClients, + HostID: c.HostID, }) if err != nil { return trace.Wrap(err) diff --git a/lib/srv/db/watcher_test.go b/lib/srv/db/watcher_test.go index 8a7750a26a07a..6020547ea9590 100644 --- a/lib/srv/db/watcher_test.go +++ b/lib/srv/db/watcher_test.go @@ -320,7 +320,6 @@ func TestWatcherCloudFetchers(t *testing.T) { reconcileCh <- d }, CloudClients: &clients.TestCloudClients{ - RDS: &mocks.RDSMockUnauth{}, // Access denied error should not affect other fetchers. RedshiftServerless: &mocks.RedshiftServerlessMock{ Workgroups: []*redshiftserverless.Workgroup{redshiftServerlessWorkgroup}, }, @@ -358,7 +357,7 @@ func assertReconciledResource(t *testing.T, ch chan types.Databases, databases t cmpopts.IgnoreFields(types.DatabaseStatusV3{}, "CACert"), )) case <-time.After(time.Second): - t.Fatal("Didn't receive reconcile event after 1s.") + require.FailNow(t, "Didn't receive reconcile event after 1s.") } } diff --git a/lib/srv/discovery/access_graph.go b/lib/srv/discovery/access_graph.go index 4bc207b21df01..6b6e84504453e 100644 --- a/lib/srv/discovery/access_graph.go +++ b/lib/srv/discovery/access_graph.go @@ -501,6 +501,7 @@ func (s *Server) accessGraphFetchersFromMatchers(ctx context.Context, matchers M fetcher, err := aws_sync.NewAWSFetcher( ctx, aws_sync.Config{ + AWSConfigProvider: s.AWSConfigProvider, CloudClients: s.CloudClients, GetEC2Client: s.GetEC2Client, AssumeRole: assumeRole, diff --git a/lib/srv/discovery/common/database.go b/lib/srv/discovery/common/database.go index 8afe335f87fcb..dcff7a2c0f614 100644 --- a/lib/srv/discovery/common/database.go +++ b/lib/srv/discovery/common/database.go @@ -35,7 +35,6 @@ import ( "github.com/aws/aws-sdk-go/service/elasticache" "github.com/aws/aws-sdk-go/service/memorydb" "github.com/aws/aws-sdk-go/service/opensearchservice" - "github.com/aws/aws-sdk-go/service/rds" "github.com/aws/aws-sdk-go/service/redshiftserverless" "github.com/gravitational/trace" @@ -286,7 +285,7 @@ func NewDatabaseFromAzurePostgresFlexServer(server *armpostgresqlflexibleservers } // NewDatabaseFromRDSInstance creates a database resource from an RDS instance. -func NewDatabaseFromRDSInstance(instance *rds.DBInstance) (types.Database, error) { +func NewDatabaseFromRDSInstance(instance *rdstypes.DBInstance) (types.Database, error) { endpoint := instance.Endpoint if endpoint == nil { return nil, trace.BadParameter("empty endpoint") @@ -307,7 +306,7 @@ func NewDatabaseFromRDSInstance(instance *rds.DBInstance) (types.Database, error }, aws.ToString(instance.DBInstanceIdentifier)), types.DatabaseSpecV3{ Protocol: protocol, - URI: fmt.Sprintf("%v:%v", aws.ToString(endpoint.Address), aws.ToInt64(endpoint.Port)), + URI: fmt.Sprintf("%v:%v", aws.ToString(endpoint.Address), aws.ToInt32(endpoint.Port)), AWS: *metadata, }) } @@ -492,7 +491,7 @@ func labelsFromRDSV2Cluster(rdsCluster *rdstypes.DBCluster, meta *types.AWS, end } // NewDatabaseFromRDSCluster creates a database resource from an RDS cluster (Aurora). -func NewDatabaseFromRDSCluster(cluster *rds.DBCluster, memberInstances []*rds.DBInstance) (types.Database, error) { +func NewDatabaseFromRDSCluster(cluster *rdstypes.DBCluster, memberInstances []rdstypes.DBInstance) (types.Database, error) { metadata, err := MetadataFromRDSCluster(cluster) if err != nil { return nil, trace.Wrap(err) @@ -508,13 +507,13 @@ func NewDatabaseFromRDSCluster(cluster *rds.DBCluster, memberInstances []*rds.DB }, aws.ToString(cluster.DBClusterIdentifier)), types.DatabaseSpecV3{ Protocol: protocol, - URI: fmt.Sprintf("%v:%v", aws.ToString(cluster.Endpoint), aws.ToInt64(cluster.Port)), + URI: fmt.Sprintf("%v:%v", aws.ToString(cluster.Endpoint), aws.ToInt32(cluster.Port)), AWS: *metadata, }) } // NewDatabaseFromRDSClusterReaderEndpoint creates a database resource from an RDS cluster reader endpoint (Aurora). -func NewDatabaseFromRDSClusterReaderEndpoint(cluster *rds.DBCluster, memberInstances []*rds.DBInstance) (types.Database, error) { +func NewDatabaseFromRDSClusterReaderEndpoint(cluster *rdstypes.DBCluster, memberInstances []rdstypes.DBInstance) (types.Database, error) { metadata, err := MetadataFromRDSCluster(cluster) if err != nil { return nil, trace.Wrap(err) @@ -530,13 +529,13 @@ func NewDatabaseFromRDSClusterReaderEndpoint(cluster *rds.DBCluster, memberInsta }, aws.ToString(cluster.DBClusterIdentifier), apiawsutils.RDSEndpointTypeReader), types.DatabaseSpecV3{ Protocol: protocol, - URI: fmt.Sprintf("%v:%v", aws.ToString(cluster.ReaderEndpoint), aws.ToInt64(cluster.Port)), + URI: fmt.Sprintf("%v:%v", aws.ToString(cluster.ReaderEndpoint), aws.ToInt32(cluster.Port)), AWS: *metadata, }) } // NewDatabasesFromRDSClusterCustomEndpoints creates database resources from RDS cluster custom endpoints (Aurora). -func NewDatabasesFromRDSClusterCustomEndpoints(cluster *rds.DBCluster, memberInstances []*rds.DBInstance) (types.Databases, error) { +func NewDatabasesFromRDSClusterCustomEndpoints(cluster *rdstypes.DBCluster, memberInstances []rdstypes.DBInstance) (types.Databases, error) { metadata, err := MetadataFromRDSCluster(cluster) if err != nil { return nil, trace.Wrap(err) @@ -551,7 +550,7 @@ func NewDatabasesFromRDSClusterCustomEndpoints(cluster *rds.DBCluster, memberIns for _, endpoint := range cluster.CustomEndpoints { // RDS custom endpoint format: // .cluster-custom-. - endpointDetails, err := apiawsutils.ParseRDSEndpoint(aws.ToString(endpoint)) + endpointDetails, err := apiawsutils.ParseRDSEndpoint(endpoint) if err != nil { errors = append(errors, trace.Wrap(err)) continue @@ -568,7 +567,7 @@ func NewDatabasesFromRDSClusterCustomEndpoints(cluster *rds.DBCluster, memberIns }, aws.ToString(cluster.DBClusterIdentifier), apiawsutils.RDSEndpointTypeCustom, endpointDetails.ClusterCustomEndpointName), types.DatabaseSpecV3{ Protocol: protocol, - URI: fmt.Sprintf("%v:%v", aws.ToString(endpoint), aws.ToInt64(cluster.Port)), + URI: fmt.Sprintf("%v:%v", endpoint, aws.ToInt32(cluster.Port)), AWS: *metadata, // Aurora instances update their certificates upon restart, and thus custom endpoint SAN may not be available right @@ -588,14 +587,12 @@ func NewDatabasesFromRDSClusterCustomEndpoints(cluster *rds.DBCluster, memberIns return databases, trace.NewAggregate(errors...) } -func checkRDSClusterMembers(cluster *rds.DBCluster) (hasWriterInstance, hasReaderInstance bool) { +func checkRDSClusterMembers(cluster *rdstypes.DBCluster) (hasWriterInstance, hasReaderInstance bool) { for _, clusterMember := range cluster.DBClusterMembers { - if clusterMember != nil { - if aws.ToBool(clusterMember.IsClusterWriter) { - hasWriterInstance = true - } else { - hasReaderInstance = true - } + if aws.ToBool(clusterMember.IsClusterWriter) { + hasWriterInstance = true + } else { + hasReaderInstance = true } } return @@ -603,7 +600,7 @@ func checkRDSClusterMembers(cluster *rds.DBCluster) (hasWriterInstance, hasReade // NewDatabasesFromRDSCluster creates all database resources from an RDS Aurora // cluster. -func NewDatabasesFromRDSCluster(cluster *rds.DBCluster, memberInstances []*rds.DBInstance) (types.Databases, error) { +func NewDatabasesFromRDSCluster(cluster *rdstypes.DBCluster, memberInstances []rdstypes.DBInstance) (types.Databases, error) { var errors []error var databases types.Databases @@ -648,7 +645,7 @@ func NewDatabasesFromRDSCluster(cluster *rds.DBCluster, memberInstances []*rds.D // NewDatabasesFromDocumentDBCluster creates all database resources from a // DocumentDB cluster. -func NewDatabasesFromDocumentDBCluster(cluster *rds.DBCluster) (types.Databases, error) { +func NewDatabasesFromDocumentDBCluster(cluster *rdstypes.DBCluster) (types.Databases, error) { var errors []error var databases types.Databases @@ -682,7 +679,7 @@ func NewDatabasesFromDocumentDBCluster(cluster *rds.DBCluster) (types.Databases, // NewDatabaseFromDocumentDBClusterEndpoint creates database resource from // DocumentDB cluster endpoint. -func NewDatabaseFromDocumentDBClusterEndpoint(cluster *rds.DBCluster) (types.Database, error) { +func NewDatabaseFromDocumentDBClusterEndpoint(cluster *rdstypes.DBCluster) (types.Database, error) { endpointType := apiawsutils.DocumentDBClusterEndpoint metadata, err := MetadataFromDocumentDBCluster(cluster, endpointType) if err != nil { @@ -695,14 +692,14 @@ func NewDatabaseFromDocumentDBClusterEndpoint(cluster *rds.DBCluster) (types.Dat }, aws.ToString(cluster.DBClusterIdentifier)), types.DatabaseSpecV3{ Protocol: types.DatabaseProtocolMongoDB, - URI: fmt.Sprintf("%v:%v", aws.ToString(cluster.Endpoint), aws.ToInt64(cluster.Port)), + URI: fmt.Sprintf("%v:%v", aws.ToString(cluster.Endpoint), aws.ToInt32(cluster.Port)), AWS: *metadata, }) } // NewDatabaseFromDocumentDBReaderEndpoint creates database resource from // DocumentDB reader endpoint. -func NewDatabaseFromDocumentDBReaderEndpoint(cluster *rds.DBCluster) (types.Database, error) { +func NewDatabaseFromDocumentDBReaderEndpoint(cluster *rdstypes.DBCluster) (types.Database, error) { endpointType := apiawsutils.DocumentDBClusterReaderEndpoint metadata, err := MetadataFromDocumentDBCluster(cluster, endpointType) if err != nil { @@ -715,13 +712,13 @@ func NewDatabaseFromDocumentDBReaderEndpoint(cluster *rds.DBCluster) (types.Data }, aws.ToString(cluster.DBClusterIdentifier), endpointType), types.DatabaseSpecV3{ Protocol: types.DatabaseProtocolMongoDB, - URI: fmt.Sprintf("%v:%v", aws.ToString(cluster.ReaderEndpoint), aws.ToInt64(cluster.Port)), + URI: fmt.Sprintf("%v:%v", aws.ToString(cluster.ReaderEndpoint), aws.ToInt32(cluster.Port)), AWS: *metadata, }) } // NewDatabaseFromRDSProxy creates database resource from RDS Proxy. -func NewDatabaseFromRDSProxy(dbProxy *rds.DBProxy, tags []*rds.Tag) (types.Database, error) { +func NewDatabaseFromRDSProxy(dbProxy *rdstypes.DBProxy, tags []rdstypes.Tag) (types.Database, error) { metadata, err := MetadataFromRDSProxy(dbProxy) if err != nil { return nil, trace.Wrap(err) @@ -744,7 +741,7 @@ func NewDatabaseFromRDSProxy(dbProxy *rds.DBProxy, tags []*rds.Tag) (types.Datab // NewDatabaseFromRDSProxyCustomEndpoint creates database resource from RDS // Proxy custom endpoint. -func NewDatabaseFromRDSProxyCustomEndpoint(dbProxy *rds.DBProxy, customEndpoint *rds.DBProxyEndpoint, tags []*rds.Tag) (types.Database, error) { +func NewDatabaseFromRDSProxyCustomEndpoint(dbProxy *rdstypes.DBProxy, customEndpoint *rdstypes.DBProxyEndpoint, tags []rdstypes.Tag) (types.Database, error) { metadata, err := MetadataFromRDSProxyCustomEndpoint(dbProxy, customEndpoint) if err != nil { return nil, trace.Wrap(err) @@ -1045,7 +1042,7 @@ func NewDatabaseFromRedshiftServerlessVPCEndpoint(endpoint *redshiftserverless.E } // MetadataFromRDSInstance creates AWS metadata from the provided RDS instance. -func MetadataFromRDSInstance(rdsInstance *rds.DBInstance) (*types.AWS, error) { +func MetadataFromRDSInstance(rdsInstance *rdstypes.DBInstance) (*types.AWS, error) { parsedARN, err := arn.Parse(aws.ToString(rdsInstance.DBInstanceArn)) if err != nil { return nil, trace.Wrap(err) @@ -1063,7 +1060,7 @@ func MetadataFromRDSInstance(rdsInstance *rds.DBInstance) (*types.AWS, error) { } // MetadataFromRDSCluster creates AWS metadata from the provided RDS cluster. -func MetadataFromRDSCluster(rdsCluster *rds.DBCluster) (*types.AWS, error) { +func MetadataFromRDSCluster(rdsCluster *rdstypes.DBCluster) (*types.AWS, error) { parsedARN, err := arn.Parse(aws.ToString(rdsCluster.DBClusterArn)) if err != nil { return nil, trace.Wrap(err) @@ -1081,7 +1078,7 @@ func MetadataFromRDSCluster(rdsCluster *rds.DBCluster) (*types.AWS, error) { // MetadataFromDocumentDBCluster creates AWS metadata from the provided // DocumentDB cluster. -func MetadataFromDocumentDBCluster(cluster *rds.DBCluster, endpointType string) (*types.AWS, error) { +func MetadataFromDocumentDBCluster(cluster *rdstypes.DBCluster, endpointType string) (*types.AWS, error) { parsedARN, err := arn.Parse(aws.ToString(cluster.DBClusterArn)) if err != nil { return nil, trace.Wrap(err) @@ -1097,13 +1094,13 @@ func MetadataFromDocumentDBCluster(cluster *rds.DBCluster, endpointType string) } // MetadataFromRDSProxy creates AWS metadata from the provided RDS Proxy. -func MetadataFromRDSProxy(rdsProxy *rds.DBProxy) (*types.AWS, error) { +func MetadataFromRDSProxy(rdsProxy *rdstypes.DBProxy) (*types.AWS, error) { parsedARN, err := arn.Parse(aws.ToString(rdsProxy.DBProxyArn)) if err != nil { return nil, trace.Wrap(err) } - // rds.DBProxy has no resource ID attribute. The resource ID can be found + // rdstypes.DBProxy has no resource ID attribute. The resource ID can be found // in the ARN, e.g.: // // arn:aws:rds:ca-central-1:123456789012:db-proxy:prx-xxxyyyzzz @@ -1127,7 +1124,7 @@ func MetadataFromRDSProxy(rdsProxy *rds.DBProxy) (*types.AWS, error) { // MetadataFromRDSProxyCustomEndpoint creates AWS metadata from the provided // RDS Proxy custom endpoint. -func MetadataFromRDSProxyCustomEndpoint(rdsProxy *rds.DBProxy, customEndpoint *rds.DBProxyEndpoint) (*types.AWS, error) { +func MetadataFromRDSProxyCustomEndpoint(rdsProxy *rdstypes.DBProxy, customEndpoint *rdstypes.DBProxyEndpoint) (*types.AWS, error) { // Using resource ID from the default proxy for IAM policies to gain the // RDS connection access. metadata, err := MetadataFromRDSProxy(rdsProxy) @@ -1323,12 +1320,12 @@ func rdsEngineToProtocol(engine string) (string, error) { // rdsEngineFamilyToProtocolAndPort converts RDS engine family to the database protocol and port. func rdsEngineFamilyToProtocolAndPort(engineFamily string) (string, int, error) { - switch engineFamily { - case rds.EngineFamilyMysql: + switch rdstypes.EngineFamily(engineFamily) { + case rdstypes.EngineFamilyMysql: return defaults.ProtocolMySQL, services.RDSProxyMySQLPort, nil - case rds.EngineFamilyPostgresql: + case rdstypes.EngineFamilyPostgresql: return defaults.ProtocolPostgres, services.RDSProxyPostgresPort, nil - case rds.EngineFamilySqlserver: + case rdstypes.EngineFamilySqlserver: return defaults.ProtocolSQLServer, services.RDSProxySQLServerPort, nil } return "", 0, trace.BadParameter("unknown RDS engine family type %q", engineFamily) @@ -1421,7 +1418,7 @@ func labelsFromAzurePostgresFlexServer(server *armpostgresqlflexibleservers.Serv } // labelsFromRDSInstance creates database labels for the provided RDS instance. -func labelsFromRDSInstance(rdsInstance *rds.DBInstance, meta *types.AWS) map[string]string { +func labelsFromRDSInstance(rdsInstance *rdstypes.DBInstance, meta *types.AWS) map[string]string { labels := labelsFromAWSMetadata(meta) labels[types.DiscoveryLabelEngine] = aws.ToString(rdsInstance.Engine) labels[types.DiscoveryLabelEngineVersion] = aws.ToString(rdsInstance.EngineVersion) @@ -1433,7 +1430,7 @@ func labelsFromRDSInstance(rdsInstance *rds.DBInstance, meta *types.AWS) map[str } // labelsFromRDSCluster creates database labels for the provided RDS cluster. -func labelsFromRDSCluster(rdsCluster *rds.DBCluster, meta *types.AWS, endpointType string, memberInstances []*rds.DBInstance) map[string]string { +func labelsFromRDSCluster(rdsCluster *rdstypes.DBCluster, meta *types.AWS, endpointType string, memberInstances []rdstypes.DBInstance) map[string]string { labels := labelsFromAWSMetadata(meta) labels[types.DiscoveryLabelEngine] = aws.ToString(rdsCluster.Engine) labels[types.DiscoveryLabelEngineVersion] = aws.ToString(rdsCluster.EngineVersion) @@ -1444,7 +1441,7 @@ func labelsFromRDSCluster(rdsCluster *rds.DBCluster, meta *types.AWS, endpointTy return addLabels(labels, libcloudaws.TagsToLabels(rdsCluster.TagList)) } -func labelsFromDocumentDBCluster(cluster *rds.DBCluster, meta *types.AWS, endpointType string) map[string]string { +func labelsFromDocumentDBCluster(cluster *rdstypes.DBCluster, meta *types.AWS, endpointType string) map[string]string { labels := labelsFromAWSMetadata(meta) labels[types.DiscoveryLabelEngine] = aws.ToString(cluster.Engine) labels[types.DiscoveryLabelEngineVersion] = aws.ToString(cluster.EngineVersion) @@ -1453,8 +1450,8 @@ func labelsFromDocumentDBCluster(cluster *rds.DBCluster, meta *types.AWS, endpoi } // labelsFromRDSProxy creates database labels for the provided RDS Proxy. -func labelsFromRDSProxy(rdsProxy *rds.DBProxy, meta *types.AWS, tags []*rds.Tag) map[string]string { - // rds.DBProxy has no TagList. +func labelsFromRDSProxy(rdsProxy *rdstypes.DBProxy, meta *types.AWS, tags []rdstypes.Tag) map[string]string { + // rdstypes.DBProxy has no TagList. labels := labelsFromAWSMetadata(meta) labels[types.DiscoveryLabelVPCID] = aws.ToString(rdsProxy.VpcId) labels[types.DiscoveryLabelEngine] = aws.ToString(rdsProxy.EngineFamily) @@ -1463,9 +1460,9 @@ func labelsFromRDSProxy(rdsProxy *rds.DBProxy, meta *types.AWS, tags []*rds.Tag) // labelsFromRDSProxyCustomEndpoint creates database labels for the provided // RDS Proxy custom endpoint. -func labelsFromRDSProxyCustomEndpoint(rdsProxy *rds.DBProxy, customEndpoint *rds.DBProxyEndpoint, meta *types.AWS, tags []*rds.Tag) map[string]string { +func labelsFromRDSProxyCustomEndpoint(rdsProxy *rdstypes.DBProxy, customEndpoint *rdstypes.DBProxyEndpoint, meta *types.AWS, tags []rdstypes.Tag) map[string]string { labels := labelsFromRDSProxy(rdsProxy, meta, tags) - labels[types.DiscoveryLabelEndpointType] = aws.ToString(customEndpoint.TargetRole) + labels[types.DiscoveryLabelEndpointType] = string(customEndpoint.TargetRole) return labels } diff --git a/lib/srv/discovery/common/database_test.go b/lib/srv/discovery/common/database_test.go index ab2b45fff24bc..891c31a18bc13 100644 --- a/lib/srv/discovery/common/database_test.go +++ b/lib/srv/discovery/common/database_test.go @@ -28,11 +28,10 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/redisenterprise/armredisenterprise" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/sql/armsql" "github.com/aws/aws-sdk-go-v2/aws" - rdsTypesV2 "github.com/aws/aws-sdk-go-v2/service/rds/types" + rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types" redshifttypes "github.com/aws/aws-sdk-go-v2/service/redshift/types" "github.com/aws/aws-sdk-go/service/elasticache" "github.com/aws/aws-sdk-go/service/memorydb" - "github.com/aws/aws-sdk-go/service/rds" "github.com/aws/aws-sdk-go/service/redshiftserverless" "github.com/google/go-cmp/cmp" "github.com/google/uuid" @@ -217,7 +216,7 @@ func TestDatabaseFromAzureRedisEnterprise(t *testing.T) { // TestDatabaseFromRDSInstance tests converting an RDS instance to a database resource. func TestDatabaseFromRDSInstance(t *testing.T) { - instance := &rds.DBInstance{ + instance := &rdstypes.DBInstance{ DBInstanceArn: aws.String("arn:aws:rds:us-west-1:123456789012:db:instance-1"), DBInstanceIdentifier: aws.String("instance-1"), DBClusterIdentifier: aws.String("cluster-1"), @@ -225,11 +224,11 @@ func TestDatabaseFromRDSInstance(t *testing.T) { IAMDatabaseAuthenticationEnabled: aws.Bool(true), Engine: aws.String(services.RDSEnginePostgres), EngineVersion: aws.String("13.0"), - Endpoint: &rds.Endpoint{ + Endpoint: &rdstypes.Endpoint{ Address: aws.String("localhost"), - Port: aws.Int64(5432), + Port: aws.Int32(5432), }, - TagList: []*rds.Tag{{ + TagList: []rdstypes.Tag{{ Key: aws.String("key"), Value: aws.String("val"), }}, @@ -268,7 +267,7 @@ func TestDatabaseFromRDSInstance(t *testing.T) { // TestDatabaseFromRDSV2Instance tests converting an RDS instance (from aws sdk v2/rds) to a database resource. func TestDatabaseFromRDSV2Instance(t *testing.T) { - instance := &rdsTypesV2.DBInstance{ + instance := &rdstypes.DBInstance{ DBInstanceArn: aws.String("arn:aws:rds:us-west-1:123456789012:db:instance-1"), DBInstanceIdentifier: aws.String("instance-1"), DBClusterIdentifier: aws.String("cluster-1"), @@ -277,16 +276,16 @@ func TestDatabaseFromRDSV2Instance(t *testing.T) { IAMDatabaseAuthenticationEnabled: aws.Bool(true), Engine: aws.String(services.RDSEnginePostgres), EngineVersion: aws.String("13.0"), - Endpoint: &rdsTypesV2.Endpoint{ + Endpoint: &rdstypes.Endpoint{ Address: aws.String("localhost"), Port: aws.Int32(5432), }, - TagList: []rdsTypesV2.Tag{{ + TagList: []rdstypes.Tag{{ Key: aws.String("key"), Value: aws.String("val"), }}, - DBSubnetGroup: &rdsTypesV2.DBSubnetGroup{ - Subnets: []rdsTypesV2.Subnet{ + DBSubnetGroup: &rdstypes.DBSubnetGroup{ + Subnets: []rdstypes.Subnet{ {SubnetIdentifier: aws.String("")}, {SubnetIdentifier: aws.String("subnet-1234567890abcdef0")}, {SubnetIdentifier: aws.String("subnet-1234567890abcdef1")}, @@ -294,7 +293,7 @@ func TestDatabaseFromRDSV2Instance(t *testing.T) { }, VpcId: aws.String("vpc-asd"), }, - VpcSecurityGroups: []rdsTypesV2.VpcSecurityGroupMembership{ + VpcSecurityGroups: []rdstypes.VpcSecurityGroupMembership{ {VpcSecurityGroupId: aws.String("")}, {VpcSecurityGroupId: aws.String("sg-1")}, {VpcSecurityGroupId: aws.String("sg-2")}, @@ -348,7 +347,7 @@ func TestDatabaseFromRDSV2Instance(t *testing.T) { newName := "override-1" instance := instance instance.TagList = append(instance.TagList, - rdsTypesV2.Tag{ + rdstypes.Tag{ Key: aws.String(overrideLabel), Value: aws.String(newName), }, @@ -365,7 +364,7 @@ func TestDatabaseFromRDSV2Instance(t *testing.T) { // TestDatabaseFromRDSInstance tests converting an RDS instance to a database resource. func TestDatabaseFromRDSInstanceNameOverride(t *testing.T) { for _, overrideLabel := range types.AWSDatabaseNameOverrideLabels { - instance := &rds.DBInstance{ + instance := &rdstypes.DBInstance{ DBInstanceArn: aws.String("arn:aws:rds:us-west-1:123456789012:db:instance-1"), DBInstanceIdentifier: aws.String("instance-1"), DBClusterIdentifier: aws.String("cluster-1"), @@ -373,11 +372,11 @@ func TestDatabaseFromRDSInstanceNameOverride(t *testing.T) { IAMDatabaseAuthenticationEnabled: aws.Bool(true), Engine: aws.String(services.RDSEnginePostgres), EngineVersion: aws.String("13.0"), - Endpoint: &rds.Endpoint{ + Endpoint: &rdstypes.Endpoint{ Address: aws.String("localhost"), - Port: aws.Int64(5432), + Port: aws.Int32(5432), }, - TagList: []*rds.Tag{ + TagList: []rdstypes.Tag{ {Key: aws.String("key"), Value: aws.String("val")}, {Key: aws.String(overrideLabel), Value: aws.String("override-1")}, }, @@ -421,8 +420,8 @@ func TestDatabaseFromRDSInstanceNameOverride(t *testing.T) { // TestDatabaseFromRDSCluster tests converting an RDS cluster to a database resource. func TestDatabaseFromRDSCluster(t *testing.T) { vpcid := uuid.NewString() - dbInstanceMembers := []*rds.DBInstance{{DBSubnetGroup: &rds.DBSubnetGroup{VpcId: aws.String(vpcid)}}} - cluster := &rds.DBCluster{ + dbInstanceMembers := []rdstypes.DBInstance{{DBSubnetGroup: &rdstypes.DBSubnetGroup{VpcId: aws.String(vpcid)}}} + cluster := &rdstypes.DBCluster{ DBClusterArn: aws.String("arn:aws:rds:us-east-1:123456789012:cluster:cluster-1"), DBClusterIdentifier: aws.String("cluster-1"), DbClusterResourceId: aws.String("resource-1"), @@ -431,12 +430,12 @@ func TestDatabaseFromRDSCluster(t *testing.T) { EngineVersion: aws.String("8.0.0"), Endpoint: aws.String("localhost"), ReaderEndpoint: aws.String("reader.host"), - Port: aws.Int64(3306), - CustomEndpoints: []*string{ - aws.String("myendpoint1.cluster-custom-example.us-east-1.rds.amazonaws.com"), - aws.String("myendpoint2.cluster-custom-example.us-east-1.rds.amazonaws.com"), + Port: aws.Int32(3306), + CustomEndpoints: []string{ + "myendpoint1.cluster-custom-example.us-east-1.rds.amazonaws.com", + "myendpoint2.cluster-custom-example.us-east-1.rds.amazonaws.com", }, - TagList: []*rds.Tag{{ + TagList: []rdstypes.Tag{{ Key: aws.String("key"), Value: aws.String("val"), }}, @@ -549,9 +548,9 @@ func TestDatabaseFromRDSCluster(t *testing.T) { t.Run("bad custom endpoints ", func(t *testing.T) { badCluster := *cluster - badCluster.CustomEndpoints = []*string{ - aws.String("badendpoint1"), - aws.String("badendpoint2"), + badCluster.CustomEndpoints = []string{ + "badendpoint1", + "badendpoint2", } _, err := NewDatabasesFromRDSClusterCustomEndpoints(&badCluster, dbInstanceMembers) require.Error(t, err) @@ -561,7 +560,7 @@ func TestDatabaseFromRDSCluster(t *testing.T) { // TestDatabaseFromRDSV2Cluster tests converting an RDS cluster to a database resource. // It uses the V2 of the aws sdk. func TestDatabaseFromRDSV2Cluster(t *testing.T) { - cluster := &rdsTypesV2.DBCluster{ + cluster := &rdstypes.DBCluster{ DBClusterArn: aws.String("arn:aws:rds:us-east-1:123456789012:cluster:cluster-1"), DBClusterIdentifier: aws.String("cluster-1"), DbClusterResourceId: aws.String("resource-1"), @@ -572,7 +571,7 @@ func TestDatabaseFromRDSV2Cluster(t *testing.T) { Endpoint: aws.String("localhost"), ReaderEndpoint: aws.String("reader.host"), Port: aws.Int32(3306), - VpcSecurityGroups: []rdsTypesV2.VpcSecurityGroupMembership{ + VpcSecurityGroups: []rdstypes.VpcSecurityGroupMembership{ {VpcSecurityGroupId: aws.String("")}, {VpcSecurityGroupId: aws.String("sg-1")}, {VpcSecurityGroupId: aws.String("sg-2")}, @@ -581,7 +580,7 @@ func TestDatabaseFromRDSV2Cluster(t *testing.T) { "myendpoint1.cluster-custom-example.us-east-1.rds.amazonaws.com", "myendpoint2.cluster-custom-example.us-east-1.rds.amazonaws.com", }, - TagList: []rdsTypesV2.Tag{{ + TagList: []rdstypes.Tag{{ Key: aws.String("key"), Value: aws.String("val"), }}, @@ -630,7 +629,7 @@ func TestDatabaseFromRDSV2Cluster(t *testing.T) { newName := "override-1" cluster.TagList = append(cluster.TagList, - rdsTypesV2.Tag{ + rdstypes.Tag{ Key: aws.String(overrideLabel), Value: aws.String(newName), }, @@ -645,10 +644,10 @@ func TestDatabaseFromRDSV2Cluster(t *testing.T) { }) t.Run("DB Cluster uses network information from DB Instance when available", func(t *testing.T) { - instance := &rdsTypesV2.DBInstance{ - DBSubnetGroup: &rdsTypesV2.DBSubnetGroup{ + instance := &rdstypes.DBInstance{ + DBSubnetGroup: &rdstypes.DBSubnetGroup{ VpcId: aws.String("vpc-123"), - Subnets: []rdsTypesV2.Subnet{ + Subnets: []rdstypes.Subnet{ {SubnetIdentifier: aws.String("subnet-123")}, {SubnetIdentifier: aws.String("subnet-456")}, }, @@ -699,9 +698,9 @@ func TestDatabaseFromRDSV2Cluster(t *testing.T) { // TestDatabaseFromRDSClusterNameOverride tests converting an RDS cluster to a database resource with overridden name. func TestDatabaseFromRDSClusterNameOverride(t *testing.T) { - dbInstanceMembers := []*rds.DBInstance{{DBSubnetGroup: &rds.DBSubnetGroup{VpcId: aws.String("vpc-123")}}} + dbInstanceMembers := []rdstypes.DBInstance{{DBSubnetGroup: &rdstypes.DBSubnetGroup{VpcId: aws.String("vpc-123")}}} for _, overrideLabel := range types.AWSDatabaseNameOverrideLabels { - cluster := &rds.DBCluster{ + cluster := &rdstypes.DBCluster{ DBClusterArn: aws.String("arn:aws:rds:us-east-1:123456789012:cluster:cluster-1"), DBClusterIdentifier: aws.String("cluster-1"), DbClusterResourceId: aws.String("resource-1"), @@ -710,12 +709,12 @@ func TestDatabaseFromRDSClusterNameOverride(t *testing.T) { EngineVersion: aws.String("8.0.0"), Endpoint: aws.String("localhost"), ReaderEndpoint: aws.String("reader.host"), - Port: aws.Int64(3306), - CustomEndpoints: []*string{ - aws.String("myendpoint1.cluster-custom-example.us-east-1.rds.amazonaws.com"), - aws.String("myendpoint2.cluster-custom-example.us-east-1.rds.amazonaws.com"), + Port: aws.Int32(3306), + CustomEndpoints: []string{ + "myendpoint1.cluster-custom-example.us-east-1.rds.amazonaws.com", + "myendpoint2.cluster-custom-example.us-east-1.rds.amazonaws.com", }, - TagList: []*rds.Tag{ + TagList: []rdstypes.Tag{ {Key: aws.String("key"), Value: aws.String("val")}, {Key: aws.String(overrideLabel), Value: aws.String("mycluster-2")}, }, @@ -831,9 +830,9 @@ func TestDatabaseFromRDSClusterNameOverride(t *testing.T) { t.Run("bad custom endpoints ", func(t *testing.T) { badCluster := *cluster - badCluster.CustomEndpoints = []*string{ - aws.String("badendpoint1"), - aws.String("badendpoint2"), + badCluster.CustomEndpoints = []string{ + "badendpoint1", + "badendpoint2", } _, err := NewDatabasesFromRDSClusterCustomEndpoints(&badCluster, dbInstanceMembers) require.Error(t, err) @@ -896,7 +895,7 @@ func TestNewDatabasesFromDocumentDBCluster(t *testing.T) { tests := []struct { name string - inputCluster *rds.DBCluster + inputCluster *rdstypes.DBCluster wantDatabases types.Databases }{ { @@ -929,26 +928,26 @@ func TestDatabaseFromRDSProxy(t *testing.T) { }{ { desc: "mysql", - engineFamily: rds.EngineFamilyMysql, + engineFamily: string(rdstypes.EngineFamilyMysql), wantProtocol: "mysql", wantPort: 3306, }, { desc: "postgres", - engineFamily: rds.EngineFamilyPostgresql, + engineFamily: string(rdstypes.EngineFamilyPostgresql), wantProtocol: "postgres", wantPort: 5432, }, { desc: "sqlserver", - engineFamily: rds.EngineFamilySqlserver, + engineFamily: string(rdstypes.EngineFamilySqlserver), wantProtocol: "sqlserver", wantPort: 1433, }, } for _, test := range tests { t.Run(test.desc, func(t *testing.T) { - dbProxy := &rds.DBProxy{ + dbProxy := &rdstypes.DBProxy{ DBProxyArn: aws.String("arn:aws:rds:ca-central-1:123456789012:db-proxy:prx-abcdef"), DBProxyName: aws.String("testproxy"), EngineFamily: aws.String(test.engineFamily), @@ -956,15 +955,15 @@ func TestDatabaseFromRDSProxy(t *testing.T) { VpcId: aws.String("test-vpc-id"), } - dbProxyEndpoint := &rds.DBProxyEndpoint{ + dbProxyEndpoint := &rdstypes.DBProxyEndpoint{ Endpoint: aws.String("custom.proxy.rds.test"), DBProxyEndpointName: aws.String("custom"), DBProxyName: aws.String("testproxy"), DBProxyEndpointArn: aws.String("arn:aws:rds:ca-central-1:123456789012:db-proxy-endpoint:prx-endpoint-abcdef"), - TargetRole: aws.String(rds.DBProxyEndpointTargetRoleReadOnly), + TargetRole: rdstypes.DBProxyEndpointTargetRoleReadOnly, } - tags := []*rds.Tag{{ + tags := []rdstypes.Tag{{ Key: aws.String("key"), Value: aws.String("val"), }} @@ -1059,7 +1058,7 @@ func TestAuroraMySQLVersion(t *testing.T) { } for _, test := range tests { t.Run(test.engineVersion, func(t *testing.T) { - require.Equal(t, test.expectedMySQLVersion, libcloudaws.AuroraMySQLVersion(&rds.DBCluster{EngineVersion: aws.String(test.engineVersion)})) + require.Equal(t, test.expectedMySQLVersion, libcloudaws.AuroraMySQLVersion(&rdstypes.DBCluster{EngineVersion: aws.String(test.engineVersion)})) }) } } @@ -1099,7 +1098,7 @@ func TestIsRDSClusterSupported(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - cluster := &rds.DBCluster{ + cluster := &rdstypes.DBCluster{ DBClusterArn: aws.String("arn:aws:rds:us-east-1:123456789012:cluster:test"), DBClusterIdentifier: aws.String(test.name), DbClusterResourceId: aws.String(uuid.New().String()), @@ -1149,7 +1148,7 @@ func TestIsRDSInstanceSupported(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - cluster := &rds.DBInstance{ + cluster := &rdstypes.DBInstance{ DBInstanceArn: aws.String("arn:aws:rds:us-east-1:123456789012:instance:test"), DBClusterIdentifier: aws.String(test.name), DbiResourceId: aws.String(uuid.New().String()), diff --git a/lib/srv/discovery/common/kubernetes_test.go b/lib/srv/discovery/common/kubernetes_test.go index b121c624a1e76..db27927634772 100644 --- a/lib/srv/discovery/common/kubernetes_test.go +++ b/lib/srv/discovery/common/kubernetes_test.go @@ -98,9 +98,8 @@ func TestNewKubeClusterFromAWSEKS(t *testing.T) { require.NoError(t, err) cluster := &eks.Cluster{ - Name: aws.String("cluster1"), - Arn: aws.String("arn:aws:eks:eu-west-1:123456789012:cluster/cluster1"), - Status: aws.String(eks.ClusterStatusActive), + Name: aws.String("cluster1"), + Arn: aws.String("arn:aws:eks:eu-west-1:123456789012:cluster/cluster1"), Tags: map[string]*string{ overrideLabel: aws.String("override-1"), "env": aws.String("prod"), diff --git a/lib/srv/discovery/common/renaming_test.go b/lib/srv/discovery/common/renaming_test.go index b01825725f672..12979392da959 100644 --- a/lib/srv/discovery/common/renaming_test.go +++ b/lib/srv/discovery/common/renaming_test.go @@ -27,15 +27,15 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/mysql/armmysqlflexibleservers" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/redis/armredis/v3" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/redisenterprise/armredisenterprise" + rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/eks" - "github.com/aws/aws-sdk-go/service/rds" "github.com/google/uuid" "github.com/stretchr/testify/require" "github.com/gravitational/teleport/api/types" azureutils "github.com/gravitational/teleport/api/utils/azure" - libcloudaws "github.com/gravitational/teleport/lib/cloud/aws" + "github.com/gravitational/teleport/lib/cloud/awstesthelpers" "github.com/gravitational/teleport/lib/cloud/azure" "github.com/gravitational/teleport/lib/cloud/gcp" "github.com/gravitational/teleport/lib/services" @@ -365,7 +365,7 @@ func requireOverrideLabelSkipsRenaming(t *testing.T, r types.ResourceWithLabels, func makeAuroraPrimaryDB(t *testing.T, name, region, accountID, overrideLabel string) types.Database { t.Helper() - cluster := &rds.DBCluster{ + cluster := &rdstypes.DBCluster{ DBClusterArn: aws.String(fmt.Sprintf("arn:aws:rds:%s:%s:cluster:%v", region, accountID, name)), DBClusterIdentifier: aws.String("cluster-1"), DbClusterResourceId: aws.String("resource-1"), @@ -373,29 +373,29 @@ func makeAuroraPrimaryDB(t *testing.T, name, region, accountID, overrideLabel st Engine: aws.String("aurora-mysql"), EngineVersion: aws.String("8.0.0"), Endpoint: aws.String("localhost"), - Port: aws.Int64(3306), - TagList: libcloudaws.LabelsToTags[rds.Tag](map[string]string{ + Port: aws.Int32(3306), + TagList: awstesthelpers.LabelsToRDSTags(map[string]string{ overrideLabel: name, }), } - database, err := NewDatabaseFromRDSCluster(cluster, []*rds.DBInstance{}) + database, err := NewDatabaseFromRDSCluster(cluster, []rdstypes.DBInstance{}) require.NoError(t, err) return database } func makeRDSInstanceDB(t *testing.T, name, region, accountID, overrideLabel string) types.Database { t.Helper() - instance := &rds.DBInstance{ + instance := &rdstypes.DBInstance{ DBInstanceArn: aws.String(fmt.Sprintf("arn:aws:rds:%s:%s:db:%v", region, accountID, name)), DBInstanceIdentifier: aws.String(name), DbiResourceId: aws.String(uuid.New().String()), Engine: aws.String(services.RDSEnginePostgres), DBInstanceStatus: aws.String("available"), - Endpoint: &rds.Endpoint{ + Endpoint: &rdstypes.Endpoint{ Address: aws.String("localhost"), - Port: aws.Int64(5432), + Port: aws.Int32(5432), }, - TagList: libcloudaws.LabelsToTags[rds.Tag](map[string]string{ + TagList: awstesthelpers.LabelsToRDSTags(map[string]string{ overrideLabel: name, }), } @@ -499,9 +499,8 @@ func labelsToAzureTags(labels map[string]string) map[string]*string { func makeEKSKubeCluster(t *testing.T, name, region, accountID, overrideLabel string) types.KubeCluster { t.Helper() eksCluster := &eks.Cluster{ - Name: aws.String(name), - Arn: aws.String(fmt.Sprintf("arn:aws:eks:%s:%s:cluster/%s", region, accountID, name)), - Status: aws.String(eks.ClusterStatusActive), + Name: aws.String(name), + Arn: aws.String(fmt.Sprintf("arn:aws:eks:%s:%s:cluster/%s", region, accountID, name)), Tags: map[string]*string{ overrideLabel: aws.String(name), }, diff --git a/lib/srv/discovery/discovery_test.go b/lib/srv/discovery/discovery_test.go index 865517ba4c33c..b140cd62710c4 100644 --- a/lib/srv/discovery/discovery_test.go +++ b/lib/srv/discovery/discovery_test.go @@ -39,6 +39,8 @@ import ( awsv2 "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/ec2" ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" + "github.com/aws/aws-sdk-go-v2/service/rds" + rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types" "github.com/aws/aws-sdk-go-v2/service/redshift" redshifttypes "github.com/aws/aws-sdk-go-v2/service/redshift/types" "github.com/aws/aws-sdk-go-v2/service/ssm" @@ -47,7 +49,6 @@ import ( "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/service/eks" "github.com/aws/aws-sdk-go/service/eks/eksiface" - "github.com/aws/aws-sdk-go/service/rds" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/google/uuid" @@ -2012,13 +2013,7 @@ func TestDiscoveryDatabase(t *testing.T) { } testCloudClients := &cloud.TestCloudClients{ - STS: &mocks.STSClientV1{}, - RDS: &mocks.RDSMock{ - DBInstances: []*rds.DBInstance{awsRDSInstance}, - DBEngineVersions: []*rds.DBEngineVersion{ - {Engine: aws.String(services.RDSEnginePostgres)}, - }, - }, + STS: &mocks.STSClientV1{}, MemoryDB: &mocks.MemoryDBMock{}, AzureRedis: azure.NewRedisClientByAPI(&azure.ARMRedisMock{ Servers: []*armredis.ResourceInfo{azRedisResource}, @@ -2363,6 +2358,12 @@ func TestDiscoveryDatabase(t *testing.T) { RedshiftClientProviderFn: newFakeRedshiftClientProvider(&mocks.RedshiftClient{ Clusters: []redshifttypes.Cluster{*awsRedshiftResource}, }), + RDSClientProviderFn: newFakeRDSClientProvider(&mocks.RDSClient{ + DBInstances: []rdstypes.DBInstance{*awsRDSInstance}, + DBEngineVersions: []rdstypes.DBEngineVersion{ + {Engine: aws.String(services.RDSEnginePostgres)}, + }, + }), }) require.NoError(t, err) @@ -2452,15 +2453,23 @@ func TestDiscoveryDatabaseRemovingDiscoveryConfigs(t *testing.T) { awsRDSInstance, awsRDSDB := makeRDSInstance(t, "aws-rds", "us-west-1", rewriteDiscoveryLabelsParams{discoveryConfigName: dc2Name, discoveryGroup: mainDiscoveryGroup}) + fakeConfigProvider := &mocks.AWSConfigProvider{ + STSClient: &mocks.STSClient{}, + } testCloudClients := &cloud.TestCloudClients{ - STS: &mocks.STSClientV1{}, - RDS: &mocks.RDSMock{ - DBInstances: []*rds.DBInstance{awsRDSInstance}, - DBEngineVersions: []*rds.DBEngineVersion{ + STS: &fakeConfigProvider.STSClient.STSClientV1, + } + dbFetcherFactory, err := db.NewAWSFetcherFactory(db.AWSFetcherFactoryConfig{ + AWSConfigProvider: fakeConfigProvider, + CloudClients: testCloudClients, + RDSClientProviderFn: newFakeRDSClientProvider(&mocks.RDSClient{ + DBInstances: []rdstypes.DBInstance{*awsRDSInstance}, + DBEngineVersions: []rdstypes.DBEngineVersion{ {Engine: aws.String(services.RDSEnginePostgres)}, }, - }, - } + }), + }) + require.NoError(t, err) ctx, cancel := context.WithCancel(context.Background()) t.Cleanup(cancel) @@ -2488,14 +2497,16 @@ func TestDiscoveryDatabaseRemovingDiscoveryConfigs(t *testing.T) { srv, err := New( authz.ContextWithUser(ctx, identity.I), &Config{ - CloudClients: testCloudClients, - ClusterFeatures: func() proto.Features { return proto.Features{} }, - KubernetesClient: fake.NewSimpleClientset(), - AccessPoint: getDiscoveryAccessPoint(tlsServer.Auth(), authClient), - Matchers: Matchers{}, - Emitter: authClient, - DiscoveryGroup: mainDiscoveryGroup, - clock: clock, + AWSConfigProvider: fakeConfigProvider, + AWSDatabaseFetcherFactory: dbFetcherFactory, + CloudClients: testCloudClients, + ClusterFeatures: func() proto.Features { return proto.Features{} }, + KubernetesClient: fake.NewSimpleClientset(), + AccessPoint: getDiscoveryAccessPoint(tlsServer.Auth(), authClient), + Matchers: Matchers{}, + Emitter: authClient, + DiscoveryGroup: mainDiscoveryGroup, + clock: clock, }) require.NoError(t, err) @@ -2618,16 +2629,16 @@ func makeEKSCluster(t *testing.T, name, region string, discoveryParams rewriteDi return eksAWSCluster, actual } -func makeRDSInstance(t *testing.T, name, region string, discoveryParams rewriteDiscoveryLabelsParams) (*rds.DBInstance, types.Database) { - instance := &rds.DBInstance{ +func makeRDSInstance(t *testing.T, name, region string, discoveryParams rewriteDiscoveryLabelsParams) (*rdstypes.DBInstance, types.Database) { + instance := &rdstypes.DBInstance{ DBInstanceArn: aws.String(fmt.Sprintf("arn:aws:rds:%v:123456789012:db:%v", region, name)), DBInstanceIdentifier: aws.String(name), DbiResourceId: aws.String(uuid.New().String()), Engine: aws.String(services.RDSEnginePostgres), DBInstanceStatus: aws.String("available"), - Endpoint: &rds.Endpoint{ + Endpoint: &rdstypes.Endpoint{ Address: aws.String("localhost"), - Port: aws.Int64(5432), + Port: aws.Int32(5432), }, } database, err := common.NewDatabaseFromRDSInstance(instance) @@ -3701,3 +3712,9 @@ func newFakeRedshiftClientProvider(c redshift.DescribeClustersAPIClient) db.Reds return c } } + +func newFakeRDSClientProvider(c db.RDSClient) db.RDSClientProviderFunc { + return func(cfg awsv2.Config, optFns ...func(*rds.Options)) db.RDSClient { + return c + } +} diff --git a/lib/srv/discovery/fetchers/aws-sync/aws-sync.go b/lib/srv/discovery/fetchers/aws-sync/aws-sync.go index 2a7e928370091..da5d9dea9d523 100644 --- a/lib/srv/discovery/fetchers/aws-sync/aws-sync.go +++ b/lib/srv/discovery/fetchers/aws-sync/aws-sync.go @@ -24,9 +24,9 @@ import ( "sync" "time" - awsv2 "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/aws/retry" - "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go-v2/service/rds" "github.com/aws/aws-sdk-go/aws/client" "github.com/aws/aws-sdk-go/service/sts" "github.com/gravitational/trace" @@ -45,6 +45,8 @@ const pageSize int64 = 500 // Config is the configuration for the AWS fetcher. type Config struct { + // AWSConfigProvider provides [aws.Config] for AWS SDK service clients. + AWSConfigProvider awsconfig.Provider // CloudClients is the cloud clients to use when fetching AWS resources. CloudClients cloud.Clients // GetEC2Client gets an AWS EC2 client for the given region. @@ -59,6 +61,23 @@ type Config struct { Integration string // DiscoveryConfigName if set, will be used to report the Discovery Config Status to the Auth Server. DiscoveryConfigName string + + // rdsClientProviderFn is an internal-only [rdsClient] provider + // func that is only set in tests. + rdsClientProviderFn rdsClientProviderFunc +} + +func (c *Config) CheckAndSetDefaults() error { + if c.AWSConfigProvider == nil { + return trace.BadParameter("missing AWSConfigProvider") + } + + if c.rdsClientProviderFn == nil { + c.rdsClientProviderFn = func(cfg aws.Config, optFns ...func(*rds.Options)) rdsClient { + return rds.NewFromConfig(cfg, optFns...) + } + } + return nil } // AssumeRole is the configuration for assuming an AWS role. @@ -182,6 +201,9 @@ func (r *Resources) UsageReport(numberAccounts int) *usageeventsv1.AccessGraphAW // NewAWSFetcher creates a new AWS fetcher. func NewAWSFetcher(ctx context.Context, cfg Config) (AWSSync, error) { + if err := cfg.CheckAndSetDefaults(); err != nil { + return nil, trace.Wrap(err) + } a := &awsFetcher{ Config: cfg, lastResult: &Resources{}, @@ -335,7 +357,7 @@ func (a *awsFetcher) getAWSV2Options() []awsconfig.OptionsFn { opts = append(opts, awsconfig.WithAssumeRole(a.Config.AssumeRole.RoleARN, a.Config.AssumeRole.ExternalID)) } const maxRetries = 10 - opts = append(opts, awsconfig.WithRetryer(func() awsv2.Retryer { + opts = append(opts, awsconfig.WithRetryer(func() aws.Retryer { return retry.NewStandard(func(so *retry.StandardOptions) { so.MaxAttempts = maxRetries so.Backoff = retry.NewExponentialJitterBackoff(300 * time.Second) @@ -361,7 +383,7 @@ func (a *awsFetcher) getAccountId(ctx context.Context) (string, error) { return "", trace.Wrap(err) } - return aws.StringValue(req.Account), nil + return aws.ToString(req.Account), nil } func (a *awsFetcher) DiscoveryConfigName() string { diff --git a/lib/srv/discovery/fetchers/aws-sync/rds.go b/lib/srv/discovery/fetchers/aws-sync/rds.go index 08195e2132e82..7441251ef7890 100644 --- a/lib/srv/discovery/fetchers/aws-sync/rds.go +++ b/lib/srv/discovery/fetchers/aws-sync/rds.go @@ -22,8 +22,9 @@ import ( "context" "sync" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/rds" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/rds" + rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types" "github.com/gravitational/trace" "golang.org/x/sync/errgroup" "google.golang.org/protobuf/types/known/timestamppb" @@ -31,12 +32,21 @@ import ( accessgraphv1alpha "github.com/gravitational/teleport/gen/proto/go/accessgraph/v1alpha" ) +// rdsClient defines a subset of the AWS RDS client API. +type rdsClient interface { + rds.DescribeDBClustersAPIClient + rds.DescribeDBInstancesAPIClient +} + +// rdsClientProviderFunc provides a [rdsClient]. +type rdsClientProviderFunc func(cfg aws.Config, optFns ...func(*rds.Options)) rdsClient + // pollAWSRDSDatabases is a function that returns a function that fetches // RDS instances and clusters. func (a *awsFetcher) pollAWSRDSDatabases(ctx context.Context, result *Resources, collectErr func(error)) func() error { return func() error { var err error - result.RDSDatabases, err = a.fetchAWSRDSDatabases(ctx, a.lastResult) + result.RDSDatabases, err = a.fetchAWSRDSDatabases(ctx) if err != nil { collectErr(trace.Wrap(err, "failed to fetch databases")) } @@ -45,7 +55,7 @@ func (a *awsFetcher) pollAWSRDSDatabases(ctx context.Context, result *Resources, } // fetchAWSRDSDatabases fetches RDS databases from all regions. -func (a *awsFetcher) fetchAWSRDSDatabases(ctx context.Context, existing *Resources) ( +func (a *awsFetcher) fetchAWSRDSDatabases(ctx context.Context) ( []*accessgraphv1alpha.AWSRDSDatabaseV1, error, ) { @@ -59,14 +69,14 @@ func (a *awsFetcher) fetchAWSRDSDatabases(ctx context.Context, existing *Resourc // This is a temporary solution until we have a better way to limit the // number of concurrent requests. eG.SetLimit(5) - collectDBs := func(db *accessgraphv1alpha.AWSRDSDatabaseV1, err error) { + collectDBs := func(db []*accessgraphv1alpha.AWSRDSDatabaseV1, err error) { hostsMu.Lock() defer hostsMu.Unlock() if err != nil { errs = append(errs, err) } if db != nil { - dbs = append(dbs, db) + dbs = append(dbs, db...) } } @@ -74,42 +84,14 @@ func (a *awsFetcher) fetchAWSRDSDatabases(ctx context.Context, existing *Resourc for _, region := range a.Regions { region := region eG.Go(func() error { - rdsClient, err := a.CloudClients.GetAWSRDSClient(ctx, region, a.getAWSOptions()...) + awsCfg, err := a.AWSConfigProvider.GetConfig(ctx, region, a.getAWSV2Options()...) if err != nil { collectDBs(nil, trace.Wrap(err)) return nil } - err = rdsClient.DescribeDBInstancesPagesWithContext(ctx, &rds.DescribeDBInstancesInput{}, - func(output *rds.DescribeDBInstancesOutput, lastPage bool) bool { - for _, db := range output.DBInstances { - // if instance belongs to a cluster, skip it as we want to represent the cluster itself - // and we pull it using DescribeDBClustersPagesWithContext instead. - if aws.StringValue(db.DBClusterIdentifier) != "" { - continue - } - protoRDS := awsRDSInstanceToRDS(db, region, a.AccountID) - collectDBs(protoRDS, nil) - } - return !lastPage - }, - ) - if err != nil { - collectDBs(nil, trace.Wrap(err)) - } - - err = rdsClient.DescribeDBClustersPagesWithContext(ctx, &rds.DescribeDBClustersInput{}, - func(output *rds.DescribeDBClustersOutput, lastPage bool) bool { - for _, db := range output.DBClusters { - protoRDS := awsRDSClusterToRDS(db, region, a.AccountID) - collectDBs(protoRDS, nil) - } - return !lastPage - }, - ) - if err != nil { - collectDBs(nil, trace.Wrap(err)) - } - + clt := a.rdsClientProviderFn(awsCfg) + a.collectDBInstances(ctx, clt, region, collectDBs) + a.collectDBClusters(ctx, clt, region, collectDBs) return nil }) } @@ -118,60 +100,123 @@ func (a *awsFetcher) fetchAWSRDSDatabases(ctx context.Context, existing *Resourc return dbs, trace.NewAggregate(append(errs, err)...) } -// awsRDSInstanceToRDS converts an rds.DBInstance to accessgraphv1alpha.AWSRDSDatabaseV1 +// awsRDSInstanceToRDS converts an rdstypes.DBInstance to accessgraphv1alpha.AWSRDSDatabaseV1 // representation. -func awsRDSInstanceToRDS(instance *rds.DBInstance, region, accountID string) *accessgraphv1alpha.AWSRDSDatabaseV1 { +func awsRDSInstanceToRDS(instance *rdstypes.DBInstance, region, accountID string) *accessgraphv1alpha.AWSRDSDatabaseV1 { var tags []*accessgraphv1alpha.AWSTag for _, v := range instance.TagList { tags = append(tags, &accessgraphv1alpha.AWSTag{ - Key: aws.StringValue(v.Key), + Key: aws.ToString(v.Key), Value: strPtrToWrapper(v.Value), }) } return &accessgraphv1alpha.AWSRDSDatabaseV1{ - Name: aws.StringValue(instance.DBInstanceIdentifier), - Arn: aws.StringValue(instance.DBInstanceArn), + Name: aws.ToString(instance.DBInstanceIdentifier), + Arn: aws.ToString(instance.DBInstanceArn), CreatedAt: awsTimeToProtoTime(instance.InstanceCreateTime), - Status: aws.StringValue(instance.DBInstanceStatus), + Status: aws.ToString(instance.DBInstanceStatus), Region: region, AccountId: accountID, Tags: tags, EngineDetails: &accessgraphv1alpha.AWSRDSEngineV1{ - Engine: aws.StringValue(instance.Engine), - Version: aws.StringValue(instance.EngineVersion), + Engine: aws.ToString(instance.Engine), + Version: aws.ToString(instance.EngineVersion), }, IsCluster: false, - ResourceId: aws.StringValue(instance.DbiResourceId), + ResourceId: aws.ToString(instance.DbiResourceId), LastSyncTime: timestamppb.Now(), } } -// awsRDSInstanceToRDS converts an rds.DBCluster to accessgraphv1alpha.AWSRDSDatabaseV1 +// awsRDSInstanceToRDS converts an rdstypes.DBCluster to accessgraphv1alpha.AWSRDSDatabaseV1 // representation. -func awsRDSClusterToRDS(instance *rds.DBCluster, region, accountID string) *accessgraphv1alpha.AWSRDSDatabaseV1 { +func awsRDSClusterToRDS(instance *rdstypes.DBCluster, region, accountID string) *accessgraphv1alpha.AWSRDSDatabaseV1 { var tags []*accessgraphv1alpha.AWSTag for _, v := range instance.TagList { tags = append(tags, &accessgraphv1alpha.AWSTag{ - Key: aws.StringValue(v.Key), + Key: aws.ToString(v.Key), Value: strPtrToWrapper(v.Value), }) } return &accessgraphv1alpha.AWSRDSDatabaseV1{ - Name: aws.StringValue(instance.DBClusterIdentifier), - Arn: aws.StringValue(instance.DBClusterArn), + Name: aws.ToString(instance.DBClusterIdentifier), + Arn: aws.ToString(instance.DBClusterArn), CreatedAt: awsTimeToProtoTime(instance.ClusterCreateTime), - Status: aws.StringValue(instance.Status), + Status: aws.ToString(instance.Status), Region: region, AccountId: accountID, Tags: tags, EngineDetails: &accessgraphv1alpha.AWSRDSEngineV1{ - Engine: aws.StringValue(instance.Engine), - Version: aws.StringValue(instance.EngineVersion), + Engine: aws.ToString(instance.Engine), + Version: aws.ToString(instance.EngineVersion), }, IsCluster: true, - ResourceId: aws.StringValue(instance.DbClusterResourceId), + ResourceId: aws.ToString(instance.DbClusterResourceId), LastSyncTime: timestamppb.Now(), } } + +func (a *awsFetcher) collectDBInstances(ctx context.Context, + clt rdsClient, + region string, + collectDBs func([]*accessgraphv1alpha.AWSRDSDatabaseV1, error), +) { + pager := rds.NewDescribeDBInstancesPaginator(clt, + &rds.DescribeDBInstancesInput{}, + func(ddpo *rds.DescribeDBInstancesPaginatorOptions) { + ddpo.StopOnDuplicateToken = true + }, + ) + var instances []*accessgraphv1alpha.AWSRDSDatabaseV1 + for pager.HasMorePages() { + page, err := pager.NextPage(ctx) + if err != nil { + old := sliceFilter(a.lastResult.RDSDatabases, func(db *accessgraphv1alpha.AWSRDSDatabaseV1) bool { + return !db.IsCluster && db.Region == region && db.AccountId == a.AccountID + }) + collectDBs(old, trace.Wrap(err)) + return + } + for _, db := range page.DBInstances { + // if instance belongs to a cluster, skip it as we want to represent the cluster itself + // and we pull it using DescribeDBClustersPaginator instead. + if aws.ToString(db.DBClusterIdentifier) != "" { + continue + } + protoRDS := awsRDSInstanceToRDS(&db, region, a.AccountID) + instances = append(instances, protoRDS) + } + } + collectDBs(instances, nil) +} + +func (a *awsFetcher) collectDBClusters( + ctx context.Context, + clt rdsClient, + region string, + collectDBs func([]*accessgraphv1alpha.AWSRDSDatabaseV1, error), +) { + pager := rds.NewDescribeDBClustersPaginator(clt, &rds.DescribeDBClustersInput{}, + func(ddpo *rds.DescribeDBClustersPaginatorOptions) { + ddpo.StopOnDuplicateToken = true + }, + ) + var clusters []*accessgraphv1alpha.AWSRDSDatabaseV1 + for pager.HasMorePages() { + page, err := pager.NextPage(ctx) + if err != nil { + old := sliceFilter(a.lastResult.RDSDatabases, func(db *accessgraphv1alpha.AWSRDSDatabaseV1) bool { + return db.IsCluster && db.Region == region && db.AccountId == a.AccountID + }) + collectDBs(old, trace.Wrap(err)) + return + } + for _, db := range page.DBClusters { + protoRDS := awsRDSClusterToRDS(&db, region, a.AccountID) + clusters = append(clusters, protoRDS) + } + } + collectDBs(clusters, nil) +} diff --git a/lib/srv/discovery/fetchers/aws-sync/rds_test.go b/lib/srv/discovery/fetchers/aws-sync/rds_test.go index bed0811d88e1d..b5ae125abc4f4 100644 --- a/lib/srv/discovery/fetchers/aws-sync/rds_test.go +++ b/lib/srv/discovery/fetchers/aws-sync/rds_test.go @@ -20,19 +20,19 @@ package aws_sync import ( "context" - "sync" "testing" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/rds" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/rds" + rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types" "github.com/google/go-cmp/cmp" "github.com/stretchr/testify/require" "google.golang.org/protobuf/testing/protocmp" "google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/wrapperspb" + "github.com/gravitational/teleport/api/types" accessgraphv1alpha "github.com/gravitational/teleport/gen/proto/go/accessgraph/v1alpha" - "github.com/gravitational/teleport/lib/cloud" "github.com/gravitational/teleport/lib/cloud/mocks" ) @@ -44,86 +44,113 @@ func TestPollAWSRDS(t *testing.T) { regions = []string{"eu-west-1"} ) - tests := []struct { - name string - want *Resources - }{ - { - name: "poll rds databases", - want: &Resources{ - RDSDatabases: []*accessgraphv1alpha.AWSRDSDatabaseV1{ + awsOIDCIntegration, err := types.NewIntegrationAWSOIDC( + types.Metadata{Name: "integration-test"}, + &types.AWSOIDCIntegrationSpecV1{ + RoleARN: "arn:aws:sts::123456789012:role/TestRole", + }, + ) + require.NoError(t, err) + + resourcesFixture := Resources{ + RDSDatabases: []*accessgraphv1alpha.AWSRDSDatabaseV1{ + { + Arn: "arn:us-west1:rds:instance1", + Status: string(rdstypes.DBProxyStatusAvailable), + Name: "db1", + EngineDetails: &accessgraphv1alpha.AWSRDSEngineV1{ + Engine: string(rdstypes.EngineFamilyMysql), + Version: "v1.1", + }, + CreatedAt: timestamppb.New(date), + Tags: []*accessgraphv1alpha.AWSTag{ { - Arn: "arn:us-west1:rds:instance1", - Status: rds.DBProxyStatusAvailable, - Name: "db1", - EngineDetails: &accessgraphv1alpha.AWSRDSEngineV1{ - Engine: rds.EngineFamilyMysql, - Version: "v1.1", - }, - CreatedAt: timestamppb.New(date), - Tags: []*accessgraphv1alpha.AWSTag{ - { - Key: "tag", - Value: wrapperspb.String("val"), - }, - }, - Region: "eu-west-1", - IsCluster: false, - AccountId: "12345678", - ResourceId: "db1", + Key: "tag", + Value: wrapperspb.String("val"), }, + }, + Region: "eu-west-1", + IsCluster: false, + AccountId: "12345678", + ResourceId: "db1", + }, + { + Arn: "arn:us-west1:rds:cluster1", + Status: string(rdstypes.DBProxyStatusAvailable), + Name: "cluster1", + EngineDetails: &accessgraphv1alpha.AWSRDSEngineV1{ + Engine: string(rdstypes.EngineFamilyMysql), + Version: "v1.1", + }, + CreatedAt: timestamppb.New(date), + Tags: []*accessgraphv1alpha.AWSTag{ { - Arn: "arn:us-west1:rds:cluster1", - Status: rds.DBProxyStatusAvailable, - Name: "cluster1", - EngineDetails: &accessgraphv1alpha.AWSRDSEngineV1{ - Engine: rds.EngineFamilyMysql, - Version: "v1.1", - }, - CreatedAt: timestamppb.New(date), - Tags: []*accessgraphv1alpha.AWSTag{ - { - Key: "tag", - Value: wrapperspb.String("val"), - }, - }, - Region: "eu-west-1", - IsCluster: true, - AccountId: "12345678", - ResourceId: "cluster1", + Key: "tag", + Value: wrapperspb.String("val"), }, }, + Region: "eu-west-1", + IsCluster: true, + AccountId: "12345678", + ResourceId: "cluster1", }, }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - mockedClients := &cloud.TestCloudClients{ - RDS: &mocks.RDSMock{ + + tests := []struct { + name string + fetcherConfigOpt func(*awsFetcher) + want *Resources + checkError func(*testing.T, error) + }{ + { + name: "poll rds databases", + want: &resourcesFixture, + fetcherConfigOpt: func(a *awsFetcher) { + a.rdsClientProviderFn = newFakeRDSClientProvider(&mocks.RDSClient{ DBInstances: dbInstances(), DBClusters: dbClusters(), - }, - } - - var ( - errs []error - mu sync.Mutex - ) - - collectErr := func(err error) { - mu.Lock() - defer mu.Unlock() - errs = append(errs, err) - } + }) + }, + checkError: func(t *testing.T, err error) { + require.NoError(t, err) + }, + }, + { + name: "reuse last synced databases on failure", + want: &resourcesFixture, + fetcherConfigOpt: func(a *awsFetcher) { + a.rdsClientProviderFn = newFakeRDSClientProvider(&mocks.RDSClient{Unauth: true}) + a.lastResult = &resourcesFixture + }, + checkError: func(t *testing.T, err error) { + require.Error(t, err) + require.ErrorContains(t, err, "failed to fetch databases") + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { a := &awsFetcher{ Config: Config{ - AccountID: accountID, - CloudClients: mockedClients, - Regions: regions, - Integration: accountID, + AccountID: accountID, + AWSConfigProvider: &mocks.AWSConfigProvider{ + OIDCIntegrationClient: &mocks.FakeOIDCIntegrationClient{ + Integration: awsOIDCIntegration, + Token: "fake-oidc-token", + }, + }, + Regions: regions, + Integration: awsOIDCIntegration.GetName(), }, } + if tt.fetcherConfigOpt != nil { + tt.fetcherConfigOpt(a) + } result := &Resources{} + collectErr := func(err error) { + tt.checkError(t, err) + } execFunc := a.pollAWSRDSDatabases(context.Background(), result, collectErr) require.NoError(t, execFunc()) require.Empty(t, cmp.Diff( @@ -144,16 +171,16 @@ func TestPollAWSRDS(t *testing.T) { } } -func dbInstances() []*rds.DBInstance { - return []*rds.DBInstance{ +func dbInstances() []rdstypes.DBInstance { + return []rdstypes.DBInstance{ { DBInstanceIdentifier: aws.String("db1"), DBInstanceArn: aws.String("arn:us-west1:rds:instance1"), InstanceCreateTime: aws.Time(date), - Engine: aws.String(rds.EngineFamilyMysql), - DBInstanceStatus: aws.String(rds.DBProxyStatusAvailable), + Engine: aws.String(string(rdstypes.EngineFamilyMysql)), + DBInstanceStatus: aws.String(string(rdstypes.DBProxyStatusAvailable)), EngineVersion: aws.String("v1.1"), - TagList: []*rds.Tag{ + TagList: []rdstypes.Tag{ { Key: aws.String("tag"), Value: aws.String("val"), @@ -164,16 +191,16 @@ func dbInstances() []*rds.DBInstance { } } -func dbClusters() []*rds.DBCluster { - return []*rds.DBCluster{ +func dbClusters() []rdstypes.DBCluster { + return []rdstypes.DBCluster{ { DBClusterIdentifier: aws.String("cluster1"), DBClusterArn: aws.String("arn:us-west1:rds:cluster1"), ClusterCreateTime: aws.Time(date), - Engine: aws.String(rds.EngineFamilyMysql), - Status: aws.String(rds.DBProxyStatusAvailable), + Engine: aws.String(string(rdstypes.EngineFamilyMysql)), + Status: aws.String(string(rdstypes.DBProxyStatusAvailable)), EngineVersion: aws.String("v1.1"), - TagList: []*rds.Tag{ + TagList: []rdstypes.Tag{ { Key: aws.String("tag"), Value: aws.String("val"), @@ -183,3 +210,9 @@ func dbClusters() []*rds.DBCluster { }, } } + +func newFakeRDSClientProvider(c rdsClient) rdsClientProviderFunc { + return func(cfg aws.Config, optFns ...func(*rds.Options)) rdsClient { + return c + } +} diff --git a/lib/srv/discovery/fetchers/db/aws.go b/lib/srv/discovery/fetchers/db/aws.go index d6d70912d7092..18b1775cf0241 100644 --- a/lib/srv/discovery/fetchers/db/aws.go +++ b/lib/srv/discovery/fetchers/db/aws.go @@ -24,6 +24,7 @@ import ( "log/slog" "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/rds" "github.com/aws/aws-sdk-go-v2/service/redshift" "github.com/gravitational/trace" @@ -76,6 +77,8 @@ type awsFetcherConfig struct { // redshiftClientProviderFn provides an AWS Redshift client. redshiftClientProviderFn RedshiftClientProviderFunc + // rdsClientProviderFn provides an AWS RDS client. + rdsClientProviderFn RDSClientProviderFunc } // CheckAndSetDefaults validates the config and sets defaults. @@ -114,6 +117,11 @@ func (cfg *awsFetcherConfig) CheckAndSetDefaults(component string) error { return redshift.NewFromConfig(cfg, optFns...) } } + if cfg.rdsClientProviderFn == nil { + cfg.rdsClientProviderFn = func(cfg aws.Config, optFns ...func(*rds.Options)) RDSClient { + return rds.NewFromConfig(cfg, optFns...) + } + } return nil } diff --git a/lib/srv/discovery/fetchers/db/aws_docdb.go b/lib/srv/discovery/fetchers/db/aws_docdb.go index a6a604be340eb..32fb1c104c6e3 100644 --- a/lib/srv/discovery/fetchers/db/aws_docdb.go +++ b/lib/srv/discovery/fetchers/db/aws_docdb.go @@ -21,14 +21,14 @@ package db import ( "context" + "github.com/aws/aws-sdk-go-v2/service/rds" + rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types" "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/rds" - "github.com/aws/aws-sdk-go/service/rds/rdsiface" "github.com/gravitational/trace" "github.com/gravitational/teleport/api/types" - "github.com/gravitational/teleport/lib/cloud" libcloudaws "github.com/gravitational/teleport/lib/cloud/aws" + "github.com/gravitational/teleport/lib/cloud/awsconfig" "github.com/gravitational/teleport/lib/srv/discovery/common" ) @@ -39,13 +39,6 @@ func newDocumentDBFetcher(cfg awsFetcherConfig) (common.Fetcher, error) { } // rdsDocumentDBFetcher retrieves DocumentDB clusters. -// -// Note that AWS DocumentDB internally uses the RDS APIs: -// https://github.com/aws/aws-sdk-go/blob/3248e69e16aa601ffa929be53a52439425257e5e/service/docdb/service.go#L33 -// The interfaces/structs in "services/docdb" are usually a subset of those in -// "services/rds". -// -// TODO(greedy52) switch to aws-sdk-go-v2/services/docdb. type rdsDocumentDBFetcher struct{} func (f *rdsDocumentDBFetcher) ComponentShortName() string { @@ -54,21 +47,22 @@ func (f *rdsDocumentDBFetcher) ComponentShortName() string { // GetDatabases returns a list of database resources representing DocumentDB endpoints. func (f *rdsDocumentDBFetcher) GetDatabases(ctx context.Context, cfg *awsFetcherConfig) (types.Databases, error) { - rdsClient, err := cfg.AWSClients.GetAWSRDSClient(ctx, cfg.Region, - cloud.WithAssumeRole(cfg.AssumeRole.RoleARN, cfg.AssumeRole.ExternalID), - cloud.WithCredentialsMaybeIntegration(cfg.Integration), + awsCfg, err := cfg.AWSConfigProvider.GetConfig(ctx, cfg.Region, + awsconfig.WithAssumeRole(cfg.AssumeRole.RoleARN, cfg.AssumeRole.ExternalID), + awsconfig.WithCredentialsMaybeIntegration(cfg.Integration), ) if err != nil { return nil, trace.Wrap(err) } - clusters, err := f.getAllDBClusters(ctx, rdsClient) + clt := cfg.rdsClientProviderFn(awsCfg) + clusters, err := f.getAllDBClusters(ctx, clt) if err != nil { return nil, trace.Wrap(libcloudaws.ConvertRequestFailureError(err)) } databases := make(types.Databases, 0) for _, cluster := range clusters { - if !libcloudaws.IsDocumentDBClusterSupported(cluster) { + if !libcloudaws.IsDocumentDBClusterSupported(&cluster) { cfg.Logger.DebugContext(ctx, "DocumentDB cluster doesn't support IAM authentication. Skipping.", "cluster", aws.StringValue(cluster.DBClusterIdentifier), "engine_version", aws.StringValue(cluster.EngineVersion)) @@ -82,7 +76,7 @@ func (f *rdsDocumentDBFetcher) GetDatabases(ctx context.Context, cfg *awsFetcher continue } - dbs, err := common.NewDatabasesFromDocumentDBCluster(cluster) + dbs, err := common.NewDatabasesFromDocumentDBCluster(&cluster) if err != nil { cfg.Logger.WarnContext(ctx, "Could not convert DocumentDB cluster to database resources.", "cluster", aws.StringValue(cluster.DBClusterIdentifier), @@ -93,15 +87,23 @@ func (f *rdsDocumentDBFetcher) GetDatabases(ctx context.Context, cfg *awsFetcher return databases, nil } -func (f *rdsDocumentDBFetcher) getAllDBClusters(ctx context.Context, rdsClient rdsiface.RDSAPI) ([]*rds.DBCluster, error) { - var pageNum int - var clusters []*rds.DBCluster - err := rdsClient.DescribeDBClustersPagesWithContext(ctx, &rds.DescribeDBClustersInput{ - Filters: rdsEngineFilter([]string{"docdb"}), - }, func(ddo *rds.DescribeDBClustersOutput, lastPage bool) bool { - pageNum++ - clusters = append(clusters, ddo.DBClusters...) - return pageNum <= maxAWSPages - }) - return clusters, trace.Wrap(err) +func (f *rdsDocumentDBFetcher) getAllDBClusters(ctx context.Context, clt RDSClient) ([]rdstypes.DBCluster, error) { + pager := rds.NewDescribeDBClustersPaginator(clt, + &rds.DescribeDBClustersInput{ + Filters: rdsEngineFilter([]string{"docdb"}), + }, + func(pagerOpts *rds.DescribeDBClustersPaginatorOptions) { + pagerOpts.StopOnDuplicateToken = true + }, + ) + + var clusters []rdstypes.DBCluster + for i := 0; i < maxAWSPages && pager.HasMorePages(); i++ { + page, err := pager.NextPage(ctx) + if err != nil { + return nil, trace.Wrap(err) + } + clusters = append(clusters, page.DBClusters...) + } + return clusters, nil } diff --git a/lib/srv/discovery/fetchers/db/aws_docdb_test.go b/lib/srv/discovery/fetchers/db/aws_docdb_test.go index 4ae7cfee582f0..151b7b7ccea37 100644 --- a/lib/srv/discovery/fetchers/db/aws_docdb_test.go +++ b/lib/srv/discovery/fetchers/db/aws_docdb_test.go @@ -21,12 +21,11 @@ package db import ( "testing" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/rds" + "github.com/aws/aws-sdk-go-v2/aws" + rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types" "github.com/stretchr/testify/require" "github.com/gravitational/teleport/api/types" - "github.com/gravitational/teleport/lib/cloud" "github.com/gravitational/teleport/lib/cloud/mocks" "github.com/gravitational/teleport/lib/srv/discovery/common" ) @@ -34,16 +33,16 @@ import ( func TestDocumentDBFetcher(t *testing.T) { t.Parallel() - docdbEngine := &rds.DBEngineVersion{ + docdbEngine := &rdstypes.DBEngineVersion{ Engine: aws.String("docdb"), } clusterProd := mocks.DocumentDBCluster("cluster1", "us-east-1", envProdLabels, mocks.WithDocumentDBClusterReader) clusterDev := mocks.DocumentDBCluster("cluster2", "us-east-1", envDevLabels) - clusterNotAvailable := mocks.DocumentDBCluster("cluster3", "us-east-1", envDevLabels, func(cluster *rds.DBCluster) { + clusterNotAvailable := mocks.DocumentDBCluster("cluster3", "us-east-1", envDevLabels, func(cluster *rdstypes.DBCluster) { cluster.Status = aws.String("creating") }) - clusterNotSupported := mocks.DocumentDBCluster("cluster4", "us-east-1", envDevLabels, func(cluster *rds.DBCluster) { + clusterNotSupported := mocks.DocumentDBCluster("cluster4", "us-east-1", envDevLabels, func(cluster *rdstypes.DBCluster) { cluster.EngineVersion = aws.String("4.0.0") }) @@ -53,11 +52,11 @@ func TestDocumentDBFetcher(t *testing.T) { tests := []awsFetcherTest{ { name: "fetch all", - inputClients: &cloud.TestCloudClients{ - RDS: &mocks.RDSMock{ - DBClusters: []*rds.DBCluster{clusterProd, clusterDev}, - DBEngineVersions: []*rds.DBEngineVersion{docdbEngine}, - }, + fetcherCfg: AWSFetcherFactoryConfig{ + RDSClientProviderFn: newFakeRDSClientProvider(&mocks.RDSClient{ + DBClusters: []rdstypes.DBCluster{*clusterProd, *clusterDev}, + DBEngineVersions: []rdstypes.DBEngineVersion{*docdbEngine}, + }), }, inputMatchers: []types.AWSMatcher{ { @@ -70,11 +69,11 @@ func TestDocumentDBFetcher(t *testing.T) { }, { name: "filter by labels", - inputClients: &cloud.TestCloudClients{ - RDS: &mocks.RDSMock{ - DBClusters: []*rds.DBCluster{clusterProd, clusterDev}, - DBEngineVersions: []*rds.DBEngineVersion{docdbEngine}, - }, + fetcherCfg: AWSFetcherFactoryConfig{ + RDSClientProviderFn: newFakeRDSClientProvider(&mocks.RDSClient{ + DBClusters: []rdstypes.DBCluster{*clusterProd, *clusterDev}, + DBEngineVersions: []rdstypes.DBEngineVersion{*docdbEngine}, + }), }, inputMatchers: []types.AWSMatcher{ { @@ -87,11 +86,11 @@ func TestDocumentDBFetcher(t *testing.T) { }, { name: "skip unsupported databases", - inputClients: &cloud.TestCloudClients{ - RDS: &mocks.RDSMock{ - DBClusters: []*rds.DBCluster{clusterProd, clusterNotSupported}, - DBEngineVersions: []*rds.DBEngineVersion{docdbEngine}, - }, + fetcherCfg: AWSFetcherFactoryConfig{ + RDSClientProviderFn: newFakeRDSClientProvider(&mocks.RDSClient{ + DBClusters: []rdstypes.DBCluster{*clusterProd, *clusterNotSupported}, + DBEngineVersions: []rdstypes.DBEngineVersion{*docdbEngine}, + }), }, inputMatchers: []types.AWSMatcher{ { @@ -104,11 +103,11 @@ func TestDocumentDBFetcher(t *testing.T) { }, { name: "skip unavailable databases", - inputClients: &cloud.TestCloudClients{ - RDS: &mocks.RDSMock{ - DBClusters: []*rds.DBCluster{clusterProd, clusterNotAvailable}, - DBEngineVersions: []*rds.DBEngineVersion{docdbEngine}, - }, + fetcherCfg: AWSFetcherFactoryConfig{ + RDSClientProviderFn: newFakeRDSClientProvider(&mocks.RDSClient{ + DBClusters: []rdstypes.DBCluster{*clusterProd, *clusterNotAvailable}, + DBEngineVersions: []rdstypes.DBEngineVersion{*docdbEngine}, + }), }, inputMatchers: []types.AWSMatcher{ { @@ -123,7 +122,7 @@ func TestDocumentDBFetcher(t *testing.T) { testAWSFetchers(t, tests...) } -func mustMakeDocumentDBDatabases(t *testing.T, cluster *rds.DBCluster) types.Databases { +func mustMakeDocumentDBDatabases(t *testing.T, cluster *rdstypes.DBCluster) types.Databases { t.Helper() databases, err := common.NewDatabasesFromDocumentDBCluster(cluster) diff --git a/lib/srv/discovery/fetchers/db/aws_rds.go b/lib/srv/discovery/fetchers/db/aws_rds.go index 639835f2b75a2..2b950534ed552 100644 --- a/lib/srv/discovery/fetchers/db/aws_rds.go +++ b/lib/srv/discovery/fetchers/db/aws_rds.go @@ -23,18 +23,30 @@ import ( "log/slog" "strings" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/rds" - "github.com/aws/aws-sdk-go/service/rds/rdsiface" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/rds" + rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types" "github.com/gravitational/trace" "github.com/gravitational/teleport/api/types" - "github.com/gravitational/teleport/lib/cloud" libcloudaws "github.com/gravitational/teleport/lib/cloud/aws" + "github.com/gravitational/teleport/lib/cloud/awsconfig" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/srv/discovery/common" ) +// RDSClientProviderFunc provides a [RDSClient]. +type RDSClientProviderFunc func(cfg aws.Config, optFns ...func(*rds.Options)) RDSClient + +// RDSClient is a subset of the AWS RDS API. +type RDSClient interface { + rds.DescribeDBClustersAPIClient + rds.DescribeDBInstancesAPIClient + rds.DescribeDBProxiesAPIClient + rds.DescribeDBProxyEndpointsAPIClient + ListTagsForResource(context.Context, *rds.ListTagsForResourceInput, ...func(*rds.Options)) (*rds.ListTagsForResourceOutput, error) +} + // newRDSDBInstancesFetcher returns a new AWS fetcher for RDS databases. func newRDSDBInstancesFetcher(cfg awsFetcherConfig) (common.Fetcher, error) { return newAWSFetcher(cfg, &rdsDBInstancesPlugin{}) @@ -49,40 +61,41 @@ func (f *rdsDBInstancesPlugin) ComponentShortName() string { // GetDatabases returns a list of database resources representing RDS instances. func (f *rdsDBInstancesPlugin) GetDatabases(ctx context.Context, cfg *awsFetcherConfig) (types.Databases, error) { - rdsClient, err := cfg.AWSClients.GetAWSRDSClient(ctx, cfg.Region, - cloud.WithAssumeRole(cfg.AssumeRole.RoleARN, cfg.AssumeRole.ExternalID), - cloud.WithCredentialsMaybeIntegration(cfg.Integration), + awsCfg, err := cfg.AWSConfigProvider.GetConfig(ctx, cfg.Region, + awsconfig.WithAssumeRole(cfg.AssumeRole.RoleARN, cfg.AssumeRole.ExternalID), + awsconfig.WithCredentialsMaybeIntegration(cfg.Integration), ) if err != nil { return nil, trace.Wrap(err) } - instances, err := getAllDBInstances(ctx, rdsClient, maxAWSPages, cfg.Logger) + clt := cfg.rdsClientProviderFn(awsCfg) + instances, err := getAllDBInstances(ctx, clt, maxAWSPages, cfg.Logger) if err != nil { - return nil, trace.Wrap(libcloudaws.ConvertRequestFailureError(err)) + return nil, trace.Wrap(libcloudaws.ConvertRequestFailureErrorV2(err)) } databases := make(types.Databases, 0, len(instances)) for _, instance := range instances { - if !libcloudaws.IsRDSInstanceSupported(instance) { + if !libcloudaws.IsRDSInstanceSupported(&instance) { cfg.Logger.DebugContext(ctx, "Skipping RDS instance that does not support IAM authentication", - "instance", aws.StringValue(instance.DBInstanceIdentifier), - "engine_mode", aws.StringValue(instance.Engine), - "engine_version", aws.StringValue(instance.EngineVersion), + "instance", aws.ToString(instance.DBInstanceIdentifier), + "engine_mode", aws.ToString(instance.Engine), + "engine_version", aws.ToString(instance.EngineVersion), ) continue } if !libcloudaws.IsRDSInstanceAvailable(instance.DBInstanceStatus, instance.DBInstanceIdentifier) { cfg.Logger.DebugContext(ctx, "Skipping unavailable RDS instance", - "instance", aws.StringValue(instance.DBInstanceIdentifier), - "status", aws.StringValue(instance.DBInstanceStatus), + "instance", aws.ToString(instance.DBInstanceIdentifier), + "status", aws.ToString(instance.DBInstanceStatus), ) continue } - database, err := common.NewDatabaseFromRDSInstance(instance) + database, err := common.NewDatabaseFromRDSInstance(&instance) if err != nil { cfg.Logger.WarnContext(ctx, "Could not convert RDS instance to database resource", - "instance", aws.StringValue(instance.DBInstanceIdentifier), + "instance", aws.ToString(instance.DBInstanceIdentifier), "error", err, ) } else { @@ -94,36 +107,40 @@ func (f *rdsDBInstancesPlugin) GetDatabases(ctx context.Context, cfg *awsFetcher // getAllDBInstances fetches all RDS instances using the provided client, up // to the specified max number of pages. -func getAllDBInstances(ctx context.Context, rdsClient rdsiface.RDSAPI, maxPages int, logger *slog.Logger) ([]*rds.DBInstance, error) { - return getAllDBInstancesWithFilters(ctx, rdsClient, maxPages, rdsInstanceEngines(), rdsEmptyFilter(), logger) +func getAllDBInstances(ctx context.Context, clt RDSClient, maxPages int, logger *slog.Logger) ([]rdstypes.DBInstance, error) { + return getAllDBInstancesWithFilters(ctx, clt, maxPages, rdsInstanceEngines(), rdsEmptyFilter(), logger) } // findDBInstancesForDBCluster returns the DBInstances associated with a given DB Cluster Identifier -func findDBInstancesForDBCluster(ctx context.Context, rdsClient rdsiface.RDSAPI, maxPages int, dbClusterIdentifier string, logger *slog.Logger) ([]*rds.DBInstance, error) { - return getAllDBInstancesWithFilters(ctx, rdsClient, maxPages, auroraEngines(), rdsClusterIDFilter(dbClusterIdentifier), logger) +func findDBInstancesForDBCluster(ctx context.Context, clt RDSClient, maxPages int, dbClusterIdentifier string, logger *slog.Logger) ([]rdstypes.DBInstance, error) { + return getAllDBInstancesWithFilters(ctx, clt, maxPages, auroraEngines(), rdsClusterIDFilter(dbClusterIdentifier), logger) } // getAllDBInstancesWithFilters fetches all RDS instances matching the filters using the provided client, up // to the specified max number of pages. -func getAllDBInstancesWithFilters(ctx context.Context, rdsClient rdsiface.RDSAPI, maxPages int, engines []string, baseFilters []*rds.Filter, logger *slog.Logger) ([]*rds.DBInstance, error) { - var instances []*rds.DBInstance - err := retryWithIndividualEngineFilters(ctx, logger, engines, func(engineFilters []*rds.Filter) error { - var pageNum int - var out []*rds.DBInstance - err := rdsClient.DescribeDBInstancesPagesWithContext(ctx, &rds.DescribeDBInstancesInput{ - Filters: append(engineFilters, baseFilters...), - }, func(ddo *rds.DescribeDBInstancesOutput, lastPage bool) bool { - pageNum++ - instances = append(instances, ddo.DBInstances...) - return pageNum <= maxPages - }) - if err == nil { - // only append to instances on nil error, just in case we have to retry. - instances = append(instances, out...) +func getAllDBInstancesWithFilters(ctx context.Context, clt RDSClient, maxPages int, engines []string, baseFilters []rdstypes.Filter, logger *slog.Logger) ([]rdstypes.DBInstance, error) { + var out []rdstypes.DBInstance + err := retryWithIndividualEngineFilters(ctx, logger, engines, func(engineFilters []rdstypes.Filter) error { + pager := rds.NewDescribeDBInstancesPaginator(clt, + &rds.DescribeDBInstancesInput{ + Filters: append(engineFilters, baseFilters...), + }, + func(dcpo *rds.DescribeDBInstancesPaginatorOptions) { + dcpo.StopOnDuplicateToken = true + }, + ) + var instances []rdstypes.DBInstance + for i := 0; i < maxPages && pager.HasMorePages(); i++ { + page, err := pager.NextPage(ctx) + if err != nil { + return trace.Wrap(err) + } + instances = append(instances, page.DBInstances...) } - return trace.Wrap(err) + out = instances + return nil }) - return instances, trace.Wrap(err) + return out, trace.Wrap(err) } // newRDSAuroraClustersFetcher returns a new AWS fetcher for RDS Aurora @@ -141,48 +158,49 @@ func (f *rdsAuroraClustersPlugin) ComponentShortName() string { // GetDatabases returns a list of database resources representing RDS clusters. func (f *rdsAuroraClustersPlugin) GetDatabases(ctx context.Context, cfg *awsFetcherConfig) (types.Databases, error) { - rdsClient, err := cfg.AWSClients.GetAWSRDSClient(ctx, cfg.Region, - cloud.WithAssumeRole(cfg.AssumeRole.RoleARN, cfg.AssumeRole.ExternalID), - cloud.WithCredentialsMaybeIntegration(cfg.Integration), + awsCfg, err := cfg.AWSConfigProvider.GetConfig(ctx, cfg.Region, + awsconfig.WithAssumeRole(cfg.AssumeRole.RoleARN, cfg.AssumeRole.ExternalID), + awsconfig.WithCredentialsMaybeIntegration(cfg.Integration), ) if err != nil { return nil, trace.Wrap(err) } - clusters, err := getAllDBClusters(ctx, rdsClient, maxAWSPages, cfg.Logger) + clt := cfg.rdsClientProviderFn(awsCfg) + clusters, err := getAllDBClusters(ctx, clt, maxAWSPages, cfg.Logger) if err != nil { return nil, trace.Wrap(libcloudaws.ConvertRequestFailureError(err)) } databases := make(types.Databases, 0, len(clusters)) for _, cluster := range clusters { - if !libcloudaws.IsRDSClusterSupported(cluster) { + if !libcloudaws.IsRDSClusterSupported(&cluster) { cfg.Logger.DebugContext(ctx, "Skipping Aurora cluster that does not support IAM authentication", - "cluster", aws.StringValue(cluster.DBClusterIdentifier), - "engine_mode", aws.StringValue(cluster.EngineMode), - "engine_version", aws.StringValue(cluster.EngineVersion), + "cluster", aws.ToString(cluster.DBClusterIdentifier), + "engine_mode", aws.ToString(cluster.EngineMode), + "engine_version", aws.ToString(cluster.EngineVersion), ) continue } if !libcloudaws.IsDBClusterAvailable(cluster.Status, cluster.DBClusterIdentifier) { cfg.Logger.DebugContext(ctx, "Skipping unavailable Aurora cluster", - "instance", aws.StringValue(cluster.DBClusterIdentifier), - "status", aws.StringValue(cluster.Status), + "instance", aws.ToString(cluster.DBClusterIdentifier), + "status", aws.ToString(cluster.Status), ) continue } - rdsDBInstances, err := findDBInstancesForDBCluster(ctx, rdsClient, maxAWSPages, aws.StringValue(cluster.DBClusterIdentifier), cfg.Logger) + rdsDBInstances, err := findDBInstancesForDBCluster(ctx, clt, maxAWSPages, aws.ToString(cluster.DBClusterIdentifier), cfg.Logger) if err != nil || len(rdsDBInstances) == 0 { cfg.Logger.WarnContext(ctx, "Could not fetch Member Instance for DB Cluster", - "instance", aws.StringValue(cluster.DBClusterIdentifier), + "instance", aws.ToString(cluster.DBClusterIdentifier), "error", err, ) } - dbs, err := common.NewDatabasesFromRDSCluster(cluster, rdsDBInstances) + dbs, err := common.NewDatabasesFromRDSCluster(&cluster, rdsDBInstances) if err != nil { cfg.Logger.WarnContext(ctx, "Could not convert RDS cluster to database resources", - "identifier", aws.StringValue(cluster.DBClusterIdentifier), + "identifier", aws.ToString(cluster.DBClusterIdentifier), "error", err, ) } @@ -193,25 +211,30 @@ func (f *rdsAuroraClustersPlugin) GetDatabases(ctx context.Context, cfg *awsFetc // getAllDBClusters fetches all RDS clusters using the provided client, up to // the specified max number of pages. -func getAllDBClusters(ctx context.Context, rdsClient rdsiface.RDSAPI, maxPages int, logger *slog.Logger) ([]*rds.DBCluster, error) { - var clusters []*rds.DBCluster - err := retryWithIndividualEngineFilters(ctx, logger, auroraEngines(), func(filters []*rds.Filter) error { - var pageNum int - var out []*rds.DBCluster - err := rdsClient.DescribeDBClustersPagesWithContext(ctx, &rds.DescribeDBClustersInput{ - Filters: filters, - }, func(ddo *rds.DescribeDBClustersOutput, lastPage bool) bool { - pageNum++ - out = append(out, ddo.DBClusters...) - return pageNum <= maxPages - }) - if err == nil { - // only append to clusters on nil error, just in case we have to retry. - clusters = append(clusters, out...) +func getAllDBClusters(ctx context.Context, clt RDSClient, maxPages int, logger *slog.Logger) ([]rdstypes.DBCluster, error) { + var out []rdstypes.DBCluster + err := retryWithIndividualEngineFilters(ctx, logger, auroraEngines(), func(filters []rdstypes.Filter) error { + pager := rds.NewDescribeDBClustersPaginator(clt, + &rds.DescribeDBClustersInput{ + Filters: filters, + }, + func(pagerOpts *rds.DescribeDBClustersPaginatorOptions) { + pagerOpts.StopOnDuplicateToken = true + }, + ) + + var clusters []rdstypes.DBCluster + for i := 0; i < maxPages && pager.HasMorePages(); i++ { + page, err := pager.NextPage(ctx) + if err != nil { + return trace.Wrap(err) + } + clusters = append(clusters, page.DBClusters...) } - return trace.Wrap(err) + out = clusters + return nil }) - return clusters, trace.Wrap(err) + return out, trace.Wrap(err) } // rdsInstanceEngines returns engines to make sure DescribeDBInstances call returns @@ -234,28 +257,28 @@ func auroraEngines() []string { } // rdsEngineFilter is a helper func to construct an RDS filter for engine names. -func rdsEngineFilter(engines []string) []*rds.Filter { - return []*rds.Filter{{ +func rdsEngineFilter(engines []string) []rdstypes.Filter { + return []rdstypes.Filter{{ Name: aws.String("engine"), - Values: aws.StringSlice(engines), + Values: engines, }} } // rdsClusterIDFilter is a helper func to construct an RDS DB Instances for returning Instances of a specific DB Cluster. -func rdsClusterIDFilter(clusterIdentifier string) []*rds.Filter { - return []*rds.Filter{{ +func rdsClusterIDFilter(clusterIdentifier string) []rdstypes.Filter { + return []rdstypes.Filter{{ Name: aws.String("db-cluster-id"), - Values: aws.StringSlice([]string{clusterIdentifier}), + Values: []string{clusterIdentifier}, }} } // rdsEmptyFilter is a helper func to construct an empty RDS filter. -func rdsEmptyFilter() []*rds.Filter { - return []*rds.Filter{} +func rdsEmptyFilter() []rdstypes.Filter { + return []rdstypes.Filter{} } // rdsFilterFn is a function that takes RDS filters and performs some operation with them, returning any error encountered. -type rdsFilterFn func([]*rds.Filter) error +type rdsFilterFn func([]rdstypes.Filter) error // retryWithIndividualEngineFilters is a helper error handling function for AWS RDS unrecognized engine name filter errors, // that will call the provided RDS querying function with filters, check the returned error, diff --git a/lib/srv/discovery/fetchers/db/aws_rds_proxy.go b/lib/srv/discovery/fetchers/db/aws_rds_proxy.go index dde1a1a189940..4df89ac46b2c0 100644 --- a/lib/srv/discovery/fetchers/db/aws_rds_proxy.go +++ b/lib/srv/discovery/fetchers/db/aws_rds_proxy.go @@ -21,14 +21,14 @@ package db import ( "context" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/rds" - "github.com/aws/aws-sdk-go/service/rds/rdsiface" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/rds" + rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types" "github.com/gravitational/trace" "github.com/gravitational/teleport/api/types" - "github.com/gravitational/teleport/lib/cloud" libcloudaws "github.com/gravitational/teleport/lib/cloud/aws" + "github.com/gravitational/teleport/lib/cloud/awsconfig" "github.com/gravitational/teleport/lib/srv/discovery/common" ) @@ -47,56 +47,57 @@ func (f *rdsDBProxyPlugin) ComponentShortName() string { // GetDatabases returns a list of database resources representing RDS // Proxies and custom endpoints. func (f *rdsDBProxyPlugin) GetDatabases(ctx context.Context, cfg *awsFetcherConfig) (types.Databases, error) { - rdsClient, err := cfg.AWSClients.GetAWSRDSClient(ctx, cfg.Region, - cloud.WithAssumeRole(cfg.AssumeRole.RoleARN, cfg.AssumeRole.ExternalID), - cloud.WithCredentialsMaybeIntegration(cfg.Integration), + awsCfg, err := cfg.AWSConfigProvider.GetConfig(ctx, cfg.Region, + awsconfig.WithAssumeRole(cfg.AssumeRole.RoleARN, cfg.AssumeRole.ExternalID), + awsconfig.WithCredentialsMaybeIntegration(cfg.Integration), ) if err != nil { return nil, trace.Wrap(err) } + clt := cfg.rdsClientProviderFn(awsCfg) // Get a list of all RDS Proxies. Each RDS Proxy has one "default" // endpoint. - rdsProxies, err := getRDSProxies(ctx, rdsClient, maxAWSPages) + rdsProxies, err := getRDSProxies(ctx, clt, maxAWSPages) if err != nil { return nil, trace.Wrap(err) } // Get all RDS Proxy custom endpoints sorted by the name of the RDS Proxy // that owns the custom endpoints. - customEndpointsByProxyName, err := getRDSProxyCustomEndpoints(ctx, rdsClient, maxAWSPages) + customEndpointsByProxyName, err := getRDSProxyCustomEndpoints(ctx, clt, maxAWSPages) if err != nil { cfg.Logger.DebugContext(ctx, "Failed to get RDS Proxy endpoints", "error", err) } var databases types.Databases for _, dbProxy := range rdsProxies { - if !aws.BoolValue(dbProxy.RequireTLS) { - cfg.Logger.DebugContext(ctx, "Skipping RDS Proxy that doesn't support TLS", "rds_proxy", aws.StringValue(dbProxy.DBProxyName)) + if !aws.ToBool(dbProxy.RequireTLS) { + cfg.Logger.DebugContext(ctx, "Skipping RDS Proxy that doesn't support TLS", "rds_proxy", aws.ToString(dbProxy.DBProxyName)) continue } - if !libcloudaws.IsRDSProxyAvailable(dbProxy) { + if !libcloudaws.IsRDSProxyAvailable(&dbProxy) { cfg.Logger.DebugContext(ctx, "Skipping unavailable RDS Proxy", - "rds_proxy", aws.StringValue(dbProxy.DBProxyName), - "status", aws.StringValue(dbProxy.Status)) + "rds_proxy", aws.ToString(dbProxy.DBProxyName), + "status", dbProxy.Status) continue } - // rds.DBProxy has no tags information. An extra SDK call is made to + // rdstypes.DBProxy has no tags information. An extra SDK call is made to // fetch the tags. If failed, keep going without the tags. - tags, err := listRDSResourceTags(ctx, rdsClient, dbProxy.DBProxyArn) + tags, err := listRDSResourceTags(ctx, clt, dbProxy.DBProxyArn) if err != nil { cfg.Logger.DebugContext(ctx, "Failed to get tags for RDS Proxy", - "rds_proxy", aws.StringValue(dbProxy.DBProxyName), + "rds_proxy", aws.ToString(dbProxy.DBProxyName), "error", err, ) } // Add a database from RDS Proxy (default endpoint). - database, err := common.NewDatabaseFromRDSProxy(dbProxy, tags) + database, err := common.NewDatabaseFromRDSProxy(&dbProxy, tags) if err != nil { cfg.Logger.DebugContext(ctx, "Could not convert RDS Proxy to database resource", - "rds_proxy", aws.StringValue(dbProxy.DBProxyName), + "rds_proxy", aws.ToString(dbProxy.DBProxyName), "error", err, ) } else { @@ -104,21 +105,21 @@ func (f *rdsDBProxyPlugin) GetDatabases(ctx context.Context, cfg *awsFetcherConf } // Add custom endpoints. - for _, customEndpoint := range customEndpointsByProxyName[aws.StringValue(dbProxy.DBProxyName)] { - if !libcloudaws.IsRDSProxyCustomEndpointAvailable(customEndpoint) { + for _, customEndpoint := range customEndpointsByProxyName[aws.ToString(dbProxy.DBProxyName)] { + if !libcloudaws.IsRDSProxyCustomEndpointAvailable(&customEndpoint) { cfg.Logger.DebugContext(ctx, "Skipping unavailable custom endpoint of RDS Proxy", - "endpoint", aws.StringValue(customEndpoint.DBProxyEndpointName), - "rds_proxy", aws.StringValue(customEndpoint.DBProxyName), - "status", aws.StringValue(customEndpoint.Status), + "endpoint", aws.ToString(customEndpoint.DBProxyEndpointName), + "rds_proxy", aws.ToString(customEndpoint.DBProxyName), + "status", customEndpoint.Status, ) continue } - database, err = common.NewDatabaseFromRDSProxyCustomEndpoint(dbProxy, customEndpoint, tags) + database, err = common.NewDatabaseFromRDSProxyCustomEndpoint(&dbProxy, &customEndpoint, tags) if err != nil { cfg.Logger.DebugContext(ctx, "Could not convert custom endpoint for RDS Proxy to database resource", - "endpoint", aws.StringValue(customEndpoint.DBProxyEndpointName), - "rds_proxy", aws.StringValue(customEndpoint.DBProxyName), + "endpoint", aws.ToString(customEndpoint.DBProxyEndpointName), + "rds_proxy", aws.ToString(customEndpoint.DBProxyName), "error", err, ) continue @@ -132,42 +133,50 @@ func (f *rdsDBProxyPlugin) GetDatabases(ctx context.Context, cfg *awsFetcherConf // getRDSProxies fetches all RDS Proxies using the provided client, up to the // specified max number of pages. -func getRDSProxies(ctx context.Context, rdsClient rdsiface.RDSAPI, maxPages int) (rdsProxies []*rds.DBProxy, err error) { - var pageNum int - err = rdsClient.DescribeDBProxiesPagesWithContext( - ctx, +func getRDSProxies(ctx context.Context, clt RDSClient, maxPages int) ([]rdstypes.DBProxy, error) { + pager := rds.NewDescribeDBProxiesPaginator(clt, &rds.DescribeDBProxiesInput{}, - func(ddo *rds.DescribeDBProxiesOutput, lastPage bool) bool { - pageNum++ - rdsProxies = append(rdsProxies, ddo.DBProxies...) - return pageNum <= maxPages + func(dcpo *rds.DescribeDBProxiesPaginatorOptions) { + dcpo.StopOnDuplicateToken = true }, ) - return rdsProxies, trace.Wrap(libcloudaws.ConvertRequestFailureError(err)) + + var rdsProxies []rdstypes.DBProxy + for i := 0; i < maxPages && pager.HasMorePages(); i++ { + page, err := pager.NextPage(ctx) + if err != nil { + return nil, trace.Wrap(libcloudaws.ConvertRequestFailureErrorV2(err)) + } + rdsProxies = append(rdsProxies, page.DBProxies...) + } + return rdsProxies, nil } // getRDSProxyCustomEndpoints fetches all RDS Proxy custom endpoints using the // provided client. -func getRDSProxyCustomEndpoints(ctx context.Context, rdsClient rdsiface.RDSAPI, maxPages int) (map[string][]*rds.DBProxyEndpoint, error) { - customEndpointsByProxyName := make(map[string][]*rds.DBProxyEndpoint) - var pageNum int - err := rdsClient.DescribeDBProxyEndpointsPagesWithContext( - ctx, +func getRDSProxyCustomEndpoints(ctx context.Context, clt RDSClient, maxPages int) (map[string][]rdstypes.DBProxyEndpoint, error) { + customEndpointsByProxyName := make(map[string][]rdstypes.DBProxyEndpoint) + pager := rds.NewDescribeDBProxyEndpointsPaginator(clt, &rds.DescribeDBProxyEndpointsInput{}, - func(ddo *rds.DescribeDBProxyEndpointsOutput, lastPage bool) bool { - pageNum++ - for _, customEndpoint := range ddo.DBProxyEndpoints { - customEndpointsByProxyName[aws.StringValue(customEndpoint.DBProxyName)] = append(customEndpointsByProxyName[aws.StringValue(customEndpoint.DBProxyName)], customEndpoint) - } - return pageNum <= maxPages + func(ddepo *rds.DescribeDBProxyEndpointsPaginatorOptions) { + ddepo.StopOnDuplicateToken = true }, ) - return customEndpointsByProxyName, trace.Wrap(libcloudaws.ConvertRequestFailureError(err)) + for i := 0; i < maxPages && pager.HasMorePages(); i++ { + page, err := pager.NextPage(ctx) + if err != nil { + return nil, trace.Wrap(libcloudaws.ConvertRequestFailureErrorV2(err)) + } + for _, customEndpoint := range page.DBProxyEndpoints { + customEndpointsByProxyName[aws.ToString(customEndpoint.DBProxyName)] = append(customEndpointsByProxyName[aws.ToString(customEndpoint.DBProxyName)], customEndpoint) + } + } + return customEndpointsByProxyName, nil } // listRDSResourceTags returns tags for provided RDS resource. -func listRDSResourceTags(ctx context.Context, rdsClient rdsiface.RDSAPI, resourceName *string) ([]*rds.Tag, error) { - output, err := rdsClient.ListTagsForResourceWithContext(ctx, &rds.ListTagsForResourceInput{ +func listRDSResourceTags(ctx context.Context, clt RDSClient, resourceName *string) ([]rdstypes.Tag, error) { + output, err := clt.ListTagsForResource(ctx, &rds.ListTagsForResourceInput{ ResourceName: resourceName, }) if err != nil { diff --git a/lib/srv/discovery/fetchers/db/aws_rds_proxy_test.go b/lib/srv/discovery/fetchers/db/aws_rds_proxy_test.go index b92ff2a439eda..b78a6469b902f 100644 --- a/lib/srv/discovery/fetchers/db/aws_rds_proxy_test.go +++ b/lib/srv/discovery/fetchers/db/aws_rds_proxy_test.go @@ -21,11 +21,10 @@ package db import ( "testing" - "github.com/aws/aws-sdk-go/service/rds" + rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types" "github.com/stretchr/testify/require" "github.com/gravitational/teleport/api/types" - "github.com/gravitational/teleport/lib/cloud" "github.com/gravitational/teleport/lib/cloud/mocks" "github.com/gravitational/teleport/lib/srv/discovery/common" ) @@ -41,22 +40,22 @@ func TestRDSDBProxyFetcher(t *testing.T) { tests := []awsFetcherTest{ { name: "fetch all", - inputClients: &cloud.TestCloudClients{ - RDS: &mocks.RDSMock{ - DBProxies: []*rds.DBProxy{rdsProxyVpc1, rdsProxyVpc2}, - DBProxyEndpoints: []*rds.DBProxyEndpoint{rdsProxyEndpointVpc1, rdsProxyEndpointVpc2}, - }, + fetcherCfg: AWSFetcherFactoryConfig{ + RDSClientProviderFn: newFakeRDSClientProvider(&mocks.RDSClient{ + DBProxies: []rdstypes.DBProxy{*rdsProxyVpc1, *rdsProxyVpc2}, + DBProxyEndpoints: []rdstypes.DBProxyEndpoint{*rdsProxyEndpointVpc1, *rdsProxyEndpointVpc2}, + }), }, inputMatchers: makeAWSMatchersForType(types.AWSMatcherRDSProxy, "us-east-1", wildcardLabels), wantDatabases: types.Databases{rdsProxyDatabaseVpc1, rdsProxyDatabaseVpc2, rdsProxyEndpointDatabaseVpc1, rdsProxyEndpointDatabaseVpc2}, }, { name: "fetch vpc1", - inputClients: &cloud.TestCloudClients{ - RDS: &mocks.RDSMock{ - DBProxies: []*rds.DBProxy{rdsProxyVpc1, rdsProxyVpc2}, - DBProxyEndpoints: []*rds.DBProxyEndpoint{rdsProxyEndpointVpc1, rdsProxyEndpointVpc2}, - }, + fetcherCfg: AWSFetcherFactoryConfig{ + RDSClientProviderFn: newFakeRDSClientProvider(&mocks.RDSClient{ + DBProxies: []rdstypes.DBProxy{*rdsProxyVpc1, *rdsProxyVpc2}, + DBProxyEndpoints: []rdstypes.DBProxyEndpoint{*rdsProxyEndpointVpc1, *rdsProxyEndpointVpc2}, + }), }, inputMatchers: makeAWSMatchersForType(types.AWSMatcherRDSProxy, "us-east-1", map[string]string{"vpc-id": "vpc1"}), wantDatabases: types.Databases{rdsProxyDatabaseVpc1, rdsProxyEndpointDatabaseVpc1}, @@ -65,7 +64,7 @@ func TestRDSDBProxyFetcher(t *testing.T) { testAWSFetchers(t, tests...) } -func makeRDSProxy(t *testing.T, name, region, vpcID string) (*rds.DBProxy, types.Database) { +func makeRDSProxy(t *testing.T, name, region, vpcID string) (*rdstypes.DBProxy, types.Database) { rdsProxy := mocks.RDSProxy(name, region, vpcID) rdsProxyDatabase, err := common.NewDatabaseFromRDSProxy(rdsProxy, nil) require.NoError(t, err) @@ -73,7 +72,7 @@ func makeRDSProxy(t *testing.T, name, region, vpcID string) (*rds.DBProxy, types return rdsProxy, rdsProxyDatabase } -func makeRDSProxyCustomEndpoint(t *testing.T, rdsProxy *rds.DBProxy, name, region string) (*rds.DBProxyEndpoint, types.Database) { +func makeRDSProxyCustomEndpoint(t *testing.T, rdsProxy *rdstypes.DBProxy, name, region string) (*rdstypes.DBProxyEndpoint, types.Database) { rdsProxyEndpoint := mocks.RDSProxyCustomEndpoint(rdsProxy, name, region) rdsProxyEndpointDatabase, err := common.NewDatabaseFromRDSProxyCustomEndpoint(rdsProxy, rdsProxyEndpoint, nil) require.NoError(t, err) diff --git a/lib/srv/discovery/fetchers/db/aws_rds_test.go b/lib/srv/discovery/fetchers/db/aws_rds_test.go index 9dfc658268eeb..f72d79d6f51e1 100644 --- a/lib/srv/discovery/fetchers/db/aws_rds_test.go +++ b/lib/srv/discovery/fetchers/db/aws_rds_test.go @@ -21,13 +21,12 @@ package db import ( "testing" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/rds" - "github.com/aws/aws-sdk-go/service/rds/rdsiface" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/rds" + rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types" "github.com/stretchr/testify/require" "github.com/gravitational/teleport/api/types" - "github.com/gravitational/teleport/lib/cloud" "github.com/gravitational/teleport/lib/cloud/mocks" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/srv/discovery/common" @@ -38,8 +37,8 @@ import ( func TestRDSFetchers(t *testing.T) { t.Parallel() - auroraMySQLEngine := &rds.DBEngineVersion{Engine: aws.String(services.RDSEngineAuroraMySQL)} - postgresEngine := &rds.DBEngineVersion{Engine: aws.String(services.RDSEnginePostgres)} + auroraMySQLEngine := &rdstypes.DBEngineVersion{Engine: aws.String(services.RDSEngineAuroraMySQL)} + postgresEngine := &rdstypes.DBEngineVersion{Engine: aws.String(services.RDSEnginePostgres)} rdsInstance1, rdsDatabase1 := makeRDSInstance(t, "instance-1", "us-east-1", envProdLabels) rdsInstance2, rdsDatabase2 := makeRDSInstance(t, "instance-2", "us-east-2", envProdLabels) @@ -58,19 +57,19 @@ func TestRDSFetchers(t *testing.T) { tests := []awsFetcherTest{ { name: "fetch all", - inputClients: &cloud.TestCloudClients{ - RDSPerRegion: map[string]rdsiface.RDSAPI{ - "us-east-1": &mocks.RDSMock{ - DBInstances: []*rds.DBInstance{rdsInstance1, rdsInstance3, auroraCluster1MemberInstance}, - DBClusters: []*rds.DBCluster{auroraCluster1}, - DBEngineVersions: []*rds.DBEngineVersion{auroraMySQLEngine, postgresEngine}, + fetcherCfg: AWSFetcherFactoryConfig{ + RDSClientProviderFn: newRegionalFakeRDSClientProvider(map[string]RDSClient{ + "us-east-1": &mocks.RDSClient{ + DBInstances: []rdstypes.DBInstance{*rdsInstance1, *rdsInstance3, *auroraCluster1MemberInstance}, + DBClusters: []rdstypes.DBCluster{*auroraCluster1}, + DBEngineVersions: []rdstypes.DBEngineVersion{*auroraMySQLEngine, *postgresEngine}, }, - "us-east-2": &mocks.RDSMock{ - DBInstances: []*rds.DBInstance{rdsInstance2, auroraCluster2MemberInstance, auroraCluster3MemberInstance}, - DBClusters: []*rds.DBCluster{auroraCluster2, auroraCluster3}, - DBEngineVersions: []*rds.DBEngineVersion{auroraMySQLEngine, postgresEngine}, + "us-east-2": &mocks.RDSClient{ + DBInstances: []rdstypes.DBInstance{*rdsInstance2, *auroraCluster2MemberInstance, *auroraCluster3MemberInstance}, + DBClusters: []rdstypes.DBCluster{*auroraCluster2, *auroraCluster3}, + DBEngineVersions: []rdstypes.DBEngineVersion{*auroraMySQLEngine, *postgresEngine}, }, - }, + }), }, inputMatchers: []types.AWSMatcher{ { @@ -91,19 +90,19 @@ func TestRDSFetchers(t *testing.T) { }, { name: "fetch different labels for different regions", - inputClients: &cloud.TestCloudClients{ - RDSPerRegion: map[string]rdsiface.RDSAPI{ - "us-east-1": &mocks.RDSMock{ - DBInstances: []*rds.DBInstance{rdsInstance1, rdsInstance3, auroraCluster1MemberInstance}, - DBClusters: []*rds.DBCluster{auroraCluster1}, - DBEngineVersions: []*rds.DBEngineVersion{auroraMySQLEngine, postgresEngine}, + fetcherCfg: AWSFetcherFactoryConfig{ + RDSClientProviderFn: newRegionalFakeRDSClientProvider(map[string]RDSClient{ + "us-east-1": &mocks.RDSClient{ + DBInstances: []rdstypes.DBInstance{*rdsInstance1, *rdsInstance3, *auroraCluster1MemberInstance}, + DBClusters: []rdstypes.DBCluster{*auroraCluster1}, + DBEngineVersions: []rdstypes.DBEngineVersion{*auroraMySQLEngine, *postgresEngine}, }, - "us-east-2": &mocks.RDSMock{ - DBInstances: []*rds.DBInstance{rdsInstance2, auroraCluster2MemberInstance, auroraCluster3MemberInstance}, - DBClusters: []*rds.DBCluster{auroraCluster2, auroraCluster3}, - DBEngineVersions: []*rds.DBEngineVersion{auroraMySQLEngine, postgresEngine}, + "us-east-2": &mocks.RDSClient{ + DBInstances: []rdstypes.DBInstance{*rdsInstance2, *auroraCluster2MemberInstance, *auroraCluster3MemberInstance}, + DBClusters: []rdstypes.DBCluster{*auroraCluster2, *auroraCluster3}, + DBEngineVersions: []rdstypes.DBEngineVersion{*auroraMySQLEngine, *postgresEngine}, }, - }, + }), }, inputMatchers: []types.AWSMatcher{ { @@ -124,19 +123,19 @@ func TestRDSFetchers(t *testing.T) { }, { name: "skip unrecognized engines", - inputClients: &cloud.TestCloudClients{ - RDSPerRegion: map[string]rdsiface.RDSAPI{ - "us-east-1": &mocks.RDSMock{ - DBInstances: []*rds.DBInstance{rdsInstance1, rdsInstance3, auroraCluster1MemberInstance}, - DBClusters: []*rds.DBCluster{auroraCluster1}, - DBEngineVersions: []*rds.DBEngineVersion{auroraMySQLEngine}, + fetcherCfg: AWSFetcherFactoryConfig{ + RDSClientProviderFn: newRegionalFakeRDSClientProvider(map[string]RDSClient{ + "us-east-1": &mocks.RDSClient{ + DBInstances: []rdstypes.DBInstance{*rdsInstance1, *rdsInstance3, *auroraCluster1MemberInstance}, + DBClusters: []rdstypes.DBCluster{*auroraCluster1}, + DBEngineVersions: []rdstypes.DBEngineVersion{*auroraMySQLEngine}, }, - "us-east-2": &mocks.RDSMock{ - DBInstances: []*rds.DBInstance{rdsInstance2, auroraCluster2MemberInstance, auroraCluster3MemberInstance}, - DBClusters: []*rds.DBCluster{auroraCluster2, auroraCluster3}, - DBEngineVersions: []*rds.DBEngineVersion{postgresEngine}, + "us-east-2": &mocks.RDSClient{ + DBInstances: []rdstypes.DBInstance{*rdsInstance2, *auroraCluster2MemberInstance, *auroraCluster3MemberInstance}, + DBClusters: []rdstypes.DBCluster{*auroraCluster2, *auroraCluster3}, + DBEngineVersions: []rdstypes.DBEngineVersion{*postgresEngine}, }, - }, + }), }, inputMatchers: []types.AWSMatcher{ { @@ -154,14 +153,14 @@ func TestRDSFetchers(t *testing.T) { }, { name: "skip unsupported databases", - inputClients: &cloud.TestCloudClients{ - RDSPerRegion: map[string]rdsiface.RDSAPI{ - "us-east-1": &mocks.RDSMock{ - DBInstances: []*rds.DBInstance{auroraCluster1MemberInstance}, - DBClusters: []*rds.DBCluster{auroraCluster1, auroraClusterUnsupported}, - DBEngineVersions: []*rds.DBEngineVersion{auroraMySQLEngine}, + fetcherCfg: AWSFetcherFactoryConfig{ + RDSClientProviderFn: newRegionalFakeRDSClientProvider(map[string]RDSClient{ + "us-east-1": &mocks.RDSClient{ + DBInstances: []rdstypes.DBInstance{*auroraCluster1MemberInstance}, + DBClusters: []rdstypes.DBCluster{*auroraCluster1, *auroraClusterUnsupported}, + DBEngineVersions: []rdstypes.DBEngineVersion{*auroraMySQLEngine}, }, - }, + }), }, inputMatchers: []types.AWSMatcher{{ Types: []string{types.AWSMatcherRDS}, @@ -172,12 +171,12 @@ func TestRDSFetchers(t *testing.T) { }, { name: "skip unavailable databases", - inputClients: &cloud.TestCloudClients{ - RDS: &mocks.RDSMock{ - DBInstances: []*rds.DBInstance{rdsInstance1, rdsInstanceUnavailable, rdsInstanceUnknownStatus, auroraCluster1MemberInstance, auroraClusterUnknownStatusMemberInstance}, - DBClusters: []*rds.DBCluster{auroraCluster1, auroraClusterUnavailable, auroraClusterUnknownStatus}, - DBEngineVersions: []*rds.DBEngineVersion{auroraMySQLEngine, postgresEngine}, - }, + fetcherCfg: AWSFetcherFactoryConfig{ + RDSClientProviderFn: newFakeRDSClientProvider(&mocks.RDSClient{ + DBInstances: []rdstypes.DBInstance{*rdsInstance1, *rdsInstanceUnavailable, *rdsInstanceUnknownStatus, *auroraCluster1MemberInstance, *auroraClusterUnknownStatusMemberInstance}, + DBClusters: []rdstypes.DBCluster{*auroraCluster1, *auroraClusterUnavailable, *auroraClusterUnknownStatus}, + DBEngineVersions: []rdstypes.DBEngineVersion{*auroraMySQLEngine, *postgresEngine}, + }), }, inputMatchers: []types.AWSMatcher{{ Types: []string{types.AWSMatcherRDS}, @@ -188,12 +187,12 @@ func TestRDSFetchers(t *testing.T) { }, { name: "Aurora cluster without writer", - inputClients: &cloud.TestCloudClients{ - RDS: &mocks.RDSMock{ - DBClusters: []*rds.DBCluster{auroraClusterNoWriter}, - DBInstances: []*rds.DBInstance{auroraClusterMemberNoWriter}, - DBEngineVersions: []*rds.DBEngineVersion{auroraMySQLEngine}, - }, + fetcherCfg: AWSFetcherFactoryConfig{ + RDSClientProviderFn: newFakeRDSClientProvider(&mocks.RDSClient{ + DBClusters: []rdstypes.DBCluster{*auroraClusterNoWriter}, + DBInstances: []rdstypes.DBInstance{*auroraClusterMemberNoWriter}, + DBEngineVersions: []rdstypes.DBEngineVersion{*auroraMySQLEngine}, + }), }, inputMatchers: []types.AWSMatcher{{ Types: []string{types.AWSMatcherRDS}, @@ -206,7 +205,7 @@ func TestRDSFetchers(t *testing.T) { testAWSFetchers(t, tests...) } -func makeRDSInstance(t *testing.T, name, region string, labels map[string]string, opts ...func(*rds.DBInstance)) (*rds.DBInstance, types.Database) { +func makeRDSInstance(t *testing.T, name, region string, labels map[string]string, opts ...func(*rdstypes.DBInstance)) (*rdstypes.DBInstance, types.Database) { instance := mocks.RDSInstance(name, region, labels, opts...) database, err := common.NewDatabaseFromRDSInstance(instance) require.NoError(t, err) @@ -214,21 +213,21 @@ func makeRDSInstance(t *testing.T, name, region string, labels map[string]string return instance, database } -func makeRDSCluster(t *testing.T, name, region string, labels map[string]string, opts ...func(*rds.DBCluster)) (*rds.DBCluster, *rds.DBInstance, types.Database) { +func makeRDSCluster(t *testing.T, name, region string, labels map[string]string, opts ...func(*rdstypes.DBCluster)) (*rdstypes.DBCluster, *rdstypes.DBInstance, types.Database) { cluster := mocks.RDSCluster(name, region, labels, opts...) dbInstanceMember := makeRDSMemberForCluster(t, name, region, "vpc-123", *cluster.Engine, labels) - database, err := common.NewDatabaseFromRDSCluster(cluster, []*rds.DBInstance{dbInstanceMember}) + database, err := common.NewDatabaseFromRDSCluster(cluster, []rdstypes.DBInstance{*dbInstanceMember}) require.NoError(t, err) common.ApplyAWSDatabaseNameSuffix(database, types.AWSMatcherRDS) return cluster, dbInstanceMember, database } -func makeRDSMemberForCluster(t *testing.T, name, region, vpcid, engine string, labels map[string]string) *rds.DBInstance { - instanceRDSMember, _ := makeRDSInstance(t, name+"-instance-1", region, labels, func(d *rds.DBInstance) { +func makeRDSMemberForCluster(t *testing.T, name, region, vpcid, engine string, labels map[string]string) *rdstypes.DBInstance { + instanceRDSMember, _ := makeRDSInstance(t, name+"-instance-1", region, labels, func(d *rdstypes.DBInstance) { if d.DBSubnetGroup == nil { - d.DBSubnetGroup = &rds.DBSubnetGroup{} + d.DBSubnetGroup = &rdstypes.DBSubnetGroup{} } - d.DBSubnetGroup.SetVpcId(vpcid) + d.DBSubnetGroup.VpcId = aws.String(vpcid) d.DBClusterIdentifier = aws.String(name) d.Engine = aws.String(engine) }) @@ -236,9 +235,9 @@ func makeRDSMemberForCluster(t *testing.T, name, region, vpcid, engine string, l return instanceRDSMember } -func makeRDSClusterWithExtraEndpoints(t *testing.T, name, region string, labels map[string]string, hasWriter bool) (*rds.DBCluster, *rds.DBInstance, types.Databases) { +func makeRDSClusterWithExtraEndpoints(t *testing.T, name, region string, labels map[string]string, hasWriter bool) (*rdstypes.DBCluster, *rdstypes.DBInstance, types.Databases) { cluster := mocks.RDSCluster(name, region, labels, - func(cluster *rds.DBCluster) { + func(cluster *rdstypes.DBCluster) { // Disable writer by default. If hasWriter, writer endpoint will be added below. cluster.DBClusterMembers = nil }, @@ -249,11 +248,11 @@ func makeRDSClusterWithExtraEndpoints(t *testing.T, name, region string, labels var databases types.Databases - instanceRDSMember := makeRDSMemberForCluster(t, name, region, "vpc-123", aws.StringValue(cluster.Engine), labels) - dbInstanceMembers := []*rds.DBInstance{instanceRDSMember} + instanceRDSMember := makeRDSMemberForCluster(t, name, region, "vpc-123", aws.ToString(cluster.Engine), labels) + dbInstanceMembers := []rdstypes.DBInstance{*instanceRDSMember} if hasWriter { - cluster.DBClusterMembers = append(cluster.DBClusterMembers, &rds.DBClusterMember{ + cluster.DBClusterMembers = append(cluster.DBClusterMembers, rdstypes.DBClusterMember{ IsClusterWriter: aws.Bool(true), // Add writer. }) @@ -277,22 +276,35 @@ func makeRDSClusterWithExtraEndpoints(t *testing.T, name, region string, labels } // withRDSInstanceStatus returns an option function for makeRDSInstance to overwrite status. -func withRDSInstanceStatus(status string) func(*rds.DBInstance) { - return func(instance *rds.DBInstance) { +func withRDSInstanceStatus(status string) func(*rdstypes.DBInstance) { + return func(instance *rdstypes.DBInstance) { instance.DBInstanceStatus = aws.String(status) } } // withRDSClusterEngineMode returns an option function for makeRDSCluster to overwrite engine mode. -func withRDSClusterEngineMode(mode string) func(*rds.DBCluster) { - return func(cluster *rds.DBCluster) { +func withRDSClusterEngineMode(mode string) func(*rdstypes.DBCluster) { + return func(cluster *rdstypes.DBCluster) { cluster.EngineMode = aws.String(mode) } } // withRDSClusterStatus returns an option function for makeRDSCluster to overwrite status. -func withRDSClusterStatus(status string) func(*rds.DBCluster) { - return func(cluster *rds.DBCluster) { +func withRDSClusterStatus(status string) func(*rdstypes.DBCluster) { + return func(cluster *rdstypes.DBCluster) { cluster.Status = aws.String(status) } } + +func newFakeRDSClientProvider(c RDSClient) RDSClientProviderFunc { + return func(cfg aws.Config, optFns ...func(*rds.Options)) RDSClient { + return c + } +} + +// provides a client specific to each region, where the map keys are regions. +func newRegionalFakeRDSClientProvider(cs map[string]RDSClient) RDSClientProviderFunc { + return func(cfg aws.Config, optFns ...func(*rds.Options)) RDSClient { + return cs[cfg.Region] + } +} diff --git a/lib/srv/discovery/fetchers/db/aws_redshift_test.go b/lib/srv/discovery/fetchers/db/aws_redshift_test.go index ded47035e96e3..9e7a2b2ab20f5 100644 --- a/lib/srv/discovery/fetchers/db/aws_redshift_test.go +++ b/lib/srv/discovery/fetchers/db/aws_redshift_test.go @@ -36,6 +36,7 @@ func newFakeRedshiftClientProvider(c RedshiftClient) RedshiftClientProviderFunc return c } } + func TestRedshiftFetcher(t *testing.T) { t.Parallel() diff --git a/lib/srv/discovery/fetchers/db/db.go b/lib/srv/discovery/fetchers/db/db.go index 8d79bc2bb65bc..1c397d976ec7b 100644 --- a/lib/srv/discovery/fetchers/db/db.go +++ b/lib/srv/discovery/fetchers/db/db.go @@ -23,6 +23,7 @@ import ( "log/slog" "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/rds" "github.com/aws/aws-sdk-go-v2/service/redshift" "github.com/gravitational/trace" "golang.org/x/exp/maps" @@ -73,8 +74,10 @@ type AWSFetcherFactoryConfig struct { AWSConfigProvider awsconfig.Provider // CloudClients is an interface for retrieving AWS SDK v1 cloud clients. CloudClients cloud.AWSClients - // RedshiftClientProviderFn is an optional function that provides + // RedshiftClientProviderFn is an optional function that provides an AWS Redshift client. RedshiftClientProviderFn RedshiftClientProviderFunc + // RDSClientProviderFn is an optional function that provides an AWS RDS API client. + RDSClientProviderFn func(cfg aws.Config, optFns ...func(*rds.Options)) RDSClient } func (c *AWSFetcherFactoryConfig) checkAndSetDefaults() error { @@ -89,6 +92,11 @@ func (c *AWSFetcherFactoryConfig) checkAndSetDefaults() error { return redshift.NewFromConfig(cfg, optFns...) } } + if c.RDSClientProviderFn == nil { + c.RDSClientProviderFn = func(cfg aws.Config, optFns ...func(*rds.Options)) RDSClient { + return rds.NewFromConfig(cfg, optFns...) + } + } return nil } @@ -134,6 +142,7 @@ func (f *AWSFetcherFactory) MakeFetchers(ctx context.Context, matchers []types.A DiscoveryConfigName: discoveryConfigName, AWSConfigProvider: f.cfg.AWSConfigProvider, redshiftClientProviderFn: f.cfg.RedshiftClientProviderFn, + rdsClientProviderFn: f.cfg.RDSClientProviderFn, }) if err != nil { return nil, trace.Wrap(err)