Skip to content

Commit

Permalink
Adds checks to make sure that leaf nodes meet these criteria
Browse files Browse the repository at this point in the history
- Each leaf has a unique value
- "if_night" values do not point to leaves that have "if_night" values
  • Loading branch information
MoseleyS committed Sep 22, 2023
1 parent ec4a11b commit bf4bdaa
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 13 deletions.
18 changes: 17 additions & 1 deletion improver/categorical/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,16 @@ def check_tree(

all_key_words = REQUIRED_KEY_WORDS + OPTIONAL_KEY_WORDS
all_leaf_key_words = LEAF_REQUIRED_KEY_WORDS + LEAF_OPTIONAL_KEY_WORDS

# Check that all leaves have a unique "leaf" value
all_leaves = [v["leaf"] for v in decision_tree.values() if "leaf" in v.keys()]
unique_leaves = set()
duplicates = [x for x in all_leaves if x in unique_leaves or unique_leaves.add(x)]
if duplicates:
issues.append(
f"These leaf categories are used more than once: {sorted(list(set(duplicates)))}"
)

for node, items in decision_tree.items():
if "leaf" in items.keys():
# Check the leaf only contains expected keys
Expand All @@ -408,7 +418,8 @@ def check_tree(
if not isinstance(leaf_target, int):
issues.append(f"Leaf '{node}' has non-int target: {leaf_target}")

# If leaf has "if_night", check it points to another leaf.
# If leaf has "if_night", check it points to another leaf
# AND that the other leaf does NOT have "if_night".
if "if_night" in items.keys():
target = decision_tree.get(items["if_night"], None)
if not target:
Expand All @@ -419,6 +430,11 @@ def check_tree(
issues.append(
f"Target '{items['if_night']}' of leaf '{node}' is not a leaf."
)
elif "if_night" in target.keys():
issues.append(
f"Night target '{items['if_night']}' of leaf '{node}' "
"also has a night target."
)
# If leaf has "group", check the group contains at least two members.
if "group" in items.keys():
members = [
Expand Down
58 changes: 46 additions & 12 deletions improver_tests/categorical/test_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,18 +530,6 @@ def modify_tree_fixture(node, key, value):
),
),
("Thunder", "leaf", 10.2, "Leaf 'Thunder' has non-int target: 10.2",),
(
"Clear_Night",
"if_night",
"kittens",
"Leaf 'Clear_Night' does not point to a valid target (kittens).",
),
(
"Clear_Night",
"if_night",
"lightning_shower",
"Target 'lightning_shower' of leaf 'Clear_Night' is not a leaf.",
),
(
"Clear_Night",
"pets",
Expand Down Expand Up @@ -644,6 +632,52 @@ def test_check_tree(modify_tree, expected):
assert result == expected


@pytest.mark.parametrize("node, key", (("Sunny_Day", "if_night"),))
@pytest.mark.parametrize(
"value, expected",
(
("kittens", "Leaf 'Sunny_Day' does not point to a valid target (kittens).",),
(
"Partly_Cloudy_Day",
"Night target 'Partly_Cloudy_Day' of leaf 'Sunny_Day' also has a night target.",
),
(
"lightning_shower",
"Target 'lightning_shower' of leaf 'Sunny_Day' is not a leaf.",
),
),
)
def test_check_tree_if_night(modify_tree, expected):
"""Test that the various possible decision tree problems related to if_night are identified.
These are separated out because we need to mark the night leaf as unreachable"""
modify_tree["Clear_Night"]["is_unreachable"] = True
result = check_tree(modify_tree)
assert result == expected


@pytest.mark.parametrize(
"nodes, expected",
(
({"Thunder": 28}, "These leaf categories are used more than once: [28]",),
(
{"Thunder": 28, "Thunder_Shower_Day": 28},
"These leaf categories are used more than once: [28]",
),
(
{"Thunder": 28, "Heavy_Snow": 26},
"These leaf categories are used more than once: [26, 28]",
),
),
)
def test_check_tree_duplicate_leaves(nodes, expected):
"""Test that the various possible leaf duplicates are identified."""
tree = wxcode_decision_tree()
for node, value in nodes.items():
tree[node]["leaf"] = value
result = check_tree(tree)
assert result == expected


def test_check_tree_non_dictionary():
"""Check ValueError is raised if non-dictionary is passed to check_tree."""
expected = "Decision tree is not a dictionary"
Expand Down

0 comments on commit bf4bdaa

Please sign in to comment.