Skip to content

Commit

Permalink
Add equal vars for Rules to cnf building (#3714)
Browse files Browse the repository at this point in the history
  • Loading branch information
kddejong authored Sep 24, 2024
1 parent 9782e80 commit c1f8b15
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 21 deletions.
11 changes: 5 additions & 6 deletions src/cfnlint/conditions/_equals.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import Any, Mapping, Tuple

from sympy import Symbol
from sympy.logic.boolalg import BooleanFalse, BooleanFunction, BooleanTrue
from sympy.logic.boolalg import BooleanFalse, BooleanTrue

from cfnlint.conditions._utils import get_hash
from cfnlint.helpers import is_function
Expand Down Expand Up @@ -145,7 +145,9 @@ def left(self):
def right(self):
return self._right

def build_cnf(self, params: dict[str, Symbol]) -> BooleanFunction:
def build_cnf(
self, params: dict[str, Symbol]
) -> BooleanTrue | BooleanFalse | Symbol:
"""Build a SymPy CNF solver based on the provided params
Args:
params dict[str, Symbol]: params is a dict that represents
Expand All @@ -158,10 +160,7 @@ def build_cnf(self, params: dict[str, Symbol]) -> BooleanFunction:
return BooleanTrue()
return BooleanFalse()

if self.hash in params:
return params.get(self.hash)

return Symbol(self.hash)
return params.get(self.hash, Symbol(self.hash))

def test(self, scenarios: Mapping[str, str]) -> bool:
"""Do an equals based on the provided scenario"""
Expand Down
13 changes: 9 additions & 4 deletions src/cfnlint/conditions/conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,12 @@ def _build_cnf(
cnf = EncodedCNF()

# build parameters and equals into solver
equal_vars: dict[str, Symbol] = {}
equal_vars: dict[str, Symbol | BooleanFalse | BooleanTrue] = {}

equals: dict[str, Equal] = {}
for condition_name in condition_names:
c_equals = self._conditions[condition_name].equals

def _build_equal_vars(c_equals: list[Equal]):
for c_equal in c_equals:
# check to see if equals already matches another one
if c_equal.hash in equal_vars:
continue

Expand All @@ -139,6 +138,12 @@ def _build_cnf(
)
equals[c_equal.hash] = c_equal

for rule in self._rules:
_build_equal_vars(rule.equals)

for condition_name in condition_names:
_build_equal_vars(self._conditions[condition_name].equals)

# Determine if a set of conditions can never be all false
allowed_values = self._parameters.copy()
if allowed_values:
Expand Down
111 changes: 100 additions & 11 deletions test/unit/module/conditions/test_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@ def test_conditions_with_rules(self):
Assertions:
- Assert:
Fn::And:
- !Condition IsProd
- !Condition IsUsEast1
- !Equals [!Ref Environment, "prod"]
- !Equals [!Ref "AWS::Region", "us-east-1"]
Rule2:
Assertions:
- Assert:
Fn::Or:
- !Condition IsProd
- !Condition IsUsEast1
- !Equals [!Ref Environment, "prod"]
- !Equals [!Ref "AWS::Region", "us-east-1"]
"""
)[0]

Expand Down Expand Up @@ -79,9 +79,9 @@ def test_conditions_with_rules_implies(self):
IsUsEast1: !Equals [!Ref "AWS::Region", "us-east-1"]
Rules:
Rule:
RuleCondition: !Condition IsProd
RuleCondition: !Equals [!Ref Environment, "prod"]
Assertions:
- Assert: !Condition IsUsEast1
- Assert: !Equals [!Ref "AWS::Region", "us-east-1"]
"""
)[0]
Expand Down Expand Up @@ -143,11 +143,11 @@ def test_conditions_with_multiple_rules(self):
Rule1:
RuleCondition: !Equals [!Ref Environment, "prod"]
Assertions:
- Assert: !Condition IsUsEast1
- Assert: !Equals [!Ref "AWS::Region", "us-east-1"]
Rule2:
RuleCondition: !Equals [!Ref Environment, "dev"]
Assertions:
- Assert: !Not [!Condition IsUsEast1]
- Assert: !Not [!Equals [!Ref "AWS::Region", "us-east-1"]]
"""
)[0]

Expand Down Expand Up @@ -366,6 +366,95 @@ def test_fn_equals_assertions_ref_never_satisfiable(self):
)
)

def test_conditions_with_rules_and_parameters(self):
template = decode_str(
"""
Conditions:
DeployGateway: !Equals
- !Ref 'DeployGateway'
- 'true'
DeployVpc: !Equals
- !Ref 'DeployVpc'
- 'true'
Parameters:
DeployAnything:
AllowedValues:
- 'false'
- 'true'
Type: 'String'
DeployGateway:
AllowedValues:
- 'false'
- 'true'
Type: 'String'
DeployVpc:
AllowedValues:
- 'false'
- 'true'
Type: 'String'
Rules:
DeployGateway:
Assertions:
- Assert: !Or
- !Equals
- !Ref 'DeployAnything'
- 'true'
- !Equals
- !Ref 'DeployGateway'
- 'false'
DeployVpc:
Assertions:
- Assert: !Or
- !Equals
- !Ref 'DeployGateway'
- 'true'
- !Equals
- !Ref 'DeployVpc'
- 'false'
Resources:
InternetGateway:
Condition: 'DeployGateway'
Type: 'AWS::EC2::InternetGateway'
InternetGatewayAttachment:
Condition: 'DeployVpc'
Type: 'AWS::EC2::VPCGatewayAttachment'
Properties:
InternetGatewayId: !Ref 'InternetGateway'
VpcId: !Ref 'Vpc'
"""
)[0]

cfn = Template("", template)
self.assertEqual(len(cfn.conditions._conditions), 2)
self.assertEqual(len(cfn.conditions._rules), 2)

self.assertListEqual(
[equal.hash for equal in cfn.conditions._rules[0].equals],
[
"d0d70a1e66dc83d7a0fce24c2eca396af1f34e53",
"bbf5c94c1a4b5a79c7a7863fe9463884cb422450",
],
)

self.assertTrue(
cfn.conditions.satisfiable(
{},
{},
)
)

self.assertTrue(
cfn.conditions.check_implies({"DeployVpc": True}, "DeployGateway")
)

self.assertFalse(
cfn.conditions.check_implies({"DeployVpc": False}, "DeployGateway")
)

self.assertFalse(
cfn.conditions.check_implies({"DeployGateway": False}, "DeployVpc")
)


class TestAssertion(TestCase):
def test_assertion_errors(self):
Expand Down Expand Up @@ -405,7 +494,7 @@ def test_init_rules_with_wrong_assertions_type(self):
Assertions: {"Foo": "Bar"}
Rule2:
Assertions:
- Assert: !Condition IsUsEast1
- Assert: !Equals [!Ref "AWS::Region", "us-east-1"]
"""
)[0]

Expand All @@ -425,8 +514,8 @@ def test_init_rules_with_no_keys(self):
Assertions:
- Assert:
Fn::Or:
- !Condition IsNotUsEast1
- !Condition IsUsEast1
- !Not [!Equals [!Ref "AWS::Region", "us-east-1"]]
- !Equals [!Ref "AWS::Region", "us-east-1"]
Rule3: []
"""
)[0]
Expand Down

0 comments on commit c1f8b15

Please sign in to comment.