Skip to content

Commit

Permalink
V1 - cleanup config and add more tests (#3204)
Browse files Browse the repository at this point in the history
* Add more testing to increase coverage
* Create fixes and more tests
  • Loading branch information
kddejong authored May 2, 2024
1 parent ac4d66c commit b96c825
Show file tree
Hide file tree
Showing 10 changed files with 314 additions and 55 deletions.
94 changes: 42 additions & 52 deletions src/cfnlint/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ class ConfigFileArgs:
"""

file_args: Dict = {}
__user_config_file = None
__project_config_file = None
__custom_config_file = None
_user_config_file = None
_project_config_file = None
_custom_config_file = None

def __init__(self, schema=None, config_file=None):
# self.file_args = self.get_config_file_defaults()
Expand All @@ -72,10 +72,10 @@ def __init__(self, schema=None, config_file=None):
self.schema = self.default_schema if not schema else schema

if config_file:
self.__custom_config_file = config_file
self._custom_config_file = config_file
else:
LOGGER.debug("Looking for CFLINTRC before attempting to load")
self.__user_config_file, self.__project_config_file = self._find_config()
self._user_config_file, self._project_config_file = self._find_config()

self.load()

Expand All @@ -91,24 +91,28 @@ def _find_config(self):
> user_config, project_config = self._find_config()
"""
config_file_name = ".cfnlintrc"
self.__user_config_file = Path.home().joinpath(config_file_name)

self.__project_config_file = Path.cwd().joinpath(config_file_name)
if self._has_file(config_file_name + ".yaml"):
self.__project_config_file = Path.cwd().joinpath(config_file_name + ".yaml")
elif self._has_file(config_file_name + ".yml"):
self.__project_config_file = Path.cwd().joinpath(config_file_name + ".yml")

user_config_path = ""
project_config_path = ""
home_path = Path.home()
for path in [
home_path.joinpath(config_file_name),
home_path.joinpath(f"{config_file_name}.yaml"),
home_path.joinpath(f"{config_file_name}.yml"),
]:
if self._has_file(path):
user_config_path = path
break

if self._has_file(self.__user_config_file):
LOGGER.debug("Found User CFNLINTRC")
user_config_path = self.__user_config_file

if self._has_file(self.__project_config_file):
LOGGER.debug("Found Project level CFNLINTRC")
project_config_path = self.__project_config_file
project_config_path = ""
cwd_path = Path.cwd()
for path in [
cwd_path.joinpath(config_file_name),
cwd_path.joinpath(f"{config_file_name}.yaml"),
cwd_path.joinpath(f"{config_file_name}.yml"),
]:
if self._has_file(path):
project_config_path = path
break

return user_config_path, project_config_path

Expand All @@ -133,20 +137,20 @@ def load(self):
CFLINTRC configuration
"""

if self.__custom_config_file:
custom_config = self._read_config(self.__custom_config_file)
if self._custom_config_file:
custom_config = self._read_config(self._custom_config_file)
LOGGER.debug("Validating Custom CFNLINTRC")
self.validate_config(custom_config, self.schema)
LOGGER.debug("Custom configuration loaded as")
LOGGER.debug("%s", custom_config)

self.file_args = custom_config
else:
user_config = self._read_config(self.__user_config_file)
user_config = self._read_config(self._user_config_file)
LOGGER.debug("Validating User CFNLINTRC")
self.validate_config(user_config, self.schema)

project_config = self._read_config(self.__project_config_file)
project_config = self._read_config(self._project_config_file)
LOGGER.debug("Validating Project CFNLINTRC")
self.validate_config(project_config, self.schema)

Expand Down Expand Up @@ -577,34 +581,20 @@ def set_template_args(self, template):
configs = template.get("Metadata", {}).get("cfn-lint", {}).get("config", {})

if isinstance(configs, dict):
for config_name, config_value in configs.items():
if config_name == "ignore_checks":
if isinstance(config_value, list):
defaults["ignore_checks"] = config_value
if config_name == "regions":
if isinstance(config_value, list):
defaults["regions"] = config_value
if config_name == "append_rules":
if isinstance(config_value, list):
defaults["append_rules"] = config_value
if config_name == "override_spec":
if isinstance(config_value, (str)):
defaults["override_spec"] = config_value
if config_name == "custom_rules":
if isinstance(config_value, (str)):
defaults["custom_rules"] = config_value
if config_name == "ignore_bad_template":
if isinstance(config_value, bool):
defaults["ignore_bad_template"] = config_value
if config_name == "include_checks":
if isinstance(config_value, list):
defaults["include_checks"] = config_value
if config_name == "configure_rules":
if isinstance(config_value, dict):
defaults["configure_rules"] = config_value
if config_name == "include_experimental":
if isinstance(config_value, bool):
defaults["include_experimental"] = config_value
for key, value in {
"ignore_checks": (list),
"regions": (list),
"append_rules": (list),
"override_spec": (str),
"custom_rules": (str),
"ignore_bad_template": (bool),
"include_checks": (list),
"configure_rules": (dict),
"include_experimental": (bool),
}.items():
if key in configs:
if isinstance(configs[key], value):
defaults[key] = configs[key]

self._template_args = defaults

Expand Down
10 changes: 8 additions & 2 deletions src/cfnlint/context/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,11 +224,13 @@ class Parameter(_Ref):
type: str = field(init=False)
default: Any = field(init=False)
allowed_values: Any = field(init=False)
description: str = field(init=False)
description: str | None = field(init=False)

parameter: InitVar[Any]

def __post_init__(self, parameter) -> None:
if not isinstance(parameter, dict):
raise ValueError("Parameter must be a object")
self.default = None
self.allowed_values = []
self.min_value = None
Expand All @@ -249,7 +251,7 @@ def __post_init__(self, parameter) -> None:

if self.type == "CommaDelimitedList" or self.type.startswith("List<"):
if "Default" in parameter:
self.default = parameter.get("Default").split(",")
self.default = parameter.get("Default", "").split(",")
for allowed_value in parameter.get("AllowedValues", []):
self.allowed_values.append(allowed_value.split(","))
else:
Expand Down Expand Up @@ -294,6 +296,8 @@ class Resource(_Ref):
resource: InitVar[Any]

def __post_init__(self, resource) -> None:
if not isinstance(resource, dict):
raise ValueError("Resource must be a object")
t = resource.get("Type")
if not isinstance(t, str):
raise ValueError("Type must be a string")
Expand Down Expand Up @@ -325,6 +329,8 @@ def __post_init__(self, instance) -> None:
for k, v in instance.items():
if isinstance(v, (str, list, int, float)):
self.keys[k] = v
else:
raise ValueError("Third keys must not be an object")

def value(self, secondary_key: str):
if secondary_key not in self.keys:
Expand Down
15 changes: 15 additions & 0 deletions test/unit/module/conditions/test_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,3 +289,18 @@ def test_test_condition(self):
{h_region: "us-east-1", h_environment: "dev"}
)
)

def test_build_scenerios_on_region_with_condition_dne(self):
"""Get condition and test"""
template = decode_str(
"""
Conditions:
IsUsEast1: !Equals [!Ref AWS::Region, "us-east-1"]
"""
)[0]

cfn = Template("", template)
self.assertListEqual(
list(cfn.conditions.build_scenerios_on_region("IsProd", "us-east-1")),
[True, False],
)
48 changes: 48 additions & 0 deletions test/unit/module/config/test_config_file_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,51 @@ def test_config_parser_fail_on_config_rules(self, yaml_mock):
self.assertEqual(
results.file_args, {"configure_rules": {"E3012": {"strict": False}}}
)

@patch("pathlib.Path.is_file", create=True)
def test_config_parser_is_file_both(self, is_file_mock):
calls = [
True,
True,
False,
False,
]
is_file_mock.side_effect = calls
my_config = cfnlint.config.ConfigFileArgs()
self.assertEqual(my_config._user_config_file.name, ".cfnlintrc")
self.assertEqual(my_config._project_config_file.name, ".cfnlintrc")
self.assertEqual(is_file_mock.call_count, len(calls))

@patch("pathlib.Path.is_file", create=True)
def test_config_parser_is_file_both_yaml(self, is_file_mock):
calls = [
False,
True,
False,
True,
False,
False,
]
is_file_mock.side_effect = calls
my_config = cfnlint.config.ConfigFileArgs()
self.assertEqual(my_config._user_config_file.name, ".cfnlintrc.yaml")
self.assertEqual(my_config._project_config_file.name, ".cfnlintrc.yaml")
self.assertEqual(is_file_mock.call_count, len(calls))

@patch("pathlib.Path.is_file", create=True)
def test_config_parser_is_file_both_yml(self, is_file_mock):
calls = [
False,
False,
True,
False,
False,
True,
False,
False,
]
is_file_mock.side_effect = calls
my_config = cfnlint.config.ConfigFileArgs()
self.assertEqual(my_config._user_config_file.name, ".cfnlintrc.yml")
self.assertEqual(my_config._project_config_file.name, ".cfnlintrc.yml")
self.assertEqual(is_file_mock.call_count, len(calls))
12 changes: 12 additions & 0 deletions test/unit/module/config/test_template_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,15 @@ def test_template_args_failure_good_and_bad_value(self):
)

self.assertEqual(config.template_args.get("configure_rules"), None)

def test_bad_template_structure(self):
"""test template args"""
config = cfnlint.config.TemplateArgs([])

self.assertEqual(config._template_args, {})

def test_bad_config_structure(self):
"""test template args"""
config = cfnlint.config.TemplateArgs({"Metadata": {"cfn-lint": {"config": []}}})

self.assertEqual(config._template_args, {})
13 changes: 13 additions & 0 deletions test/unit/module/context/test_conditions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""
Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
SPDX-License-Identifier: MIT-0
"""

import pytest

from cfnlint.context.context import _init_conditions


def test_conditions():
with pytest.raises(ValueError):
_init_conditions([])
104 changes: 104 additions & 0 deletions test/unit/module/context/test_create_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
"""
Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
SPDX-License-Identifier: MIT-0
"""

from collections import namedtuple

import pytest

from cfnlint import Template
from cfnlint.context.context import create_context_for_template

_Counts = namedtuple("_Counts", ["resources", "parameters", "conditions", "mappings"])


@pytest.mark.parametrize(
"name,instance,counts",
[
(
"Valid template",
{
"Parameters": {
"Env": {
"Type": "String",
}
},
"Conditions": {
"IsUsEast1": {"Fn::Equals": [{"Ref": "AWS::Region"}, "us-east-1"]}
},
"Mappings": {"Map": {"us-east-1": {"foo": "bar"}}},
"Resources": {
"Bucket": {
"Type": "AWS::S3::Bucket",
},
},
},
_Counts(resources=1, parameters=1, conditions=1, mappings=1),
),
(
"Bad types in template",
{
"Parameters": [],
"Conditions": [],
"Mappings": [],
"Resources": [],
},
_Counts(resources=0, parameters=0, conditions=0, mappings=0),
),
(
"Invalid type configurations",
{
"Parameters": {
"BusinessUnit": [],
"Env": {
"Type": "String",
},
},
"Mappings": {"AnotherMap": [], "Map": {"us-east-1": {"foo": "bar"}}},
"Resources": {
"Instance": [],
"Bucket": {
"Type": "AWS::S3::Bucket",
},
},
},
_Counts(resources=1, parameters=1, conditions=0, mappings=1),
),
(
"Invalid mapping second key",
{
"Mappings": {
"BadKey": {
"Foo": [],
},
"Map": {"us-east-1": {"foo": "bar"}},
},
},
_Counts(resources=0, parameters=0, conditions=0, mappings=1),
),
(
"Invalid mapping third key",
{
"Mappings": {
"BadKey": {
"Foo": {
"Bar": {},
},
},
"Map": {"us-east-1": {"foo": "bar"}},
},
},
_Counts(resources=0, parameters=0, conditions=0, mappings=1),
),
],
)
def test_create_context(name, instance, counts):
cfn = Template("", instance, ["us-east-1"])
context = create_context_for_template(cfn)

for i in counts._fields:
assert len(getattr(context, i)) == getattr(counts, i), (
f"Test {name} has {i} {len(getattr(context, i))} "
"and expected {getattr(counts, i)}"
)
Loading

0 comments on commit b96c825

Please sign in to comment.