Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

V1 - cleanup config and add more tests #3204

Merged
merged 2 commits into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 42 additions & 52 deletions src/cfnlint/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ class ConfigFileArgs:
"""

file_args: Dict = {}
__user_config_file = None
__project_config_file = None
__custom_config_file = None
_user_config_file = None
_project_config_file = None
_custom_config_file = None

def __init__(self, schema=None, config_file=None):
# self.file_args = self.get_config_file_defaults()
Expand All @@ -72,10 +72,10 @@ def __init__(self, schema=None, config_file=None):
self.schema = self.default_schema if not schema else schema

if config_file:
self.__custom_config_file = config_file
self._custom_config_file = config_file
else:
LOGGER.debug("Looking for CFLINTRC before attempting to load")
self.__user_config_file, self.__project_config_file = self._find_config()
self._user_config_file, self._project_config_file = self._find_config()

self.load()

Expand All @@ -91,24 +91,28 @@ def _find_config(self):
> user_config, project_config = self._find_config()
"""
config_file_name = ".cfnlintrc"
self.__user_config_file = Path.home().joinpath(config_file_name)

self.__project_config_file = Path.cwd().joinpath(config_file_name)
if self._has_file(config_file_name + ".yaml"):
self.__project_config_file = Path.cwd().joinpath(config_file_name + ".yaml")
elif self._has_file(config_file_name + ".yml"):
self.__project_config_file = Path.cwd().joinpath(config_file_name + ".yml")

user_config_path = ""
project_config_path = ""
home_path = Path.home()
for path in [
home_path.joinpath(config_file_name),
home_path.joinpath(f"{config_file_name}.yaml"),
home_path.joinpath(f"{config_file_name}.yml"),
]:
if self._has_file(path):
user_config_path = path
break

if self._has_file(self.__user_config_file):
LOGGER.debug("Found User CFNLINTRC")
user_config_path = self.__user_config_file

if self._has_file(self.__project_config_file):
LOGGER.debug("Found Project level CFNLINTRC")
project_config_path = self.__project_config_file
project_config_path = ""
cwd_path = Path.cwd()
for path in [
cwd_path.joinpath(config_file_name),
cwd_path.joinpath(f"{config_file_name}.yaml"),
cwd_path.joinpath(f"{config_file_name}.yml"),
]:
if self._has_file(path):
project_config_path = path
break

return user_config_path, project_config_path

