Skip to content

Commit

Permalink
migrate lib/srv/discovery eks to aws sdk v2 - tests wip
Browse files Browse the repository at this point in the history
  • Loading branch information
creack committed Dec 29, 2024
1 parent 8c5273e commit 6432b9a
Show file tree
Hide file tree
Showing 19 changed files with 359 additions and 277 deletions.
2 changes: 0 additions & 2 deletions lib/cloud/clients.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,6 @@ type AWSClients interface {
GetAWSIAMClient(ctx context.Context, region string, opts ...AWSOptionsFn) (iamiface.IAMAPI, error)
// GetAWSSTSClient returns AWS STS client for the specified region.
GetAWSSTSClient(ctx context.Context, region string, opts ...AWSOptionsFn) (stsiface.STSAPI, error)
// GetAWSEKSClient returns AWS EKS client for the specified region.
GetAWSEKSClient(ctx context.Context, region string, opts ...AWSOptionsFn) (eksiface.EKSAPI, error)
// GetAWSKMSClient returns AWS KMS client for the specified region.
GetAWSKMSClient(ctx context.Context, region string, opts ...AWSOptionsFn) (kmsiface.KMSAPI, error)
// GetAWSS3Client returns AWS S3 client.
Expand Down
19 changes: 9 additions & 10 deletions lib/cloud/mocks/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ import (
"github.com/gravitational/trace"
)

// STSMock mocks AWS STS API.
type STSMock struct {
// STSClientV1 mocks AWS STS API.
type STSClientV1 struct {
stsiface.STSAPI
ARN string
URL *url.URL
Expand All @@ -47,36 +47,36 @@ type STSMock struct {
mu sync.Mutex
}

func (m *STSMock) GetAssumedRoleARNs() []string {
func (m *STSClientV1) GetAssumedRoleARNs() []string {
m.mu.Lock()
defer m.mu.Unlock()
return m.assumedRoleARNs
}

func (m *STSMock) GetAssumedRoleExternalIDs() []string {
func (m *STSClientV1) GetAssumedRoleExternalIDs() []string {
m.mu.Lock()
defer m.mu.Unlock()
return m.assumedRoleExternalIDs
}

func (m *STSMock) ResetAssumeRoleHistory() {
func (m *STSClientV1) ResetAssumeRoleHistory() {
m.mu.Lock()
defer m.mu.Unlock()
m.assumedRoleARNs = nil
m.assumedRoleExternalIDs = nil
}

func (m *STSMock) GetCallerIdentityWithContext(aws.Context, *sts.GetCallerIdentityInput, ...request.Option) (*sts.GetCallerIdentityOutput, error) {
func (m *STSClientV1) GetCallerIdentityWithContext(aws.Context, *sts.GetCallerIdentityInput, ...request.Option) (*sts.GetCallerIdentityOutput, error) {
return &sts.GetCallerIdentityOutput{
Arn: aws.String(m.ARN),
}, nil
}

func (m *STSMock) AssumeRole(in *sts.AssumeRoleInput) (*sts.AssumeRoleOutput, error) {
func (m *STSClientV1) AssumeRole(in *sts.AssumeRoleInput) (*sts.AssumeRoleOutput, error) {
return m.AssumeRoleWithContext(context.Background(), in)
}

func (m *STSMock) AssumeRoleWithContext(ctx aws.Context, in *sts.AssumeRoleInput, _ ...request.Option) (*sts.AssumeRoleOutput, error) {
func (m *STSClientV1) AssumeRoleWithContext(ctx aws.Context, in *sts.AssumeRoleInput, _ ...request.Option) (*sts.AssumeRoleOutput, error) {
m.mu.Lock()
defer m.mu.Unlock()
if !slices.Contains(m.assumedRoleARNs, aws.StringValue(in.RoleArn)) {
Expand All @@ -94,7 +94,7 @@ func (m *STSMock) AssumeRoleWithContext(ctx aws.Context, in *sts.AssumeRoleInput
}, nil
}

func (m *STSMock) GetCallerIdentityRequest(req *sts.GetCallerIdentityInput) (*request.Request, *sts.GetCallerIdentityOutput) {
func (m *STSClientV1) GetCallerIdentityRequest(req *sts.GetCallerIdentityInput) (*request.Request, *sts.GetCallerIdentityOutput) {
return &request.Request{
HTTPRequest: &http.Request{
Header: http.Header{},
Expand Down Expand Up @@ -369,5 +369,4 @@ func (e *EKSMock) ListAssociatedAccessPoliciesPagesWithContext(_ aws.Context, _
AssociatedAccessPolicies: e.AssociatedPolicies,
}, true)
return nil

}
16 changes: 9 additions & 7 deletions lib/integrations/awsoidc/eks_enroll_clusters.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,10 @@ const (
concurrentEKSEnrollingLimit = 5
)

var agentRepoURL = url.URL{Scheme: "https", Host: "charts.releases.teleport.dev"}
var agentStagingRepoURL = url.URL{Scheme: "https", Host: "charts.releases.development.teleport.dev"}
var (
agentRepoURL = url.URL{Scheme: "https", Host: "charts.releases.teleport.dev"}
agentStagingRepoURL = url.URL{Scheme: "https", Host: "charts.releases.development.teleport.dev"}
)

// EnrollEKSClusterResult contains result for a single EKS cluster enrollment, if it was successful 'Error' will be nil
// otherwise it will contain an error happened during enrollment.
Expand Down Expand Up @@ -462,7 +464,6 @@ func enrollEKSCluster(ctx context.Context, log *slog.Logger, clock clockwork.Clo
return "",
issueTypeFromCheckAgentInstalledError(err),
trace.Wrap(err, "could not check if teleport-kube-agent is already installed.")

} else if alreadyInstalled {
return "",
// When using EKS Auto Discovery, after the Kube Agent connects to the Teleport cluster, it is ignored in next discovery iterations.
Expand Down Expand Up @@ -708,19 +709,20 @@ func installKubeAgent(ctx context.Context, cfg installKubeAgentParams) error {
if cfg.req.IsCloud && cfg.req.EnableAutoUpgrades {
vals["updater"] = map[string]any{"enabled": true, "releaseChannel": "stable/cloud"}

vals["highAvailability"] = map[string]any{"replicaCount": 2,
vals["highAvailability"] = map[string]any{
"replicaCount": 2,
"podDisruptionBudget": map[string]any{"enabled": true, "minAvailable": 1},
}
}
if modules.GetModules().BuildType() == modules.BuildEnterprise {
vals["enterprise"] = true
}

eksTags := make(map[string]*string, len(cfg.eksCluster.Tags))
eksTags := make(map[string]string, len(cfg.eksCluster.Tags))
for k, v := range cfg.eksCluster.Tags {
eksTags[k] = aws.String(v)
eksTags[k] = v
}
eksTags[types.OriginLabel] = aws.String(types.OriginCloud)
eksTags[types.OriginLabel] = types.OriginCloud
kubeCluster, err := common.NewKubeClusterFromAWSEKS(aws.ToString(cfg.eksCluster.Name), aws.ToString(cfg.eksCluster.Arn), eksTags)
if err != nil {
return trace.Wrap(err)
Expand Down
2 changes: 1 addition & 1 deletion lib/kube/proxy/kube_creds_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ func Test_DynamicKubeCreds(t *testing.T) {
Host: "sts.amazonaws.com",
Path: "/?Action=GetCallerIdentity&Version=2011-06-15",
}
sts := &mocks.STSMock{
sts := &mocks.STSClientV1{
// u is used to presign the request
// here we just verify the pre-signed request includes this url.
URL: u,
Expand Down
2 changes: 1 addition & 1 deletion lib/srv/db/access_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2481,7 +2481,7 @@ func (p *agentParams) setDefaults(c *testContext) {

if p.CloudClients == nil {
p.CloudClients = &clients.TestCloudClients{
STS: &mocks.STSMock{},
STS: &mocks.STSClientV1{},
RDS: &mocks.RDSMock{},
Redshift: &mocks.RedshiftMock{},
RedshiftServerless: &mocks.RedshiftServerlessMock{},
Expand Down
6 changes: 4 additions & 2 deletions lib/srv/db/cloud/iam_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func TestAWSIAM(t *testing.T) {
}

// Configure mocks.
stsClient := &mocks.STSMock{
stsClient := &mocks.STSClientV1{
ARN: "arn:aws:iam::123456789012:role/test-role",
}

Expand Down Expand Up @@ -294,7 +294,7 @@ func TestAWSIAMNoPermissions(t *testing.T) {
t.Cleanup(cancel)

// Create unauthorized mocks for AWS services.
stsClient := &mocks.STSMock{
stsClient := &mocks.STSClientV1{
ARN: "arn:aws:iam::123456789012:role/test-role",
}
// Make configurator.
Expand Down Expand Up @@ -429,6 +429,7 @@ func (m *mockAccessPoint) GetClusterName(opts ...services.MarshalOption) (types.
ClusterID: "cluster-id",
})
}

func (m *mockAccessPoint) AcquireSemaphore(ctx context.Context, params types.AcquireSemaphoreRequest) (*types.SemaphoreLease, error) {
return &types.SemaphoreLease{
SemaphoreKind: params.SemaphoreKind,
Expand All @@ -437,6 +438,7 @@ func (m *mockAccessPoint) AcquireSemaphore(ctx context.Context, params types.Acq
Expires: params.Expires,
}, nil
}

func (m *mockAccessPoint) CancelSemaphoreLease(ctx context.Context, lease types.SemaphoreLease) error {
return nil
}
4 changes: 2 additions & 2 deletions lib/srv/db/cloud/meta_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ func TestAWSMetadata(t *testing.T) {
},
}

stsMock := &mocks.STSMock{}
stsMock := &mocks.STSClientV1{}

// Configure Redshift Serverless API mock.
redshiftServerlessWorkgroup := mocks.RedshiftServerlessWorkgroup("my-workgroup", "us-west-1")
Expand Down Expand Up @@ -406,7 +406,7 @@ func TestAWSMetadataNoPermissions(t *testing.T) {
rds := &mocks.RDSMockUnauth{}
redshift := &mocks.RedshiftMockUnauth{}

stsMock := &mocks.STSMock{}
stsMock := &mocks.STSClientV1{}

// Create metadata fetcher.
metadata, err := NewMetadata(MetadataConfig{
Expand Down
4 changes: 2 additions & 2 deletions lib/srv/db/cloud/resource_checker_url_aws_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ func TestURLChecker_AWS(t *testing.T) {
OpenSearch: &mocks.OpenSearchMock{
Domains: []*opensearchservice.DomainStatus{openSearchDomain, openSearchVPCDomain},
},
STS: &mocks.STSMock{},
STS: &mocks.STSClientV1{},
}
mockClientsUnauth := &cloud.TestCloudClients{
RDS: &mocks.RDSMockUnauth{},
Expand All @@ -151,7 +151,7 @@ func TestURLChecker_AWS(t *testing.T) {
ElastiCache: &mocks.ElastiCacheMock{Unauth: true},
MemoryDB: &mocks.MemoryDBMock{Unauth: true},
OpenSearch: &mocks.OpenSearchMock{Unauth: true},
STS: &mocks.STSMock{},
STS: &mocks.STSClientV1{},
}

// Test both check methods.
Expand Down
32 changes: 15 additions & 17 deletions lib/srv/db/common/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ func TestAuthGetRedshiftServerlessAuthToken(t *testing.T) {
t.Parallel()

// setup mock aws sessions.
stsMock := &mocks.STSMock{}
stsMock := &mocks.STSClientV1{}
clock := clockwork.NewFakeClock()
auth, err := NewAuth(AuthConfig{
Clock: clock,
Expand Down Expand Up @@ -466,7 +466,7 @@ func TestAuthGetAWSTokenWithAssumedRole(t *testing.T) {
t.Cleanup(cancel)
tests := map[string]struct {
checkGetAuthFn func(t *testing.T, auth Auth)
checkSTS func(t *testing.T, stsMock *mocks.STSMock)
checkSTS func(t *testing.T, stsMock *mocks.STSClientV1)
}{
"Redshift": {
checkGetAuthFn: func(t *testing.T, auth Auth) {
Expand All @@ -485,7 +485,7 @@ func TestAuthGetAWSTokenWithAssumedRole(t *testing.T) {
require.Equal(t, "IAM:some-user", dbUser)
require.Equal(t, "some-password", dbPassword)
},
checkSTS: func(t *testing.T, stsMock *mocks.STSMock) {
checkSTS: func(t *testing.T, stsMock *mocks.STSClientV1) {
t.Helper()
require.Contains(t, stsMock.GetAssumedRoleARNs(), "arn:aws:iam::123456789012:role/RedshiftRole")
require.Contains(t, stsMock.GetAssumedRoleExternalIDs(), "externalRedshift")
Expand All @@ -508,7 +508,7 @@ func TestAuthGetAWSTokenWithAssumedRole(t *testing.T) {
require.Equal(t, "IAM:some-role", dbUser)
require.Equal(t, "some-password-for-some-role", dbPassword)
},
checkSTS: func(t *testing.T, stsMock *mocks.STSMock) {
checkSTS: func(t *testing.T, stsMock *mocks.STSClientV1) {
t.Helper()
require.Contains(t, stsMock.GetAssumedRoleARNs(), "arn:aws:iam::123456789012:role/RedshiftRole")
require.Contains(t, stsMock.GetAssumedRoleExternalIDs(), "externalRedshift")
Expand All @@ -530,7 +530,7 @@ func TestAuthGetAWSTokenWithAssumedRole(t *testing.T) {
require.Equal(t, "IAM:some-user", dbUser)
require.Equal(t, "some-password", dbPassword)
},
checkSTS: func(t *testing.T, stsMock *mocks.STSMock) {
checkSTS: func(t *testing.T, stsMock *mocks.STSClientV1) {
t.Helper()
require.Contains(t, stsMock.GetAssumedRoleARNs(), "arn:aws:iam::123456789012:role/RedshiftServerlessRole")
require.Contains(t, stsMock.GetAssumedRoleExternalIDs(), "externalRedshiftServerless")
Expand All @@ -550,7 +550,7 @@ func TestAuthGetAWSTokenWithAssumedRole(t *testing.T) {
require.NoError(t, err)
require.Contains(t, token, "DBUser=some-user")
},
checkSTS: func(t *testing.T, stsMock *mocks.STSMock) {
checkSTS: func(t *testing.T, stsMock *mocks.STSClientV1) {
t.Helper()
require.Contains(t, stsMock.GetAssumedRoleARNs(), "arn:aws:iam::123456789012:role/RDSProxyRole")
require.Contains(t, stsMock.GetAssumedRoleExternalIDs(), "externalRDSProxy")
Expand Down Expand Up @@ -578,15 +578,15 @@ func TestAuthGetAWSTokenWithAssumedRole(t *testing.T) {
require.Equal(t, "arn:aws:iam::123456789012:role/RedisRole/20010203/ca-central-1/elasticache/aws4_request",
query.Get("X-Amz-Credential"))
},
checkSTS: func(t *testing.T, stsMock *mocks.STSMock) {
checkSTS: func(t *testing.T, stsMock *mocks.STSClientV1) {
t.Helper()
require.Contains(t, stsMock.GetAssumedRoleARNs(), "arn:aws:iam::123456789012:role/RedisRole")
require.Contains(t, stsMock.GetAssumedRoleExternalIDs(), "externalElastiCacheRedis")
},
},
}

stsMock := &mocks.STSMock{}
stsMock := &mocks.STSClientV1{}
clock := clockwork.NewFakeClockAt(time.Date(2001, time.February, 3, 0, 0, 0, 0, time.UTC))
auth, err := NewAuth(AuthConfig{
Clock: clock,
Expand Down Expand Up @@ -623,7 +623,7 @@ func TestGetAWSIAMCreds(t *testing.T) {

for name, tt := range map[string]struct {
db types.Database
stsMock *mocks.STSMock
stsMock *mocks.STSClientV1
username string
expectedKeyId string
expectedAssumedRoles []string
Expand All @@ -632,7 +632,7 @@ func TestGetAWSIAMCreds(t *testing.T) {
}{
"username is full role ARN": {
db: newMongoAtlasDatabase(t, types.AWS{}),
stsMock: &mocks.STSMock{},
stsMock: &mocks.STSClientV1{},
username: "arn:aws:iam::123456789012:role/role-name",
expectedKeyId: "arn:aws:iam::123456789012:role/role-name",
expectedAssumedRoles: []string{"arn:aws:iam::123456789012:role/role-name"},
Expand All @@ -641,7 +641,7 @@ func TestGetAWSIAMCreds(t *testing.T) {
},
"username is partial role ARN": {
db: newMongoAtlasDatabase(t, types.AWS{}),
stsMock: &mocks.STSMock{
stsMock: &mocks.STSClientV1{
// This is the role returned by the STS GetCallerIdentity.
ARN: "arn:aws:iam::222222222222:role/teleport-service-role",
},
Expand All @@ -653,7 +653,7 @@ func TestGetAWSIAMCreds(t *testing.T) {
},
"unable to fetch account ID": {
db: newMongoAtlasDatabase(t, types.AWS{}),
stsMock: &mocks.STSMock{
stsMock: &mocks.STSClientV1{
ARN: "",
},
username: "role/role-name",
Expand All @@ -664,7 +664,7 @@ func TestGetAWSIAMCreds(t *testing.T) {
ExternalID: "123123",
AssumeRoleARN: "arn:aws:iam::222222222222:role/teleport-service-role-external",
}),
stsMock: &mocks.STSMock{
stsMock: &mocks.STSClientV1{
ARN: "arn:aws:iam::111111111111:role/teleport-service-role",
},
username: "role/role-name",
Expand Down Expand Up @@ -938,8 +938,7 @@ func generateAzureVM(t *testing.T, identities []string) armcompute.VirtualMachin
}

// authClientMock is a mock that implements AuthClient interface.
type authClientMock struct {
}
type authClientMock struct{}

// GenerateDatabaseCert generates a cert using fixtures TLS CA.
func (m *authClientMock) GenerateDatabaseCert(ctx context.Context, req *proto.DatabaseCertRequest) (*proto.DatabaseCertResponse, error) {
Expand Down Expand Up @@ -977,8 +976,7 @@ func (m *authClientMock) GenerateDatabaseCert(ctx context.Context, req *proto.Da
}, nil
}

type accessPointMock struct {
}
type accessPointMock struct{}

// GetAuthPreference always returns types.DefaultAuthPreference().
func (m accessPointMock) GetAuthPreference(ctx context.Context) (types.AuthPreference, error) {
Expand Down
1 change: 1 addition & 0 deletions lib/srv/discovery/access_graph.go
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,7 @@ func (s *Server) accessGraphFetchersFromMatchers(ctx context.Context, matchers M
ctx,
aws_sync.Config{
CloudClients: s.CloudClients,
GetEKSClient: s.GetAWSSyncEKSClient,
GetEC2Client: s.GetEC2Client,
AssumeRole: assumeRole,
Regions: awsFetcher.Regions,
Expand Down
Loading

0 comments on commit 6432b9a

Please sign in to comment.