Skip to content

Commit

Permalink
Add feature to delete Transit Gateway and its attachments (#178)
Browse files Browse the repository at this point in the history
  • Loading branch information
rafaelleonardocruz authored Mar 31, 2021
1 parent 3f1d764 commit 67a87cf
Show file tree
Hide file tree
Showing 13 changed files with 670 additions and 17 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ The currently supported functionality includes:

- Deleting all Auto scaling groups in an AWS account
- Deleting all Elastic Load Balancers (Classic and V2) in an AWS account
- Deleting all Transit Gateways in an AWS account
- Deleting all EBS Volumes in an AWS account
- Deleting all unprotected EC2 instances in an AWS account
- Deleting all AMIs in an AWS account
Expand Down
45 changes: 45 additions & 0 deletions aws/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,48 @@ func GetAllResources(targetRegions []string, excludeAfter time.Time, resourceTyp
}
// End LoadBalancerV2 Arns

// TransitGatewayVpcAttachment
transitGatewayVpcAttachments := TransitGatewaysVpcAttachment{}
if IsNukeable(transitGatewayVpcAttachments.ResourceName(), resourceTypes) {
transitGatewayVpcAttachmentIds, err := getAllTransitGatewayVpcAttachments(session, region, excludeAfter)
if err != nil {
return nil, errors.WithStackTrace(err)
}
if len(transitGatewayVpcAttachmentIds) > 0 {
transitGatewayVpcAttachments.Ids = awsgo.StringValueSlice(transitGatewayVpcAttachmentIds)
resourcesInRegion.Resources = append(resourcesInRegion.Resources, transitGatewayVpcAttachments)
}
}
// End TransitGatewayVpcAttachment

// TransitGatewayRouteTable
transitGatewayRouteTables := TransitGatewaysRouteTables{}
if IsNukeable(transitGatewayRouteTables.ResourceName(), resourceTypes) {
transitGatewayRouteTableIds, err := getAllTransitGatewayRouteTables(session, region, excludeAfter)
if err != nil {
return nil, errors.WithStackTrace(err)
}
if len(transitGatewayRouteTableIds) > 0 {
transitGatewayRouteTables.Ids = awsgo.StringValueSlice(transitGatewayRouteTableIds)
resourcesInRegion.Resources = append(resourcesInRegion.Resources, transitGatewayRouteTables)
}
}
// End TransitGatewayRouteTable

// TransitGateway
transitGateways := TransitGateways{}
if IsNukeable(transitGateways.ResourceName(), resourceTypes) {
transitGatewayIds, err := getAllTransitGatewayInstances(session, region, excludeAfter)
if err != nil {
return nil, errors.WithStackTrace(err)
}
if len(transitGatewayIds) > 0 {
transitGateways.Ids = awsgo.StringValueSlice(transitGatewayIds)
resourcesInRegion.Resources = append(resourcesInRegion.Resources, transitGateways)
}
}
// End TransitGateway

// EC2 Instances
ec2Instances := EC2Instances{}
if IsNukeable(ec2Instances.ResourceName(), resourceTypes) {
Expand Down Expand Up @@ -535,6 +577,9 @@ func ListResourceTypes() []string {
LaunchConfigs{}.ResourceName(),
LoadBalancers{}.ResourceName(),
LoadBalancersV2{}.ResourceName(),
TransitGatewaysVpcAttachment{}.ResourceName(),
TransitGatewaysRouteTables{}.ResourceName(),
TransitGateways{}.ResourceName(),
EC2Instances{}.ResourceName(),
EBSVolumes{}.ResourceName(),
EIPAddresses{}.ResourceName(),
Expand Down
34 changes: 34 additions & 0 deletions aws/ec2_utils_for_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package aws

import (
"testing"

"github.com/aws/aws-sdk-go/aws"
awsgo "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/stretchr/testify/require"
)

func getVpcSubnets(t *testing.T, session *session.Session, vpcId string) []string {
svc := ec2.New(session)

param := &ec2.DescribeSubnetsInput{
Filters: []*ec2.Filter{
{
Name: aws.String("vpc-id"),
Values: aws.StringSlice([]string{vpcId}),
},
},
}

result, err := svc.DescribeSubnets(param)
require.NoError(t, err)

var subnets []string

for _, v := range result.Subnets {
subnets = append(subnets, awsgo.StringValue(v.SubnetId))
}
return subnets
}
9 changes: 4 additions & 5 deletions aws/ecs_cluster_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func TestCanTagEcsClusters(t *testing.T) {
})
require.NoError(t, err)

cluster := createEcsFargateCluster(t, awsSession, "cloud-nuke-test-" + util.UniqueID())
cluster := createEcsFargateCluster(t, awsSession, "cloud-nuke-test-"+util.UniqueID())
defer deleteEcsCluster(awsSession, cluster)

tagValue := time.Now().UTC()
Expand Down Expand Up @@ -83,11 +83,11 @@ func TestCanNukeAllEcsClustersOlderThan24Hours(t *testing.T) {
})
require.NoError(t, err)

cluster1 := createEcsFargateCluster(t, awsSession, "cloud-nuke-test-" + util.UniqueID())
cluster1 := createEcsFargateCluster(t, awsSession, "cloud-nuke-test-"+util.UniqueID())
defer deleteEcsCluster(awsSession, cluster1)
cluster2 := createEcsFargateCluster(t, awsSession, "cloud-nuke-test-" + util.UniqueID())
cluster2 := createEcsFargateCluster(t, awsSession, "cloud-nuke-test-"+util.UniqueID())
defer deleteEcsCluster(awsSession, cluster2)
cluster3 := createEcsFargateCluster(t, awsSession, "cloud-nuke-test-" + util.UniqueID())
cluster3 := createEcsFargateCluster(t, awsSession, "cloud-nuke-test-"+util.UniqueID())
defer deleteEcsCluster(awsSession, cluster3)

now := time.Now().UTC()
Expand All @@ -114,4 +114,3 @@ func TestCanNukeAllEcsClustersOlderThan24Hours(t *testing.T) {

assert.Contains(t, awsgo.StringValueSlice(allLeftClusterArns), awsgo.StringValue(cluster2.ClusterArn))
}

7 changes: 3 additions & 4 deletions aws/rds_cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ func waitUntilRdsClusterDeleted(svc *rds.RDS, input *rds.DescribeDBClustersInput
for i := 0; i < 90; i++ {
_, err := svc.DescribeDBClusters(input)
if err != nil {
if awsErr, isAwsErr := err.(awserr.Error); isAwsErr && awsErr.Code() == rds.ErrCodeDBClusterNotFoundFault {
if awsErr, isAwsErr := err.(awserr.Error); isAwsErr && awsErr.Code() == rds.ErrCodeDBClusterNotFoundFault {
return nil
}

Expand All @@ -30,7 +30,6 @@ func waitUntilRdsClusterDeleted(svc *rds.RDS, input *rds.DescribeDBClustersInput
return RdsDeleteError{name: *input.DBClusterIdentifier}
}


func getAllRdsClusters(session *session.Session, excludeAfter time.Time) ([]*string, error) {
svc := rds.New(session)

Expand All @@ -44,7 +43,7 @@ func getAllRdsClusters(session *session.Session, excludeAfter time.Time) ([]*str

for _, database := range result.DBClusters {
if excludeAfter.After(*database.ClusterCreateTime) {
names = append(names, database.DBClusterIdentifier)
names = append(names, database.DBClusterIdentifier)
}
}

Expand All @@ -65,7 +64,7 @@ func nukeAllRdsClusters(session *session.Session, names []*string) error {
for _, name := range names {
params := &rds.DeleteDBClusterInput{
DBClusterIdentifier: name,
SkipFinalSnapshot: awsgo.Bool(true),
SkipFinalSnapshot: awsgo.Bool(true),
}

_, err := svc.DeleteDBCluster(params)
Expand Down
8 changes: 4 additions & 4 deletions aws/rds_cluster_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ func createTestRDSCluster(t *testing.T, session *session.Session, name string) {
svc := rds.New(session)
params := &rds.CreateDBClusterInput{
DBClusterIdentifier: awsgo.String(name),
Engine: awsgo.String("aurora-mysql"),
MasterUsername: awsgo.String("gruntwork"),
MasterUserPassword: awsgo.String("password"),
Engine: awsgo.String("aurora-mysql"),
MasterUsername: awsgo.String("gruntwork"),
MasterUserPassword: awsgo.String("password"),
}

_, err := svc.CreateDBCluster(params)
Expand All @@ -39,7 +39,7 @@ func TestNukeRDSCluster(t *testing.T) {
)

rdsName := "cloud-nuke-test" + util.UniqueID()
excludeAfter := time.Now().Add(1*time.Hour)
excludeAfter := time.Now().Add(1 * time.Hour)

createTestRDSCluster(t, session, rdsName)

Expand Down
1 change: 0 additions & 1 deletion aws/rds_cluster_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,3 @@ func (instance DBClusters) Nuke(session *session.Session, identifiers []string)

return nil
}

2 changes: 1 addition & 1 deletion aws/rds_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func TestNukeRDSInstance(t *testing.T) {
)

rdsName := "cloud-nuke-test-" + util.UniqueID()
excludeAfter := time.Now().Add(1*time.Hour)
excludeAfter := time.Now().Add(1 * time.Hour)

createTestRDSInstance(t, session, rdsName)

Expand Down
2 changes: 1 addition & 1 deletion aws/rds_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func (instance DBInstances) Nuke(session *session.Session, identifiers []string)
return nil
}

type RdsDeleteError struct{
type RdsDeleteError struct {
name string
}

Expand Down
177 changes: 177 additions & 0 deletions aws/transit_gateway.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
package aws

import (
"time"

"github.com/aws/aws-sdk-go/aws"
awsgo "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/gruntwork-io/cloud-nuke/logging"
"github.com/gruntwork-io/gruntwork-cli/errors"
)

func sleepWithMessage(duration time.Duration, whySleepMessage string) {
logging.Logger.Infof("Sleeping %v: %s", duration, whySleepMessage)
time.Sleep(duration)
}

// Returns a formatted string of TransitGateway IDs
func getAllTransitGatewayInstances(session *session.Session, region string, excludeAfter time.Time) ([]*string, error) {
svc := ec2.New(session)
result, err := svc.DescribeTransitGateways(&ec2.DescribeTransitGatewaysInput{})
if err != nil {
return nil, errors.WithStackTrace(err)
}

var ids []*string
for _, transitGateway := range result.TransitGateways {
if excludeAfter.After(*transitGateway.CreationTime) && awsgo.StringValue(transitGateway.State) != "deleted" && awsgo.StringValue(transitGateway.State) != "deleting" {
ids = append(ids, transitGateway.TransitGatewayId)
}
}

return ids, nil
}

// Delete all TransitGateways
func nukeAllTransitGatewayInstances(session *session.Session, ids []*string) error {
svc := ec2.New(session)

if len(ids) == 0 {
logging.Logger.Infof("No Transit Gateways to nuke in region %s", *session.Config.Region)
return nil
}

logging.Logger.Infof("Deleting all Transit Gateways in region %s", *session.Config.Region)
var deletedIds []*string

for _, id := range ids {
params := &ec2.DeleteTransitGatewayInput{
TransitGatewayId: id,
}

_, err := svc.DeleteTransitGateway(params)
if err != nil {
logging.Logger.Errorf("[Failed] %s", err)
} else {
deletedIds = append(deletedIds, id)
logging.Logger.Infof("Deleted Transit Gateway: %s", *id)
}
}

logging.Logger.Infof("[OK] %d Transit Gateway(s) deleted in %s", len(deletedIds), *session.Config.Region)
return nil
}

// Returns a formatted string of TranstGatewayRouteTable IDs
func getAllTransitGatewayRouteTables(session *session.Session, region string, excludeAfter time.Time) ([]*string, error) {
svc := ec2.New(session)

// Remove defalt route table, that will be deleted along with its TransitGateway
param := &ec2.DescribeTransitGatewayRouteTablesInput{
Filters: []*ec2.Filter{
{
Name: aws.String("default-association-route-table"),
Values: []*string{
aws.String("false"),
}},
},
}

result, err := svc.DescribeTransitGatewayRouteTables(param)
if err != nil {
return nil, errors.WithStackTrace(err)
}

var ids []*string
for _, transitGatewayRouteTable := range result.TransitGatewayRouteTables {
if excludeAfter.After(*transitGatewayRouteTable.CreationTime) && awsgo.StringValue(transitGatewayRouteTable.State) != "deleted" && awsgo.StringValue(transitGatewayRouteTable.State) != "deleting" {
ids = append(ids, transitGatewayRouteTable.TransitGatewayRouteTableId)
}
}

return ids, nil
}

// Delete all TransitGatewayRouteTables
func nukeAllTransitGatewayRouteTables(session *session.Session, ids []*string) error {
svc := ec2.New(session)

if len(ids) == 0 {
logging.Logger.Infof("No Transit Gateway Route Tables to nuke in region %s", *session.Config.Region)
return nil
}

logging.Logger.Infof("Deleting all Transit Gateway Route Tables in region %s", *session.Config.Region)
var deletedIds []*string

for _, id := range ids {
param := &ec2.DeleteTransitGatewayRouteTableInput{
TransitGatewayRouteTableId: id,
}

_, err := svc.DeleteTransitGatewayRouteTable(param)
if err != nil {
logging.Logger.Errorf("[Failed] %s", err)
} else {
deletedIds = append(deletedIds, id)
logging.Logger.Infof("Deleted Transit Gateway Route Table: %s", *id)
}
}

logging.Logger.Infof("[OK] %d Transit Gateway Route Table(s) deleted in %s", len(deletedIds), *session.Config.Region)
return nil
}

// Returns a formated string of TransitGatewayVpcAttachment IDs
func getAllTransitGatewayVpcAttachments(session *session.Session, region string, excludeAfter time.Time) ([]*string, error) {
svc := ec2.New(session)
result, err := svc.DescribeTransitGatewayVpcAttachments(&ec2.DescribeTransitGatewayVpcAttachmentsInput{})
if err != nil {
return nil, errors.WithStackTrace(err)
}

var ids []*string
for _, tgwVpcAttachment := range result.TransitGatewayVpcAttachments {
if excludeAfter.After(*tgwVpcAttachment.CreationTime) && awsgo.StringValue(tgwVpcAttachment.State) != "deleted" && awsgo.StringValue(tgwVpcAttachment.State) != "deleting" {
ids = append(ids, tgwVpcAttachment.TransitGatewayAttachmentId)
}
}

return ids, nil
}

// Delete all TransitGatewayVpcAttachments
func nukeAllTransitGatewayVpcAttachments(session *session.Session, ids []*string) error {
svc := ec2.New(session)

if len(ids) == 0 {
logging.Logger.Infof("No Transit Gateway Vpc Attachments to nuke in region %s", *session.Config.Region)
return nil
}

logging.Logger.Infof("Deleting all Transit Gateway Vpc Attachments in region %s", *session.Config.Region)
var deletedIds []*string

for _, id := range ids {
param := &ec2.DeleteTransitGatewayVpcAttachmentInput{
TransitGatewayAttachmentId: id,
}

_, err := svc.DeleteTransitGatewayVpcAttachment(param)
if err != nil {
logging.Logger.Errorf("[Failed] %s", err)
} else {
deletedIds = append(deletedIds, id)
logging.Logger.Infof("Deleted Transit Gateway Vpc Attachment: %s", *id)
}
}

sleepMessage := "TransitGateway Vpc Attachments takes some time to create, and since there is no waiter available, we sleep instead."
sleepFor := 180 * time.Second
sleepWithMessage(sleepFor, sleepMessage)

logging.Logger.Infof(("[OK] %d Transit Gateway Vpc Attachment(s) deleted in %s"), len(deletedIds), *session.Config.Region)
return nil
}
Loading

0 comments on commit 67a87cf

Please sign in to comment.