diff --git a/dowhy/causal_estimators/two_stage_regression_estimator.py b/dowhy/causal_estimators/two_stage_regression_estimator.py index 4c2010ae89..ebb01558ed 100644 --- a/dowhy/causal_estimators/two_stage_regression_estimator.py +++ b/dowhy/causal_estimators/two_stage_regression_estimator.py @@ -84,6 +84,9 @@ def __init__( ) self.logger.info("INFO: Using Two Stage Regression Estimator") # Check if the treatment is one-dimensional + if len(self._target_estimand.treatment_variable) > 1: + error_msg = str(self.__class__) + "cannot handle more than one treatment variable" + raise Exception(error_msg) modified_target_estimand = copy.deepcopy(self._target_estimand) modified_target_estimand.identifier_method = "backdoor" modified_target_estimand.backdoor_variables = self._target_estimand.mediation_first_stage_confounders @@ -173,23 +176,32 @@ def fit( if self._target_estimand.identifier_method == "frontdoor": self.logger.debug("Front-door variable used:" + ",".join(self._target_estimand.get_frontdoor_variables())) self._frontdoor_variables_names = self._target_estimand.get_frontdoor_variables() - if self._frontdoor_variables_names: + if len(self._frontdoor_variables_names) > 1: + raise ValueError( + "Only singleton frontdoor variables are supported for estimation using TwoStageRegression." + ) self._frontdoor_variables = data[self._frontdoor_variables_names] else: self._frontdoor_variables = None error_msg = "No front-door variable present. Two stage regression is not applicable" self.logger.error(error_msg) + raise ValueError(error_msg) elif self._target_estimand.identifier_method == "mediation": self.logger.debug("Mediators used:" + ",".join(self._target_estimand.get_mediator_variables())) self._mediators_names = self._target_estimand.get_mediator_variables() if self._mediators_names: + if len(self._mediators_names) > 1: + raise ValueError( + "Only singleton mediator variables are supported for estimation using TwoStageRegression." + ) self._mediators = data[self._mediators_names] else: self._mediators = None error_msg = "No mediator variable present. Two stage regression is not applicable" self.logger.error(error_msg) + raise ValueError(error_msg) elif self._target_estimand.identifier_method == "iv": self.logger.debug( "Instrumental variables used:" + ",".join(self._target_estimand.get_instrumental_variables()) diff --git a/dowhy/causal_graph.py b/dowhy/causal_graph.py index 64e780baae..582e924a53 100755 --- a/dowhy/causal_graph.py +++ b/dowhy/causal_graph.py @@ -85,7 +85,7 @@ def __init__( import pydot P_list = pydot.graph_from_dot_data(graph) - self._graph = nx.drawing.nx_pydot.from_pydot(P_list[0]) + self._graph = nx.DiGraph(nx.drawing.nx_pydot.from_pydot(P_list[0])) except Exception as e: self.logger.error("Error: Pydot cannot be loaded. " + str(e)) raise e diff --git a/dowhy/causal_identifier/auto_identifier.py b/dowhy/causal_identifier/auto_identifier.py index ca3fb259a3..4fbcff3c4d 100644 --- a/dowhy/causal_identifier/auto_identifier.py +++ b/dowhy/causal_identifier/auto_identifier.py @@ -271,7 +271,7 @@ def identify_ate_effect( ### 3. FRONTDOOR IDENTIFICATION # Now checking if there is a valid frontdoor variable - frontdoor_variables_names = identify_frontdoor(graph, action_nodes, outcome_nodes) + frontdoor_variables_names = identify_frontdoor(graph, action_nodes, outcome_nodes, observed_nodes) logger.info("Frontdoor variables for treatment and outcome:" + str(frontdoor_variables_names)) if len(frontdoor_variables_names) > 0: frontdoor_estimand_expr = construct_frontdoor_estimand( @@ -777,18 +777,18 @@ def build_backdoor_estimands_dict( def identify_frontdoor( - graph: nx.DiGraph, action_nodes: List[str], outcome_nodes: List[str], dseparation_algo: str = "default" + graph: nx.DiGraph, + action_nodes: List[str], + outcome_nodes: List[str], + observed_nodes: List[str], + dseparation_algo: str = "default", ): - """Find a valid frontdoor variable if it exists. - - Currently only supports a single variable frontdoor set. - """ + """Find a valid frontdoor variable set if it exists.""" frontdoor_var = None frontdoor_paths = None fdoor_graph = None if dseparation_algo == "default": cond1_graph = do_surgery(graph, action_nodes, remove_incoming_edges=True) - bdoor_graph1 = do_surgery(graph, action_nodes, remove_outgoing_edges=True) elif dseparation_algo == "naive": frontdoor_paths = get_all_directed_paths(graph, action_nodes, outcome_nodes) else: @@ -797,48 +797,50 @@ def identify_frontdoor( eligible_variables = ( get_descendants(graph, action_nodes) - set(outcome_nodes) - set(get_descendants(graph, outcome_nodes)) ) - # For simplicity, assuming a one-variable frontdoor set - for candidate_var in eligible_variables: - # Cond 1: All directed paths intercepted by candidate_var - cond1 = check_valid_frontdoor_set( - graph, - action_nodes, - outcome_nodes, - parse_state(candidate_var), - frontdoor_paths=frontdoor_paths, - new_graph=cond1_graph, - dseparation_algo=dseparation_algo, - ) - logger.debug("Candidate frontdoor set: {0}, is_dseparated: {1}".format(candidate_var, cond1)) - if not cond1: - continue - # Cond 2: No confounding between treatment and candidate var - cond2 = check_valid_backdoor_set( - graph, - action_nodes, - parse_state(candidate_var), - set(), - backdoor_paths=None, - new_graph=bdoor_graph1, - dseparation_algo=dseparation_algo, - ) - if not cond2: - continue - # Cond 3: treatment blocks all confounding between candidate_var and outcome - bdoor_graph2 = do_surgery(graph, candidate_var, remove_outgoing_edges=True) - cond3 = check_valid_backdoor_set( - graph, - parse_state(candidate_var), - outcome_nodes, - action_nodes, - backdoor_paths=None, - new_graph=bdoor_graph2, - dseparation_algo=dseparation_algo, - ) - is_valid_frontdoor = cond1 and cond2 and cond3 - if is_valid_frontdoor: - frontdoor_var = candidate_var - break + eligible_variables = eligible_variables.intersection(set(observed_nodes)) + set_sizes = range(1, len(eligible_variables) + 1, 1) + for size_candidate_set in set_sizes: + for candidate_set in itertools.combinations(eligible_variables, size_candidate_set): + candidate_set = list(candidate_set) + # Cond 1: All directed paths intercepted by candidate_var + cond1 = check_valid_frontdoor_set( + graph, + action_nodes, + outcome_nodes, + candidate_set, + frontdoor_paths=frontdoor_paths, + new_graph=cond1_graph, + dseparation_algo=dseparation_algo, + ) + logger.debug("Candidate frontdoor set: {0}, Cond1: is_dseparated: {1}".format(candidate_set, cond1)) + if not cond1: + continue + # Cond 2: No confounding between treatment and candidate var + cond2 = check_valid_backdoor_set( + graph, + action_nodes, + candidate_set, + set(), + backdoor_paths=None, + dseparation_algo=dseparation_algo, + )["is_dseparated"] + if not cond2: + continue + # Cond 3: treatment blocks all confounding between candidate_var and outcome + bdoor_graph2 = do_surgery(graph, candidate_set, remove_outgoing_edges=True) + cond3 = check_valid_backdoor_set( + graph, + candidate_set, + outcome_nodes, + action_nodes, + backdoor_paths=None, + new_graph=bdoor_graph2, + dseparation_algo=dseparation_algo, + )["is_dseparated"] + is_valid_frontdoor = cond1 and cond2 and cond3 + if is_valid_frontdoor: + frontdoor_var = candidate_set + break return parse_state(frontdoor_var) diff --git a/dowhy/graph.py b/dowhy/graph.py index f739ffc069..805c940ca7 100644 --- a/dowhy/graph.py +++ b/dowhy/graph.py @@ -94,6 +94,7 @@ def check_valid_backdoor_set( if new_graph is None: # Assume that nodes1 is the treatment new_graph = do_surgery(graph, nodes1, remove_outgoing_edges=True) + dseparated = nx.algorithms.d_separated(new_graph, set(nodes1), set(nodes2), set(nodes3)) elif dseparation_algo == "naive": # ignores new_graph parameter, always uses self._graph @@ -417,7 +418,7 @@ def build_graph_from_str(graph_str: str) -> nx.DiGraph: import pydot P_list = pydot.graph_from_dot_data(graph_str) - return nx.drawing.nx_pydot.from_pydot(P_list[0]) + return nx.DiGraph(nx.drawing.nx_pydot.from_pydot(P_list[0])) except Exception as e: _logger.error("Error: Pydot cannot be loaded. " + str(e)) raise e diff --git a/tests/causal_estimators/test_two_stage_regression_estimator.py b/tests/causal_estimators/test_two_stage_regression_estimator.py index 5be515d0c5..e4849b4ca4 100644 --- a/tests/causal_estimators/test_two_stage_regression_estimator.py +++ b/tests/causal_estimators/test_two_stage_regression_estimator.py @@ -143,3 +143,37 @@ def test_frontdoor_estimator(self): # Estimate the effect with front-door estimate = model.estimate_effect(identified_estimand=estimand, method_name="frontdoor.two_stage_regression") assert estimate.value == pytest.approx(0.45, 0.025) + + @mark.parametrize( + [ + "Estimator", + "num_treatments", + "num_frontdoor_variables", + ], + [ + ( + TwoStageRegressionEstimator, + [2, 1], + [1, 2], + ) + ], + ) + def test_frontdoor_num_variables_error(self, Estimator, num_treatments, num_frontdoor_variables): + estimator_tester = TestEstimator(error_tolerance=0, Estimator=Estimator, identifier_method="frontdoor") + with pytest.raises((ValueError, Exception)): + estimator_tester.average_treatment_effect_testsuite( + num_common_causes=[1, 1], + num_instruments=[0, 0], + num_effect_modifiers=[0, 0], + num_treatments=num_treatments, + num_frontdoor_variables=num_frontdoor_variables, + treatment_is_binary=[True], + outcome_is_binary=[False], + confidence_intervals=[ + True, + ], + test_significance=[ + False, + ], + method_params={"num_simulations": 10, "num_null_simulations": 10}, + ) diff --git a/tests/causal_identifiers/base.py b/tests/causal_identifiers/base.py index 3cdbb110ca..d688182190 100644 --- a/tests/causal_identifiers/base.py +++ b/tests/causal_identifiers/base.py @@ -2,7 +2,7 @@ from dowhy.graph import build_graph_from_str -from .example_graphs import TEST_GRAPH_SOLUTIONS +from .example_graphs import TEST_FRONTDOOR_GRAPH_SOLUTIONS, TEST_GRAPH_SOLUTIONS class IdentificationTestGraphSolution(object): @@ -27,6 +27,27 @@ def __init__( ) +class IdentificationTestFrontdoorGraphSolution(object): + def __init__( + self, + graph_str, + observed_variables, + valid_frontdoor_sets, + invalid_frontdoor_sets, + ): + self.graph = build_graph_from_str(graph_str) + self.action_nodes = ["X"] + self.outcome_nodes = ["Y"] + self.observed_nodes = observed_variables + self.valid_frontdoor_sets = valid_frontdoor_sets + self.invalid_frontdoor_sets = invalid_frontdoor_sets + + @pytest.fixture(params=TEST_GRAPH_SOLUTIONS.keys()) def example_graph_solution(request): return IdentificationTestGraphSolution(**TEST_GRAPH_SOLUTIONS[request.param]) + + +@pytest.fixture(params=TEST_FRONTDOOR_GRAPH_SOLUTIONS.keys()) +def example_frontdoor_graph_solution(request): + return IdentificationTestFrontdoorGraphSolution(**TEST_FRONTDOOR_GRAPH_SOLUTIONS[request.param]) diff --git a/tests/causal_identifiers/example_graphs.py b/tests/causal_identifiers/example_graphs.py index d46c098830..d678dfd11a 100644 --- a/tests/causal_identifiers/example_graphs.py +++ b/tests/causal_identifiers/example_graphs.py @@ -387,3 +387,39 @@ direct_maximal_adjustment_sets=[{"W1", "M", "W2"}], ), } + + +TEST_FRONTDOOR_GRAPH_SOLUTIONS = { + "valid_singleton": dict( + graph_str="digraph {X; Y; Z; M1; X->M1; Z->X; Z->Y; M1->Y;}", + observed_variables=["X", "Y", "M1"], + valid_frontdoor_sets=[{"M1"}], + invalid_frontdoor_sets=[{"Z"}], + ), + "valid_doubleton": dict( + graph_str="digraph {X; Y; Z; M1; M2; X->M1; X->M2; Z->X; Z->Y; M1->Y; M2->Y}", + observed_variables=["X", "Y", "M1", "M2"], + valid_frontdoor_sets=[ + {"M1", "M2"}, + ], + invalid_frontdoor_sets=[{"Z"}, {"M1"}, {"M2"}], + ), + "no_frontdoor": dict( + graph_str="digraph{E;X;R;M;Y; E->X;E->R;X->M;R->M;M->Y}", + observed_variables=["X", "R", "M", "Y"], + valid_frontdoor_sets=[], + invalid_frontdoor_sets=[{"M"}, {"M", "E"}, {"M", "R"}], + ), + "no_frontdoor_simple": dict( + graph_str="digraph{X;Y;D;B;X->B;D->B;B->Y;D->Y}", + observed_variables=["X", "B", "Y"], + valid_frontdoor_sets=[], + invalid_frontdoor_sets=[{"B"}, {"D"}], + ), + "no_frontdoor_in_obs": dict( + graph_str="digraph {X; Y; Z; M1; M2; X->M1; X->M2; Z->X; Z->Y; M1->Y; M2->Y}", + observed_variables=["X", "Y", "M1", "Z"], + valid_frontdoor_sets=[], + invalid_frontdoor_sets=[{"Z"}, {"M1"}, {"M2"}, {"M1", "M2"}], + ), +} diff --git a/tests/causal_identifiers/test_frontdoor_identifier.py b/tests/causal_identifiers/test_frontdoor_identifier.py new file mode 100644 index 0000000000..a6409b3aa8 --- /dev/null +++ b/tests/causal_identifiers/test_frontdoor_identifier.py @@ -0,0 +1,50 @@ +import numpy as np +import pandas as pd +import pytest + +from dowhy import CausalModel +from dowhy.causal_identifier import AutoIdentifier, BackdoorAdjustment +from dowhy.causal_identifier.auto_identifier import identify_frontdoor +from dowhy.causal_identifier.identify_effect import EstimandType + +from .base import IdentificationTestFrontdoorGraphSolution, example_frontdoor_graph_solution + + +class TestFrontdoorIdentification(object): + def test_identify_frontdoor_functional_api( + self, example_frontdoor_graph_solution: IdentificationTestFrontdoorGraphSolution + ): + graph = example_frontdoor_graph_solution.graph + expected_sets = example_frontdoor_graph_solution.valid_frontdoor_sets + invalid_sets = example_frontdoor_graph_solution.invalid_frontdoor_sets + frontdoor_set = identify_frontdoor( + graph, + observed_nodes=example_frontdoor_graph_solution.observed_nodes, + action_nodes=["X"], + outcome_nodes=["Y"], + ) + + assert ( + (len(frontdoor_set) == 0) and (len(expected_sets) == 0) + ) or ( # No adjustments exist and that's expected. + set(frontdoor_set) in expected_sets and set(frontdoor_set) not in invalid_sets + ) + + def test_identify_frontdoor_causal_model( + self, example_frontdoor_graph_solution: IdentificationTestFrontdoorGraphSolution + ): + graph = example_frontdoor_graph_solution.graph + expected_sets = example_frontdoor_graph_solution.valid_frontdoor_sets + invalid_sets = example_frontdoor_graph_solution.invalid_frontdoor_sets + observed_nodes = example_frontdoor_graph_solution.observed_nodes + # Building the causal model + num_samples = 10 + df = pd.DataFrame(np.random.random((num_samples, len(observed_nodes))), columns=observed_nodes) + model = CausalModel(data=df, treatment="X", outcome="Y", graph=graph) + estimand = model.identify_effect() + frontdoor_set = estimand.frontdoor_variables + assert ( + (len(frontdoor_set) == 0) and (len(expected_sets) == 0) + ) or ( # No adjustments exist and that's expected. + (set(frontdoor_set) in expected_sets) and (set(frontdoor_set) not in invalid_sets) + )