diff --git a/README.md b/README.md index 0704cad1..fa945f75 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,7 @@ The currently supported functionality includes: - Deleting all Auto scaling groups in an AWS account - Deleting all Elastic Load Balancers (Classic and V2) in an AWS account +- Deleting all Transit Gateways in an AWS account - Deleting all EBS Volumes in an AWS account - Deleting all unprotected EC2 instances in an AWS account - Deleting all AMIs in an AWS account diff --git a/aws/aws.go b/aws/aws.go index 7bf3a779..ed7dc48d 100644 --- a/aws/aws.go +++ b/aws/aws.go @@ -276,6 +276,48 @@ func GetAllResources(targetRegions []string, excludeAfter time.Time, resourceTyp } // End LoadBalancerV2 Arns + // TransitGatewayVpcAttachment + transitGatewayVpcAttachments := TransitGatewaysVpcAttachment{} + if IsNukeable(transitGatewayVpcAttachments.ResourceName(), resourceTypes) { + transitGatewayVpcAttachmentIds, err := getAllTransitGatewayVpcAttachments(session, region, excludeAfter) + if err != nil { + return nil, errors.WithStackTrace(err) + } + if len(transitGatewayVpcAttachmentIds) > 0 { + transitGatewayVpcAttachments.Ids = awsgo.StringValueSlice(transitGatewayVpcAttachmentIds) + resourcesInRegion.Resources = append(resourcesInRegion.Resources, transitGatewayVpcAttachments) + } + } + // End TransitGatewayVpcAttachment + + // TransitGatewayRouteTable + transitGatewayRouteTables := TransitGatewaysRouteTables{} + if IsNukeable(transitGatewayRouteTables.ResourceName(), resourceTypes) { + transitGatewayRouteTableIds, err := getAllTransitGatewayRouteTables(session, region, excludeAfter) + if err != nil { + return nil, errors.WithStackTrace(err) + } + if len(transitGatewayRouteTableIds) > 0 { + transitGatewayRouteTables.Ids = awsgo.StringValueSlice(transitGatewayRouteTableIds) + resourcesInRegion.Resources = append(resourcesInRegion.Resources, transitGatewayRouteTables) + } + } + // End TransitGatewayRouteTable + + // TransitGateway + transitGateways := TransitGateways{} + if IsNukeable(transitGateways.ResourceName(), resourceTypes) { + transitGatewayIds, err := getAllTransitGatewayInstances(session, region, excludeAfter) + if err != nil { + return nil, errors.WithStackTrace(err) + } + if len(transitGatewayIds) > 0 { + transitGateways.Ids = awsgo.StringValueSlice(transitGatewayIds) + resourcesInRegion.Resources = append(resourcesInRegion.Resources, transitGateways) + } + } + // End TransitGateway + // EC2 Instances ec2Instances := EC2Instances{} if IsNukeable(ec2Instances.ResourceName(), resourceTypes) { @@ -535,6 +577,9 @@ func ListResourceTypes() []string { LaunchConfigs{}.ResourceName(), LoadBalancers{}.ResourceName(), LoadBalancersV2{}.ResourceName(), + TransitGatewaysVpcAttachment{}.ResourceName(), + TransitGatewaysRouteTables{}.ResourceName(), + TransitGateways{}.ResourceName(), EC2Instances{}.ResourceName(), EBSVolumes{}.ResourceName(), EIPAddresses{}.ResourceName(), diff --git a/aws/ec2_utils_for_test.go b/aws/ec2_utils_for_test.go new file mode 100644 index 00000000..629e3e72 --- /dev/null +++ b/aws/ec2_utils_for_test.go @@ -0,0 +1,34 @@ +package aws + +import ( + "testing" + + "github.com/aws/aws-sdk-go/aws" + awsgo "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/ec2" + "github.com/stretchr/testify/require" +) + +func getVpcSubnets(t *testing.T, session *session.Session, vpcId string) []string { + svc := ec2.New(session) + + param := &ec2.DescribeSubnetsInput{ + Filters: []*ec2.Filter{ + { + Name: aws.String("vpc-id"), + Values: aws.StringSlice([]string{vpcId}), + }, + }, + } + + result, err := svc.DescribeSubnets(param) + require.NoError(t, err) + + var subnets []string + + for _, v := range result.Subnets { + subnets = append(subnets, awsgo.StringValue(v.SubnetId)) + } + return subnets +} diff --git a/aws/ecs_cluster_test.go b/aws/ecs_cluster_test.go index 53202c80..98a86238 100644 --- a/aws/ecs_cluster_test.go +++ b/aws/ecs_cluster_test.go @@ -21,7 +21,7 @@ func TestCanTagEcsClusters(t *testing.T) { }) require.NoError(t, err) - cluster := createEcsFargateCluster(t, awsSession, "cloud-nuke-test-" + util.UniqueID()) + cluster := createEcsFargateCluster(t, awsSession, "cloud-nuke-test-"+util.UniqueID()) defer deleteEcsCluster(awsSession, cluster) tagValue := time.Now().UTC() @@ -83,11 +83,11 @@ func TestCanNukeAllEcsClustersOlderThan24Hours(t *testing.T) { }) require.NoError(t, err) - cluster1 := createEcsFargateCluster(t, awsSession, "cloud-nuke-test-" + util.UniqueID()) + cluster1 := createEcsFargateCluster(t, awsSession, "cloud-nuke-test-"+util.UniqueID()) defer deleteEcsCluster(awsSession, cluster1) - cluster2 := createEcsFargateCluster(t, awsSession, "cloud-nuke-test-" + util.UniqueID()) + cluster2 := createEcsFargateCluster(t, awsSession, "cloud-nuke-test-"+util.UniqueID()) defer deleteEcsCluster(awsSession, cluster2) - cluster3 := createEcsFargateCluster(t, awsSession, "cloud-nuke-test-" + util.UniqueID()) + cluster3 := createEcsFargateCluster(t, awsSession, "cloud-nuke-test-"+util.UniqueID()) defer deleteEcsCluster(awsSession, cluster3) now := time.Now().UTC() @@ -114,4 +114,3 @@ func TestCanNukeAllEcsClustersOlderThan24Hours(t *testing.T) { assert.Contains(t, awsgo.StringValueSlice(allLeftClusterArns), awsgo.StringValue(cluster2.ClusterArn)) } - diff --git a/aws/rds_cluster.go b/aws/rds_cluster.go index 46339d24..502c9857 100644 --- a/aws/rds_cluster.go +++ b/aws/rds_cluster.go @@ -16,7 +16,7 @@ func waitUntilRdsClusterDeleted(svc *rds.RDS, input *rds.DescribeDBClustersInput for i := 0; i < 90; i++ { _, err := svc.DescribeDBClusters(input) if err != nil { - if awsErr, isAwsErr := err.(awserr.Error); isAwsErr && awsErr.Code() == rds.ErrCodeDBClusterNotFoundFault { + if awsErr, isAwsErr := err.(awserr.Error); isAwsErr && awsErr.Code() == rds.ErrCodeDBClusterNotFoundFault { return nil } @@ -30,7 +30,6 @@ func waitUntilRdsClusterDeleted(svc *rds.RDS, input *rds.DescribeDBClustersInput return RdsDeleteError{name: *input.DBClusterIdentifier} } - func getAllRdsClusters(session *session.Session, excludeAfter time.Time) ([]*string, error) { svc := rds.New(session) @@ -44,7 +43,7 @@ func getAllRdsClusters(session *session.Session, excludeAfter time.Time) ([]*str for _, database := range result.DBClusters { if excludeAfter.After(*database.ClusterCreateTime) { - names = append(names, database.DBClusterIdentifier) + names = append(names, database.DBClusterIdentifier) } } @@ -65,7 +64,7 @@ func nukeAllRdsClusters(session *session.Session, names []*string) error { for _, name := range names { params := &rds.DeleteDBClusterInput{ DBClusterIdentifier: name, - SkipFinalSnapshot: awsgo.Bool(true), + SkipFinalSnapshot: awsgo.Bool(true), } _, err := svc.DeleteDBCluster(params) diff --git a/aws/rds_cluster_test.go b/aws/rds_cluster_test.go index 694b5de1..a0104b6e 100644 --- a/aws/rds_cluster_test.go +++ b/aws/rds_cluster_test.go @@ -19,9 +19,9 @@ func createTestRDSCluster(t *testing.T, session *session.Session, name string) { svc := rds.New(session) params := &rds.CreateDBClusterInput{ DBClusterIdentifier: awsgo.String(name), - Engine: awsgo.String("aurora-mysql"), - MasterUsername: awsgo.String("gruntwork"), - MasterUserPassword: awsgo.String("password"), + Engine: awsgo.String("aurora-mysql"), + MasterUsername: awsgo.String("gruntwork"), + MasterUserPassword: awsgo.String("password"), } _, err := svc.CreateDBCluster(params) @@ -39,7 +39,7 @@ func TestNukeRDSCluster(t *testing.T) { ) rdsName := "cloud-nuke-test" + util.UniqueID() - excludeAfter := time.Now().Add(1*time.Hour) + excludeAfter := time.Now().Add(1 * time.Hour) createTestRDSCluster(t, session, rdsName) diff --git a/aws/rds_cluster_types.go b/aws/rds_cluster_types.go index af4398d5..ab3bcfc3 100644 --- a/aws/rds_cluster_types.go +++ b/aws/rds_cluster_types.go @@ -32,4 +32,3 @@ func (instance DBClusters) Nuke(session *session.Session, identifiers []string) return nil } - diff --git a/aws/rds_test.go b/aws/rds_test.go index 319552f7..7ae8d331 100644 --- a/aws/rds_test.go +++ b/aws/rds_test.go @@ -73,7 +73,7 @@ func TestNukeRDSInstance(t *testing.T) { ) rdsName := "cloud-nuke-test-" + util.UniqueID() - excludeAfter := time.Now().Add(1*time.Hour) + excludeAfter := time.Now().Add(1 * time.Hour) createTestRDSInstance(t, session, rdsName) diff --git a/aws/rds_types.go b/aws/rds_types.go index 03e50601..3fd091f6 100644 --- a/aws/rds_types.go +++ b/aws/rds_types.go @@ -33,7 +33,7 @@ func (instance DBInstances) Nuke(session *session.Session, identifiers []string) return nil } -type RdsDeleteError struct{ +type RdsDeleteError struct { name string } diff --git a/aws/transit_gateway.go b/aws/transit_gateway.go new file mode 100644 index 00000000..0c623ae3 --- /dev/null +++ b/aws/transit_gateway.go @@ -0,0 +1,177 @@ +package aws + +import ( + "time" + + "github.com/aws/aws-sdk-go/aws" + awsgo "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/ec2" + "github.com/gruntwork-io/cloud-nuke/logging" + "github.com/gruntwork-io/gruntwork-cli/errors" +) + +func sleepWithMessage(duration time.Duration, whySleepMessage string) { + logging.Logger.Infof("Sleeping %v: %s", duration, whySleepMessage) + time.Sleep(duration) +} + +// Returns a formatted string of TransitGateway IDs +func getAllTransitGatewayInstances(session *session.Session, region string, excludeAfter time.Time) ([]*string, error) { + svc := ec2.New(session) + result, err := svc.DescribeTransitGateways(&ec2.DescribeTransitGatewaysInput{}) + if err != nil { + return nil, errors.WithStackTrace(err) + } + + var ids []*string + for _, transitGateway := range result.TransitGateways { + if excludeAfter.After(*transitGateway.CreationTime) && awsgo.StringValue(transitGateway.State) != "deleted" && awsgo.StringValue(transitGateway.State) != "deleting" { + ids = append(ids, transitGateway.TransitGatewayId) + } + } + + return ids, nil +} + +// Delete all TransitGateways +func nukeAllTransitGatewayInstances(session *session.Session, ids []*string) error { + svc := ec2.New(session) + + if len(ids) == 0 { + logging.Logger.Infof("No Transit Gateways to nuke in region %s", *session.Config.Region) + return nil + } + + logging.Logger.Infof("Deleting all Transit Gateways in region %s", *session.Config.Region) + var deletedIds []*string + + for _, id := range ids { + params := &ec2.DeleteTransitGatewayInput{ + TransitGatewayId: id, + } + + _, err := svc.DeleteTransitGateway(params) + if err != nil { + logging.Logger.Errorf("[Failed] %s", err) + } else { + deletedIds = append(deletedIds, id) + logging.Logger.Infof("Deleted Transit Gateway: %s", *id) + } + } + + logging.Logger.Infof("[OK] %d Transit Gateway(s) deleted in %s", len(deletedIds), *session.Config.Region) + return nil +} + +// Returns a formatted string of TranstGatewayRouteTable IDs +func getAllTransitGatewayRouteTables(session *session.Session, region string, excludeAfter time.Time) ([]*string, error) { + svc := ec2.New(session) + + // Remove defalt route table, that will be deleted along with its TransitGateway + param := &ec2.DescribeTransitGatewayRouteTablesInput{ + Filters: []*ec2.Filter{ + { + Name: aws.String("default-association-route-table"), + Values: []*string{ + aws.String("false"), + }}, + }, + } + + result, err := svc.DescribeTransitGatewayRouteTables(param) + if err != nil { + return nil, errors.WithStackTrace(err) + } + + var ids []*string + for _, transitGatewayRouteTable := range result.TransitGatewayRouteTables { + if excludeAfter.After(*transitGatewayRouteTable.CreationTime) && awsgo.StringValue(transitGatewayRouteTable.State) != "deleted" && awsgo.StringValue(transitGatewayRouteTable.State) != "deleting" { + ids = append(ids, transitGatewayRouteTable.TransitGatewayRouteTableId) + } + } + + return ids, nil +} + +// Delete all TransitGatewayRouteTables +func nukeAllTransitGatewayRouteTables(session *session.Session, ids []*string) error { + svc := ec2.New(session) + + if len(ids) == 0 { + logging.Logger.Infof("No Transit Gateway Route Tables to nuke in region %s", *session.Config.Region) + return nil + } + + logging.Logger.Infof("Deleting all Transit Gateway Route Tables in region %s", *session.Config.Region) + var deletedIds []*string + + for _, id := range ids { + param := &ec2.DeleteTransitGatewayRouteTableInput{ + TransitGatewayRouteTableId: id, + } + + _, err := svc.DeleteTransitGatewayRouteTable(param) + if err != nil { + logging.Logger.Errorf("[Failed] %s", err) + } else { + deletedIds = append(deletedIds, id) + logging.Logger.Infof("Deleted Transit Gateway Route Table: %s", *id) + } + } + + logging.Logger.Infof("[OK] %d Transit Gateway Route Table(s) deleted in %s", len(deletedIds), *session.Config.Region) + return nil +} + +// Returns a formated string of TransitGatewayVpcAttachment IDs +func getAllTransitGatewayVpcAttachments(session *session.Session, region string, excludeAfter time.Time) ([]*string, error) { + svc := ec2.New(session) + result, err := svc.DescribeTransitGatewayVpcAttachments(&ec2.DescribeTransitGatewayVpcAttachmentsInput{}) + if err != nil { + return nil, errors.WithStackTrace(err) + } + + var ids []*string + for _, tgwVpcAttachment := range result.TransitGatewayVpcAttachments { + if excludeAfter.After(*tgwVpcAttachment.CreationTime) && awsgo.StringValue(tgwVpcAttachment.State) != "deleted" && awsgo.StringValue(tgwVpcAttachment.State) != "deleting" { + ids = append(ids, tgwVpcAttachment.TransitGatewayAttachmentId) + } + } + + return ids, nil +} + +// Delete all TransitGatewayVpcAttachments +func nukeAllTransitGatewayVpcAttachments(session *session.Session, ids []*string) error { + svc := ec2.New(session) + + if len(ids) == 0 { + logging.Logger.Infof("No Transit Gateway Vpc Attachments to nuke in region %s", *session.Config.Region) + return nil + } + + logging.Logger.Infof("Deleting all Transit Gateway Vpc Attachments in region %s", *session.Config.Region) + var deletedIds []*string + + for _, id := range ids { + param := &ec2.DeleteTransitGatewayVpcAttachmentInput{ + TransitGatewayAttachmentId: id, + } + + _, err := svc.DeleteTransitGatewayVpcAttachment(param) + if err != nil { + logging.Logger.Errorf("[Failed] %s", err) + } else { + deletedIds = append(deletedIds, id) + logging.Logger.Infof("Deleted Transit Gateway Vpc Attachment: %s", *id) + } + } + + sleepMessage := "TransitGateway Vpc Attachments takes some time to create, and since there is no waiter available, we sleep instead." + sleepFor := 180 * time.Second + sleepWithMessage(sleepFor, sleepMessage) + + logging.Logger.Infof(("[OK] %d Transit Gateway Vpc Attachment(s) deleted in %s"), len(deletedIds), *session.Config.Region) + return nil +} diff --git a/aws/transit_gateway_test.go b/aws/transit_gateway_test.go new file mode 100644 index 00000000..01a15991 --- /dev/null +++ b/aws/transit_gateway_test.go @@ -0,0 +1,304 @@ +package aws + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/aws/aws-sdk-go/aws" + awsgo "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/ec2" + "github.com/gruntwork-io/cloud-nuke/util" + "github.com/gruntwork-io/gruntwork-cli/errors" +) + +func createTestTransitGateway(t *testing.T, session *session.Session, name string) ec2.TransitGateway { + svc := ec2.New(session) + + tgwName := ec2.TagSpecification{ + ResourceType: awsgo.String(ec2.ResourceTypeTransitGateway), + Tags: []*ec2.Tag{ + { + Key: aws.String("Name"), + Value: aws.String(name), + }, + }, + } + + param := &ec2.CreateTransitGatewayInput{ + TagSpecifications: []*ec2.TagSpecification{&tgwName}, + } + + result, err := svc.CreateTransitGateway(param) + require.NoError(t, err) + require.True(t, len(aws.StringValue(result.TransitGateway.TransitGatewayId)) > 0, "Could not create test TransitGateway") + + sleepMessage := "TransitGateway takes some time to create, and since there is no waiter available, we sleep instead." + sleepFor := 180 * time.Second + sleepWithMessage(sleepFor, sleepMessage) + + return *result.TransitGateway +} + +func TestGetAllTransitGatewayInstances(t *testing.T) { + t.Parallel() + + region, err := getRandomRegion() + require.NoError(t, err) + + session, err := session.NewSession(&awsgo.Config{ + Region: awsgo.String(region)}, + ) + require.NoError(t, err) + + tgwName := "cloud-nuke-test-" + util.UniqueID() + tgw := createTestTransitGateway(t, session, tgwName) + + defer nukeAllTransitGatewayInstances(session, []*string{tgw.TransitGatewayId}) + + ids, err := getAllTransitGatewayInstances(session, region, time.Now().Add(1*time.Hour*-1)) + require.NoError(t, err) + assert.NotContains(t, awsgo.StringValueSlice(ids), awsgo.StringValue(tgw.TransitGatewayId)) + + ids, err = getAllTransitGatewayInstances(session, region, time.Now().Add(1*time.Hour)) + require.NoError(t, err) + assert.Contains(t, awsgo.StringValueSlice(ids), awsgo.StringValue(tgw.TransitGatewayId)) +} + +func TestNukeTransitGateway(t *testing.T) { + t.Parallel() + + region, err := getRandomRegion() + if err != nil { + assert.Fail(t, errors.WithStackTrace(err).Error()) + } + require.NoError(t, err) + + session, err := session.NewSession(&awsgo.Config{ + Region: awsgo.String(region)}, + ) + require.NoError(t, err) + + svc := ec2.New(session) + + tgwName := "cloud-nuke-test-" + util.UniqueID() + tgw := createTestTransitGateway(t, session, tgwName) + + _, err = svc.DescribeTransitGateways(&ec2.DescribeTransitGatewaysInput{ + TransitGatewayIds: []*string{ + tgw.TransitGatewayId, + }, + }) + require.NoError(t, err) + + err = nukeAllTransitGatewayInstances(session, []*string{tgw.TransitGatewayId}) + require.NoError(t, err) + + ids, err := getAllTransitGatewayInstances(session, region, time.Now().Add(1*time.Hour)) + require.NoError(t, err) + + assert.NotContains(t, awsgo.StringValueSlice(ids), awsgo.StringValue(tgw.TransitGatewayId)) +} + +func createTestTransitGatewayRouteTable(t *testing.T, session *session.Session, name string) ec2.TransitGatewayRouteTable { + svc := ec2.New(session) + + transitGateway := createTestTransitGateway(t, session, name) + + tgwRouteTableName := ec2.TagSpecification{ + ResourceType: awsgo.String(ec2.ResourceTypeTransitGatewayRouteTable), + Tags: []*ec2.Tag{ + { + Key: aws.String("Name"), + Value: aws.String(name), + }, + }, + } + + param := &ec2.CreateTransitGatewayRouteTableInput{ + TagSpecifications: []*ec2.TagSpecification{&tgwRouteTableName}, + TransitGatewayId: transitGateway.TransitGatewayId, + } + + result, err := svc.CreateTransitGatewayRouteTable(param) + require.NoError(t, err) + require.True(t, len(aws.StringValue(result.TransitGatewayRouteTable.TransitGatewayRouteTableId)) > 0, "Could not create test TransitGateway Route Table") + + sleepMessage := "TransitGateway Route Tables takes some time to create, and since there is no waiter available, we sleep instead." + sleepFor := 180 * time.Second + sleepWithMessage(sleepFor, sleepMessage) + + return *result.TransitGatewayRouteTable +} + +func TestGetAllTransitGatewayRouteTableInstances(t *testing.T) { + t.Parallel() + + region, err := getRandomRegion() + require.NoError(t, err) + + session, err := session.NewSession(&awsgo.Config{ + Region: awsgo.String(region)}, + ) + require.NoError(t, err) + + tgwRouteTableName := "cloud-nuke-test-" + util.UniqueID() + tgwRouteTable := createTestTransitGatewayRouteTable(t, session, tgwRouteTableName) + + defer nukeAllTransitGatewayRouteTables(session, []*string{tgwRouteTable.TransitGatewayRouteTableId}) + defer nukeAllTransitGatewayInstances(session, []*string{tgwRouteTable.TransitGatewayId}) + + ids, err := getAllTransitGatewayRouteTables(session, region, time.Now().Add(1*time.Hour*-1)) + require.NoError(t, err) + assert.NotContains(t, awsgo.StringValueSlice(ids), awsgo.StringValue(tgwRouteTable.TransitGatewayRouteTableId)) + + ids, err = getAllTransitGatewayRouteTables(session, region, time.Now().Add(1*time.Hour)) + require.NoError(t, err) + assert.Contains(t, awsgo.StringValueSlice(ids), awsgo.StringValue(tgwRouteTable.TransitGatewayRouteTableId)) +} + +func TestNukeTransitGatewayRouteTable(t *testing.T) { + t.Parallel() + + region, err := getRandomRegion() + require.NoError(t, err) + + session, err := session.NewSession(&awsgo.Config{ + Region: awsgo.String(region)}, + ) + require.NoError(t, err) + + svc := ec2.New(session) + + tgwRouteTableName := "cloud-nuke-test-" + util.UniqueID() + tgwRouteTable := createTestTransitGatewayRouteTable(t, session, tgwRouteTableName) + defer nukeAllTransitGatewayInstances(session, []*string{tgwRouteTable.TransitGatewayId}) + + _, err = svc.DescribeTransitGatewayRouteTables(&ec2.DescribeTransitGatewayRouteTablesInput{ + TransitGatewayRouteTableIds: []*string{ + tgwRouteTable.TransitGatewayRouteTableId, + }, + }) + require.NoError(t, err) + + err = nukeAllTransitGatewayRouteTables(session, []*string{tgwRouteTable.TransitGatewayRouteTableId}) + require.NoError(t, err) + + ids, err := getAllTransitGatewayRouteTables(session, region, time.Now().Add(1*time.Hour)) + require.NoError(t, err) + assert.NotContains(t, awsgo.StringValueSlice(ids), awsgo.StringValue(tgwRouteTable.TransitGatewayRouteTableId)) +} + +func createTestTransitGatewayVpcAttachment(t *testing.T, session *session.Session, name string) ec2.TransitGatewayVpcAttachment { + svc := ec2.New(session) + + transitGateway := createTestTransitGateway(t, session, name) + + input := &ec2.DescribeVpcsInput{ + Filters: []*ec2.Filter{ + { + Name: awsgo.String("isDefault"), + Values: []*string{awsgo.String("true")}, + }, + }, + } + + vpcs, err := svc.DescribeVpcs(input) + assert.NoError(t, err) + require.NoError(t, err) + require.Len(t, vpcs.Vpcs, 1) + + vpc := vpcs.Vpcs[0] + + subnets := getVpcSubnets(t, session, awsgo.StringValue(vpc.VpcId)) + + tgwVpctAttachmentName := ec2.TagSpecification{ + ResourceType: awsgo.String(ec2.ResourceTypeTransitGatewayAttachment), + Tags: []*ec2.Tag{ + { + Key: aws.String("Name"), + Value: aws.String(name), + }, + }, + } + + param := &ec2.CreateTransitGatewayVpcAttachmentInput{ + TagSpecifications: []*ec2.TagSpecification{&tgwVpctAttachmentName}, + TransitGatewayId: transitGateway.TransitGatewayId, + VpcId: vpc.VpcId, + SubnetIds: awsgo.StringSlice(subnets), + } + + result, err := svc.CreateTransitGatewayVpcAttachment(param) + require.NoError(t, err) + require.True(t, len(aws.StringValue(result.TransitGatewayVpcAttachment.TransitGatewayAttachmentId)) > 0, "Could not create test Transitgateway Vpc Attachment") + + sleepMessage := "TransitGateway Vpc Attachment takes some time to create, and since there is no waiter available, we sleep instead." + sleepFor := 180 * time.Second + sleepWithMessage(sleepFor, sleepMessage) + + return *result.TransitGatewayVpcAttachment +} + +func TestGetAllTransitGatewayVpcAttachment(t *testing.T) { + t.Parallel() + + region, err := getRandomRegion() + require.NoError(t, err) + + session, err := session.NewSession(&awsgo.Config{ + Region: awsgo.String(region)}, + ) + require.NoError(t, err) + + tgwName := "cloud-nuke-test-" + util.UniqueID() + tgwAttachment := createTestTransitGatewayVpcAttachment(t, session, tgwName) + + defer nukeAllTransitGatewayVpcAttachments(session, []*string{tgwAttachment.TransitGatewayAttachmentId}) + defer nukeAllTransitGatewayInstances(session, []*string{tgwAttachment.TransitGatewayId}) + + ids, err := getAllTransitGatewayVpcAttachments(session, region, time.Now().Add(1*time.Hour*-1)) + require.NoError(t, err) + assert.NotContains(t, awsgo.StringValueSlice(ids), awsgo.StringValue(tgwAttachment.TransitGatewayAttachmentId)) + + ids, err = getAllTransitGatewayVpcAttachments(session, region, time.Now().Add(1*time.Hour)) + require.NoError(t, err) + assert.Contains(t, awsgo.StringValueSlice(ids), awsgo.StringValue(tgwAttachment.TransitGatewayAttachmentId)) +} + +func TestNukeTransitGatewayVpcAttachment(t *testing.T) { + t.Parallel() + + region, err := getRandomRegion() + if err != nil { + assert.Fail(t, errors.WithStackTrace(err).Error()) + } + require.NoError(t, err) + + session, err := session.NewSession(&awsgo.Config{ + Region: awsgo.String(region)}, + ) + require.NoError(t, err) + + svc := ec2.New(session) + + tgwVpcAttachmentName := "cloud-nuke-test-" + util.UniqueID() + tgwVpcAttachment := createTestTransitGatewayVpcAttachment(t, session, tgwVpcAttachmentName) + _, err = svc.DescribeTransitGatewayVpcAttachments(&ec2.DescribeTransitGatewayVpcAttachmentsInput{ + TransitGatewayAttachmentIds: []*string{ + tgwVpcAttachment.TransitGatewayAttachmentId, + }, + }) + require.NoError(t, err) + defer nukeAllTransitGatewayInstances(session, []*string{tgwVpcAttachment.TransitGatewayId}) + + err = nukeAllTransitGatewayVpcAttachments(session, []*string{tgwVpcAttachment.TransitGatewayAttachmentId}) + require.NoError(t, err) + + ids, err := getAllTransitGatewayVpcAttachments(session, region, time.Now().Add(1*time.Hour)) + require.NoError(t, err) + assert.NotContains(t, awsgo.StringValueSlice(ids), aws.StringValue(tgwVpcAttachment.TransitGatewayAttachmentId)) +} diff --git a/aws/transit_gateway_types.go b/aws/transit_gateway_types.go new file mode 100644 index 00000000..efff3e26 --- /dev/null +++ b/aws/transit_gateway_types.go @@ -0,0 +1,94 @@ +package aws + +import ( + awsgo "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/gruntwork-io/gruntwork-cli/errors" +) + +// TransitGatewaysVpcAttachment - represents all transit gateways vpc attachments +type TransitGatewaysVpcAttachment struct { + Ids []string +} + +// ResourceName - the simple name of the aws resource +func (tgw TransitGatewaysVpcAttachment) ResourceName() string { + return "transit-gateway-attachment" +} + +// MaxBatchSize - Tentative batch size to ensure AWS doesn't throttle +func (tgw TransitGatewaysVpcAttachment) MaxBatchSize() int { + return maxBatchSize +} + +// ResourceIdentifiers - The Ids of the transit gateways +func (tgw TransitGatewaysVpcAttachment) ResourceIdentifiers() []string { + return tgw.Ids +} + +// Nuke - nuke 'em all!!! +func (tgw TransitGatewaysVpcAttachment) Nuke(session *session.Session, identifiers []string) error { + if err := nukeAllTransitGatewayVpcAttachments(session, awsgo.StringSlice(identifiers)); err != nil { + return errors.WithStackTrace(err) + } + + return nil +} + +// TransitGatewaysRouteTables - represents all transit gateways route tables +type TransitGatewaysRouteTables struct { + Ids []string +} + +// ResourceName - the simple name of the aws resource +func (tgw TransitGatewaysRouteTables) ResourceName() string { + return "transit-gateway-route-table" +} + +// MaxBatchSize - Tentative batch size to ensure AWS doesn't throttle +func (tgw TransitGatewaysRouteTables) MaxBatchSize() int { + return maxBatchSize +} + +// ResourceIdentifiers - The arns of the transit gateways route tables +func (tgw TransitGatewaysRouteTables) ResourceIdentifiers() []string { + return tgw.Ids +} + +// Nuke - nuke 'em all!!! +func (tgw TransitGatewaysRouteTables) Nuke(session *session.Session, identifiers []string) error { + if err := nukeAllTransitGatewayRouteTables(session, awsgo.StringSlice(identifiers)); err != nil { + return errors.WithStackTrace(err) + } + + return nil +} + +// TransitGateways - represents all transit gateways +type TransitGateways struct { + Ids []string +} + +// ResourceName - the simple name of the aws resource +func (tgw TransitGateways) ResourceName() string { + return "transit-gateway" +} + +// MaxBatchSize - Tentative batch size to ensure AWS doesn't throttle +func (tgw TransitGateways) MaxBatchSize() int { + return maxBatchSize +} + +// ResourceIdentifiers - The Ids of the transit gateways +func (tgw TransitGateways) ResourceIdentifiers() []string { + return tgw.Ids +} + +// Nuke - nuke 'em all!!! +func (tgw TransitGateways) Nuke(session *session.Session, identifiers []string) error { + if err := nukeAllTransitGatewayInstances(session, awsgo.StringSlice(identifiers)); err != nil { + return errors.WithStackTrace(err) + } + + return nil +} diff --git a/go.sum b/go.sum index 2746f931..a38c08f9 100644 --- a/go.sum +++ b/go.sum @@ -232,7 +232,8 @@ github.com/grpc-ecosystem/grpc-gateway v1.9.5/go.mod h1:vNeuVxBJEsws4ogUvrchl83t github.com/gruntwork-io/go-commons v0.8.2 h1:2jrQH6ou6GxShXpNmxhVuVktp5E2so115nSESbbDOj0= github.com/gruntwork-io/go-commons v0.8.2/go.mod h1:aH1kYhkEgb7+RRMDVVKFXBBX0KfECzEhp1UYmU12oO4= github.com/gruntwork-io/gruntwork-cli v0.7.0 h1:YgSAmfCj9c61H+zuvHwKfYUwlMhu5arnQQLM4RH+CYs= -github.com/gruntwork-io/gruntwork-cli v0.7.0/go.mod h1:jp6Z7NcLF2avpY8v71fBx6hds9eOFPELSuD/VPv7w00= +github.com/gruntwork-io/gruntwork-cli v0.7.1 h1:F/GEuj3NiBY+qV+1RvFW7J7CN+8bzXvaYGRCCw7Hq7Y= +github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI= github.com/gruntwork-io/kubergrunt v0.6.10/go.mod h1:AjSwJPP107t8pihDgJCWCG/RG92Q1oiRXL/OdR6OiaQ= github.com/gruntwork-io/terratest v0.30.0/go.mod h1:7dNmTD2zDKUEVqfmvcUU5c9mZi+986mcXNzhzqPYPg8= github.com/gruntwork-io/terratest v0.32.9 h1:ciWWJxISk06LAYImn6h1Vvir8hUz13VtwT2//fYCDcA=