Skip to content

Commit

Permalink
adding default case
Browse files Browse the repository at this point in the history
Signed-off-by: Nicholas Parente <parentenickj@gmail.com>
  • Loading branch information
nparent1 committed Dec 30, 2024
1 parent 0261ad0 commit 5f3bc5b
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 2 deletions.
2 changes: 1 addition & 1 deletion dowhy/causal_identifier/adjustment_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,5 @@ def get_variables(self):
return self.variables

def get_num_paths_blocked_by_observed_nodes(self):
"""Return the number of paths blocked by the observed nodes (optional)"""
"""Return the number of paths blocked by observed nodes (optional)"""
return self.num_paths_blocked_by_observed_nodes
22 changes: 21 additions & 1 deletion dowhy/causal_identifier/auto_identifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
from enum import Enum
from typing import Dict, List, Optional, Union
import copy

import networkx as nx
import sympy as sp
Expand All @@ -21,6 +22,8 @@
get_descendants,
get_instruments,
has_directed_path,
get_proper_causal_path_nodes,
get_proper_backdoor_graph
)
from dowhy.utils.api import parse_state

Expand Down Expand Up @@ -884,7 +887,24 @@ def identify_complete_adjustment_set(
observed_nodes: List[str],
covariate_adjustment: CovariateAdjustment = CovariateAdjustment.COVARIATE_ADJUSTMENT_DEFAULT
) -> List[AdjustmentSet]:
# TODO: Implement this. Must return a list of AdjustmentSet objects.

graph_pbd = get_proper_backdoor_graph(graph, action_nodes, outcome_nodes)
pcp_nodes = get_proper_causal_path_nodes(graph, action_nodes, outcome_nodes)

if covariate_adjustment == CovariateAdjustment.COVARIATE_ADJUSTMENT_DEFAULT:
# In default case, we don't find all exhaustive adjustment sets
adjustment_set = nx.algorithms.find_minimal_d_separator(
graph_pbd,
action_nodes,
outcome_nodes,
# Require the adjustment set to consist only of observed nodes
restricted=((set(graph.nodes) - set(pcp_nodes)) & set(observed_nodes))
)
if adjustment_set is None:
logger.info("No adjustment sets found.")
return []
return [AdjustmentSet(AdjustmentSet.GENERAL, adjustment_set)]

return [AdjustmentSet(AdjustmentSet.GENERAL, [])]


Expand Down
41 changes: 41 additions & 0 deletions dowhy/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import re
from abc import abstractmethod
from typing import Any, List, Protocol
import copy

import networkx as nx
from networkx.algorithms.dag import has_cycle
Expand Down Expand Up @@ -187,13 +188,53 @@ def is_blocked(graph: nx.DiGraph, path, conditioned_nodes):
return False


def get_ancestors(graph: nx.DiGraph, nodes):
ancestors = set()
for node_name in nodes:
ancestors = ancestors.union(set(nx.ancestors(graph, node_name)))
return ancestors


def get_descendants(graph: nx.DiGraph, nodes):
descendants = set()
for node_name in nodes:
descendants = descendants.union(set(nx.descendants(graph, node_name)))
return descendants


def get_proper_causal_path_nodes(graph: nx.DiGraph, action_nodes, outcome_nodes):
# Process is described in Van Der Zander et al. "Constructing Separators and
# Adjustment Sets in Ancestral Graphs", Section 4.1

# 1) Create modified graphs removing inbound and outbound arrows from the action nodes, respectively.
graph_post_interv = copy.deepcopy(graph) # remove incoming arrows to our action nodes
edges_to_remove = [(u, v) for u, v in graph_post_interv.in_edges(action_nodes)]
graph_post_interv.remove_edges_from(edges_to_remove)
graph_with_action_nodes_as_sinks = copy.deepcopy(graph) # remove outbound arrows from our action nodes
edges_to_remove = [(u, v) for u, v in graph_with_action_nodes_as_sinks.out_edges(action_nodes)]
graph_with_action_nodes_as_sinks.remove_edges_from(edges_to_remove)

# 2) Use the modified graphs to identify the nodes which lie on proper causal paths from the
# action nodes to the outcome nodes.
de_x = get_descendants(graph_post_interv, action_nodes)
an_y = get_ancestors(graph_with_action_nodes_as_sinks, outcome_nodes)
return (set(de_x) - set(action_nodes)) & an_y


def get_proper_backdoor_graph(graph: nx.DiGraph, action_nodes, outcome_nodes):
# Process is described in Van Der Zander et al. "Constructing Separators and
# Adjustment Sets in Ancestral Graphs", Section 4.1

# First we can just call get_proper_causal_path_nodes, then
# we remove edges from the action_nodes to the proper causal path nodes
graph_pbd = copy.deepcopy(graph)
graph_pbd.remove_edges_from(
[(u, v) for u in action_nodes for v in get_proper_causal_path_nodes(graph, action_nodes, outcome_nodes)]
)
return graph_pbd



def check_dseparation(graph: nx.DiGraph, nodes1, nodes2, nodes3, new_graph=None, dseparation_algo="default"):
if dseparation_algo == "default":
if new_graph is None:
Expand Down

0 comments on commit 5f3bc5b

Please sign in to comment.