Expand All @@ -133,20 +137,20 @@ def load(self):
CFLINTRC configuration
"""

if self.__custom_config_file:
custom_config = self._read_config(self.__custom_config_file)
if self._custom_config_file:
custom_config = self._read_config(self._custom_config_file)
LOGGER.debug("Validating Custom CFNLINTRC")
self.validate_config(custom_config, self.schema)
LOGGER.debug("Custom configuration loaded as")
LOGGER.debug("%s", custom_config)

self.file_args = custom_config
else:
user_config = self._read_config(self.__user_config_file)
user_config = self._read_config(self._user_config_file)
LOGGER.debug("Validating User CFNLINTRC")
self.validate_config(user_config, self.schema)

project_config = self._read_config(self.__project_config_file)
project_config = self._read_config(self._project_config_file)
LOGGER.debug("Validating Project CFNLINTRC")
self.validate_config(project_config, self.schema)

Expand Down Expand Up @@ -577,34 +581,20 @@ def set_template_args(self, template):
configs = template.get("Metadata", {}).get("cfn-lint", {}).get("config", {})

if isinstance(configs, dict):
for config_name, config_value in configs.items():
if config_name == "ignore_checks":
if isinstance(config_value, list):
defaults["ignore_checks"] = config_value
if config_name == "regions":
if isinstance(config_value, list):
defaults["regions"] = config_value
if config_name == "append_rules":
if isinstance(config_value, list):
defaults["append_rules"] = config_value
if config_name == "override_spec":
if isinstance(config_value, (str)):
defaults["override_spec"] = config_value
if config_name == "custom_rules":
if isinstance(config_value, (str)):
defaults["custom_rules"] = config_value
if config_name == "ignore_bad_template":
if isinstance(config_value, bool):
defaults["ignore_bad_template"] = config_value
if config_name == "include_checks":
if isinstance(config_value, list):
defaults["include_checks"] = config_value
if config_name == "configure_rules":
if isinstance(config_value, dict):
defaults["configure_rules"] = config_value
if config_name == "include_experimental":
if isinstance(config_value, bool):
defaults["include_experimental"] = config_value
for key, value in {
"ignore_checks": (list),
"regions": (list),
"append_rules": (list),
"override_spec": (str),
"custom_rules": (str),
"ignore_bad_template": (bool),
"include_checks": (list),
"configure_rules": (dict),
"include_experimental": (bool),
}.items():
if key in configs:
if isinstance(configs[key], value):
defaults[key] = configs[key]

self._template_args = defaults

Expand Down
10 changes: 8 additions & 2 deletions src/cfnlint/context/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,11 +224,13 @@ class Parameter(_Ref):
type: str = field(init=False)
default: Any = field(init=False)
allowed_values: Any = field(init=False)
description: str = field(init=False)
description: str | None = field(init=False)

parameter: InitVar[Any]

def __post_init__(self, parameter) -> None:
if not isinstance(parameter, dict):
raise ValueError("Parameter must be a object")
self.default = None
self.allowed_values = []
self.min_value = None
Expand All @@ -249,7 +251,7 @@ def __post_init__(self, parameter) -> None:

if self.type == "CommaDelimitedList" or self.type.startswith("List<"):
if "Default" in parameter:
self.default = parameter.get("Default").split(",")
self.default = parameter.get("Default", "").split(",")
for allowed_value in parameter.get("AllowedValues", []):
self.allowed_values.append(allowed_value.split(","))
else:
Expand Down Expand Up @@ -294,6 +296,8 @@ class Resource(_Ref):
resource: InitVar[Any]

def __post_init__(self, resource) -> None:
if not isinstance(resource, dict):
raise ValueError("Resource must be a object")
t = resource.get("Type")
if not isinstance(t, str):
raise ValueError("Type must be a string")
Expand Down Expand Up @@ -325,6 +329,8 @@ def __post_init__(self, instance) -> None:
for k, v in instance.items():
if isinstance(v, (str, list, int, float)):
self.keys[k] = v
else:
raise ValueError("Third keys must not be an object")

def value(self, secondary_key: str):
if secondary_key not in self.keys:
Expand Down
15 changes: 15 additions & 0 deletions test/unit/module/conditions/test_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,3 +289,18 @@ def test_test_condition(self):
{h_region: "us-east-1", h_environment: "dev"}
)
)

def test_build_scenerios_on_region_with_condition_dne(self):
"""Get condition and test"""
template = decode_str(
"""
Conditions:
IsUsEast1: !Equals [!Ref AWS::Region, "us-east-1"]
"""
)[0]

cfn = Template("", template)
self.assertListEqual(
list(cfn.conditions.build_scenerios_on_region("IsProd", "us-east-1")),
[True, False],
)
48 changes: 48 additions & 0 deletions test/unit/module/config/test_config_file_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,51 @@ def test_config_parser_fail_on_config_rules(self, yaml_mock):
self.assertEqual(
results.file_args, {"configure_rules": {"E3012": {"strict": False}}}
)

@patch("pathlib.Path.is_file", create=True)
def test_config_parser_is_file_both(self, is_file_mock):
calls = [
True,
True,
False,
False,
]
is_file_mock.side_effect = calls
my_config = cfnlint.config.ConfigFileArgs()
self.assertEqual(my_config._user_config_file.name, ".cfnlintrc")
self.assertEqual(my_config._project_config_file.name, ".cfnlintrc")
self.assertEqual(is_file_mock.call_count, len(calls))

@patch("pathlib.Path.is_file", create=True)
def test_config_parser_is_file_both_yaml(self, is_file_mock):
calls = [
False,
True,
False,
True,
False,
False,
]
is_file_mock.side_effect = calls
my_config = cfnlint.config.ConfigFileArgs()
self.assertEqual(my_config._user_config_file.name, ".cfnlintrc.yaml")
self.assertEqual(my_config._project_config_file.name, ".cfnlintrc.yaml")
self.assertEqual(is_file_mock.call_count, len(calls))

@patch("pathlib.Path.is_file", create=True)
def test_config_parser_is_file_both_yml(self, is_file_mock):
calls = [
False,
False,
True,
False,
False,
True,
False,
False,
]
is_file_mock.side_effect = calls
my_config = cfnlint.config.ConfigFileArgs()
self.assertEqual(my_config._user_config_file.name, ".cfnlintrc.yml")
self.assertEqual(my_config._project_config_file.name, ".cfnlintrc.yml")
self.assertEqual(is_file_mock.call_count, len(calls))
12 changes: 12 additions & 0 deletions test/unit/module/config/test_template_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,15 @@ def test_template_args_failure_good_and_bad_value(self):
)

self.assertEqual(config.template_args.get("configure_rules"), None)

def test_bad_template_structure(self):
"""test template args"""
config = cfnlint.config.TemplateArgs([])

self.assertEqual(config._template_args, {})

def test_bad_config_structure(self):
"""test template args"""
config = cfnlint.config.TemplateArgs({"Metadata": {"cfn-lint": {"config": []}}})

self.assertEqual(config._template_args, {})
13 changes: 13 additions & 0 deletions test/unit/module/context/test_conditions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""
Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
SPDX-License-Identifier: MIT-0
"""

