diff --git a/controllers/awscluster_controller_test.go b/controllers/awscluster_controller_test.go index d0f51fa4e4..f1354ec840 100644 --- a/controllers/awscluster_controller_test.go +++ b/controllers/awscluster_controller_test.go @@ -776,6 +776,17 @@ func mockedDeleteInstanceCalls(m *mocks.MockEC2APIMockRecorder) { } func mockedVPCCallsForExistingVPCAndSubnets(m *mocks.MockEC2APIMockRecorder) { + m.DescribeNatGatewaysPagesWithContext(context.TODO(), gomock.Eq(&ec2.DescribeNatGatewaysInput{ + Filter: []*ec2.Filter{ + { + Name: aws.String("vpc-id"), + Values: []*string{aws.String("vpc-exists")}, + }, + { + Name: aws.String("state"), + Values: aws.StringSlice([]string{ec2.VpcStatePending, ec2.VpcStateAvailable}), + }, + }}), gomock.Any()).Return(nil) m.CreateTagsWithContext(context.TODO(), gomock.Eq(&ec2.CreateTagsInput{ Resources: aws.StringSlice([]string{"subnet-1"}), Tags: []*ec2.Tag{ diff --git a/pkg/cloud/services/network/natgateways.go b/pkg/cloud/services/network/natgateways.go index 665f5cc250..6b14c6e712 100644 --- a/pkg/cloud/services/network/natgateways.go +++ b/pkg/cloud/services/network/natgateways.go @@ -41,6 +41,10 @@ import ( func (s *Service) reconcileNatGateways() error { if s.scope.VPC().IsUnmanaged(s.scope.Name()) { s.scope.Trace("Skipping NAT gateway reconcile in unmanaged mode") + _, err := s.updateNatGatewayIPs(false) + if err != nil { + return err + } return nil } @@ -66,44 +70,11 @@ func (s *Service) reconcileNatGateways() error { return nil } - existing, err := s.describeNatGatewaysBySubnet() + subnetIDs, err := s.updateNatGatewayIPs(true) if err != nil { return err } - natGatewaysIPs := []string{} - subnetIDs := []string{} - - for _, sn := range s.scope.Subnets().FilterPublic().FilterNonCni() { - if sn.GetResourceID() == "" { - continue - } - - if ngw, ok := existing[sn.GetResourceID()]; ok { - if len(ngw.NatGatewayAddresses) > 0 && ngw.NatGatewayAddresses[0].PublicIp != nil { - natGatewaysIPs = append(natGatewaysIPs, *ngw.NatGatewayAddresses[0].PublicIp) - } - // Make sure tags are up to date. - if err := wait.WaitForWithRetryable(wait.NewBackoff(), func() (bool, error) { - buildParams := s.getNatGatewayTagParams(*ngw.NatGatewayId) - tagsBuilder := tags.New(&buildParams, tags.WithEC2(s.EC2Client)) - if err := tagsBuilder.Ensure(converters.TagsToMap(ngw.Tags)); err != nil { - return false, err - } - return true, nil - }, awserrors.ResourceNotFound); err != nil { - record.Warnf(s.scope.InfraCluster(), "FailedTagNATGateway", "Failed to tag managed NAT Gateway %q: %v", *ngw.NatGatewayId, err) - return errors.Wrapf(err, "failed to tag nat gateway %q", *ngw.NatGatewayId) - } - - continue - } - - subnetIDs = append(subnetIDs, sn.GetResourceID()) - } - - s.scope.SetNatGatewaysIPs(natGatewaysIPs) - // Batch the creation of NAT gateways if len(subnetIDs) > 0 { // set NatGatewayCreationStarted if the condition has never been set before @@ -133,6 +104,49 @@ func (s *Service) reconcileNatGateways() error { return nil } +func (s *Service) updateNatGatewayIPs(updateTags bool) ([]string, error) { + existing, err := s.describeNatGatewaysBySubnet() + if err != nil { + return nil, err + } + + natGatewaysIPs := []string{} + subnetIDs := []string{} + + for _, sn := range s.scope.Subnets().FilterPublic().FilterNonCni() { + if sn.GetResourceID() == "" { + continue + } + + if ngw, ok := existing[sn.GetResourceID()]; ok { + if len(ngw.NatGatewayAddresses) > 0 && ngw.NatGatewayAddresses[0].PublicIp != nil { + natGatewaysIPs = append(natGatewaysIPs, *ngw.NatGatewayAddresses[0].PublicIp) + } + if updateTags { + // Make sure tags are up to date. + if err := wait.WaitForWithRetryable(wait.NewBackoff(), func() (bool, error) { + buildParams := s.getNatGatewayTagParams(*ngw.NatGatewayId) + tagsBuilder := tags.New(&buildParams, tags.WithEC2(s.EC2Client)) + if err := tagsBuilder.Ensure(converters.TagsToMap(ngw.Tags)); err != nil { + return false, err + } + return true, nil + }, awserrors.ResourceNotFound); err != nil { + record.Warnf(s.scope.InfraCluster(), "FailedTagNATGateway", "Failed to tag managed NAT Gateway %q: %v", *ngw.NatGatewayId, err) + return nil, errors.Wrapf(err, "failed to tag nat gateway %q", *ngw.NatGatewayId) + } + } + + continue + } + + subnetIDs = append(subnetIDs, sn.GetResourceID()) + } + + s.scope.SetNatGatewaysIPs(natGatewaysIPs) + return subnetIDs, nil +} + func (s *Service) deleteNatGateways() error { if s.scope.VPC().IsUnmanaged(s.scope.Name()) { s.scope.Trace("Skipping NAT gateway deletion in unmanaged mode")