From 5f3bc5bb7662dddc11997f53983cb41054d04539 Mon Sep 17 00:00:00 2001 From: Nicholas Parente Date: Mon, 30 Dec 2024 11:23:39 -0500 Subject: [PATCH] adding default case Signed-off-by: Nicholas Parente --- dowhy/causal_identifier/adjustment_set.py | 2 +- dowhy/causal_identifier/auto_identifier.py | 22 +++++++++++- dowhy/graph.py | 41 ++++++++++++++++++++++ 3 files changed, 63 insertions(+), 2 deletions(-) diff --git a/dowhy/causal_identifier/adjustment_set.py b/dowhy/causal_identifier/adjustment_set.py index 13c728598..48a5f7877 100644 --- a/dowhy/causal_identifier/adjustment_set.py +++ b/dowhy/causal_identifier/adjustment_set.py @@ -23,5 +23,5 @@ def get_variables(self): return self.variables def get_num_paths_blocked_by_observed_nodes(self): - """Return the number of paths blocked by the observed nodes (optional)""" + """Return the number of paths blocked by observed nodes (optional)""" return self.num_paths_blocked_by_observed_nodes diff --git a/dowhy/causal_identifier/auto_identifier.py b/dowhy/causal_identifier/auto_identifier.py index 74c653f25..7d9622fdc 100644 --- a/dowhy/causal_identifier/auto_identifier.py +++ b/dowhy/causal_identifier/auto_identifier.py @@ -2,6 +2,7 @@ import logging from enum import Enum from typing import Dict, List, Optional, Union +import copy import networkx as nx import sympy as sp @@ -21,6 +22,8 @@ get_descendants, get_instruments, has_directed_path, + get_proper_causal_path_nodes, + get_proper_backdoor_graph ) from dowhy.utils.api import parse_state @@ -884,7 +887,24 @@ def identify_complete_adjustment_set( observed_nodes: List[str], covariate_adjustment: CovariateAdjustment = CovariateAdjustment.COVARIATE_ADJUSTMENT_DEFAULT ) -> List[AdjustmentSet]: - # TODO: Implement this. Must return a list of AdjustmentSet objects. + + graph_pbd = get_proper_backdoor_graph(graph, action_nodes, outcome_nodes) + pcp_nodes = get_proper_causal_path_nodes(graph, action_nodes, outcome_nodes) + + if covariate_adjustment == CovariateAdjustment.COVARIATE_ADJUSTMENT_DEFAULT: + # In default case, we don't find all exhaustive adjustment sets + adjustment_set = nx.algorithms.find_minimal_d_separator( + graph_pbd, + action_nodes, + outcome_nodes, + # Require the adjustment set to consist only of observed nodes + restricted=((set(graph.nodes) - set(pcp_nodes)) & set(observed_nodes)) + ) + if adjustment_set is None: + logger.info("No adjustment sets found.") + return [] + return [AdjustmentSet(AdjustmentSet.GENERAL, adjustment_set)] + return [AdjustmentSet(AdjustmentSet.GENERAL, [])] diff --git a/dowhy/graph.py b/dowhy/graph.py index f2258927a..84f274ea2 100644 --- a/dowhy/graph.py +++ b/dowhy/graph.py @@ -5,6 +5,7 @@ import re from abc import abstractmethod from typing import Any, List, Protocol +import copy import networkx as nx from networkx.algorithms.dag import has_cycle @@ -187,6 +188,13 @@ def is_blocked(graph: nx.DiGraph, path, conditioned_nodes): return False +def get_ancestors(graph: nx.DiGraph, nodes): + ancestors = set() + for node_name in nodes: + ancestors = ancestors.union(set(nx.ancestors(graph, node_name))) + return ancestors + + def get_descendants(graph: nx.DiGraph, nodes): descendants = set() for node_name in nodes: @@ -194,6 +202,39 @@ def get_descendants(graph: nx.DiGraph, nodes): return descendants +def get_proper_causal_path_nodes(graph: nx.DiGraph, action_nodes, outcome_nodes): + # Process is described in Van Der Zander et al. "Constructing Separators and + # Adjustment Sets in Ancestral Graphs", Section 4.1 + + # 1) Create modified graphs removing inbound and outbound arrows from the action nodes, respectively. + graph_post_interv = copy.deepcopy(graph) # remove incoming arrows to our action nodes + edges_to_remove = [(u, v) for u, v in graph_post_interv.in_edges(action_nodes)] + graph_post_interv.remove_edges_from(edges_to_remove) + graph_with_action_nodes_as_sinks = copy.deepcopy(graph) # remove outbound arrows from our action nodes + edges_to_remove = [(u, v) for u, v in graph_with_action_nodes_as_sinks.out_edges(action_nodes)] + graph_with_action_nodes_as_sinks.remove_edges_from(edges_to_remove) + + # 2) Use the modified graphs to identify the nodes which lie on proper causal paths from the + # action nodes to the outcome nodes. + de_x = get_descendants(graph_post_interv, action_nodes) + an_y = get_ancestors(graph_with_action_nodes_as_sinks, outcome_nodes) + return (set(de_x) - set(action_nodes)) & an_y + + +def get_proper_backdoor_graph(graph: nx.DiGraph, action_nodes, outcome_nodes): + # Process is described in Van Der Zander et al. "Constructing Separators and + # Adjustment Sets in Ancestral Graphs", Section 4.1 + + # First we can just call get_proper_causal_path_nodes, then + # we remove edges from the action_nodes to the proper causal path nodes + graph_pbd = copy.deepcopy(graph) + graph_pbd.remove_edges_from( + [(u, v) for u in action_nodes for v in get_proper_causal_path_nodes(graph, action_nodes, outcome_nodes)] + ) + return graph_pbd + + + def check_dseparation(graph: nx.DiGraph, nodes1, nodes2, nodes3, new_graph=None, dseparation_algo="default"): if dseparation_algo == "default": if new_graph is None: