From 9be439e578a1dba4e81034d3c58a73a52590b2e6 Mon Sep 17 00:00:00 2001 From: Nicholas Parente Date: Wed, 2 Oct 2024 21:36:01 -0400 Subject: [PATCH] Cleaning up and refactoring causal_learns has_directed_path Signed-off-by: Nicholas Parente --- dowhy/causal_graph.py | 9 +++++---- dowhy/graph.py | 2 -- .../causal_identifiers/test_auto_identifier.py | 17 +++++++---------- tests/test_causal_graph.py | 1 + 4 files changed, 13 insertions(+), 16 deletions(-) diff --git a/dowhy/causal_graph.py b/dowhy/causal_graph.py index 098be0af95..a7f01b6bf5 100755 --- a/dowhy/causal_graph.py +++ b/dowhy/causal_graph.py @@ -5,6 +5,7 @@ import networkx as nx from dowhy.gcm.causal_models import ProbabilisticCausalModel +from dowhy.graph import has_directed_path from dowhy.utils.api import parse_state from dowhy.utils.graph_operations import daggity_to_dot from dowhy.utils.plotting import plot @@ -449,11 +450,11 @@ def get_all_directed_paths(self, nodes1, nodes2): def has_directed_path(self, nodes1, nodes2): """Checks if there is any directed path between two sets of nodes. - Currently only supports singleton sets. + Returns True if and only if every one of the treatments has at least one direct + path to one of the outcomes. And, every one of the outcomes has a direct path from + at least one of the treatments. """ - # dpaths = self.get_all_directed_paths(nodes1, nodes2) - # return len(dpaths) > 0 - return nx.has_path(self._graph, nodes1[0], nodes2[0]) + return has_directed_path(self._graph, nodes1[0], nodes2[0]) def get_adjacency_matrix(self, *args, **kwargs): """ diff --git a/dowhy/graph.py b/dowhy/graph.py index 3588b2c1da..55551ed0d3 100644 --- a/dowhy/graph.py +++ b/dowhy/graph.py @@ -277,8 +277,6 @@ def has_directed_path(graph: nx.DiGraph, nodes1, nodes2): path to one of the outcomes. And, every one of the outcomes has a direct path from at least one of the treatments. """ - # dpaths = self.get_all_directed_paths(nodes1, nodes2) - # return len(dpaths) > 0 outcome_node_candidates = set() action_node_candidates = set() for node in nodes1: diff --git a/tests/causal_identifiers/test_auto_identifier.py b/tests/causal_identifiers/test_auto_identifier.py index c13203bdb6..51c04daa6e 100644 --- a/tests/causal_identifiers/test_auto_identifier.py +++ b/tests/causal_identifiers/test_auto_identifier.py @@ -4,17 +4,14 @@ class TestAutoIdentification(object): - def test_auto_identify_consistently_checks_for_directed_paths(self): + def test_auto_identify_identifies_no_directed_path(self): + # Test added for issue #1250 graph = build_graph_from_str("digraph{T->Y;A->Y;A->B;}") identifier = AutoIdentifier(estimand_type=EstimandType.NONPARAMETRIC_ATE) - identified_estimand = identifier.identify_effect( + + assert identifier.identify_effect( graph, action_nodes=["T", "B"], outcome_nodes=["Y"], observed_nodes=["T", "Y", "A", "B"] - ) - identified_estimand_swapped_action_order = identifier.identify_effect( + ).no_directed_path + assert identifier.identify_effect( graph, action_nodes=["B", "T"], outcome_nodes=["Y"], observed_nodes=["T", "Y", "A", "B"] - ) - backdoor_vars = identified_estimand.get_backdoor_variables() - backdoor_vars_swapped_action_order = identified_estimand_swapped_action_order.get_backdoor_variables() - - assert len(backdoor_vars) == 0 - assert len(backdoor_vars_swapped_action_order) == 0 + ).no_directed_path diff --git a/tests/test_causal_graph.py b/tests/test_causal_graph.py index a9949bd57f..b386f11057 100644 --- a/tests/test_causal_graph.py +++ b/tests/test_causal_graph.py @@ -123,4 +123,5 @@ def test_build_graph_from_str(self): def test_has_path(self): assert has_directed_path(self.nx_graph, ["X0"], ["y"]) assert has_directed_path(self.nx_graph, ["X0", "X1", "X2"], ["y", "v0"]) + assert not has_directed_path(self.nx_graph, [], ["y"]) assert not has_directed_path(self.nx_graph, ["X0", "X1", "X2"], ["y", "v0", "Z0"])