import pytest

from cfnlint.context.context import _init_conditions


def test_conditions():
with pytest.raises(ValueError):
_init_conditions([])
104 changes: 104 additions & 0 deletions test/unit/module/context/test_create_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
"""
Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
SPDX-License-Identifier: MIT-0
"""

from collections import namedtuple

import pytest

from cfnlint import Template
from cfnlint.context.context import create_context_for_template

_Counts = namedtuple("_Counts", ["resources", "parameters", "conditions", "mappings"])


@pytest.mark.parametrize(
"name,instance,counts",
[
(
"Valid template",
{
"Parameters": {
"Env": {
"Type": "String",
}
},
"Conditions": {
"IsUsEast1": {"Fn::Equals": [{"Ref": "AWS::Region"}, "us-east-1"]}
},
"Mappings": {"Map": {"us-east-1": {"foo": "bar"}}},
"Resources": {
"Bucket": {
"Type": "AWS::S3::Bucket",
},
},
},
_Counts(resources=1, parameters=1, conditions=1, mappings=1),
),
(
"Bad types in template",
{
"Parameters": [],
"Conditions": [],
"Mappings": [],
"Resources": [],
},
_Counts(resources=0, parameters=0, conditions=0, mappings=0),
),
(
"Invalid type configurations",
{
"Parameters": {
"BusinessUnit": [],
"Env": {
"Type": "String",
},
},
"Mappings": {"AnotherMap": [], "Map": {"us-east-1": {"foo": "bar"}}},
"Resources": {
"Instance": [],
"Bucket": {
"Type": "AWS::S3::Bucket",
},
},
},
_Counts(resources=1, parameters=1, conditions=0, mappings=1),
),
(
"Invalid mapping second key",
{
"Mappings": {
"BadKey": {
"Foo": [],
},
"Map": {"us-east-1": {"foo": "bar"}},
},
},
_Counts(resources=0, parameters=0, conditions=0, mappings=1),
),
(
"Invalid mapping third key",
{
"Mappings": {
"BadKey": {
"Foo": {
"Bar": {},
},
},
"Map": {"us-east-1": {"foo": "bar"}},
},
},
_Counts(resources=0, parameters=0, conditions=0, mappings=1),
),
],
)
def test_create_context(name, instance, counts):
cfn = Template("", instance, ["us-east-1"])
context = create_context_for_template(cfn)

for i in counts._fields:
assert len(getattr(context, i)) == getattr(counts, i), (
f"Test {name} has {i} {len(getattr(context, i))} "
"and expected {getattr(counts, i)}"
)
Loading
Loading