diff --git a/improver/categorical/utilities.py b/improver/categorical/utilities.py index 51ba0dc896..6505345dd3 100644 --- a/improver/categorical/utilities.py +++ b/improver/categorical/utilities.py @@ -388,6 +388,16 @@ def check_tree( all_key_words = REQUIRED_KEY_WORDS + OPTIONAL_KEY_WORDS all_leaf_key_words = LEAF_REQUIRED_KEY_WORDS + LEAF_OPTIONAL_KEY_WORDS + + # Check that all leaves have a unique "leaf" value + all_leaves = [v["leaf"] for v in decision_tree.values() if "leaf" in v.keys()] + unique_leaves = set() + duplicates = [x for x in all_leaves if x in unique_leaves or unique_leaves.add(x)] + if duplicates: + issues.append( + f"These leaf categories are used more than once: {sorted(list(set(duplicates)))}" + ) + for node, items in decision_tree.items(): if "leaf" in items.keys(): # Check the leaf only contains expected keys @@ -408,7 +418,8 @@ def check_tree( if not isinstance(leaf_target, int): issues.append(f"Leaf '{node}' has non-int target: {leaf_target}") - # If leaf has "if_night", check it points to another leaf. + # If leaf has "if_night", check it points to another leaf + # AND that the other leaf does NOT have "if_night". if "if_night" in items.keys(): target = decision_tree.get(items["if_night"], None) if not target: @@ -419,6 +430,11 @@ def check_tree( issues.append( f"Target '{items['if_night']}' of leaf '{node}' is not a leaf." ) + elif "if_night" in target.keys(): + issues.append( + f"Night target '{items['if_night']}' of leaf '{node}' " + "also has a night target." + ) # If leaf has "group", check the group contains at least two members. if "group" in items.keys(): members = [ diff --git a/improver_tests/categorical/test_utilities.py b/improver_tests/categorical/test_utilities.py index 04c6482b8c..8116a6e90b 100644 --- a/improver_tests/categorical/test_utilities.py +++ b/improver_tests/categorical/test_utilities.py @@ -530,18 +530,6 @@ def modify_tree_fixture(node, key, value): ), ), ("Thunder", "leaf", 10.2, "Leaf 'Thunder' has non-int target: 10.2",), - ( - "Clear_Night", - "if_night", - "kittens", - "Leaf 'Clear_Night' does not point to a valid target (kittens).", - ), - ( - "Clear_Night", - "if_night", - "lightning_shower", - "Target 'lightning_shower' of leaf 'Clear_Night' is not a leaf.", - ), ( "Clear_Night", "pets", @@ -644,6 +632,52 @@ def test_check_tree(modify_tree, expected): assert result == expected +@pytest.mark.parametrize("node, key", (("Sunny_Day", "if_night"),)) +@pytest.mark.parametrize( + "value, expected", + ( + ("kittens", "Leaf 'Sunny_Day' does not point to a valid target (kittens).",), + ( + "Partly_Cloudy_Day", + "Night target 'Partly_Cloudy_Day' of leaf 'Sunny_Day' also has a night target.", + ), + ( + "lightning_shower", + "Target 'lightning_shower' of leaf 'Sunny_Day' is not a leaf.", + ), + ), +) +def test_check_tree_if_night(modify_tree, expected): + """Test that the various possible decision tree problems related to if_night are identified. + These are separated out because we need to mark the night leaf as unreachable""" + modify_tree["Clear_Night"]["is_unreachable"] = True + result = check_tree(modify_tree) + assert result == expected + + +@pytest.mark.parametrize( + "nodes, expected", + ( + ({"Thunder": 28}, "These leaf categories are used more than once: [28]",), + ( + {"Thunder": 28, "Thunder_Shower_Day": 28}, + "These leaf categories are used more than once: [28]", + ), + ( + {"Thunder": 28, "Heavy_Snow": 26}, + "These leaf categories are used more than once: [26, 28]", + ), + ), +) +def test_check_tree_duplicate_leaves(nodes, expected): + """Test that the various possible leaf duplicates are identified.""" + tree = wxcode_decision_tree() + for node, value in nodes.items(): + tree[node]["leaf"] = value + result = check_tree(tree) + assert result == expected + + def test_check_tree_non_dictionary(): """Check ValueError is raised if non-dictionary is passed to check_tree.""" expected = "Decision tree is not a dictionary"