diff --git a/pkg/provider/azure_loadbalancer.go b/pkg/provider/azure_loadbalancer.go index ea584b5b96..5b92ce0674 100644 --- a/pkg/provider/azure_loadbalancer.go +++ b/pkg/provider/azure_loadbalancer.go @@ -3076,6 +3076,11 @@ func (az *Cloud) reconcileSecurityRules(sg network.SecurityGroup, updatedRules = removeDuplicatedSecurityRules(updatedRules) + if dirtySg || shouldTidySecurityRules(updatedRules) { + updatedRules = tidySecurityRules(updatedRules) + dirtySg = true + } + for _, r := range updatedRules { klog.V(10).Infof("Updated security rule while processing %s: %s:%s -> %s:%s", service.Name, logSafe(r.SourceAddressPrefix), logSafe(r.SourcePortRange), logSafeCollection(r.DestinationAddressPrefix, r.DestinationAddressPrefixes), logSafe(r.DestinationPortRange)) } @@ -3083,6 +3088,96 @@ func (az *Cloud) reconcileSecurityRules(sg network.SecurityGroup, return dirtySg, updatedRules, nil } +// shouldTidySecurityRules returns true if the priorities of rules should be re-ordered. +// The rules should be tidied if there are any allow rules after deny rules which will not work. +func shouldTidySecurityRules(rules []network.SecurityRule) bool { + if len(rules) <= 1 { + return false + } + sort.Slice(rules, func(i, j int) bool { + return *rules[i].Priority < *rules[j].Priority + }) + + denyRuleFound := false + for i := range rules { + if *rules[i].Priority < consts.LoadBalancerMinimumPriority { + continue + } + if *rules[i].Priority > consts.LoadBalancerMaximumPriority { + break + } + + switch rules[i].Access { + case network.SecurityRuleAccessDeny: + denyRuleFound = true + case network.SecurityRuleAccessAllow: + if denyRuleFound { + // Allow rule after deny rule, should tidy + return true + } + } + } + return false +} + +// tidySecurityRules reorders the rules to make the order deterministic and to ensure that the rules are in the correct order. +func tidySecurityRules(rules []network.SecurityRule) []network.SecurityRule { + if len(rules) <= 1 { + return rules + } + var ( + allowRules []network.SecurityRule + denyAllRules []network.SecurityRule + unmanagedRules []network.SecurityRule // rules priority not in the range of cloud-provider; keep them untouched + ) + for _, rule := range rules { + p := *rule.Priority + if p < consts.LoadBalancerMinimumPriority || p > consts.LoadBalancerMaximumPriority { + unmanagedRules = append(unmanagedRules, rule) + continue + } + + switch rule.Access { + case network.SecurityRuleAccessAllow: + allowRules = append(allowRules, rule) + case network.SecurityRuleAccessDeny: + denyAllRules = append(denyAllRules, rule) + } + } + + // Tidy allow rules + { + sort.Slice(allowRules, func(i, j int) bool { + // Sort by name to make the order deterministic + return *allowRules[i].Name < *allowRules[j].Name + }) + p := int32(consts.LoadBalancerMinimumPriority) + for i := range allowRules { + allowRules[i].Priority = pointer.Int32(p) + p++ + } + } + // Tidy deny rules + { + sort.Slice(denyAllRules, func(i, j int) bool { + // Sort by name to make the order deterministic + return *denyAllRules[i].Name < *denyAllRules[j].Name + }) + p := int32(consts.LoadBalancerMaximumPriority) + for i := range denyAllRules { + denyAllRules[i].Priority = pointer.Int32(p) + p-- + } + } + + rv := append(append(allowRules, denyAllRules...), unmanagedRules...) + sort.Slice(rv, func(i, j int) bool { + return *rv[i].Priority < *rv[j].Priority + }) + + return rv +} + func (az *Cloud) getExpectedSecurityRules( wantLb bool, ports []v1.ServicePort, diff --git a/pkg/provider/azure_loadbalancer_test.go b/pkg/provider/azure_loadbalancer_test.go index a5555c1cba..4edcbc82b4 100644 --- a/pkg/provider/azure_loadbalancer_test.go +++ b/pkg/provider/azure_loadbalancer_test.go @@ -4762,12 +4762,12 @@ func TestReconcileSecurityGroupCommon(t *testing.T) { }, }, { - Name: pointer.String("asvc-TCP-80-foo"), + Name: pointer.String("asvc-TCP-80-bar"), SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ Protocol: network.SecurityRuleProtocol("Tcp"), SourcePortRange: pointer.String("*"), DestinationPortRange: pointer.String(strconv.Itoa(80)), - SourceAddressPrefix: pointer.String("foo"), + SourceAddressPrefix: pointer.String("bar"), DestinationAddressPrefixes: &([]string{"10.0.0.1", "10.0.0.2"}), Access: network.SecurityRuleAccessAllow, Priority: pointer.Int32(502), @@ -4775,12 +4775,12 @@ func TestReconcileSecurityGroupCommon(t *testing.T) { }, }, { - Name: pointer.String("asvc-TCP-80-bar"), + Name: pointer.String("asvc-TCP-80-foo"), SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ Protocol: network.SecurityRuleProtocol("Tcp"), SourcePortRange: pointer.String("*"), DestinationPortRange: pointer.String(strconv.Itoa(80)), - SourceAddressPrefix: pointer.String("bar"), + SourceAddressPrefix: pointer.String("foo"), DestinationAddressPrefixes: &([]string{"10.0.0.1", "10.0.0.2"}), Access: network.SecurityRuleAccessAllow, Priority: pointer.Int32(503), @@ -4839,12 +4839,12 @@ func TestReconcileSecurityGroupCommon(t *testing.T) { SecurityGroupPropertiesFormat: &network.SecurityGroupPropertiesFormat{ SecurityRules: &[]network.SecurityRule{ { - Name: pointer.String("asvc-TCP-80-192.168.0.1_32"), + Name: pointer.String("asvc-TCP-80-10.10.10.0_24"), SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ Protocol: network.SecurityRuleProtocol("Tcp"), SourcePortRange: pointer.String("*"), DestinationPortRange: pointer.String(strconv.Itoa(80)), - SourceAddressPrefix: pointer.String("192.168.0.1/32"), + SourceAddressPrefix: pointer.String("10.10.10.0/24"), DestinationAddressPrefixes: &([]string{"10.0.0.1", "10.0.0.2"}), Access: network.SecurityRuleAccessAllow, Priority: pointer.Int32(500), @@ -4852,12 +4852,12 @@ func TestReconcileSecurityGroupCommon(t *testing.T) { }, }, { - Name: pointer.String("asvc-TCP-80-10.10.10.0_24"), + Name: pointer.String("asvc-TCP-80-192.168.0.1_32"), SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ Protocol: network.SecurityRuleProtocol("Tcp"), SourcePortRange: pointer.String("*"), DestinationPortRange: pointer.String(strconv.Itoa(80)), - SourceAddressPrefix: pointer.String("10.10.10.0/24"), + SourceAddressPrefix: pointer.String("192.168.0.1/32"), DestinationAddressPrefixes: &([]string{"10.0.0.1", "10.0.0.2"}), Access: network.SecurityRuleAccessAllow, Priority: pointer.Int32(501), @@ -4865,12 +4865,12 @@ func TestReconcileSecurityGroupCommon(t *testing.T) { }, }, { - Name: pointer.String("asvc-TCP-80-foo"), + Name: pointer.String("asvc-TCP-80-bar"), SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ Protocol: network.SecurityRuleProtocol("Tcp"), SourcePortRange: pointer.String("*"), DestinationPortRange: pointer.String(strconv.Itoa(80)), - SourceAddressPrefix: pointer.String("foo"), + SourceAddressPrefix: pointer.String("bar"), DestinationAddressPrefixes: &([]string{"10.0.0.1", "10.0.0.2"}), Access: network.SecurityRuleAccessAllow, Priority: pointer.Int32(502), @@ -4878,12 +4878,12 @@ func TestReconcileSecurityGroupCommon(t *testing.T) { }, }, { - Name: pointer.String("asvc-TCP-80-bar"), + Name: pointer.String("asvc-TCP-80-foo"), SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ Protocol: network.SecurityRuleProtocol("Tcp"), SourcePortRange: pointer.String("*"), DestinationPortRange: pointer.String(strconv.Itoa(80)), - SourceAddressPrefix: pointer.String("bar"), + SourceAddressPrefix: pointer.String("foo"), DestinationAddressPrefixes: &([]string{"10.0.0.1", "10.0.0.2"}), Access: network.SecurityRuleAccessAllow, Priority: pointer.Int32(503), @@ -4924,58 +4924,398 @@ func TestReconcileSecurityGroupCommon(t *testing.T) { } func TestReconcileSecurityGroupLoadBalancerSourceRanges(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - az := GetTestCloud(ctrl) - service := getTestService("test1", v1.ProtocolTCP, map[string]string{consts.ServiceAnnotationDenyAllExceptLoadBalancerSourceRanges: "true"}, false, 80) - service.Spec.LoadBalancerSourceRanges = []string{"1.2.3.4/32"} - existingSg := network.SecurityGroup{ - Name: pointer.String("nsg"), - SecurityGroupPropertiesFormat: &network.SecurityGroupPropertiesFormat{ - SecurityRules: &[]network.SecurityRule{}, - }, - } - lbIPs := &[]string{"1.1.1.1"} - expectedSg := network.SecurityGroup{ - Name: pointer.String("nsg"), - SecurityGroupPropertiesFormat: &network.SecurityGroupPropertiesFormat{ - SecurityRules: &[]network.SecurityRule{ + tests := []struct { + name string + sourceRanges []string + originalRules []network.SecurityRule + expectedRules []network.SecurityRule + }{ + { + name: "should add deny-all rule if not exists #1", + sourceRanges: []string{"1.2.3.4/32"}, + originalRules: nil, + expectedRules: []network.SecurityRule{ { Name: pointer.String("atest1-TCP-80-1.2.3.4_32"), SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ - Protocol: network.SecurityRuleProtocol("Tcp"), + Protocol: network.SecurityRuleProtocolTCP, SourcePortRange: pointer.String("*"), SourceAddressPrefix: pointer.String("1.2.3.4/32"), DestinationPortRange: pointer.String("80"), DestinationAddressPrefix: pointer.String("1.1.1.1"), - Access: network.SecurityRuleAccess("Allow"), + Access: network.SecurityRuleAccessAllow, Priority: pointer.Int32(500), - Direction: network.SecurityRuleDirection("Inbound"), + Direction: network.SecurityRuleDirectionInbound, + }, + }, + { + Name: pointer.String("atest1-TCP-80-deny_all"), + SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Protocol: network.SecurityRuleProtocolTCP, + SourcePortRange: pointer.String("*"), + SourceAddressPrefix: pointer.String("*"), + DestinationPortRange: pointer.String("80"), + DestinationAddressPrefix: pointer.String("1.1.1.1"), + Access: network.SecurityRuleAccessDeny, + Priority: pointer.Int32(4096), + Direction: network.SecurityRuleDirectionInbound, + }, + }, + }, + }, + + { + name: "should add deny-all rule if not exists #2", + sourceRanges: []string{"1.2.3.4/32"}, + originalRules: []network.SecurityRule{ + { + Name: pointer.String("atest1-TCP-80-1.2.3.4_32"), + SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Protocol: network.SecurityRuleProtocolTCP, + SourcePortRange: pointer.String("*"), + SourceAddressPrefix: pointer.String("1.2.3.4/32"), + DestinationPortRange: pointer.String("80"), + DestinationAddressPrefix: pointer.String("1.1.1.1"), + Access: network.SecurityRuleAccessAllow, + Priority: pointer.Int32(505), + Direction: network.SecurityRuleDirectionInbound, + }, + }, + }, + expectedRules: []network.SecurityRule{ + { + Name: pointer.String("atest1-TCP-80-1.2.3.4_32"), + SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Protocol: network.SecurityRuleProtocolTCP, + SourcePortRange: pointer.String("*"), + SourceAddressPrefix: pointer.String("1.2.3.4/32"), + DestinationPortRange: pointer.String("80"), + DestinationAddressPrefix: pointer.String("1.1.1.1"), + Access: network.SecurityRuleAccessAllow, + Priority: pointer.Int32(500), // would be reset + Direction: network.SecurityRuleDirectionInbound, + }, + }, + { + Name: pointer.String("atest1-TCP-80-deny_all"), + SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Protocol: network.SecurityRuleProtocolTCP, + SourcePortRange: pointer.String("*"), + SourceAddressPrefix: pointer.String("*"), + DestinationPortRange: pointer.String("80"), + DestinationAddressPrefix: pointer.String("1.1.1.1"), + Access: network.SecurityRuleAccessDeny, + Priority: pointer.Int32(4096), + Direction: network.SecurityRuleDirectionInbound, + }, + }, + }, + }, + { + name: "should move deny-all rule after allow rule: reorder priority", + sourceRanges: []string{"1.2.3.4/32"}, + originalRules: []network.SecurityRule{ + { + Name: pointer.String("atest1-TCP-80-1.2.3.4_32"), + SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Protocol: network.SecurityRuleProtocolTCP, + SourcePortRange: pointer.String("*"), + SourceAddressPrefix: pointer.String("1.2.3.4/32"), + DestinationPortRange: pointer.String("80"), + DestinationAddressPrefix: pointer.String("1.1.1.1"), + Access: network.SecurityRuleAccessAllow, + Priority: pointer.Int32(505), + Direction: network.SecurityRuleDirectionInbound, + }, + }, + { + Name: pointer.String("atest1-TCP-80-deny_all"), + SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Protocol: network.SecurityRuleProtocolTCP, + SourcePortRange: pointer.String("*"), + SourceAddressPrefix: pointer.String("*"), + DestinationPortRange: pointer.String("80"), + DestinationAddressPrefix: pointer.String("1.1.1.1"), + Access: network.SecurityRuleAccessDeny, + Priority: pointer.Int32(503), + Direction: network.SecurityRuleDirectionInbound, + }, + }, + }, + expectedRules: []network.SecurityRule{ + { + Name: pointer.String("atest1-TCP-80-1.2.3.4_32"), + SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Protocol: network.SecurityRuleProtocolTCP, + SourcePortRange: pointer.String("*"), + SourceAddressPrefix: pointer.String("1.2.3.4/32"), + DestinationPortRange: pointer.String("80"), + DestinationAddressPrefix: pointer.String("1.1.1.1"), + Access: network.SecurityRuleAccessAllow, + Priority: pointer.Int32(500), // would be reset + Direction: network.SecurityRuleDirectionInbound, }, }, { Name: pointer.String("atest1-TCP-80-deny_all"), SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ - Protocol: network.SecurityRuleProtocol("Tcp"), + Protocol: network.SecurityRuleProtocolTCP, SourcePortRange: pointer.String("*"), SourceAddressPrefix: pointer.String("*"), DestinationPortRange: pointer.String("80"), DestinationAddressPrefix: pointer.String("1.1.1.1"), - Access: network.SecurityRuleAccess("Deny"), + Access: network.SecurityRuleAccessDeny, + Priority: pointer.Int32(4096), + Direction: network.SecurityRuleDirectionInbound, + }, + }, + }, + }, + { + name: "should move deny-all rule after allow rule: add new source range", + sourceRanges: []string{"1.2.3.4/32", "10.0.0.0/16"}, + originalRules: []network.SecurityRule{ + { + Name: pointer.String("atest1-TCP-80-1.2.3.4_32"), + SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Protocol: network.SecurityRuleProtocolTCP, + SourcePortRange: pointer.String("*"), + SourceAddressPrefix: pointer.String("1.2.3.4/32"), + DestinationPortRange: pointer.String("80"), + DestinationAddressPrefix: pointer.String("1.1.1.1"), + Access: network.SecurityRuleAccessAllow, + Priority: pointer.Int32(505), + Direction: network.SecurityRuleDirectionInbound, + }, + }, + { + Name: pointer.String("atest1-TCP-80-deny_all"), + SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Protocol: network.SecurityRuleProtocolTCP, + SourcePortRange: pointer.String("*"), + SourceAddressPrefix: pointer.String("*"), + DestinationPortRange: pointer.String("80"), + DestinationAddressPrefix: pointer.String("1.1.1.1"), + Access: network.SecurityRuleAccessDeny, + Priority: pointer.Int32(503), + Direction: network.SecurityRuleDirectionInbound, + }, + }, + }, + expectedRules: []network.SecurityRule{ + { + Name: pointer.String("atest1-TCP-80-1.2.3.4_32"), + SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Protocol: network.SecurityRuleProtocolTCP, + SourcePortRange: pointer.String("*"), + SourceAddressPrefix: pointer.String("1.2.3.4/32"), + DestinationPortRange: pointer.String("80"), + DestinationAddressPrefix: pointer.String("1.1.1.1"), + Access: network.SecurityRuleAccessAllow, + Priority: pointer.Int32(500), // would be reset + Direction: network.SecurityRuleDirectionInbound, + }, + }, + { + Name: pointer.String("atest1-TCP-80-10.0.0.0_16"), + SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Protocol: network.SecurityRuleProtocolTCP, + SourcePortRange: pointer.String("*"), + SourceAddressPrefix: pointer.String("10.0.0.0/16"), + DestinationPortRange: pointer.String("80"), + DestinationAddressPrefix: pointer.String("1.1.1.1"), + Access: network.SecurityRuleAccessAllow, Priority: pointer.Int32(501), - Direction: network.SecurityRuleDirection("Inbound"), + Direction: network.SecurityRuleDirectionInbound, + }, + }, + { + Name: pointer.String("atest1-TCP-80-deny_all"), + SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Protocol: network.SecurityRuleProtocolTCP, + SourcePortRange: pointer.String("*"), + SourceAddressPrefix: pointer.String("*"), + DestinationPortRange: pointer.String("80"), + DestinationAddressPrefix: pointer.String("1.1.1.1"), + Access: network.SecurityRuleAccessDeny, + Priority: pointer.Int32(4096), + Direction: network.SecurityRuleDirectionInbound, + }, + }, + }, + }, + { + name: "should move deny-all rule after allow rule: keep unmanaged rules unchanged", + sourceRanges: []string{"1.2.3.4/32"}, + originalRules: []network.SecurityRule{ + { + Name: pointer.String("unmanaged_rule_1"), + SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Protocol: network.SecurityRuleProtocolTCP, + SourcePortRange: pointer.String("*"), + SourceAddressPrefix: pointer.String("1.2.3.4/32"), + DestinationPortRange: pointer.String("80"), + DestinationAddressPrefix: pointer.String("1.1.1.1"), + Access: network.SecurityRuleAccessAllow, + Priority: pointer.Int32(499), + Direction: network.SecurityRuleDirectionInbound, + }, + }, + { + Name: pointer.String("unmanaged_rule_2"), + SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Protocol: network.SecurityRuleProtocolTCP, + SourcePortRange: pointer.String("*"), + SourceAddressPrefix: pointer.String("1.2.3.4/16"), + DestinationPortRange: pointer.String("443"), + DestinationAddressPrefix: pointer.String("1.1.1.1"), + Access: network.SecurityRuleAccessAllow, + Priority: pointer.Int32(5000), + Direction: network.SecurityRuleDirectionInbound, + }, + }, + { + Name: pointer.String("unmanaged_rule_3"), + SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Protocol: network.SecurityRuleProtocolTCP, + SourcePortRange: pointer.String("*"), + SourceAddressPrefix: pointer.String("1.2.3.4/16"), + DestinationPortRange: pointer.String("443"), + DestinationAddressPrefix: pointer.String("1.1.1.1"), + Access: network.SecurityRuleAccessDeny, + Priority: pointer.Int32(400), + Direction: network.SecurityRuleDirectionInbound, + }, + }, + { + Name: pointer.String("atest1-TCP-80-1.2.3.4_32"), + SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Protocol: network.SecurityRuleProtocolTCP, + SourcePortRange: pointer.String("*"), + SourceAddressPrefix: pointer.String("1.2.3.4/32"), + DestinationPortRange: pointer.String("80"), + DestinationAddressPrefix: pointer.String("1.1.1.1"), + Access: network.SecurityRuleAccessAllow, + Priority: pointer.Int32(505), + Direction: network.SecurityRuleDirectionInbound, + }, + }, + { + Name: pointer.String("atest1-TCP-80-deny_all"), + SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Protocol: network.SecurityRuleProtocolTCP, + SourcePortRange: pointer.String("*"), + SourceAddressPrefix: pointer.String("*"), + DestinationPortRange: pointer.String("80"), + DestinationAddressPrefix: pointer.String("1.1.1.1"), + Access: network.SecurityRuleAccessDeny, + Priority: pointer.Int32(503), + Direction: network.SecurityRuleDirectionInbound, + }, + }, + }, + expectedRules: []network.SecurityRule{ + { + Name: pointer.String("unmanaged_rule_3"), + SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Protocol: network.SecurityRuleProtocolTCP, + SourcePortRange: pointer.String("*"), + SourceAddressPrefix: pointer.String("1.2.3.4/16"), + DestinationPortRange: pointer.String("443"), + DestinationAddressPrefix: pointer.String("1.1.1.1"), + Access: network.SecurityRuleAccessDeny, + Priority: pointer.Int32(400), + Direction: network.SecurityRuleDirectionInbound, + }, + }, + { + Name: pointer.String("unmanaged_rule_1"), + SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Protocol: network.SecurityRuleProtocolTCP, + SourcePortRange: pointer.String("*"), + SourceAddressPrefix: pointer.String("1.2.3.4/32"), + DestinationPortRange: pointer.String("80"), + DestinationAddressPrefix: pointer.String("1.1.1.1"), + Access: network.SecurityRuleAccessAllow, + Priority: pointer.Int32(499), + Direction: network.SecurityRuleDirectionInbound, + }, + }, + { + Name: pointer.String("atest1-TCP-80-1.2.3.4_32"), + SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Protocol: network.SecurityRuleProtocolTCP, + SourcePortRange: pointer.String("*"), + SourceAddressPrefix: pointer.String("1.2.3.4/32"), + DestinationPortRange: pointer.String("80"), + DestinationAddressPrefix: pointer.String("1.1.1.1"), + Access: network.SecurityRuleAccessAllow, + Priority: pointer.Int32(500), // would be reset + Direction: network.SecurityRuleDirectionInbound, + }, + }, + { + Name: pointer.String("atest1-TCP-80-deny_all"), + SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Protocol: network.SecurityRuleProtocolTCP, + SourcePortRange: pointer.String("*"), + SourceAddressPrefix: pointer.String("*"), + DestinationPortRange: pointer.String("80"), + DestinationAddressPrefix: pointer.String("1.1.1.1"), + Access: network.SecurityRuleAccessDeny, + Priority: pointer.Int32(4096), + Direction: network.SecurityRuleDirectionInbound, + }, + }, + { + Name: pointer.String("unmanaged_rule_2"), + SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Protocol: network.SecurityRuleProtocolTCP, + SourcePortRange: pointer.String("*"), + SourceAddressPrefix: pointer.String("1.2.3.4/16"), + DestinationPortRange: pointer.String("443"), + DestinationAddressPrefix: pointer.String("1.1.1.1"), + Access: network.SecurityRuleAccessAllow, + Priority: pointer.Int32(5000), + Direction: network.SecurityRuleDirectionInbound, }, }, }, }, } - mockSGClient := az.SecurityGroupsClient.(*mocksecuritygroupclient.MockInterface) - mockSGClient.EXPECT().Get(gomock.Any(), az.ResourceGroup, gomock.Any(), gomock.Any()).Return(existingSg, nil) - mockSGClient.EXPECT().CreateOrUpdate(gomock.Any(), az.ResourceGroup, gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) - sg, err := az.reconcileSecurityGroup("testCluster", &service, lbIPs, nil, true) - assert.NoError(t, err) - assert.Equal(t, expectedSg, *sg) + + for _, tt := range tests { + originalSg := network.SecurityGroup{ + Name: pointer.String("nsg"), + SecurityGroupPropertiesFormat: &network.SecurityGroupPropertiesFormat{ + SecurityRules: &tt.originalRules, + }, + } + expectedSg := network.SecurityGroup{ + Name: pointer.String("nsg"), + SecurityGroupPropertiesFormat: &network.SecurityGroupPropertiesFormat{ + SecurityRules: &tt.expectedRules, + }, + } + + t.Run(tt.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + az := GetTestCloud(ctrl) + service := getTestService("test1", v1.ProtocolTCP, map[string]string{consts.ServiceAnnotationDenyAllExceptLoadBalancerSourceRanges: "true"}, false, 80) + service.Spec.LoadBalancerSourceRanges = tt.sourceRanges + lbIPs := &[]string{"1.1.1.1"} + + mockSGClient := az.SecurityGroupsClient.(*mocksecuritygroupclient.MockInterface) + mockSGClient.EXPECT().Get(gomock.Any(), az.ResourceGroup, gomock.Any(), gomock.Any()).Return(originalSg, nil) + mockSGClient.EXPECT().CreateOrUpdate(gomock.Any(), az.ResourceGroup, gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) + sg, err := az.reconcileSecurityGroup("testCluster", &service, lbIPs, nil, true) + assert.NoError(t, err) + assert.Equal(t, expectedSg, *sg) + }) + } } func TestReconcileSecurityGroup(t *testing.T) { @@ -9635,3 +9975,262 @@ func fakeEnsureHostsInPool() func(*v1.Service, []*v1.Node, string, string, strin return nil } } + +func newAllowRule(name string, priority int32) network.SecurityRule { + return network.SecurityRule{ + Name: &name, + SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Priority: &priority, + Access: network.SecurityRuleAccessAllow, + }, + } +} +func newDenyRule(name string, priority int32) network.SecurityRule { + return network.SecurityRule{ + Name: &name, + SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Priority: &priority, + Access: network.SecurityRuleAccessDeny, + }, + } +} + +func Test_shouldTidySecurityRules(t *testing.T) { + type args struct { + rules []network.SecurityRule + } + tests := []struct { + name string + args args + want bool + }{ + { + name: "should return false if there are no rules #1", + args: args{ + rules: nil, + }, + want: false, + }, + { + name: "should return false if there are no rules #2", + args: args{ + rules: []network.SecurityRule{}, + }, + want: false, + }, + { + name: "should return false if there is 1 rule", + args: args{ + rules: []network.SecurityRule{ + newAllowRule("rule1", 500), + }, + }, + want: false, + }, + { + name: "should return false if no allow rule after deny rule", + args: args{ + rules: []network.SecurityRule{ + newAllowRule("rule1", 500), + newDenyRule("rule2", 505), + }, + }, + want: false, + }, + { + name: "should return false if no allow rule after deny rule: ignore unmanaged rules", + args: args{ + rules: []network.SecurityRule{ + newAllowRule("unmanaged_rule1", 300), + newAllowRule("unmanaged_rule2", 400), + newDenyRule("unmanaged_rule3", 450), + + newAllowRule("rule1", 500), + newDenyRule("rule2", 505), + + newDenyRule("unmanaged_rule4", 5000), + newAllowRule("unmanaged_rule5", 6000), + }, + }, + want: false, + }, + { + name: "should return true if allow rule after deny rule", + args: args{ + rules: []network.SecurityRule{ + newAllowRule("rule1", 500), + newDenyRule("rule4", 501), + newAllowRule("rule2", 502), + newDenyRule("rule5", 503), + newAllowRule("rule3", 504), + }, + }, + want: true, + }, + { + name: "should return true if allow rule after deny rule: with unmanaged rules", + args: args{ + rules: []network.SecurityRule{ + + newAllowRule("unmanaged_rule1", 300), + newAllowRule("unmanaged_rule2", 400), + newDenyRule("unmanaged_rule3", 450), + + newAllowRule("rule1", 500), + newDenyRule("rule4", 501), + newAllowRule("rule2", 502), + newDenyRule("rule5", 503), + newAllowRule("rule3", 504), + + newDenyRule("unmanaged_rule4", 5000), + newAllowRule("unmanaged_rule5", 6000), + }, + }, + want: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equalf(t, tt.want, shouldTidySecurityRules(tt.args.rules), "shouldTidySecurityRules(%v)", tt.args.rules) + }) + } +} + +func Test_tidySecurityRules(t *testing.T) { + + type args struct { + rules []network.SecurityRule + } + tests := []struct { + name string + args args + want []network.SecurityRule + }{ + { + name: "should return the same rules if there are no rules #1", + args: args{ + rules: nil, + }, + want: nil, + }, + { + name: "should return the same rules if there are no rules #2", + args: args{ + rules: []network.SecurityRule{}, + }, + want: []network.SecurityRule{}, + }, + { + name: "should return the same rules if there is 1 rule", + args: args{ + rules: []network.SecurityRule{ + newAllowRule("rule1", 500), + }, + }, + want: []network.SecurityRule{ + newAllowRule("rule1", 500), + }, + }, + { + name: "should return the same rules if there are only unmanaged rules", + args: args{ + rules: []network.SecurityRule{ + newAllowRule("unmanaged_rule1", 300), + newAllowRule("unmanaged_rule2", 400), + newDenyRule("unmanaged_rule3", 450), + + newAllowRule("rule1", 500), + + newDenyRule("unmanaged_rule4", 5000), + newAllowRule("unmanaged_rule5", 6000), + }, + }, + want: []network.SecurityRule{ + newAllowRule("unmanaged_rule1", 300), + newAllowRule("unmanaged_rule2", 400), + newDenyRule("unmanaged_rule3", 450), + + newAllowRule("rule1", 500), + + newDenyRule("unmanaged_rule4", 5000), + newAllowRule("unmanaged_rule5", 6000), + }, + }, + { + name: "should reorder managed rules #1", + args: args{ + rules: []network.SecurityRule{ + newAllowRule("rule1", 500), + newAllowRule("rule2", 541), + newAllowRule("rule3", 652), + newDenyRule("rule4", 700), + newDenyRule("rule5", 1000), + }, + }, + want: []network.SecurityRule{ + newAllowRule("rule1", 500), + newAllowRule("rule2", 501), + newAllowRule("rule3", 502), + newDenyRule("rule5", 4095), + newDenyRule("rule4", 4096), + }, + }, + { + name: "should reorder managed rules #2", + args: args{ + rules: []network.SecurityRule{ + newAllowRule("rule1", 500), + newDenyRule("rule4", 501), + newAllowRule("rule2", 502), + newDenyRule("rule5", 503), + newAllowRule("rule3", 504), + }, + }, + want: []network.SecurityRule{ + newAllowRule("rule1", 500), + newAllowRule("rule2", 501), + newAllowRule("rule3", 502), + newDenyRule("rule5", 4095), + newDenyRule("rule4", 4096), + }, + }, + { + name: "should reorder managed rules #3: ignore unmanaged rules", + args: args{ + rules: []network.SecurityRule{ + newAllowRule("unmanaged_rule1", 300), + newAllowRule("unmanaged_rule2", 400), + newDenyRule("unmanaged_rule3", 450), + + newAllowRule("rule1", 500), + newDenyRule("rule4", 501), + newAllowRule("rule2", 502), + newDenyRule("rule5", 503), + newAllowRule("rule3", 504), + + newDenyRule("unmanaged_rule4", 5000), + newAllowRule("unmanaged_rule5", 6000), + }, + }, + want: []network.SecurityRule{ + newAllowRule("unmanaged_rule1", 300), + newAllowRule("unmanaged_rule2", 400), + newDenyRule("unmanaged_rule3", 450), + + newAllowRule("rule1", 500), + newAllowRule("rule2", 501), + newAllowRule("rule3", 502), + newDenyRule("rule5", 4095), + newDenyRule("rule4", 4096), + + newDenyRule("unmanaged_rule4", 5000), + newAllowRule("unmanaged_rule5", 6000), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equalf(t, tt.want, tidySecurityRules(tt.args.rules), "tidySecurityRules(%v)", tt.args.rules) + }) + } +} diff --git a/pkg/provider/azure_test.go b/pkg/provider/azure_test.go index 804033c8a5..c0ab0e7c0a 100644 --- a/pkg/provider/azure_test.go +++ b/pkg/provider/azure_test.go @@ -1970,6 +1970,7 @@ func getTestSecurityGroupCommon(az *Cloud, v4Enabled, v6Enabled bool, services . return network.SecurityRule{ Name: pointer.String(ruleName), SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Priority: pointer.Int32(0), SourceAddressPrefix: pointer.String(src), DestinationPortRange: pointer.String(fmt.Sprintf("%d", port.Port)), }, @@ -2264,6 +2265,7 @@ func securityRuleMatches(serviceSourceRange string, servicePort v1.ServicePort, } func validateSecurityGroupCommon(t *testing.T, az *Cloud, securityGroup *network.SecurityGroup, v4Enabled, v6Enabled bool, services ...v1.Service) { + t.Helper() expectedRules := make(map[string]bool) for _, svc := range services { svc := svc @@ -2803,6 +2805,7 @@ func TestIfServiceSpecifiesSharedRuleAndRuleExistsThenTheServicesPortAndAddressA rule := network.SecurityRule{ Name: &ruleName, SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ + Priority: pointer.Int32(500), Protocol: network.SecurityRuleProtocolTCP, SourcePortRange: pointer.String("*"), SourceAddressPrefix: pointer.String("Internet"),