Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug fix for frontdoor identification and a new set of tests #1093

Merged
merged 3 commits into from
Dec 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion dowhy/causal_estimators/two_stage_regression_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ def __init__(
)
self.logger.info("INFO: Using Two Stage Regression Estimator")
# Check if the treatment is one-dimensional
if len(self._target_estimand.treatment_variable) > 1:
error_msg = str(self.__class__) + "cannot handle more than one treatment variable"
raise Exception(error_msg)
modified_target_estimand = copy.deepcopy(self._target_estimand)
modified_target_estimand.identifier_method = "backdoor"
modified_target_estimand.backdoor_variables = self._target_estimand.mediation_first_stage_confounders
Expand Down Expand Up @@ -173,23 +176,32 @@ def fit(
if self._target_estimand.identifier_method == "frontdoor":
self.logger.debug("Front-door variable used:" + ",".join(self._target_estimand.get_frontdoor_variables()))
self._frontdoor_variables_names = self._target_estimand.get_frontdoor_variables()

if self._frontdoor_variables_names:
if len(self._frontdoor_variables_names) > 1:
raise ValueError(
"Only singleton frontdoor variables are supported for estimation using TwoStageRegression."
)
self._frontdoor_variables = data[self._frontdoor_variables_names]
else:
self._frontdoor_variables = None
error_msg = "No front-door variable present. Two stage regression is not applicable"
self.logger.error(error_msg)
raise ValueError(error_msg)
elif self._target_estimand.identifier_method == "mediation":
self.logger.debug("Mediators used:" + ",".join(self._target_estimand.get_mediator_variables()))
self._mediators_names = self._target_estimand.get_mediator_variables()

if self._mediators_names:
if len(self._mediators_names) > 1:
raise ValueError(
"Only singleton mediator variables are supported for estimation using TwoStageRegression."
)
self._mediators = data[self._mediators_names]
else:
self._mediators = None
error_msg = "No mediator variable present. Two stage regression is not applicable"
self.logger.error(error_msg)
raise ValueError(error_msg)
elif self._target_estimand.identifier_method == "iv":
self.logger.debug(
"Instrumental variables used:" + ",".join(self._target_estimand.get_instrumental_variables())
Expand Down
2 changes: 1 addition & 1 deletion dowhy/causal_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __init__(
import pydot

P_list = pydot.graph_from_dot_data(graph)
self._graph = nx.drawing.nx_pydot.from_pydot(P_list[0])
self._graph = nx.DiGraph(nx.drawing.nx_pydot.from_pydot(P_list[0]))
except Exception as e:
self.logger.error("Error: Pydot cannot be loaded. " + str(e))
raise e
Expand Down
100 changes: 51 additions & 49 deletions dowhy/causal_identifier/auto_identifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def identify_ate_effect(

### 3. FRONTDOOR IDENTIFICATION
# Now checking if there is a valid frontdoor variable
frontdoor_variables_names = identify_frontdoor(graph, action_nodes, outcome_nodes)
frontdoor_variables_names = identify_frontdoor(graph, action_nodes, outcome_nodes, observed_nodes)
logger.info("Frontdoor variables for treatment and outcome:" + str(frontdoor_variables_names))
if len(frontdoor_variables_names) > 0:
frontdoor_estimand_expr = construct_frontdoor_estimand(
Expand Down Expand Up @@ -777,18 +777,18 @@ def build_backdoor_estimands_dict(


def identify_frontdoor(
graph: nx.DiGraph, action_nodes: List[str], outcome_nodes: List[str], dseparation_algo: str = "default"
graph: nx.DiGraph,
action_nodes: List[str],
outcome_nodes: List[str],
observed_nodes: List[str],
dseparation_algo: str = "default",
):
"""Find a valid frontdoor variable if it exists.

Currently only supports a single variable frontdoor set.
"""
"""Find a valid frontdoor variable set if it exists."""
frontdoor_var = None
frontdoor_paths = None
fdoor_graph = None
if dseparation_algo == "default":
cond1_graph = do_surgery(graph, action_nodes, remove_incoming_edges=True)
bdoor_graph1 = do_surgery(graph, action_nodes, remove_outgoing_edges=True)
elif dseparation_algo == "naive":
frontdoor_paths = get_all_directed_paths(graph, action_nodes, outcome_nodes)
else:
Expand All @@ -797,48 +797,50 @@ def identify_frontdoor(
eligible_variables = (
get_descendants(graph, action_nodes) - set(outcome_nodes) - set(get_descendants(graph, outcome_nodes))
)
# For simplicity, assuming a one-variable frontdoor set
for candidate_var in eligible_variables:
# Cond 1: All directed paths intercepted by candidate_var
cond1 = check_valid_frontdoor_set(
graph,
action_nodes,
outcome_nodes,
parse_state(candidate_var),
frontdoor_paths=frontdoor_paths,
new_graph=cond1_graph,
dseparation_algo=dseparation_algo,
)
logger.debug("Candidate frontdoor set: {0}, is_dseparated: {1}".format(candidate_var, cond1))
if not cond1:
continue
# Cond 2: No confounding between treatment and candidate var
cond2 = check_valid_backdoor_set(
graph,
action_nodes,
parse_state(candidate_var),
set(),
backdoor_paths=None,
new_graph=bdoor_graph1,
dseparation_algo=dseparation_algo,
)
if not cond2:
continue
# Cond 3: treatment blocks all confounding between candidate_var and outcome
bdoor_graph2 = do_surgery(graph, candidate_var, remove_outgoing_edges=True)
cond3 = check_valid_backdoor_set(
graph,
parse_state(candidate_var),
outcome_nodes,
action_nodes,
backdoor_paths=None,
new_graph=bdoor_graph2,
dseparation_algo=dseparation_algo,
)
is_valid_frontdoor = cond1 and cond2 and cond3
if is_valid_frontdoor:
frontdoor_var = candidate_var
break
eligible_variables = eligible_variables.intersection(set(observed_nodes))
set_sizes = range(1, len(eligible_variables) + 1, 1)
for size_candidate_set in set_sizes:
for candidate_set in itertools.combinations(eligible_variables, size_candidate_set):
candidate_set = list(candidate_set)
# Cond 1: All directed paths intercepted by candidate_var
cond1 = check_valid_frontdoor_set(
graph,
action_nodes,
outcome_nodes,
candidate_set,
frontdoor_paths=frontdoor_paths,
new_graph=cond1_graph,
dseparation_algo=dseparation_algo,
)
logger.debug("Candidate frontdoor set: {0}, Cond1: is_dseparated: {1}".format(candidate_set, cond1))
if not cond1:
continue
# Cond 2: No confounding between treatment and candidate var
cond2 = check_valid_backdoor_set(
graph,
action_nodes,
candidate_set,
set(),
backdoor_paths=None,
dseparation_algo=dseparation_algo,
)["is_dseparated"]
if not cond2:
continue
# Cond 3: treatment blocks all confounding between candidate_var and outcome
bdoor_graph2 = do_surgery(graph, candidate_set, remove_outgoing_edges=True)
cond3 = check_valid_backdoor_set(
graph,
candidate_set,
outcome_nodes,
action_nodes,
backdoor_paths=None,
new_graph=bdoor_graph2,
dseparation_algo=dseparation_algo,
)["is_dseparated"]
is_valid_frontdoor = cond1 and cond2 and cond3
if is_valid_frontdoor:
frontdoor_var = candidate_set
break
return parse_state(frontdoor_var)


Expand Down
3 changes: 2 additions & 1 deletion dowhy/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def check_valid_backdoor_set(
if new_graph is None:
# Assume that nodes1 is the treatment
new_graph = do_surgery(graph, nodes1, remove_outgoing_edges=True)

dseparated = nx.algorithms.d_separated(new_graph, set(nodes1), set(nodes2), set(nodes3))
elif dseparation_algo == "naive":
# ignores new_graph parameter, always uses self._graph
Expand Down Expand Up @@ -417,7 +418,7 @@ def build_graph_from_str(graph_str: str) -> nx.DiGraph:
import pydot

P_list = pydot.graph_from_dot_data(graph_str)
return nx.drawing.nx_pydot.from_pydot(P_list[0])
return nx.DiGraph(nx.drawing.nx_pydot.from_pydot(P_list[0]))
except Exception as e:
_logger.error("Error: Pydot cannot be loaded. " + str(e))
raise e
Expand Down
34 changes: 34 additions & 0 deletions tests/causal_estimators/test_two_stage_regression_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,37 @@ def test_frontdoor_estimator(self):
# Estimate the effect with front-door
estimate = model.estimate_effect(identified_estimand=estimand, method_name="frontdoor.two_stage_regression")
assert estimate.value == pytest.approx(0.45, 0.025)

@mark.parametrize(
[
"Estimator",
"num_treatments",
"num_frontdoor_variables",
],
[
(
TwoStageRegressionEstimator,
[2, 1],
[1, 2],
)
],
)
def test_frontdoor_num_variables_error(self, Estimator, num_treatments, num_frontdoor_variables):
estimator_tester = TestEstimator(error_tolerance=0, Estimator=Estimator, identifier_method="frontdoor")
with pytest.raises((ValueError, Exception)):
estimator_tester.average_treatment_effect_testsuite(
num_common_causes=[1, 1],
num_instruments=[0, 0],
num_effect_modifiers=[0, 0],
num_treatments=num_treatments,
num_frontdoor_variables=num_frontdoor_variables,
treatment_is_binary=[True],
outcome_is_binary=[False],
confidence_intervals=[
True,
],
test_significance=[
False,
],
method_params={"num_simulations": 10, "num_null_simulations": 10},
)
23 changes: 22 additions & 1 deletion tests/causal_identifiers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from dowhy.graph import build_graph_from_str

from .example_graphs import TEST_GRAPH_SOLUTIONS
from .example_graphs import TEST_FRONTDOOR_GRAPH_SOLUTIONS, TEST_GRAPH_SOLUTIONS


class IdentificationTestGraphSolution(object):
Expand All @@ -27,6 +27,27 @@ def __init__(
)


class IdentificationTestFrontdoorGraphSolution(object):
def __init__(
self,
graph_str,
observed_variables,
valid_frontdoor_sets,
invalid_frontdoor_sets,
):
self.graph = build_graph_from_str(graph_str)
self.action_nodes = ["X"]
self.outcome_nodes = ["Y"]
self.observed_nodes = observed_variables
self.valid_frontdoor_sets = valid_frontdoor_sets
self.invalid_frontdoor_sets = invalid_frontdoor_sets


@pytest.fixture(params=TEST_GRAPH_SOLUTIONS.keys())
def example_graph_solution(request):
return IdentificationTestGraphSolution(**TEST_GRAPH_SOLUTIONS[request.param])


@pytest.fixture(params=TEST_FRONTDOOR_GRAPH_SOLUTIONS.keys())
def example_frontdoor_graph_solution(request):
return IdentificationTestFrontdoorGraphSolution(**TEST_FRONTDOOR_GRAPH_SOLUTIONS[request.param])
36 changes: 36 additions & 0 deletions tests/causal_identifiers/example_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,3 +387,39 @@
direct_maximal_adjustment_sets=[{"W1", "M", "W2"}],
),
}


TEST_FRONTDOOR_GRAPH_SOLUTIONS = {
"valid_singleton": dict(
graph_str="digraph {X; Y; Z; M1; X->M1; Z->X; Z->Y; M1->Y;}",
observed_variables=["X", "Y", "M1"],
valid_frontdoor_sets=[{"M1"}],
invalid_frontdoor_sets=[{"Z"}],
),
"valid_doubleton": dict(
graph_str="digraph {X; Y; Z; M1; M2; X->M1; X->M2; Z->X; Z->Y; M1->Y; M2->Y}",
observed_variables=["X", "Y", "M1", "M2"],
valid_frontdoor_sets=[
{"M1", "M2"},
],
invalid_frontdoor_sets=[{"Z"}, {"M1"}, {"M2"}],
),
"no_frontdoor": dict(
graph_str="digraph{E;X;R;M;Y; E->X;E->R;X->M;R->M;M->Y}",
observed_variables=["X", "R", "M", "Y"],
valid_frontdoor_sets=[],
invalid_frontdoor_sets=[{"M"}, {"M", "E"}, {"M", "R"}],
),
"no_frontdoor_simple": dict(
graph_str="digraph{X;Y;D;B;X->B;D->B;B->Y;D->Y}",
observed_variables=["X", "B", "Y"],
valid_frontdoor_sets=[],
invalid_frontdoor_sets=[{"B"}, {"D"}],
),
"no_frontdoor_in_obs": dict(
graph_str="digraph {X; Y; Z; M1; M2; X->M1; X->M2; Z->X; Z->Y; M1->Y; M2->Y}",
observed_variables=["X", "Y", "M1", "Z"],
valid_frontdoor_sets=[],
invalid_frontdoor_sets=[{"Z"}, {"M1"}, {"M2"}, {"M1", "M2"}],
),
}
50 changes: 50 additions & 0 deletions tests/causal_identifiers/test_frontdoor_identifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import numpy as np
import pandas as pd
import pytest

from dowhy import CausalModel
from dowhy.causal_identifier import AutoIdentifier, BackdoorAdjustment
from dowhy.causal_identifier.auto_identifier import identify_frontdoor
from dowhy.causal_identifier.identify_effect import EstimandType

from .base import IdentificationTestFrontdoorGraphSolution, example_frontdoor_graph_solution


class TestFrontdoorIdentification(object):
def test_identify_frontdoor_functional_api(
self, example_frontdoor_graph_solution: IdentificationTestFrontdoorGraphSolution
):
graph = example_frontdoor_graph_solution.graph
expected_sets = example_frontdoor_graph_solution.valid_frontdoor_sets
invalid_sets = example_frontdoor_graph_solution.invalid_frontdoor_sets
frontdoor_set = identify_frontdoor(
graph,
observed_nodes=example_frontdoor_graph_solution.observed_nodes,
action_nodes=["X"],
outcome_nodes=["Y"],
)

assert (
(len(frontdoor_set) == 0) and (len(expected_sets) == 0)
) or ( # No adjustments exist and that's expected.
set(frontdoor_set) in expected_sets and set(frontdoor_set) not in invalid_sets
)

def test_identify_frontdoor_causal_model(
self, example_frontdoor_graph_solution: IdentificationTestFrontdoorGraphSolution
):
graph = example_frontdoor_graph_solution.graph
expected_sets = example_frontdoor_graph_solution.valid_frontdoor_sets
invalid_sets = example_frontdoor_graph_solution.invalid_frontdoor_sets
observed_nodes = example_frontdoor_graph_solution.observed_nodes
# Building the causal model
num_samples = 10
df = pd.DataFrame(np.random.random((num_samples, len(observed_nodes))), columns=observed_nodes)
model = CausalModel(data=df, treatment="X", outcome="Y", graph=graph)
estimand = model.identify_effect()
frontdoor_set = estimand.frontdoor_variables
assert (
(len(frontdoor_set) == 0) and (len(expected_sets) == 0)
) or ( # No adjustments exist and that's expected.
(set(frontdoor_set) in expected_sets) and (set(frontdoor_set) not in invalid_sets)
)