Skip to content

Commit

Permalink
Cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
creack committed Jan 5, 2025
1 parent 6852c7f commit 7174944
Show file tree
Hide file tree
Showing 11 changed files with 77 additions and 161 deletions.
40 changes: 15 additions & 25 deletions lib/cloud/clients.go
Original file line number Diff line number Diff line change
Expand Up @@ -594,15 +594,6 @@ func (c *cloudClients) GetAWSSTSClient(ctx context.Context, region string, opts
return sts.New(session), nil
}

// // GetAWSEKSClient returns AWS EKS client for the specified region.
// func (c *cloudClients) GetAWSEKSClient(ctx context.Context, region string, opts ...AWSOptionsFn) (eksiface.EKSAPI, error) {
// session, err := c.GetAWSSession(ctx, region, opts...)
// if err != nil {
// return nil, trace.Wrap(err)
// }
// return eks.New(session), nil
// }

// GetAWSKMSClient returns AWS KMS client for the specified region.
func (c *cloudClients) GetAWSKMSClient(ctx context.Context, region string, opts ...AWSOptionsFn) (kmsiface.KMSAPI, error) {
session, err := c.GetAWSSession(ctx, region, opts...)
Expand Down Expand Up @@ -1027,22 +1018,21 @@ var _ Clients = (*TestCloudClients)(nil)

// TestCloudClients are used in tests.
type TestCloudClients struct {
RDS rdsiface.RDSAPI
RDSPerRegion map[string]rdsiface.RDSAPI
Redshift redshiftiface.RedshiftAPI
RedshiftServerless redshiftserverlessiface.RedshiftServerlessAPI
ElastiCache elasticacheiface.ElastiCacheAPI
OpenSearch opensearchserviceiface.OpenSearchServiceAPI
MemoryDB memorydbiface.MemoryDBAPI
SecretsManager secretsmanageriface.SecretsManagerAPI
IAM iamiface.IAMAPI
STS stsiface.STSAPI
GCPSQL gcp.SQLAdminClient
GCPGKE gcp.GKEClient
GCPProjects gcp.ProjectsClient
GCPInstances gcp.InstancesClient
InstanceMetadata imds.Client
// EKS eksiface.EKSAPI
RDS rdsiface.RDSAPI
RDSPerRegion map[string]rdsiface.RDSAPI
Redshift redshiftiface.RedshiftAPI
RedshiftServerless redshiftserverlessiface.RedshiftServerlessAPI
ElastiCache elasticacheiface.ElastiCacheAPI
OpenSearch opensearchserviceiface.OpenSearchServiceAPI
MemoryDB memorydbiface.MemoryDBAPI
SecretsManager secretsmanageriface.SecretsManagerAPI
IAM iamiface.IAMAPI
STS stsiface.STSAPI
GCPSQL gcp.SQLAdminClient
GCPGKE gcp.GKEClient
GCPProjects gcp.ProjectsClient
GCPInstances gcp.InstancesClient
InstanceMetadata imds.Client
KMS kmsiface.KMSAPI
S3 s3iface.S3API
AzureMySQL azure.DBServersClient
Expand Down
100 changes: 9 additions & 91 deletions lib/cloud/mocks/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ import (
"github.com/gravitational/trace"
)

// STSClientV1 mocks AWS STS API.
type STSClientV1 struct {
// STSMock mocks AWS STS API.
type STSMock struct {
stsiface.STSAPI
ARN string
URL *url.URL
Expand All @@ -45,36 +45,36 @@ type STSClientV1 struct {
mu sync.Mutex
}

func (m *STSClientV1) GetAssumedRoleARNs() []string {
func (m *STSMock) GetAssumedRoleARNs() []string {
m.mu.Lock()
defer m.mu.Unlock()
return m.assumedRoleARNs
}

func (m *STSClientV1) GetAssumedRoleExternalIDs() []string {
func (m *STSMock) GetAssumedRoleExternalIDs() []string {
m.mu.Lock()
defer m.mu.Unlock()
return m.assumedRoleExternalIDs
}

func (m *STSClientV1) ResetAssumeRoleHistory() {
func (m *STSMock) ResetAssumeRoleHistory() {
m.mu.Lock()
defer m.mu.Unlock()
m.assumedRoleARNs = nil
m.assumedRoleExternalIDs = nil
}

func (m *STSClientV1) GetCallerIdentityWithContext(aws.Context, *sts.GetCallerIdentityInput, ...request.Option) (*sts.GetCallerIdentityOutput, error) {
func (m *STSMock) GetCallerIdentityWithContext(aws.Context, *sts.GetCallerIdentityInput, ...request.Option) (*sts.GetCallerIdentityOutput, error) {
return &sts.GetCallerIdentityOutput{
Arn: aws.String(m.ARN),
}, nil
}

func (m *STSClientV1) AssumeRole(in *sts.AssumeRoleInput) (*sts.AssumeRoleOutput, error) {
func (m *STSMock) AssumeRole(in *sts.AssumeRoleInput) (*sts.AssumeRoleOutput, error) {
return m.AssumeRoleWithContext(context.Background(), in)
}

func (m *STSClientV1) AssumeRoleWithContext(ctx aws.Context, in *sts.AssumeRoleInput, _ ...request.Option) (*sts.AssumeRoleOutput, error) {
func (m *STSMock) AssumeRoleWithContext(ctx aws.Context, in *sts.AssumeRoleInput, _ ...request.Option) (*sts.AssumeRoleOutput, error) {
m.mu.Lock()
defer m.mu.Unlock()
if !slices.Contains(m.assumedRoleARNs, aws.StringValue(in.RoleArn)) {
Expand All @@ -92,7 +92,7 @@ func (m *STSClientV1) AssumeRoleWithContext(ctx aws.Context, in *sts.AssumeRoleI
}, nil
}

func (m *STSClientV1) GetCallerIdentityRequest(req *sts.GetCallerIdentityInput) (*request.Request, *sts.GetCallerIdentityOutput) {
func (m *STSMock) GetCallerIdentityRequest(req *sts.GetCallerIdentityInput) (*request.Request, *sts.GetCallerIdentityOutput) {
return &request.Request{
HTTPRequest: &http.Request{
Header: http.Header{},
Expand Down Expand Up @@ -286,85 +286,3 @@ func (m *IAMErrorMock) PutUserPolicyWithContext(ctx aws.Context, input *iam.PutU
}
return nil, trace.AccessDenied("unauthorized")
}

// // EKSMock is a mock EKS client.
// type EKSMock struct {
// eksiface.EKSAPI
// Clusters []*eks.Cluster
// AccessEntries []*eks.AccessEntry
// AssociatedPolicies []*eks.AssociatedAccessPolicy
// Notify chan struct{}
// }

// func (e *EKSMock) DescribeClusterWithContext(_ aws.Context, req *eks.DescribeClusterInput, _ ...request.Option) (*eks.DescribeClusterOutput, error) {
// defer func() {
// if e.Notify != nil {
// e.Notify <- struct{}{}
// }
// }()
// for _, cluster := range e.Clusters {
// if aws.StringValue(req.Name) == aws.StringValue(cluster.Name) {
// return &eks.DescribeClusterOutput{Cluster: cluster}, nil
// }
// }
// return nil, trace.NotFound("cluster %v not found", aws.StringValue(req.Name))
// }

// func (e *EKSMock) ListClustersPagesWithContext(_ aws.Context, _ *eks.ListClustersInput, f func(*eks.ListClustersOutput, bool) bool, _ ...request.Option) error {
// defer func() {
// if e.Notify != nil {
// e.Notify <- struct{}{}
// }
// }()
// clusters := make([]*string, 0, len(e.Clusters))
// for _, cluster := range e.Clusters {
// clusters = append(clusters, cluster.Name)
// }
// f(&eks.ListClustersOutput{
// Clusters: clusters,
// }, true)
// return nil
// }

// func (e *EKSMock) ListAccessEntriesPagesWithContext(_ aws.Context, _ *eks.ListAccessEntriesInput, f func(*eks.ListAccessEntriesOutput, bool) bool, _ ...request.Option) error {
// defer func() {
// if e.Notify != nil {
// e.Notify <- struct{}{}
// }
// }()
// accessEntries := make([]*string, 0, len(e.Clusters))
// for _, a := range e.AccessEntries {
// accessEntries = append(accessEntries, a.PrincipalArn)
// }
// f(&eks.ListAccessEntriesOutput{
// AccessEntries: accessEntries,
// }, true)
// return nil
// }

// func (e *EKSMock) DescribeAccessEntryWithContext(_ aws.Context, req *eks.DescribeAccessEntryInput, _ ...request.Option) (*eks.DescribeAccessEntryOutput, error) {
// defer func() {
// if e.Notify != nil {
// e.Notify <- struct{}{}
// }
// }()
// for _, a := range e.AccessEntries {
// if aws.StringValue(req.PrincipalArn) == aws.StringValue(a.PrincipalArn) && aws.StringValue(a.ClusterName) == aws.StringValue(req.ClusterName) {
// return &eks.DescribeAccessEntryOutput{AccessEntry: a}, nil
// }
// }
// return nil, trace.NotFound("access entry %v not found", aws.StringValue(req.PrincipalArn))
// }

// func (e *EKSMock) ListAssociatedAccessPoliciesPagesWithContext(_ aws.Context, _ *eks.ListAssociatedAccessPoliciesInput, f func(*eks.ListAssociatedAccessPoliciesOutput, bool) bool, _ ...request.Option) error {
// defer func() {
// if e.Notify != nil {
// e.Notify <- struct{}{}
// }
// }()

// f(&eks.ListAssociatedAccessPoliciesOutput{
// AssociatedAccessPolicies: e.AssociatedPolicies,
// }, true)
// return nil
// }
18 changes: 13 additions & 5 deletions lib/kube/proxy/cluster_details.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ import (
// kubeDetails contain the cluster-related details including authentication.
type kubeDetails struct {
kubeCreds

// dynamicLabels is the dynamic labels executor for this cluster.
dynamicLabels *labels.Dynamic
// kubeCluster is the dynamic kube_cluster or a static generated from kubeconfig and that only has the name populated.
Expand Down Expand Up @@ -268,8 +269,14 @@ func (k *kubeDetails) getObjectGVK(resource apiResource) *schema.GroupVersionKin

// getKubeClusterCredentials generates kube credentials for dynamic clusters.
func getKubeClusterCredentials(ctx context.Context, cfg clusterDetailsConfig) (kubeCreds, error) {
dynCredsCfg := dynamicCredsConfig{kubeCluster: cfg.cluster, log: cfg.log, checker: cfg.checker, resourceMatchers: cfg.resourceMatchers, clock: cfg.clock, component: cfg.component}
switch {
switch dynCredsCfg := (dynamicCredsConfig{
kubeCluster: cfg.cluster,
log: cfg.log,
checker: cfg.checker,
resourceMatchers: cfg.resourceMatchers,
clock: cfg.clock,
component: cfg.component,
}); {
case cfg.cluster.IsKubeconfig():
return getStaticCredentialsFromKubeconfig(ctx, cfg.component, cfg.cluster, cfg.log, cfg.checker)
case cfg.cluster.IsAzure():
Expand Down Expand Up @@ -333,19 +340,20 @@ func getAWSResourceMatcherToCluster(kubeCluster types.KubeCluster, resourceMatch
if match, _, _ := services.MatchLabels(matcher.Labels, kubeCluster.GetAllLabels()); !match {
continue
}

return &(matcher.AWS)
return &matcher.AWS
}
return nil
}

// STSPresignClient is the subset of the STS presign interface we use in fetchers.
type STSPresignClient = kubeutils.STSPresignClient

// EKSClient is the subset of the EKS Client interface we use.
type EKSClient interface {
eks.DescribeClusterAPIClient
}

// STSClient is the subset of the STS Client interface we use.
type STSClient interface {
stscreds.AssumeRoleAPIClient
}
Expand All @@ -366,7 +374,7 @@ func getAWSClientRestConfig(cloudClients ClientGetter, clock clockwork.Clock, re
region := cluster.GetAWSConfig().Region
opts := []awsconfig.OptionsFn{
awsconfig.WithAmbientCredentials(),
// TODO(@creack): Re-enable this when session cache v2 gets merged (#50561).
// TODO(@GavinFrazar): Re-enable this when session cache v2 gets merged (#50561).
// awsconfig.WithoutSessionCache(),
}
stsClient, err := cloudClients.GetAWSSTSClient(ctx, region, opts...)
Expand Down
2 changes: 1 addition & 1 deletion lib/srv/db/access_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2481,7 +2481,7 @@ func (p *agentParams) setDefaults(c *testContext) {

if p.CloudClients == nil {
p.CloudClients = &clients.TestCloudClients{
STS: &mocks.STSClientV1{},
STS: &mocks.STSMock{},
RDS: &mocks.RDSMock{},
Redshift: &mocks.RedshiftMock{},
RedshiftServerless: &mocks.RedshiftServerlessMock{},
Expand Down
4 changes: 2 additions & 2 deletions lib/srv/db/cloud/iam_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func TestAWSIAM(t *testing.T) {
}

// Configure mocks.
stsClient := &mocks.STSClientV1{
stsClient := &mocks.STSMock{
ARN: "arn:aws:iam::123456789012:role/test-role",
}

Expand Down Expand Up @@ -294,7 +294,7 @@ func TestAWSIAMNoPermissions(t *testing.T) {
t.Cleanup(cancel)

// Create unauthorized mocks for AWS services.
stsClient := &mocks.STSClientV1{
stsClient := &mocks.STSMock{
ARN: "arn:aws:iam::123456789012:role/test-role",
}
// Make configurator.
Expand Down
4 changes: 2 additions & 2 deletions lib/srv/db/cloud/meta_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ func TestAWSMetadata(t *testing.T) {
},
}

stsMock := &mocks.STSClientV1{}
stsMock := &mocks.STSMock{}

// Configure Redshift Serverless API mock.
redshiftServerlessWorkgroup := mocks.RedshiftServerlessWorkgroup("my-workgroup", "us-west-1")
Expand Down Expand Up @@ -406,7 +406,7 @@ func TestAWSMetadataNoPermissions(t *testing.T) {
rds := &mocks.RDSMockUnauth{}
redshift := &mocks.RedshiftMockUnauth{}

stsMock := &mocks.STSClientV1{}
stsMock := &mocks.STSMock{}

// Create metadata fetcher.
metadata, err := NewMetadata(MetadataConfig{
Expand Down
4 changes: 2 additions & 2 deletions lib/srv/db/cloud/resource_checker_url_aws_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ func TestURLChecker_AWS(t *testing.T) {
OpenSearch: &mocks.OpenSearchMock{
Domains: []*opensearchservice.DomainStatus{openSearchDomain, openSearchVPCDomain},
},
STS: &mocks.STSClientV1{},
STS: &mocks.STSMock{},
}
mockClientsUnauth := &cloud.TestCloudClients{
RDS: &mocks.RDSMockUnauth{},
Expand All @@ -151,7 +151,7 @@ func TestURLChecker_AWS(t *testing.T) {
ElastiCache: &mocks.ElastiCacheMock{Unauth: true},
MemoryDB: &mocks.MemoryDBMock{Unauth: true},
OpenSearch: &mocks.OpenSearchMock{Unauth: true},
STS: &mocks.STSClientV1{},
STS: &mocks.STSMock{},
}

// Test both check methods.
Expand Down
Loading

0 comments on commit 7174944

Please sign in to comment.