Skip to content

Commit

Permalink
Cleaning up and refactoring causal_learns has_directed_path
Browse files Browse the repository at this point in the history
Signed-off-by: Nicholas Parente <parentenickj@gmail.com>
  • Loading branch information
nparent1 committed Oct 3, 2024
1 parent daa868f commit 9be439e
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 16 deletions.
9 changes: 5 additions & 4 deletions dowhy/causal_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import networkx as nx

from dowhy.gcm.causal_models import ProbabilisticCausalModel
from dowhy.graph import has_directed_path
from dowhy.utils.api import parse_state
from dowhy.utils.graph_operations import daggity_to_dot
from dowhy.utils.plotting import plot
Expand Down Expand Up @@ -449,11 +450,11 @@ def get_all_directed_paths(self, nodes1, nodes2):
def has_directed_path(self, nodes1, nodes2):
"""Checks if there is any directed path between two sets of nodes.
Currently only supports singleton sets.
Returns True if and only if every one of the treatments has at least one direct
path to one of the outcomes. And, every one of the outcomes has a direct path from
at least one of the treatments.
"""
# dpaths = self.get_all_directed_paths(nodes1, nodes2)
# return len(dpaths) > 0
return nx.has_path(self._graph, nodes1[0], nodes2[0])
return has_directed_path(self._graph, nodes1[0], nodes2[0])

def get_adjacency_matrix(self, *args, **kwargs):
"""
Expand Down
2 changes: 0 additions & 2 deletions dowhy/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,8 +277,6 @@ def has_directed_path(graph: nx.DiGraph, nodes1, nodes2):
path to one of the outcomes. And, every one of the outcomes has a direct path from
at least one of the treatments.
"""
# dpaths = self.get_all_directed_paths(nodes1, nodes2)
# return len(dpaths) > 0
outcome_node_candidates = set()
action_node_candidates = set()
for node in nodes1:
Expand Down
17 changes: 7 additions & 10 deletions tests/causal_identifiers/test_auto_identifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,14 @@


class TestAutoIdentification(object):
def test_auto_identify_consistently_checks_for_directed_paths(self):
def test_auto_identify_identifies_no_directed_path(self):
# Test added for issue #1250
graph = build_graph_from_str("digraph{T->Y;A->Y;A->B;}")
identifier = AutoIdentifier(estimand_type=EstimandType.NONPARAMETRIC_ATE)
identified_estimand = identifier.identify_effect(

assert identifier.identify_effect(
graph, action_nodes=["T", "B"], outcome_nodes=["Y"], observed_nodes=["T", "Y", "A", "B"]
)
identified_estimand_swapped_action_order = identifier.identify_effect(
).no_directed_path
assert identifier.identify_effect(
graph, action_nodes=["B", "T"], outcome_nodes=["Y"], observed_nodes=["T", "Y", "A", "B"]
)
backdoor_vars = identified_estimand.get_backdoor_variables()
backdoor_vars_swapped_action_order = identified_estimand_swapped_action_order.get_backdoor_variables()

assert len(backdoor_vars) == 0
assert len(backdoor_vars_swapped_action_order) == 0
).no_directed_path
1 change: 1 addition & 0 deletions tests/test_causal_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,4 +123,5 @@ def test_build_graph_from_str(self):
def test_has_path(self):
assert has_directed_path(self.nx_graph, ["X0"], ["y"])
assert has_directed_path(self.nx_graph, ["X0", "X1", "X2"], ["y", "v0"])
assert not has_directed_path(self.nx_graph, [], ["y"])
assert not has_directed_path(self.nx_graph, ["X0", "X1", "X2"], ["y", "v0", "Z0"])

0 comments on commit 9be439e

Please sign in to comment.