Skip to content

Commit

Permalink
Bug fix for frontdoor identification and a new set of tests (#1093)
Browse files Browse the repository at this point in the history
* fixed frontdoor bug and added tests

Signed-off-by: Amit Sharma <amit_sharma@live.com>

* updated docstring

Signed-off-by: Amit Sharma <amit_sharma@live.com>

* reformatted file

Signed-off-by: Amit Sharma <amit_sharma@live.com>

---------

Signed-off-by: Amit Sharma <amit_sharma@live.com>
  • Loading branch information
amit-sharma authored Dec 3, 2023
1 parent 1d050f0 commit 048117c
Show file tree
Hide file tree
Showing 8 changed files with 209 additions and 53 deletions.
14 changes: 13 additions & 1 deletion dowhy/causal_estimators/two_stage_regression_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ def __init__(
)
self.logger.info("INFO: Using Two Stage Regression Estimator")
# Check if the treatment is one-dimensional
if len(self._target_estimand.treatment_variable) > 1:
error_msg = str(self.__class__) + "cannot handle more than one treatment variable"
raise Exception(error_msg)
modified_target_estimand = copy.deepcopy(self._target_estimand)
modified_target_estimand.identifier_method = "backdoor"
modified_target_estimand.backdoor_variables = self._target_estimand.mediation_first_stage_confounders
Expand Down Expand Up @@ -173,23 +176,32 @@ def fit(
if self._target_estimand.identifier_method == "frontdoor":
self.logger.debug("Front-door variable used:" + ",".join(self._target_estimand.get_frontdoor_variables()))
self._frontdoor_variables_names = self._target_estimand.get_frontdoor_variables()

if self._frontdoor_variables_names:
if len(self._frontdoor_variables_names) > 1:
raise ValueError(
"Only singleton frontdoor variables are supported for estimation using TwoStageRegression."
)
self._frontdoor_variables = data[self._frontdoor_variables_names]
else:
self._frontdoor_variables = None
error_msg = "No front-door variable present. Two stage regression is not applicable"
self.logger.error(error_msg)
raise ValueError(error_msg)
elif self._target_estimand.identifier_method == "mediation":
self.logger.debug("Mediators used:" + ",".join(self._target_estimand.get_mediator_variables()))
self._mediators_names = self._target_estimand.get_mediator_variables()

if self._mediators_names:
if len(self._mediators_names) > 1:
raise ValueError(
"Only singleton mediator variables are supported for estimation using TwoStageRegression."
)
self._mediators = data[self._mediators_names]
else:
self._mediators = None
error_msg = "No mediator variable present. Two stage regression is not applicable"
self.logger.error(error_msg)
raise ValueError(error_msg)
elif self._target_estimand.identifier_method == "iv":
self.logger.debug(
"Instrumental variables used:" + ",".join(self._target_estimand.get_instrumental_variables())
Expand Down
2 changes: 1 addition & 1 deletion dowhy/causal_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __init__(
import pydot

P_list = pydot.graph_from_dot_data(graph)
self._graph = nx.drawing.nx_pydot.from_pydot(P_list[0])
self._graph = nx.DiGraph(nx.drawing.nx_pydot.from_pydot(P_list[0]))
except Exception as e:
self.logger.error("Error: Pydot cannot be loaded. " + str(e))
raise e
Expand Down
100 changes: 51 additions & 49 deletions dowhy/causal_identifier/auto_identifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def identify_ate_effect(

### 3. FRONTDOOR IDENTIFICATION
# Now checking if there is a valid frontdoor variable
frontdoor_variables_names = identify_frontdoor(graph, action_nodes, outcome_nodes)
frontdoor_variables_names = identify_frontdoor(graph, action_nodes, outcome_nodes, observed_nodes)
logger.info("Frontdoor variables for treatment and outcome:" + str(frontdoor_variables_names))
if len(frontdoor_variables_names) > 0:
frontdoor_estimand_expr = construct_frontdoor_estimand(
Expand Down Expand Up @@ -777,18 +777,18 @@ def build_backdoor_estimands_dict(


def identify_frontdoor(
graph: nx.DiGraph, action_nodes: List[str], outcome_nodes: List[str], dseparation_algo: str = "default"
graph: nx.DiGraph,
action_nodes: List[str],
outcome_nodes: List[str],
observed_nodes: List[str],
dseparation_algo: str = "default",
):
"""Find a valid frontdoor variable if it exists.
Currently only supports a single variable frontdoor set.
"""
"""Find a valid frontdoor variable set if it exists."""
frontdoor_var = None
frontdoor_paths = None
fdoor_graph = None
if dseparation_algo == "default":
cond1_graph = do_surgery(graph, action_nodes, remove_incoming_edges=True)
bdoor_graph1 = do_surgery(graph, action_nodes, remove_outgoing_edges=True)
elif dseparation_algo == "naive":
frontdoor_paths = get_all_directed_paths(graph, action_nodes, outcome_nodes)
else:
Expand All @@ -797,48 +797,50 @@ def identify_frontdoor(
eligible_variables = (
get_descendants(graph, action_nodes) - set(outcome_nodes) - set(get_descendants(graph, outcome_nodes))
)
# For simplicity, assuming a one-variable frontdoor set
for candidate_var in eligible_variables:
# Cond 1: All directed paths intercepted by candidate_var
cond1 = check_valid_frontdoor_set(
graph,
action_nodes,
outcome_nodes,
parse_state(candidate_var),
frontdoor_paths=frontdoor_paths,
new_graph=cond1_graph,
dseparation_algo=dseparation_algo,
)
logger.debug("Candidate frontdoor set: {0}, is_dseparated: {1}".format(candidate_var, cond1))
if not cond1:
continue
# Cond 2: No confounding between treatment and candidate var
cond2 = check_valid_backdoor_set(
graph,
action_nodes,
parse_state(candidate_var),
set(),
backdoor_paths=None,
new_graph=bdoor_graph1,
dseparation_algo=dseparation_algo,
)
if not cond2:
continue
# Cond 3: treatment blocks all confounding between candidate_var and outcome
bdoor_graph2 = do_surgery(graph, candidate_var, remove_outgoing_edges=True)
cond3 = check_valid_backdoor_set(
graph,
parse_state(candidate_var),
outcome_nodes,
action_nodes,
backdoor_paths=None,
new_graph=bdoor_graph2,
dseparation_algo=dseparation_algo,
)
is_valid_frontdoor = cond1 and cond2 and cond3
if is_valid_frontdoor:
frontdoor_var = candidate_var
break
eligible_variables = eligible_variables.intersection(set(observed_nodes))
set_sizes = range(1, len(eligible_variables) + 1, 1)
for size_candidate_set in set_sizes:
for candidate_set in itertools.combinations(eligible_variables, size_candidate_set):
candidate_set = list(candidate_set)
# Cond 1: All directed paths intercepted by candidate_var
cond1 = check_valid_frontdoor_set(
graph,
action_nodes,
outcome_nodes,
candidate_set,
frontdoor_paths=frontdoor_paths,
new_graph=cond1_graph,
dseparation_algo=dseparation_algo,
)
logger.debug("Candidate frontdoor set: {0}, Cond1: is_dseparated: {1}".format(candidate_set, cond1))
if not cond1:
continue
# Cond 2: No confounding between treatment and candidate var
cond2 = check_valid_backdoor_set(
graph,
action_nodes,
candidate_set,
set(),
backdoor_paths=None,
dseparation_algo=dseparation_algo,
)["is_dseparated"]
if not cond2:
continue
# Cond 3: treatment blocks all confounding between candidate_var and outcome
bdoor_graph2 = do_surgery(graph, candidate_set, remove_outgoing_edges=True)
cond3 = check_valid_backdoor_set(
graph,
candidate_set,
outcome_nodes,
action_nodes,
backdoor_paths=None,
new_graph=bdoor_graph2,
dseparation_algo=dseparation_algo,
)["is_dseparated"]
is_valid_frontdoor = cond1 and cond2 and cond3
if is_valid_frontdoor:
frontdoor_var = candidate_set
break
return parse_state(frontdoor_var)


Expand Down
3 changes: 2 additions & 1 deletion dowhy/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def check_valid_backdoor_set(
if new_graph is None:
# Assume that nodes1 is the treatment
new_graph = do_surgery(graph, nodes1, remove_outgoing_edges=True)

dseparated = nx.algorithms.d_separated(new_graph, set(nodes1), set(nodes2), set(nodes3))
elif dseparation_algo == "naive":
# ignores new_graph parameter, always uses self._graph
Expand Down Expand Up @@ -417,7 +418,7 @@ def build_graph_from_str(graph_str: str) -> nx.DiGraph:
import pydot

P_list = pydot.graph_from_dot_data(graph_str)
return nx.drawing.nx_pydot.from_pydot(P_list[0])
return nx.DiGraph(nx.drawing.nx_pydot.from_pydot(P_list[0]))
except Exception as e:
_logger.error("Error: Pydot cannot be loaded. " + str(e))
raise e
Expand Down
34 changes: 34 additions & 0 deletions tests/causal_estimators/test_two_stage_regression_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,37 @@ def test_frontdoor_estimator(self):
# Estimate the effect with front-door
estimate = model.estimate_effect(identified_estimand=estimand, method_name="frontdoor.two_stage_regression")
assert estimate.value == pytest.approx(0.45, 0.025)

@mark.parametrize(
[
"Estimator",
"num_treatments",
"num_frontdoor_variables",
],
[
(
TwoStageRegressionEstimator,
[2, 1],
[1, 2],
)
],
)
def test_frontdoor_num_variables_error(self, Estimator, num_treatments, num_frontdoor_variables):
estimator_tester = TestEstimator(error_tolerance=0, Estimator=Estimator, identifier_method="frontdoor")
with pytest.raises((ValueError, Exception)):
estimator_tester.average_treatment_effect_testsuite(
num_common_causes=[1, 1],
num_instruments=[0, 0],
num_effect_modifiers=[0, 0],
num_treatments=num_treatments,
num_frontdoor_variables=num_frontdoor_variables,
treatment_is_binary=[True],
outcome_is_binary=[False],
confidence_intervals=[
True,
],
test_significance=[
False,
],
method_params={"num_simulations": 10, "num_null_simulations": 10},
)
23 changes: 22 additions & 1 deletion tests/causal_identifiers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from dowhy.graph import build_graph_from_str

from .example_graphs import TEST_GRAPH_SOLUTIONS
from .example_graphs import TEST_FRONTDOOR_GRAPH_SOLUTIONS, TEST_GRAPH_SOLUTIONS


class IdentificationTestGraphSolution(object):
Expand All @@ -27,6 +27,27 @@ def __init__(
)


class IdentificationTestFrontdoorGraphSolution(object):
def __init__(
self,
graph_str,
observed_variables,
valid_frontdoor_sets,
invalid_frontdoor_sets,
):
self.graph = build_graph_from_str(graph_str)
self.action_nodes = ["X"]
self.outcome_nodes = ["Y"]
self.observed_nodes = observed_variables
self.valid_frontdoor_sets = valid_frontdoor_sets
self.invalid_frontdoor_sets = invalid_frontdoor_sets


@pytest.fixture(params=TEST_GRAPH_SOLUTIONS.keys())
def example_graph_solution(request):
return IdentificationTestGraphSolution(**TEST_GRAPH_SOLUTIONS[request.param])


@pytest.fixture(params=TEST_FRONTDOOR_GRAPH_SOLUTIONS.keys())
def example_frontdoor_graph_solution(request):
return IdentificationTestFrontdoorGraphSolution(**TEST_FRONTDOOR_GRAPH_SOLUTIONS[request.param])
36 changes: 36 additions & 0 deletions tests/causal_identifiers/example_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,3 +387,39 @@
direct_maximal_adjustment_sets=[{"W1", "M", "W2"}],
),
}


TEST_FRONTDOOR_GRAPH_SOLUTIONS = {
"valid_singleton": dict(
graph_str="digraph {X; Y; Z; M1; X->M1; Z->X; Z->Y; M1->Y;}",
observed_variables=["X", "Y", "M1"],
valid_frontdoor_sets=[{"M1"}],
invalid_frontdoor_sets=[{"Z"}],
),
"valid_doubleton": dict(
graph_str="digraph {X; Y; Z; M1; M2; X->M1; X->M2; Z->X; Z->Y; M1->Y; M2->Y}",
observed_variables=["X", "Y", "M1", "M2"],
valid_frontdoor_sets=[
{"M1", "M2"},
],
invalid_frontdoor_sets=[{"Z"}, {"M1"}, {"M2"}],
),
"no_frontdoor": dict(
graph_str="digraph{E;X;R;M;Y; E->X;E->R;X->M;R->M;M->Y}",
observed_variables=["X", "R", "M", "Y"],
valid_frontdoor_sets=[],
invalid_frontdoor_sets=[{"M"}, {"M", "E"}, {"M", "R"}],
),
"no_frontdoor_simple": dict(
graph_str="digraph{X;Y;D;B;X->B;D->B;B->Y;D->Y}",
observed_variables=["X", "B", "Y"],
valid_frontdoor_sets=[],
invalid_frontdoor_sets=[{"B"}, {"D"}],
),
"no_frontdoor_in_obs": dict(
graph_str="digraph {X; Y; Z; M1; M2; X->M1; X->M2; Z->X; Z->Y; M1->Y; M2->Y}",
observed_variables=["X", "Y", "M1", "Z"],
valid_frontdoor_sets=[],
invalid_frontdoor_sets=[{"Z"}, {"M1"}, {"M2"}, {"M1", "M2"}],
),
}
50 changes: 50 additions & 0 deletions tests/causal_identifiers/test_frontdoor_identifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import numpy as np
import pandas as pd
import pytest

from dowhy import CausalModel
from dowhy.causal_identifier import AutoIdentifier, BackdoorAdjustment
from dowhy.causal_identifier.auto_identifier import identify_frontdoor
from dowhy.causal_identifier.identify_effect import EstimandType

from .base import IdentificationTestFrontdoorGraphSolution, example_frontdoor_graph_solution


class TestFrontdoorIdentification(object):
def test_identify_frontdoor_functional_api(
self, example_frontdoor_graph_solution: IdentificationTestFrontdoorGraphSolution
):
graph = example_frontdoor_graph_solution.graph
expected_sets = example_frontdoor_graph_solution.valid_frontdoor_sets
invalid_sets = example_frontdoor_graph_solution.invalid_frontdoor_sets
frontdoor_set = identify_frontdoor(
graph,
observed_nodes=example_frontdoor_graph_solution.observed_nodes,
action_nodes=["X"],
outcome_nodes=["Y"],
)

assert (
(len(frontdoor_set) == 0) and (len(expected_sets) == 0)
) or ( # No adjustments exist and that's expected.
set(frontdoor_set) in expected_sets and set(frontdoor_set) not in invalid_sets
)

def test_identify_frontdoor_causal_model(
self, example_frontdoor_graph_solution: IdentificationTestFrontdoorGraphSolution
):
graph = example_frontdoor_graph_solution.graph
expected_sets = example_frontdoor_graph_solution.valid_frontdoor_sets
invalid_sets = example_frontdoor_graph_solution.invalid_frontdoor_sets
observed_nodes = example_frontdoor_graph_solution.observed_nodes
# Building the causal model
num_samples = 10
df = pd.DataFrame(np.random.random((num_samples, len(observed_nodes))), columns=observed_nodes)
model = CausalModel(data=df, treatment="X", outcome="Y", graph=graph)
estimand = model.identify_effect()
frontdoor_set = estimand.frontdoor_variables
assert (
(len(frontdoor_set) == 0) and (len(expected_sets) == 0)
) or ( # No adjustments exist and that's expected.
(set(frontdoor_set) in expected_sets) and (set(frontdoor_set) not in invalid_sets)
)

0 comments on commit 048117c

Please sign in to comment.