diff --git a/docs/source/example_notebooks/do_sampler_demo.ipynb b/docs/source/example_notebooks/do_sampler_demo.ipynb index 73161be470..f6190021f6 100644 --- a/docs/source/example_notebooks/do_sampler_demo.ipynb +++ b/docs/source/example_notebooks/do_sampler_demo.ipynb @@ -45,7 +45,7 @@ "\n", "## Integration\n", "\n", - "The do-sampler is built on top of the identification abstraction used throughout do-why. It uses a `dowhy.CausalModel` to perform identification, and builds any models it needs automatically using this identification.\n", + "The do-sampler is built on top of the identification abstraction used throughout do-why. It automatically performs an identification, and builds any models it needs automatically using this identification.\n", "\n", "## Specifying Interventions\n", "\n", @@ -128,7 +128,8 @@ "model = CausalModel(df, \n", " causes,\n", " outcomes,\n", - " common_causes=common_causes)" + " common_causes=common_causes)\n", + "nx_graph = model._graph._graph" ] }, { @@ -162,8 +163,11 @@ "source": [ "from dowhy.do_samplers.weighting_sampler import WeightingSampler\n", "\n", - "sampler = WeightingSampler(df,\n", - " causal_model=model,\n", + "sampler = WeightingSampler(graph=nx_graph,\n", + " action_nodes=causes,\n", + " outcome_nodes=outcomes,\n", + " observed_nodes=df.columns.tolist(),\n", + " data=df,\n", " keep_original_treatment=True,\n", " variable_types={'D': 'b', 'Z': 'c', 'Y': 'c'}\n", " )\n", @@ -207,7 +211,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -221,7 +225,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.5" + "version": "3.8.10" }, "toc": { "base_numbering": 1, diff --git a/docs/source/example_notebooks/dowhy_causal_api.ipynb b/docs/source/example_notebooks/dowhy_causal_api.ipynb index decae0f41e..b4efa43590 100644 --- a/docs/source/example_notebooks/dowhy_causal_api.ipynb +++ b/docs/source/example_notebooks/dowhy_causal_api.ipynb @@ -16,6 +16,7 @@ "source": [ "import dowhy.datasets\n", "import dowhy.api\n", + "from dowhy.graph import build_graph_from_str\n", "\n", "import numpy as np\n", "import pandas as pd\n", @@ -36,7 +37,7 @@ " treatment_is_binary=True)\n", "df = data['df']\n", "df['y'] = df['y'] + np.random.normal(size=len(df)) # Adding noise to data. Without noise, the variance in Y|X, Z is zero, and mcmc fails.\n", - "#data['dot_graph'] = 'digraph { v ->y;X0-> v;X0-> y;}'\n", + "nx_graph = build_graph_from_str(data[\"dot_graph\"])\n", "\n", "treatment= data[\"treatment_name\"][0]\n", "outcome = data[\"outcome_name\"][0]\n", @@ -47,15 +48,17 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "scrolled": true + }, "outputs": [], "source": [ "# data['df'] is just a regular pandas.DataFrame\n", "df.causal.do(x=treatment,\n", - " variable_types={treatment: 'b', outcome: 'c', common_cause: 'c'},\n", - " outcome=outcome,\n", - " common_causes=[common_cause],\n", - " proceed_when_unidentifiable=True).groupby(treatment).mean().plot(y=outcome, kind='bar')" + " variable_types={treatment: 'b', outcome: 'c', common_cause: 'c'},\n", + " outcome=outcome,\n", + " common_causes=[common_cause],\n", + " ).groupby(treatment).mean().plot(y=outcome, kind='bar')" ] }, { @@ -68,8 +71,8 @@ " variable_types={treatment:'b', outcome: 'c', common_cause: 'c'}, \n", " outcome=outcome,\n", " method='weighting', \n", - " common_causes=[common_cause],\n", - " proceed_when_unidentifiable=True).groupby(treatment).mean().plot(y=outcome, kind='bar')" + " common_causes=[common_cause]\n", + " ).groupby(treatment).mean().plot(y=outcome, kind='bar')" ] }, { @@ -81,14 +84,14 @@ "cdf_1 = df.causal.do(x={treatment: 1}, \n", " variable_types={treatment: 'b', outcome: 'c', common_cause: 'c'}, \n", " outcome=outcome, \n", - " dot_graph=data['dot_graph'],\n", - " proceed_when_unidentifiable=True)\n", + " graph=nx_graph\n", + " )\n", "\n", "cdf_0 = df.causal.do(x={treatment: 0}, \n", " variable_types={treatment: 'b', outcome: 'c', common_cause: 'c'}, \n", " outcome=outcome, \n", - " dot_graph=data['dot_graph'],\n", - " proceed_when_unidentifiable=True)\n" + " graph=nx_graph\n", + " )\n" ] }, { @@ -158,7 +161,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -172,7 +175,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.5" + "version": "3.8.10" }, "toc": { "base_numbering": 1, diff --git a/docs/source/example_notebooks/dowhy_efficient_backdoor_example.ipynb b/docs/source/example_notebooks/dowhy_efficient_backdoor_example.ipynb index 4e96972930..4c6b7deb02 100644 --- a/docs/source/example_notebooks/dowhy_efficient_backdoor_example.ipynb +++ b/docs/source/example_notebooks/dowhy_efficient_backdoor_example.ipynb @@ -62,7 +62,9 @@ "outputs": [], "source": [ "from dowhy.causal_graph import CausalGraph\n", - "from dowhy.causal_identifier import AutoIdentifier, BackdoorAdjustment, EstimandType" + "from dowhy.causal_identifier import AutoIdentifier, BackdoorAdjustment, EstimandType\n", + "from dowhy.graph import build_graph_from_str\n", + "from dowhy.utils.plotting import plot" ] }, { @@ -135,9 +137,7 @@ "]\n", "treatment_name = \"warm-up\"\n", "outcome_name = \"injury\"\n", - "G = CausalGraph(\n", - " graph=graph_str, treatment_name=treatment_name, outcome_name=outcome_name, observed_node_names=observed_node_names\n", - ")" + "G = build_graph_from_str(graph_str)" ] }, { @@ -153,7 +153,7 @@ "metadata": {}, "outputs": [], "source": [ - "G.view_graph()" + "plot(G)" ] }, { @@ -184,7 +184,11 @@ ")\n", "print(\n", " ident_eff.identify_effect(\n", - " graph=G, treatment_name=treatment_name, outcome_name=outcome_name, conditional_node_names=conditional_node_names\n", + " graph=G, \n", + " action_nodes=treatment_name, \n", + " outcome_nodes=outcome_name,\n", + " observed_nodes=observed_node_names,\n", + " conditional_node_names=conditional_node_names\n", " )\n", ")" ] @@ -215,7 +219,11 @@ ")\n", "print(\n", " ident_minimal_eff.identify_effect(\n", - " graph=G, treatment_name=treatment_name, outcome_name=outcome_name, conditional_node_names=conditional_node_names\n", + " graph=G, \n", + " action_nodes=treatment_name, \n", + " outcome_nodes=outcome_name, \n", + " observed_nodes=observed_node_names,\n", + " conditional_node_names=conditional_node_names\n", " )\n", ")" ] @@ -239,7 +247,11 @@ ")\n", "print(\n", " ident_mincost_eff.identify_effect(\n", - " graph=G, treatment_name=treatment_name, outcome_name=outcome_name, conditional_node_names=conditional_node_names\n", + " graph=G, \n", + " action_nodes=treatment_name, \n", + " outcome_nodes=outcome_name,\n", + " observed_nodes=observed_node_names,\n", + " conditional_node_names=conditional_node_names\n", " )\n", ")" ] @@ -294,9 +306,7 @@ "observed_node_names = [\"X\", \"Y\", \"Z1\", \"Z2\"]\n", "treatment_name = \"X\"\n", "outcome_name = \"Y\"\n", - "G = CausalGraph(\n", - " graph=graph_str, treatment_name=treatment_name, outcome_name=outcome_name, observed_node_names=observed_node_names\n", - ")" + "G = build_graph_from_str(graph_str)" ] }, { @@ -317,7 +327,10 @@ " backdoor_adjustment=BackdoorAdjustment.BACKDOOR_EFFICIENT,\n", ")\n", "try:\n", - " results_eff = ident_eff.identify_effect(graph=G, treatment_name=treatment_name, outcome_name=outcome_name)\n", + " results_eff = ident_eff.identify_effect(graph=G, \n", + " action_nodes=treatment_name, \n", + " outcome_nodes=outcome_name,\n", + " observed_nodes=observed_node_names)\n", "except ValueError as e:\n", " print(e)" ] @@ -335,8 +348,9 @@ "print(\n", " ident_minimal_eff.identify_effect(\n", " graph=G,\n", - " treatment_name=treatment_name,\n", - " outcome_name=outcome_name,\n", + " action_nodes=treatment_name,\n", + " outcome_nodes=outcome_name,\n", + " observed_nodes=observed_node_names\n", " )\n", ")" ] @@ -354,8 +368,9 @@ "print(\n", " ident_mincost_eff.identify_effect(\n", " graph=G,\n", - " treatment_name=treatment_name,\n", - " outcome_name=outcome_name,\n", + " action_nodes=treatment_name,\n", + " outcome_nodes=outcome_name,\n", + " observed_nodes=observed_node_names\n", " )\n", ")" ] @@ -391,9 +406,7 @@ "observed_node_names = [\"X\", \"Y\"]\n", "treatment_name = \"X\"\n", "outcome_name = \"Y\"\n", - "G = CausalGraph(\n", - " graph=graph_str, treatment_name=treatment_name, outcome_name=outcome_name, observed_node_names=observed_node_names\n", - ")" + "G = build_graph_from_str(graph_str)" ] }, { @@ -409,8 +422,9 @@ "try:\n", " results_eff = ident_eff.identify_effect(\n", " graph=G,\n", - " treatment_name=treatment_name,\n", - " outcome_name=outcome_name,\n", + " action_nodes=treatment_name,\n", + " outcome_nodes=outcome_name,\n", + " observed_nodes=observed_node_names\n", " )\n", "except ValueError as e:\n", " print(e)" @@ -475,9 +489,7 @@ " (\"R\", {\"cost\": 2}),\n", " (\"T\", {\"cost\": 1}),\n", "]\n", - "G = CausalGraph(\n", - " graph=graph_str, treatment_name=treatment_name, outcome_name=outcome_name, observed_node_names=observed_node_names\n", - ")" + "G = build_graph_from_str(graph_str)" ] }, { @@ -504,7 +516,11 @@ ")\n", "print(\n", " ident_mincost_eff.identify_effect(\n", - " graph=G, treatment_name=treatment_name, outcome_name=outcome_name, conditional_node_names=conditional_node_names\n", + " graph=G, \n", + " action_nodes=treatment_name, \n", + " outcome_nodes=outcome_name, \n", + " observed_nodes=observed_node_names,\n", + " conditional_node_names=conditional_node_names\n", " )\n", ")" ] @@ -528,22 +544,19 @@ ")\n", "print(\n", " ident_minimal_eff.identify_effect(\n", - " graph=G, treatment_name=treatment_name, outcome_name=outcome_name, conditional_node_names=conditional_node_names\n", + " graph=G, \n", + " action_nodes=treatment_name,\n", + " outcome_nodes=outcome_name, \n", + " observed_nodes=observed_node_names,\n", + " conditional_node_names=conditional_node_names\n", " )\n", ")" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { "kernelspec": { - "display_name": "Python 3.8.10 ('dowhy-_zBapv7Q-py3.8')", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, diff --git a/docs/source/example_notebooks/dowhy_functional_api.ipynb b/docs/source/example_notebooks/dowhy_functional_api.ipynb index 84a998fd9e..06ebe5fcb0 100644 --- a/docs/source/example_notebooks/dowhy_functional_api.ipynb +++ b/docs/source/example_notebooks/dowhy_functional_api.ipynb @@ -48,6 +48,8 @@ "from dowhy.causal_estimators.econml import Econml\n", "from dowhy.causal_estimators.propensity_score_matching_estimator import PropensityScoreMatchingEstimator\n", "from dowhy.causal_graph import CausalGraph\n", + "from dowhy.graph import build_graph\n", + "\n", "# Functional API imports\n", "from dowhy.causal_identifier import (\n", " BackdoorAdjustment,\n", @@ -132,14 +134,13 @@ "outcome_name = data[\"outcome_name\"]\n", "print(outcome_name)\n", "\n", - "graph = CausalGraph(\n", - " treatment_name=treatment_name,\n", - " outcome_name=outcome_name,\n", - " graph=data[\"gml_graph\"],\n", - " effect_modifier_names=data[\"effect_modifier_names\"],\n", - " common_cause_names=data[\"common_causes_names\"],\n", - " observed_node_names=data[\"df\"].columns.tolist(),\n", - ")" + "graph = build_graph(\n", + " action_nodes=treatment_name,\n", + " outcome_nodes=outcome_name,\n", + " effect_modifier_nodes=data[\"effect_modifier_names\"],\n", + " common_cause_nodes=data[\"common_causes_names\"],\n", + ")\n", + "observed_nodes = data[\"df\"].columns.tolist()" ] }, { @@ -156,13 +157,14 @@ "outputs": [], "source": [ "# Default identify_effect call example:\n", - "identified_estimand = identify_effect(graph, treatment_name, outcome_name)\n", + "identified_estimand = identify_effect(graph, treatment_name, outcome_name, observed_nodes)\n", "\n", "# auto_identify_effect example with extra parameters:\n", "identified_estimand_auto = identify_effect_auto(\n", " graph,\n", " treatment_name,\n", " outcome_name,\n", + " observed_nodes,\n", " estimand_type=EstimandType.NONPARAMETRIC_ATE,\n", " backdoor_adjustment=BackdoorAdjustment.BACKDOOR_EFFICIENT,\n", ")\n", @@ -196,7 +198,7 @@ " confidence_intervals=False,\n", ").fit(\n", " data=data[\"df\"],\n", - " effect_modifier_names=graph.get_effect_modifiers(treatment_name, outcome_name),\n", + " effect_modifier_names=data[\"effect_modifier_names\"]\n", ")\n", "\n", "estimate = estimator.estimate_effect(\n", @@ -242,7 +244,7 @@ " ),\n", ").fit(\n", " data=data[\"df\"],\n", - " effect_modifier_names=graph.get_effect_modifiers(treatment_name, outcome_name),\n", + " effect_modifier_names=data[\"effect_modifier_names\"],\n", ")\n", "\n", "estimate_econml = estimator.estimate_effect(\n", @@ -377,7 +379,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3.8.10 ('dowhy-_zBapv7Q-py3.8')", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -391,7 +393,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.10 (default, Jun 22 2022, 20:18:18) \n[GCC 9.4.0]" + "version": "3.8.10" }, "toc": { "base_numbering": 1, diff --git a/docs/source/example_notebooks/lalonde_pandas_api.ipynb b/docs/source/example_notebooks/lalonde_pandas_api.ipynb index 1b2a44cc29..07ca3293fe 100644 --- a/docs/source/example_notebooks/lalonde_pandas_api.ipynb +++ b/docs/source/example_notebooks/lalonde_pandas_api.ipynb @@ -93,8 +93,8 @@ " outcome='re78',\n", " common_causes=['nodegr', 'black', 'hisp', 'age', 'educ', 'married'],\n", " variable_types={'age': 'c', 'educ':'c', 'black': 'd', 'hisp': 'd', \n", - " 'married': 'd', 'nodegr': 'd','re78': 'c', 'treat': 'b'},\n", - " proceed_when_unidentifiable=True)" + " 'married': 'd', 'nodegr': 'd','re78': 'c', 'treat': 'b'}\n", + " )" ] }, { @@ -265,8 +265,8 @@ " outcome='re78',\n", " common_causes=['nodegr', 'black', 'hisp', 'age', 'educ', 'married'],\n", " variable_types={'age': 'c', 'educ':'c', 'black': 'd', 'hisp': 'd', \n", - " 'married': 'd', 'nodegr': 'd','re78': 'c', 'treat': 'b'},\n", - " proceed_when_unidentifiable=True)" + " 'married': 'd', 'nodegr': 'd','re78': 'c', 'treat': 'b'}\n", + " )" ] }, { @@ -304,7 +304,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -318,7 +318,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.5" + "version": "3.8.10" }, "toc": { "base_numbering": 1, diff --git a/docs/source/example_notebooks/tutorial-causalinference-machinelearning-using-dowhy-econml.ipynb b/docs/source/example_notebooks/tutorial-causalinference-machinelearning-using-dowhy-econml.ipynb index b18c6cb084..20738f78c0 100644 --- a/docs/source/example_notebooks/tutorial-causalinference-machinelearning-using-dowhy-econml.ipynb +++ b/docs/source/example_notebooks/tutorial-causalinference-machinelearning-using-dowhy-econml.ipynb @@ -566,7 +566,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3.9.4 64-bit", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -580,7 +580,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.4" + "version": "3.8.10" }, "toc": { "base_numbering": 1, diff --git a/dowhy/api/causal_data_frame.py b/dowhy/api/causal_data_frame.py index 30d68137b7..67f627f9f3 100755 --- a/dowhy/api/causal_data_frame.py +++ b/dowhy/api/causal_data_frame.py @@ -1,7 +1,9 @@ +import networkx as nx import pandas as pd import dowhy.do_samplers as do_samplers -from dowhy.causal_model import CausalModel +from dowhy import EstimandType +from dowhy.graph import build_graph from dowhy.utils.api import parse_state @@ -14,7 +16,7 @@ def __init__(self, pandas_obj): :param pandas_obj: """ self._obj = pandas_obj - self._causal_model = None + self._graph = None self._sampler = None self._identified_estimand = None self._method = None @@ -25,7 +27,7 @@ def reset(self): :return: """ - self._causal_model = None + self._graph = None self._identified_estimand = None self._sampler = None self._method = None @@ -38,10 +40,9 @@ def do( variable_types={}, outcome=None, params=None, - dot_graph=None, + graph: nx.DiGraph = None, common_causes=None, - estimand_type="nonparametric-ate", - proceed_when_unidentifiable=False, + estimand_type=EstimandType.NONPARAMETRIC_ATE, stateful=False, ): """ @@ -92,18 +93,16 @@ def do( outcome = parse_state(outcome) if not stateful or method != self._method: self.reset() - if not self._causal_model: - self._causal_model = CausalModel( - self._obj, - [xi for xi in x.keys()], - outcome, - graph=dot_graph, - common_causes=common_causes, - instruments=None, - estimand_type=estimand_type, - proceed_when_unidentifiable=proceed_when_unidentifiable, + + if graph is None: + graph = build_graph( + action_nodes=[xi for xi in x.keys()], + outcome_nodes=outcome, + common_cause_nodes=common_causes, + effect_modifier_nodes=None, + instrument_nodes=None, + mediator_nodes=None, ) - # self._identified_estimand = self._causal_model.identify_effect() if not bool(variable_types): # check if the variables dictionary is empty variable_types = dict(self._obj.dtypes) # Convert the series containing data types to a dictionary @@ -125,15 +124,16 @@ def do( self._method = method do_sampler_class = do_samplers.get_class_object(method + "_sampler") self._sampler = do_sampler_class( - self._obj, - # self._identified_estimand, - # self._causal_model._treatment, - # self._causal_model._outcome, + graph, + observed_nodes=list(graph.nodes()), + action_nodes=[xi for xi in x.keys()], + outcome_nodes=outcome, + data=self._obj, params=params, variable_types=variable_types, num_cores=num_cores, - causal_model=self._causal_model, keep_original_treatment=keep_original_treatment, + estimand_type=estimand_type, ) result = self._sampler.do_sample(x) if not stateful: diff --git a/dowhy/causal_estimator.py b/dowhy/causal_estimator.py index fba234310b..03f8b70280 100755 --- a/dowhy/causal_estimator.py +++ b/dowhy/causal_estimator.py @@ -1,8 +1,7 @@ import copy import logging from collections import namedtuple -from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Union import numpy as np import pandas as pd @@ -711,7 +710,7 @@ def estimate_effect( elif identified_estimand.estimands[identifier_name] is None: logger.error("No valid identified estimand available.") return CausalEstimate( - None, None, None, None, None, control_value=control_value, treatment_value=treatment_value + None, None, None, None, None, None, control_value=control_value, treatment_value=treatment_value ) if fit_estimator: diff --git a/dowhy/causal_graph.py b/dowhy/causal_graph.py index 764e726635..64e780baae 100755 --- a/dowhy/causal_graph.py +++ b/dowhy/causal_graph.py @@ -54,6 +54,8 @@ def __init__( if graph is None: self._graph = nx.DiGraph() self._graph = self.build_graph(common_cause_names, instrument_names, effect_modifier_names, mediator_names) + elif isinstance(graph, nx.DiGraph): + self._graph = graph elif re.match(r".*\.dot", graph): # load dot file try: @@ -90,7 +92,9 @@ def __init__( elif re.match(".*graph\s*\[.*\]\s*", graph): self._graph = nx.DiGraph(nx.parse_gml(graph)) else: - self.logger.error("Error: Please provide graph (as string or text file) in dot or gml format.") + self.logger.error( + "Error: Please provide networkx graph or graph (as string or text file) in dot or gml format." + ) self.logger.error("Error: Incorrect graph format") raise ValueError if missing_nodes_as_confounders: diff --git a/dowhy/causal_identifier/auto_identifier.py b/dowhy/causal_identifier/auto_identifier.py index 77555c1643..ca3fb259a3 100644 --- a/dowhy/causal_identifier/auto_identifier.py +++ b/dowhy/causal_identifier/auto_identifier.py @@ -3,12 +3,24 @@ from enum import Enum from typing import Dict, List, Optional, Union +import networkx as nx import sympy as sp import sympy.stats as spstats -from dowhy.causal_graph import CausalGraph from dowhy.causal_identifier.efficient_backdoor import EfficientBackdoor from dowhy.causal_identifier.identified_estimand import IdentifiedEstimand +from dowhy.graph import ( + check_dseparation, + check_valid_backdoor_set, + check_valid_frontdoor_set, + check_valid_mediation_set, + do_surgery, + get_all_directed_paths, + get_backdoor_paths, + get_descendants, + get_instruments, + has_directed_path, +) from dowhy.utils.api import parse_state logger = logging.getLogger(__name__) @@ -26,7 +38,6 @@ class EstimandType(Enum): class BackdoorAdjustment(Enum): - # Backdoor method names BACKDOOR_DEFAULT = "default" BACKDOOR_EXHAUSTIVE = "exhaustive-search" @@ -70,36 +81,33 @@ def __init__( self, estimand_type: EstimandType, backdoor_adjustment: BackdoorAdjustment = BackdoorAdjustment.BACKDOOR_DEFAULT, - proceed_when_unidentifiable: bool = False, optimize_backdoor: bool = False, costs: Optional[List] = None, ): self.estimand_type = estimand_type self.backdoor_adjustment = backdoor_adjustment - self._proceed_when_unidentifiable = proceed_when_unidentifiable self.optimize_backdoor = optimize_backdoor self.costs = costs self.logger = logging.getLogger(__name__) def identify_effect( self, - graph: CausalGraph, - treatment_name: Union[str, List[str]], - outcome_name: Union[str, List[str]], + graph: nx.DiGraph, + action_nodes: Union[str, List[str]], + outcome_nodes: Union[str, List[str]], + observed_nodes: Union[str, List[str]], conditional_node_names: List[str] = None, - **kwargs, ): estimand = identify_effect_auto( graph, - treatment_name, - outcome_name, + action_nodes, + outcome_nodes, + observed_nodes, self.estimand_type, conditional_node_names, self.backdoor_adjustment, - self._proceed_when_unidentifiable, self.optimize_backdoor, self.costs, - **kwargs, ) estimand.identifier = self @@ -108,17 +116,19 @@ def identify_effect( def identify_backdoor( self, - graph: CausalGraph, - treatment_name: List[str], - outcome_name: str, + graph: nx.DiGraph, + action_nodes: List[str], + outcome_nodes: List[str], + observed_nodes: List[str], include_unobserved: bool = False, dseparation_algo: str = "default", direct_effect: bool = False, ): return identify_backdoor( graph, - treatment_name, - outcome_name, + action_nodes, + outcome_nodes, + observed_nodes, self.backdoor_adjustment, include_unobserved, dseparation_algo, @@ -127,16 +137,15 @@ def identify_backdoor( def identify_effect_auto( - graph: CausalGraph, - treatment_name: Union[str, List[str]], - outcome_name: Union[str, List[str]], + graph: nx.DiGraph, + action_nodes: Union[str, List[str]], + outcome_nodes: Union[str, List[str]], + observed_nodes: Union[str, List[str]], estimand_type: EstimandType, conditional_node_names: List[str] = None, backdoor_adjustment: BackdoorAdjustment = BackdoorAdjustment.BACKDOOR_DEFAULT, - proceed_when_unidentifiable: bool = False, optimize_backdoor: bool = False, costs: Optional[List] = None, - **kwargs, ) -> IdentifiedEstimand: """Main method that returns an identified estimand (if one exists). @@ -152,41 +161,42 @@ def identify_effect_auto( :returns: target estimand, an instance of the IdentifiedEstimand class """ - treatment_name = parse_state(treatment_name) - outcome_name = parse_state(outcome_name) + observed_nodes = parse_state(observed_nodes) + action_nodes = parse_state(action_nodes) + outcome_nodes = parse_state(outcome_nodes) # First, check if there is a directed path from action to outcome - if not graph.has_directed_path(treatment_name, outcome_name): + if not has_directed_path(graph, action_nodes, outcome_nodes): logger.warn("No directed path from treatment to outcome. Causal Effect is zero.") return IdentifiedEstimand( None, - treatment_variable=treatment_name, - outcome_variable=outcome_name, + treatment_variable=action_nodes, + outcome_variable=outcome_nodes, no_directed_path=True, ) if estimand_type == EstimandType.NONPARAMETRIC_ATE: return identify_ate_effect( graph, - treatment_name, - outcome_name, + action_nodes, + outcome_nodes, + observed_nodes, backdoor_adjustment, optimize_backdoor, estimand_type, costs, conditional_node_names, - proceed_when_unidentifiable, ) elif estimand_type == EstimandType.NONPARAMETRIC_NDE: return identify_nde_effect( - graph, treatment_name, outcome_name, backdoor_adjustment, estimand_type, proceed_when_unidentifiable + graph, action_nodes, outcome_nodes, observed_nodes, backdoor_adjustment, estimand_type ) elif estimand_type == EstimandType.NONPARAMETRIC_NIE: return identify_nie_effect( - graph, treatment_name, outcome_name, backdoor_adjustment, estimand_type, proceed_when_unidentifiable + graph, action_nodes, outcome_nodes, observed_nodes, backdoor_adjustment, estimand_type ) elif estimand_type == EstimandType.NONPARAMETRIC_CDE: return identify_cde_effect( - graph, treatment_name, outcome_name, backdoor_adjustment, estimand_type, proceed_when_unidentifiable + graph, action_nodes, outcome_nodes, observed_nodes, backdoor_adjustment, estimand_type ) else: raise ValueError( @@ -200,15 +210,15 @@ def identify_effect_auto( def identify_ate_effect( - graph: CausalGraph, - treatment_name: List[str], - outcome_name: str, + graph: nx.DiGraph, + action_nodes: List[str], + outcome_nodes: List[str], + observed_nodes: List[str], backdoor_adjustment: BackdoorAdjustment, optimize_backdoor: bool, estimand_type: EstimandType, costs: List, conditional_node_names: List[str] = None, - proceed_when_unidentifiable: bool = False, ): estimands_dict = {} mediation_first_stage_confounders = None @@ -218,21 +228,27 @@ def identify_ate_effect( if backdoor_adjustment not in EFFICIENT_METHODS: # First, checking if there are any valid backdoor adjustment sets if optimize_backdoor == False: - backdoor_sets = identify_backdoor(graph, treatment_name, outcome_name, backdoor_adjustment) + backdoor_sets = identify_backdoor(graph, action_nodes, outcome_nodes, observed_nodes, backdoor_adjustment) else: from dowhy.causal_identifier.backdoor import Backdoor - path = Backdoor(graph._graph, treatment_name, outcome_name) + path = Backdoor(graph, action_nodes, outcome_nodes) backdoor_sets = path.get_backdoor_vars() elif backdoor_adjustment in EFFICIENT_METHODS: backdoor_sets = identify_efficient_backdoor( - graph, backdoor_adjustment, costs, conditional_node_names=conditional_node_names + graph, + action_nodes, + outcome_nodes, + observed_nodes, + backdoor_adjustment, + costs, + conditional_node_names=conditional_node_names, ) estimands_dict, backdoor_variables_dict = build_backdoor_estimands_dict( - graph, treatment_name, outcome_name, backdoor_sets, estimands_dict + action_nodes, outcome_nodes, observed_nodes, backdoor_sets, estimands_dict ) # Setting default "backdoor" identification adjustment set - default_backdoor_id = get_default_backdoor_set_id(graph, treatment_name, outcome_name, backdoor_variables_dict) + default_backdoor_id = get_default_backdoor_set_id(graph, action_nodes, outcome_nodes, backdoor_variables_dict) if len(backdoor_variables_dict) > 0: estimands_dict["backdoor"] = estimands_dict.get(str(default_backdoor_id), None) backdoor_variables_dict["backdoor"] = backdoor_variables_dict.get(str(default_backdoor_id), None) @@ -240,12 +256,12 @@ def identify_ate_effect( estimands_dict["backdoor"] = None ### 2. INSTRUMENTAL VARIABLE IDENTIFICATION # Now checking if there is also a valid iv estimand - instrument_names = graph.get_instruments(treatment_name, outcome_name) + instrument_names = get_instruments(graph, action_nodes, outcome_nodes) logger.info("Instrumental variables for treatment and outcome:" + str(instrument_names)) if len(instrument_names) > 0: iv_estimand_expr = construct_iv_estimand( - treatment_name, - outcome_name, + action_nodes, + outcome_nodes, instrument_names, ) logger.debug("Identified expression = " + str(iv_estimand_expr)) @@ -255,21 +271,21 @@ def identify_ate_effect( ### 3. FRONTDOOR IDENTIFICATION # Now checking if there is a valid frontdoor variable - frontdoor_variables_names = identify_frontdoor(graph, treatment_name, outcome_name) + frontdoor_variables_names = identify_frontdoor(graph, action_nodes, outcome_nodes) logger.info("Frontdoor variables for treatment and outcome:" + str(frontdoor_variables_names)) if len(frontdoor_variables_names) > 0: frontdoor_estimand_expr = construct_frontdoor_estimand( - treatment_name, - outcome_name, + action_nodes, + outcome_nodes, frontdoor_variables_names, ) logger.debug("Identified expression = " + str(frontdoor_estimand_expr)) estimands_dict["frontdoor"] = frontdoor_estimand_expr mediation_first_stage_confounders = identify_mediation_first_stage_confounders( - graph, treatment_name, outcome_name, frontdoor_variables_names, backdoor_adjustment + graph, action_nodes, outcome_nodes, frontdoor_variables_names, observed_nodes, backdoor_adjustment ) mediation_second_stage_confounders = identify_mediation_second_stage_confounders( - graph, treatment_name, frontdoor_variables_names, outcome_name, backdoor_adjustment + graph, action_nodes, frontdoor_variables_names, outcome_nodes, observed_nodes, backdoor_adjustment ) else: estimands_dict["frontdoor"] = None @@ -277,8 +293,8 @@ def identify_ate_effect( # Finally returning the estimand object estimand = IdentifiedEstimand( None, - treatment_variable=treatment_name, - outcome_variable=outcome_name, + treatment_variable=action_nodes, + outcome_variable=outcome_nodes, estimand_type=estimand_type, estimands=estimands_dict, backdoor_variables=backdoor_variables_dict, @@ -292,12 +308,12 @@ def identify_ate_effect( def identify_cde_effect( - graph: CausalGraph, - treatment_name: List[str], - outcome_name: str, + graph: nx.DiGraph, + action_nodes: List[str], + outcome_nodes: List[str], + observed_nodes: List[str], backdoor_adjustment: BackdoorAdjustment, estimand_type: EstimandType, - proceed_when_unidentifiable: bool = False, ): """Identify controlled direct effect. For a definition, see Vanderwheele (2011). Controlled direct and mediated effects: definition, identification and bounds. @@ -310,12 +326,14 @@ def identify_cde_effect( """ estimands_dict = {} # Pick algorithm to compute backdoor sets according to method chosen - backdoor_sets = identify_backdoor(graph, treatment_name, outcome_name, backdoor_adjustment, direct_effect=True) + backdoor_sets = identify_backdoor( + graph, action_nodes, outcome_nodes, observed_nodes, backdoor_adjustment, direct_effect=True + ) estimands_dict, backdoor_variables_dict = build_backdoor_estimands_dict( - graph, treatment_name, outcome_name, backdoor_sets, estimands_dict + action_nodes, outcome_nodes, observed_nodes, backdoor_sets, estimands_dict ) # Setting default "backdoor" identification adjustment set - default_backdoor_id = get_default_backdoor_set_id(graph, treatment_name, outcome_name, backdoor_variables_dict) + default_backdoor_id = get_default_backdoor_set_id(graph, action_nodes, outcome_nodes, backdoor_variables_dict) if len(backdoor_variables_dict) > 0: estimands_dict["backdoor"] = estimands_dict.get(str(default_backdoor_id), None) backdoor_variables_dict["backdoor"] = backdoor_variables_dict.get(str(default_backdoor_id), None) @@ -325,8 +343,8 @@ def identify_cde_effect( # Finally returning the estimand object estimand = IdentifiedEstimand( None, - treatment_variable=treatment_name, - outcome_variable=outcome_name, + treatment_variable=action_nodes, + outcome_variable=outcome_nodes, estimand_type=estimand_type, estimands=estimands_dict, backdoor_variables=backdoor_variables_dict, @@ -340,22 +358,22 @@ def identify_cde_effect( def identify_nie_effect( - graph: CausalGraph, - treatment_name: List[str], - outcome_name: str, + graph: nx.DiGraph, + action_nodes: List[str], + outcome_nodes: List[str], + observed_nodes: List[str], backdoor_adjustment: BackdoorAdjustment, estimand_type: EstimandType, - proceed_when_unidentifiable: bool = False, ): estimands_dict = {} ### 1. FIRST DOING BACKDOOR IDENTIFICATION # First, checking if there are any valid backdoor adjustment sets - backdoor_sets = identify_backdoor(graph, treatment_name, outcome_name, backdoor_adjustment) + backdoor_sets = identify_backdoor(graph, action_nodes, outcome_nodes, observed_nodes, backdoor_adjustment) estimands_dict, backdoor_variables_dict = build_backdoor_estimands_dict( - graph, treatment_name, outcome_name, backdoor_sets, estimands_dict + action_nodes, outcome_nodes, observed_nodes, backdoor_sets, estimands_dict ) # Setting default "backdoor" identification adjustment set - default_backdoor_id = get_default_backdoor_set_id(graph, treatment_name, outcome_name, backdoor_variables_dict) + default_backdoor_id = get_default_backdoor_set_id(graph, action_nodes, outcome_nodes, backdoor_variables_dict) backdoor_variables_dict["backdoor"] = backdoor_variables_dict.get(str(default_backdoor_id), None) ### 2. SECOND, CHECKING FOR MEDIATORS @@ -363,30 +381,30 @@ def identify_nie_effect( estimands_dict = {} # Need to reinitialize this dictionary to avoid including the backdoor sets mediation_first_stage_confounders = None mediation_second_stage_confounders = None - mediators_names = identify_mediation(graph, treatment_name, outcome_name) + mediators_names = identify_mediation(graph, action_nodes, outcome_nodes) logger.info("Mediators for treatment and outcome:" + str(mediators_names)) if len(mediators_names) > 0: mediation_estimand_expr = construct_mediation_estimand( estimand_type, - treatment_name, - outcome_name, + action_nodes, + outcome_nodes, mediators_names, ) logger.debug("Identified expression = " + str(mediation_estimand_expr)) estimands_dict["mediation"] = mediation_estimand_expr mediation_first_stage_confounders = identify_mediation_first_stage_confounders( - graph, treatment_name, outcome_name, mediators_names, backdoor_adjustment + graph, action_nodes, outcome_nodes, mediators_names, observed_nodes, backdoor_adjustment ) mediation_second_stage_confounders = identify_mediation_second_stage_confounders( - graph, treatment_name, mediators_names, outcome_name, backdoor_adjustment + graph, action_nodes, mediators_names, outcome_nodes, observed_nodes, backdoor_adjustment ) else: estimands_dict["mediation"] = None # Finally returning the estimand object estimand = IdentifiedEstimand( None, - treatment_variable=treatment_name, - outcome_variable=outcome_name, + treatment_variable=action_nodes, + outcome_variable=outcome_nodes, estimand_type=estimand_type, estimands=estimands_dict, backdoor_variables=backdoor_variables_dict, @@ -401,22 +419,22 @@ def identify_nie_effect( def identify_nde_effect( - graph: CausalGraph, - treatment_name: List[str], - outcome_name: str, + graph: nx.DiGraph, + action_nodes: List[str], + outcome_nodes: List[str], + observed_nodes: List[str], backdoor_adjustment: BackdoorAdjustment, estimand_type: EstimandType, - proceed_when_unidentifiable: bool = False, ): estimands_dict = {} ### 1. FIRST DOING BACKDOOR IDENTIFICATION # First, checking if there are any valid backdoor adjustment sets - backdoor_sets = identify_backdoor(graph, treatment_name, outcome_name, backdoor_adjustment) + backdoor_sets = identify_backdoor(graph, action_nodes, outcome_nodes, observed_nodes, backdoor_adjustment) estimands_dict, backdoor_variables_dict = build_backdoor_estimands_dict( - graph, treatment_name, outcome_name, backdoor_sets, estimands_dict + action_nodes, outcome_nodes, observed_nodes, backdoor_sets, estimands_dict ) # Setting default "backdoor" identification adjustment set - default_backdoor_id = get_default_backdoor_set_id(graph, treatment_name, outcome_name, backdoor_variables_dict) + default_backdoor_id = get_default_backdoor_set_id(graph, action_nodes, outcome_nodes, backdoor_variables_dict) backdoor_variables_dict["backdoor"] = backdoor_variables_dict.get(str(default_backdoor_id), None) ### 2. SECOND, CHECKING FOR MEDIATORS @@ -424,30 +442,30 @@ def identify_nde_effect( estimands_dict = {} mediation_first_stage_confounders = None mediation_second_stage_confounders = None - mediators_names = identify_mediation(graph, treatment_name, outcome_name) + mediators_names = identify_mediation(graph, action_nodes, outcome_nodes) logger.info("Mediators for treatment and outcome:" + str(mediators_names)) if len(mediators_names) > 0: mediation_estimand_expr = construct_mediation_estimand( estimand_type, - treatment_name, - outcome_name, + action_nodes, + outcome_nodes, mediators_names, ) logger.debug("Identified expression = " + str(mediation_estimand_expr)) estimands_dict["mediation"] = mediation_estimand_expr mediation_first_stage_confounders = identify_mediation_first_stage_confounders( - graph, treatment_name, outcome_name, mediators_names, backdoor_adjustment + graph, action_nodes, outcome_nodes, mediators_names, observed_nodes, backdoor_adjustment ) mediation_second_stage_confounders = identify_mediation_second_stage_confounders( - graph, treatment_name, mediators_names, outcome_name, backdoor_adjustment + graph, action_nodes, mediators_names, outcome_nodes, observed_nodes, backdoor_adjustment ) else: estimands_dict["mediation"] = None # Finally returning the estimand object estimand = IdentifiedEstimand( None, - treatment_variable=treatment_name, - outcome_variable=outcome_name, + treatment_variable=action_nodes, + outcome_variable=outcome_nodes, estimand_type=estimand_type, estimands=estimands_dict, backdoor_variables=backdoor_variables_dict, @@ -462,9 +480,10 @@ def identify_nde_effect( def identify_backdoor( - graph: CausalGraph, - treatment_name: List[str], - outcome_name: str, + graph: nx.DiGraph, + action_nodes: List[str], + outcome_nodes: List[str], + observed_nodes: List[str], backdoor_adjustment: BackdoorAdjustment, include_unobserved: bool = False, dseparation_algo: str = "default", @@ -473,12 +492,14 @@ def identify_backdoor( backdoor_sets = [] backdoor_paths = None bdoor_graph = None + observed_nodes = set(observed_nodes) if dseparation_algo == "naive": - backdoor_paths = graph.get_backdoor_paths(treatment_name, outcome_name) + backdoor_paths = get_backdoor_paths(graph, action_nodes, outcome_nodes) elif dseparation_algo == "default": - bdoor_graph = graph.do_surgery( - treatment_name, - target_node_names=outcome_name, + bdoor_graph = do_surgery( + graph, + action_nodes, + target_node_names=outcome_nodes, remove_outgoing_edges=True, remove_only_direct_edges_to_target=direct_effect, ) @@ -490,9 +511,10 @@ def identify_backdoor( # First, checking if empty set is a valid backdoor set empty_set = set() - check = graph.check_valid_backdoor_set( - treatment_name, - outcome_name, + check = check_valid_backdoor_set( + graph, + action_nodes, + outcome_nodes, empty_set, backdoor_paths=backdoor_paths, new_graph=bdoor_graph, @@ -506,28 +528,32 @@ def identify_backdoor( # Second, checking for all other sets of variables. If include_unobserved is false, then only observed variables are eligible. eligible_variables = ( - graph.get_all_nodes(include_unobserved=include_unobserved) - set(treatment_name) - set(outcome_name) + set([node for node in graph.nodes if include_unobserved or node in observed_nodes]) + - set(action_nodes) + - set(outcome_nodes) ) + if direct_effect: # only remove descendants of Y # also allow any causes of Y that are not caused by T (for lower variance) - eligible_variables -= graph.get_descendants(outcome_name) + eligible_variables -= get_descendants(graph, outcome_nodes) else: # remove descendants of T (mediators) and descendants of Y - eligible_variables -= graph.get_descendants(treatment_name) + eligible_variables -= get_descendants(graph, action_nodes) # If var is d-separated from both treatment or outcome, it cannot # be a part of the backdoor set filt_eligible_variables = set() for var in eligible_variables: - dsep_treat_var = graph.check_dseparation(treatment_name, parse_state(var), set()) - dsep_outcome_var = graph.check_dseparation(outcome_name, parse_state(var), set()) + dsep_treat_var = check_dseparation(graph, action_nodes, parse_state(var), set()) + dsep_outcome_var = check_dseparation(graph, outcome_nodes, parse_state(var), set()) if not dsep_outcome_var or not dsep_treat_var: filt_eligible_variables.add(var) if backdoor_adjustment in METHOD_NAMES: backdoor_sets, found_valid_adjustment_set = find_valid_adjustment_sets( graph, - treatment_name, - outcome_name, + action_nodes, + outcome_nodes, + observed_nodes, backdoor_paths, bdoor_graph, dseparation_algo, @@ -540,8 +566,9 @@ def identify_backdoor( # repeat the above search with BACKDOOR_MIN backdoor_sets, _ = find_valid_adjustment_sets( graph, - treatment_name, - outcome_name, + action_nodes, + outcome_nodes, + observed_nodes, backdoor_paths, bdoor_graph, dseparation_algo, @@ -558,7 +585,10 @@ def identify_backdoor( def identify_efficient_backdoor( - graph: CausalGraph, + graph: nx.DiGraph, + action_nodes: List[str], + outcome_nodes: List[str], + observed_nodes: List[str], backdoor_adjustment: BackdoorAdjustment, costs: List, conditional_node_names: List[str] = None, @@ -606,6 +636,9 @@ def identify_efficient_backdoor( logger.warning("No costs were passed, so they will be assumed to be constant and equal to 1.") efficient_bd = EfficientBackdoor( graph=graph, + action_nodes=action_nodes, + outcome_nodes=outcome_nodes, + observed_nodes=observed_nodes, conditional_node_names=conditional_node_names, costs=costs, ) @@ -622,11 +655,12 @@ def identify_efficient_backdoor( def find_valid_adjustment_sets( - graph: CausalGraph, - treatment_name: List, - outcome_name: List, + graph: nx.DiGraph, + action_nodes: List[str], + outcome_nodes: List[str], + observed_nodes: List[str], backdoor_paths: List, - bdoor_graph: CausalGraph, + bdoor_graph: nx.DiGraph, dseparation_algo: str, backdoor_sets: List, filt_eligible_variables: List, @@ -635,7 +669,7 @@ def find_valid_adjustment_sets( ): num_iterations = 0 found_valid_adjustment_set = False - all_nodes_observed = graph.all_observed(graph.get_all_nodes()) + is_all_observed = set(graph.nodes) == set(observed_nodes) # If `minimal-adjustment` method is specified, start the search from the set with minimum size. Otherwise, start from the largest. set_sizes = ( range(1, len(filt_eligible_variables) + 1, 1) @@ -644,9 +678,10 @@ def find_valid_adjustment_sets( ) for size_candidate_set in set_sizes: for candidate_set in itertools.combinations(filt_eligible_variables, size_candidate_set): - check = graph.check_valid_backdoor_set( - treatment_name, - outcome_name, + check = check_valid_backdoor_set( + graph, + action_nodes, + outcome_nodes, candidate_set, backdoor_paths=backdoor_paths, new_graph=bdoor_graph, @@ -677,7 +712,7 @@ def find_valid_adjustment_sets( # does not satisfy backdoor, then none of its subsets will. if ( backdoor_adjustment in {BackdoorAdjustment.BACKDOOR_DEFAULT, BackdoorAdjustment.BACKDOOR_MAX} - and all_nodes_observed + and is_all_observed ): break if num_iterations > max_iterations: @@ -687,14 +722,14 @@ def find_valid_adjustment_sets( def get_default_backdoor_set_id( - graph: CausalGraph, treatment_name: List[str], outcome_name: List[str], backdoor_sets_dict: Dict + graph: nx.DiGraph, action_nodes: List[str], outcome_nodes: List[str], backdoor_sets_dict: Dict ): # Adding a None estimand if no backdoor set found if len(backdoor_sets_dict) == 0: return None # Default set contains minimum possible number of instrumental variables, to prevent lowering variance in the treatment variable. - instrument_names = set(graph.get_instruments(treatment_name, outcome_name)) + instrument_names = set(get_instruments(graph, action_nodes, outcome_nodes)) iv_count_dict = { key: len(set(bdoor_set).intersection(instrument_names)) for key, bdoor_set in backdoor_sets_dict.items() } @@ -713,27 +748,28 @@ def get_default_backdoor_set_id( def build_backdoor_estimands_dict( - graph: CausalGraph, - treatment_name: List[str], - outcome_name: List[str], + treatment_names: List[str], + outcome_names: List[str], + observed_nodes: List[str], backdoor_sets: List[str], estimands_dict: Dict, ): """Build the final dict for backdoor sets by filtering unobserved variables if needed.""" backdoor_variables_dict = {} - is_identified = [graph.all_observed(bset["backdoor_set"]) for bset in backdoor_sets] + observed_nodes = set(observed_nodes) + is_identified = [set(bset["backdoor_set"]).issubset(observed_nodes) for bset in backdoor_sets] if any(is_identified): logger.info("Causal effect can be identified.") backdoor_sets_arr = [ - list(bset["backdoor_set"]) for bset in backdoor_sets if graph.all_observed(bset["backdoor_set"]) + list(bset["backdoor_set"]) for bset in backdoor_sets if set(bset["backdoor_set"]).issubset(observed_nodes) ] else: # there is unobserved confounding logger.warning("Backdoor identification failed.") backdoor_sets_arr = [] for i in range(len(backdoor_sets_arr)): - backdoor_estimand_expr = construct_backdoor_estimand(treatment_name, outcome_name, backdoor_sets_arr[i]) + backdoor_estimand_expr = construct_backdoor_estimand(treatment_names, outcome_names, backdoor_sets_arr[i]) logger.debug("Identified expression = " + str(backdoor_estimand_expr)) estimands_dict["backdoor" + str(i + 1)] = backdoor_estimand_expr backdoor_variables_dict["backdoor" + str(i + 1)] = backdoor_sets_arr[i] @@ -741,7 +777,7 @@ def build_backdoor_estimands_dict( def identify_frontdoor( - graph: CausalGraph, treatment_name: List[str], outcome_name: List[str], dseparation_algo: str = "default" + graph: nx.DiGraph, action_nodes: List[str], outcome_nodes: List[str], dseparation_algo: str = "default" ): """Find a valid frontdoor variable if it exists. @@ -751,22 +787,23 @@ def identify_frontdoor( frontdoor_paths = None fdoor_graph = None if dseparation_algo == "default": - cond1_graph = graph.do_surgery(treatment_name, remove_incoming_edges=True) - bdoor_graph1 = graph.do_surgery(treatment_name, remove_outgoing_edges=True) + cond1_graph = do_surgery(graph, action_nodes, remove_incoming_edges=True) + bdoor_graph1 = do_surgery(graph, action_nodes, remove_outgoing_edges=True) elif dseparation_algo == "naive": - frontdoor_paths = graph.get_all_directed_paths(treatment_name, outcome_name) + frontdoor_paths = get_all_directed_paths(graph, action_nodes, outcome_nodes) else: raise ValueError(f"d-separation algorithm {dseparation_algo} is not supported") eligible_variables = ( - graph.get_descendants(treatment_name) - set(outcome_name) - set(graph.get_descendants(outcome_name)) + get_descendants(graph, action_nodes) - set(outcome_nodes) - set(get_descendants(graph, outcome_nodes)) ) # For simplicity, assuming a one-variable frontdoor set for candidate_var in eligible_variables: # Cond 1: All directed paths intercepted by candidate_var - cond1 = graph.check_valid_frontdoor_set( - treatment_name, - outcome_name, + cond1 = check_valid_frontdoor_set( + graph, + action_nodes, + outcome_nodes, parse_state(candidate_var), frontdoor_paths=frontdoor_paths, new_graph=cond1_graph, @@ -776,8 +813,9 @@ def identify_frontdoor( if not cond1: continue # Cond 2: No confounding between treatment and candidate var - cond2 = graph.check_valid_backdoor_set( - treatment_name, + cond2 = check_valid_backdoor_set( + graph, + action_nodes, parse_state(candidate_var), set(), backdoor_paths=None, @@ -787,11 +825,12 @@ def identify_frontdoor( if not cond2: continue # Cond 3: treatment blocks all confounding between candidate_var and outcome - bdoor_graph2 = graph.do_surgery(candidate_var, remove_outgoing_edges=True) - cond3 = graph.check_valid_backdoor_set( + bdoor_graph2 = do_surgery(graph, candidate_var, remove_outgoing_edges=True) + cond3 = check_valid_backdoor_set( + graph, parse_state(candidate_var), - outcome_name, - treatment_name, + outcome_nodes, + action_nodes, backdoor_paths=None, new_graph=bdoor_graph2, dseparation_algo=dseparation_algo, @@ -803,19 +842,20 @@ def identify_frontdoor( return parse_state(frontdoor_var) -def identify_mediation(graph: CausalGraph, treatment_name: List[str], outcome_name: List[str]): +def identify_mediation(graph: nx.DiGraph, action_nodes: List[str], outcome_nodes: List[str]): """Find a valid mediator if it exists. Currently only supports a single variable mediator set. """ mediation_var = None - mediation_paths = graph.get_all_directed_paths(treatment_name, outcome_name) - eligible_variables = graph.get_descendants(treatment_name) - set(outcome_name) + mediation_paths = get_all_directed_paths(graph, action_nodes, outcome_nodes) + eligible_variables = get_descendants(graph, action_nodes) - set(outcome_nodes) # For simplicity, assuming a one-variable mediation set for candidate_var in eligible_variables: - is_valid_mediation = graph.check_valid_mediation_set( - treatment_name, - outcome_name, + is_valid_mediation = check_valid_mediation_set( + graph, + action_nodes, + outcome_nodes, parse_state(candidate_var), mediation_paths=mediation_paths, ) @@ -827,48 +867,50 @@ def identify_mediation(graph: CausalGraph, treatment_name: List[str], outcome_na def identify_mediation_first_stage_confounders( - graph: CausalGraph, - treatment_name: List[str], - outcome_name: List[str], - mediators_names: List[str], + graph: nx.DiGraph, + action_nodes: List[str], + outcome_nodes: List[str], + mediator_nodes: List[str], + observed_nodes: List[str], backdoor_adjustment: BackdoorAdjustment, ): # Create estimands dict as per the API for backdoor, but do not return it estimands_dict = {} - backdoor_sets = identify_backdoor(graph, treatment_name, mediators_names, backdoor_adjustment) + backdoor_sets = identify_backdoor(graph, action_nodes, mediator_nodes, observed_nodes, backdoor_adjustment) estimands_dict, backdoor_variables_dict = build_backdoor_estimands_dict( - graph, - treatment_name, - mediators_names, + action_nodes, + mediator_nodes, + observed_nodes, backdoor_sets, estimands_dict, ) # Setting default "backdoor" identification adjustment set - default_backdoor_id = get_default_backdoor_set_id(graph, treatment_name, outcome_name, backdoor_variables_dict) + default_backdoor_id = get_default_backdoor_set_id(graph, action_nodes, outcome_nodes, backdoor_variables_dict) estimands_dict["backdoor"] = estimands_dict.get(str(default_backdoor_id), None) backdoor_variables_dict["backdoor"] = backdoor_variables_dict.get(str(default_backdoor_id), None) return backdoor_variables_dict def identify_mediation_second_stage_confounders( - graph: CausalGraph, - treatment_name: List[str], - mediators_names: List[str], - outcome_name: List[str], + graph: nx.DiGraph, + action_nodes: List[str], + mediator_nodes: List[str], + outcome_nodes: List[str], + observed_nodes: List[str], backdoor_adjustment: BackdoorAdjustment, ): # Create estimands dict as per the API for backdoor, but do not return it estimands_dict = {} - backdoor_sets = identify_backdoor(graph, mediators_names, outcome_name, backdoor_adjustment) + backdoor_sets = identify_backdoor(graph, mediator_nodes, outcome_nodes, observed_nodes, backdoor_adjustment) estimands_dict, backdoor_variables_dict = build_backdoor_estimands_dict( - graph, - mediators_names, - outcome_name, + mediator_nodes, + outcome_nodes, + observed_nodes, backdoor_sets, estimands_dict, ) # Setting default "backdoor" identification adjustment set - default_backdoor_id = get_default_backdoor_set_id(graph, treatment_name, outcome_name, backdoor_variables_dict) + default_backdoor_id = get_default_backdoor_set_id(graph, action_nodes, outcome_nodes, backdoor_variables_dict) estimands_dict["backdoor"] = estimands_dict.get(str(default_backdoor_id), None) backdoor_variables_dict["backdoor"] = backdoor_variables_dict.get(str(default_backdoor_id), None) return backdoor_variables_dict @@ -966,7 +1008,7 @@ def construct_frontdoor_estimand( def construct_mediation_estimand( - estimand_type: EstimandType, treatment_name: List[str], outcome_name: List[str], mediators_names: List[str] + estimand_type: EstimandType, action_nodes: List[str], outcome_nodes: List[str], mediator_nodes: List[str] ): # TODO: support multivariate treatments better. expr = None @@ -974,18 +1016,18 @@ def construct_mediation_estimand( EstimandType.NONPARAMETRIC_NDE, EstimandType.NONPARAMETRIC_NIE, ): - outcome_name = outcome_name[0] - sym_outcome = spstats.Normal(outcome_name, 0, 1) - sym_treatment_symbols = [spstats.Normal(t, 0, 1) for t in treatment_name] + outcome_nodes = outcome_nodes[0] + sym_outcome = spstats.Normal(outcome_nodes, 0, 1) + sym_treatment_symbols = [spstats.Normal(t, 0, 1) for t in action_nodes] sym_treatment = sp.Array(sym_treatment_symbols) - sym_mediators_symbols = [sp.Symbol(inst) for inst in mediators_names] + sym_mediators_symbols = [sp.Symbol(inst) for inst in mediator_nodes] sym_mediators = sp.Array(sym_mediators_symbols) sym_outcome_derivative = sp.Derivative(sym_outcome, sym_mediators) sym_treatment_derivative = sp.Derivative(sym_mediators, sym_treatment) # For direct effect - num_expr_str = outcome_name - if len(mediators_names) > 0: - num_expr_str += "|" + ",".join(mediators_names) + num_expr_str = outcome_nodes + if len(mediator_nodes) > 0: + num_expr_str += "|" + ",".join(mediator_nodes) sym_mu = sp.Symbol("mu") sym_sigma = sp.Symbol("sigma", positive=True) sym_conditional_outcome = spstats.Normal(num_expr_str, sym_mu, sym_sigma) @@ -998,17 +1040,17 @@ def construct_mediation_estimand( "Mediation": ( "{2} intercepts (blocks) all directed paths from {0} to {1} except the path {{{0}}}\N{RIGHTWARDS ARROW}{{{1}}}." ).format( - ",".join(treatment_name), - ",".join(outcome_name), - ",".join(mediators_names), + ",".join(action_nodes), + ",".join(outcome_nodes), + ",".join(mediator_nodes), ), "First-stage-unconfoundedness": ( "If U\N{RIGHTWARDS ARROW}{{{0}}} and U\N{RIGHTWARDS ARROW}{{{1}}}" " then P({1}|{0},U) = P({1}|{0})" - ).format(",".join(treatment_name), ",".join(mediators_names)), + ).format(",".join(action_nodes), ",".join(mediator_nodes)), "Second-stage-unconfoundedness": ( "If U\N{RIGHTWARDS ARROW}{{{2}}} and U\N{RIGHTWARDS ARROW}{1}" " then P({1}|{2}, {0}, U) = P({1}|{2}, {0})" - ).format(",".join(treatment_name), outcome_name, ",".join(mediators_names)), + ).format(",".join(action_nodes), outcome_nodes, ",".join(mediator_nodes)), } else: raise ValueError( diff --git a/dowhy/causal_identifier/efficient_backdoor.py b/dowhy/causal_identifier/efficient_backdoor.py index b9deffebdb..9742413f06 100644 --- a/dowhy/causal_identifier/efficient_backdoor.py +++ b/dowhy/causal_identifier/efficient_backdoor.py @@ -1,3 +1,5 @@ +from typing import List + import networkx as nx import numpy as np @@ -10,9 +12,17 @@ class EfficientBackdoor: Implements methods for finding optimal (efficient) backdoor sets. """ - def __init__(self, graph, conditional_node_names=None, costs=None): + def __init__( + self, + graph: nx.DiGraph, + action_nodes: List[str], + outcome_nodes: List[str], + observed_nodes: List[str], + conditional_node_names=None, + costs=None, + ): """ - :param graph: CausalGraph + :param graph: nx.DiGraph A causal graph. :param costs: list A list with non-negative costs associated with variables in the graph. Only used @@ -25,20 +35,18 @@ def __init__(self, graph, conditional_node_names=None, costs=None): provided, it is assumed that the intervention sets the treatment to a constant. """ assert ( - len(graph.treatment_name) == 1 + len(action_nodes) == 1 ), "The methods for computing efficient backdoor sets are only valid for one dimensional treatments" assert ( - len(graph.outcome_name) == 1 + len(outcome_nodes) == 1 ), "The methods for computing efficient backdoor sets are only valid for one dimensional outcomes" self.graph = graph if costs is None: # If no costs are passed, use uniform costs - costs = [(node, {"cost": 1}) for node in self.graph._graph.nodes] + costs = [(node, {"cost": 1}) for node in self.graph.nodes] assert all([tup["cost"] > 0 for _, tup in costs]), "All costs must be positive" - self.graph._graph.add_nodes_from(costs) - self.observed_nodes = set( - [node for node in self.graph._graph.nodes if self.graph._graph.nodes[node]["observed"] == "yes"] - ) + self.graph.add_nodes_from(costs) + self.observed_nodes = set([node for node in self.graph.nodes if node in set(observed_nodes)]) if conditional_node_names is None: conditional_node_names = [] assert set(conditional_node_names).issubset( @@ -46,6 +54,9 @@ def __init__(self, graph, conditional_node_names=None, costs=None): ), "Some conditional variables are not marked as observed" self.conditional_node_names = conditional_node_names + self.treatment_name = action_nodes[0] + self.outcome_name = outcome_nodes[0] + def ancestors_all(self, nodes): """Method to compute the set of all ancestors of a set of nodes. A node is always an ancestor of itself. @@ -59,7 +70,7 @@ def ancestors_all(self, nodes): ancestors = set() for node in nodes: - ancestors_node = nx.ancestors(self.graph._graph, node) + ancestors_node = nx.ancestors(self.graph, node) ancestors = ancestors.union(ancestors_node) ancestors = ancestors.union(set(nodes)) @@ -77,7 +88,7 @@ def backdoor_graph(self, G): """ Gbd = G.copy() - for path in nx.all_simple_edge_paths(G, self.graph.treatment_name[0], self.graph.outcome_name[0]): + for path in nx.all_simple_edge_paths(G, self.treatment_name, self.outcome_name): first_edge = path[0] Gbd.remove_edge(first_edge[0], first_edge[1]) @@ -94,16 +105,16 @@ def causal_vertices(self): causal_vertices = set() causal_paths = list( nx.all_simple_paths( - self.graph._graph, - source=self.graph.treatment_name[0], - target=self.graph.outcome_name[0], + self.graph, + source=self.treatment_name, + target=self.outcome_name, ) ) for path in causal_paths: causal_vertices = causal_vertices.union(set(path)) - causal_vertices.remove(self.graph.treatment_name[0]) + causal_vertices.remove(self.treatment_name) return causal_vertices @@ -117,9 +128,9 @@ def forbidden(self): forbidden = set() for node in self.causal_vertices(): - forbidden = forbidden.union(nx.descendants(self.graph._graph, node).union({node})) + forbidden = forbidden.union(nx.descendants(self.graph, node).union({node})) - return forbidden.union({self.graph.treatment_name[0]}) + return forbidden.union({self.treatment_name}) def ignore(self): """Method to compute the set of ignorable vertices with respect to @@ -131,13 +142,11 @@ def ignore(self): :returns ignore: set The set of ignorable vertices. """ - set1 = set( - self.ancestors_all(self.conditional_node_names + [self.graph.treatment_name[0], self.graph.outcome_name[0]]) - ) - set1.remove(self.graph.treatment_name[0]) - set1.remove(self.graph.outcome_name[0]) + set1 = set(self.ancestors_all(self.conditional_node_names + [self.treatment_name, self.outcome_name])) + set1.remove(self.treatment_name) + set1.remove(self.outcome_name) - set2 = set(self.graph._graph.nodes()) - self.observed_nodes + set2 = set(self.graph.nodes()) - self.observed_nodes set2 = set2.union(self.forbidden()) ignore = set1.intersection(set2) @@ -159,7 +168,7 @@ def unblocked(self, H, Z): G2 = H.subgraph(H.nodes() - set(Z)) - B = nx.node_connected_component(G2, self.graph.treatment_name[0]) + B = nx.node_connected_component(G2, self.treatment_name) unblocked = set(nx.node_boundary(H, B)) return unblocked @@ -173,10 +182,8 @@ def build_H0(self): The H0 graph. """ # restriction to ancestors - anc = self.ancestors_all( - self.conditional_node_names + [self.graph.treatment_name[0], self.graph.outcome_name[0]] - ) - G2 = self.graph._graph.subgraph(anc) + anc = self.ancestors_all(self.conditional_node_names + [self.treatment_name, self.outcome_name]) + G2 = self.graph.subgraph(anc) # back-door graph G3 = self.backdoor_graph(G2) @@ -210,8 +217,8 @@ def build_H1(self): break for node in self.conditional_node_names: - H1.add_edge(self.graph.treatment_name[0], node) - H1.add_edge(node, self.graph.outcome_name[0]) + H1.add_edge(self.treatment_name, node) + H1.add_edge(node, self.outcome_name) return H1 @@ -251,10 +258,10 @@ def compute_smallest_mincut(self): D = self.build_D() _, flow_dict = nx.algorithms.flow.maximum_flow( flowG=D, - _s=self.graph.outcome_name[0] + "''", - _t=self.graph.treatment_name[0] + "'", + _s=self.outcome_name + "''", + _t=self.treatment_name + "'", ) - queu = [self.graph.outcome_name[0] + "''"] + queu = [self.outcome_name + "''"] S_c = set() visited = set() while len(queu) > 0: @@ -290,7 +297,7 @@ def h_operator(self, S): The set obtained from applying the h operator to S. """ Z = set() - for node in self.graph._graph.nodes: + for node in self.graph.nodes: nodep = node + "'" nodepp = node + "''" condition = nodep in S and nodepp not in S @@ -310,12 +317,12 @@ def optimal_adj_set(self): The optimal adjustment set. """ H1 = self.build_H1() - if self.graph.treatment_name[0] in H1.neighbors(self.graph.outcome_name[0]): + if self.treatment_name in H1.neighbors(self.outcome_name): raise ValueError(EXCEPTION_NO_ADJ) - elif self.observed_nodes == set(self.graph._graph.nodes()) or self.observed_nodes.issubset( - self.ancestors_all(self.conditional_node_names + [self.graph.treatment_name[0], self.graph.outcome_name[0]]) + elif self.observed_nodes == set(self.graph.nodes()) or self.observed_nodes.issubset( + self.ancestors_all(self.conditional_node_names + [self.treatment_name, self.outcome_name]) ): - optimal = nx.node_boundary(H1, {self.graph.outcome_name[0]}) + optimal = nx.node_boundary(H1, {self.outcome_name}) return optimal else: raise ValueError(EXCEPTION_COND_NO_OPT) @@ -330,10 +337,10 @@ def optimal_minimal_adj_set(self): H1 = self.build_H1() - if self.graph.treatment_name[0] in H1.neighbors(self.graph.outcome_name[0]): + if self.treatment_name in H1.neighbors(self.outcome_name): raise ValueError(EXCEPTION_NO_ADJ) else: - optimal_minimal = self.unblocked(H1, nx.node_boundary(H1, [self.graph.outcome_name[0]])) + optimal_minimal = self.unblocked(H1, nx.node_boundary(H1, [self.outcome_name])) return optimal_minimal def optimal_mincost_adj_set(self): @@ -347,7 +354,7 @@ def optimal_mincost_adj_set(self): The optimal minimum cost adjustment set. """ H1 = self.build_H1() - if self.graph.treatment_name[0] in H1.neighbors(self.graph.outcome_name[0]): + if self.treatment_name in H1.neighbors(self.outcome_name): raise ValueError(EXCEPTION_NO_ADJ) else: S_c = self.compute_smallest_mincut() diff --git a/dowhy/causal_identifier/id_identifier.py b/dowhy/causal_identifier/id_identifier.py index 2b1d1bfa9e..db8af1bff2 100644 --- a/dowhy/causal_identifier/id_identifier.py +++ b/dowhy/causal_identifier/id_identifier.py @@ -3,7 +3,7 @@ import networkx as nx import numpy as np -from dowhy.causal_graph import CausalGraph +from dowhy.graph import get_adjacency_matrix from dowhy.utils.api import parse_state from dowhy.utils.graph_operations import find_ancestor, find_c_components, induced_graph from dowhy.utils.ordered_set import OrderedSet @@ -96,21 +96,18 @@ class IDIdentifier: def identify_effect( self, - graph: CausalGraph, - treatment_name: Union[str, List[str]], - outcome_name: Union[str, List[str]], - node_names: Optional[Union[str, List[str]]] = None, - **kwargs, + graph: nx.DiGraph, + action_nodes: Union[str, List[str]], + outcome_nodes: Union[str, List[str]], + observed_nodes: Union[str, List[str]], ): - return identify_effect_id(graph, treatment_name, outcome_name, node_names, **kwargs) + return identify_effect_id(graph, action_nodes, outcome_nodes) def identify_effect_id( - graph: CausalGraph, - treatment_name: Union[str, List[str]], - outcome_name: Union[str, List[str]], - node_names: Optional[Union[str, List[str]]] = None, - **kwargs, + graph: nx.DiGraph, + action_nodes: Union[str, List[str]], + outcome_nodes: Union[str, List[str]], ) -> IDExpression: """ Implementation of the ID algorithm. @@ -119,25 +116,22 @@ def identify_effect_id( :param treatment_names: OrderedSet comprising names of treatment variables. :param outcome_names:OrderedSet comprising names of outcome variables. - :param node_names: OrderedSet comprising names of all nodes in the graph :returns: target estimand, an instance of the IDExpression class. """ + node_names = OrderedSet(graph.nodes) - if node_names is None: - node_names = OrderedSet(graph._graph.nodes) - - adjacency_matrix = graph.get_adjacency_matrix() + adjacency_matrix = np.asmatrix(get_adjacency_matrix(graph)) try: - tsort_node_names = OrderedSet(list(nx.topological_sort(graph._graph))) # topological sorting of graph nodes + tsort_node_names = OrderedSet(list(nx.topological_sort(graph))) # topological sorting of graph nodes except: raise Exception("The graph must be a directed acyclic graph (DAG).") return __adjacency_matrix_identify_effect( adjacency_matrix, - OrderedSet(parse_state(treatment_name)), - OrderedSet(parse_state(outcome_name)), + OrderedSet(parse_state(action_nodes)), + OrderedSet(parse_state(outcome_nodes)), tsort_node_names, node_names, ) diff --git a/dowhy/causal_identifier/identify_effect.py b/dowhy/causal_identifier/identify_effect.py index d4bb9b8580..5152435f5f 100755 --- a/dowhy/causal_identifier/identify_effect.py +++ b/dowhy/causal_identifier/identify_effect.py @@ -1,6 +1,7 @@ from typing import List, Protocol, Union -from dowhy.causal_graph import CausalGraph +import networkx as nx + from dowhy.causal_identifier.auto_identifier import BackdoorAdjustment, EstimandType, identify_effect_auto from dowhy.causal_identifier.identified_estimand import IdentifiedEstimand @@ -15,12 +16,12 @@ class CausalIdentifier(Protocol): """ def identify_effect( - self, graph: CausalGraph, treatment_name: Union[str, List[str]], outcome_name: Union[str, List[str]], **kwargs + self, graph: nx.DiGraph, action_nodes: Union[str, List[str]], outcome_nodes: Union[str, List[str]], **kwargs ): - """Identify the causal effect to be estimated based on a CausalGraph - :param graph: CausalGraph to be analyzed - :param treatment_name: name of the treatment - :param outcome_name: name of the outcome + """Identify the causal effect to be estimated based on a causal graph + :param graph: Causal graph to be analyzed + :param action_nodes: name of the treatment + :param outcome_nodes: name of the outcome :param **kwargs: Additional parameters required by the identify_effect of a specific CausalIdentifier for example: conditional_node_names in AutoIdentifier or node_names in IDIdentifier :returns: a probability expression (estimand) for the causal effect if identified, else NULL @@ -29,23 +30,24 @@ def identify_effect( def identify_effect( - graph: CausalGraph, - treatment_name: Union[str, List[str]], - outcome_name: Union[str, List[str]], + graph: nx.DiGraph, + action_nodes: Union[str, List[str]], + outcome_nodes: Union[str, List[str]], + observed_nodes: Union[str, List[str]], ) -> IdentifiedEstimand: - """Identify the causal effect to be estimated based on a CausalGraph + """Identify the causal effect to be estimated based on a causal graph - :param graph: CausalGraph to be analyzed + :param graph: Causal graph to be analyzed :param treatment: name of the treatment :param outcome: name of the outcome :returns: a probability expression (estimand) for the causal effect if identified, else NULL """ return identify_effect_auto( graph, - treatment_name, - outcome_name, + action_nodes, + outcome_nodes, + observed_nodes, EstimandType.NONPARAMETRIC_ATE, backdoor_adjustment=BackdoorAdjustment.BACKDOOR_DEFAULT, - proceed_when_unidentifiable=True, optimize_backdoor=False, ) diff --git a/dowhy/causal_model.py b/dowhy/causal_model.py index b5a22c7c06..277cd5fec7 100755 --- a/dowhy/causal_model.py +++ b/dowhy/causal_model.py @@ -11,7 +11,6 @@ import dowhy.causal_estimators as causal_estimators import dowhy.causal_refuters as causal_refuters import dowhy.graph_learners as graph_learners -import dowhy.utils.cli_helpers as cli from dowhy.causal_estimator import CausalEstimate, estimate_effect from dowhy.causal_graph import CausalGraph from dowhy.causal_identifier import AutoIdentifier, BackdoorAdjustment, IDIdentifier @@ -216,12 +215,14 @@ def identify_effect( identifier = AutoIdentifier( estimand_type=estimand_type, backdoor_adjustment=BackdoorAdjustment(method_name), - proceed_when_unidentifiable=proceed_when_unidentifiable, optimize_backdoor=optimize_backdoor, ) identified_estimand = identifier.identify_effect( - graph=self._graph, treatment_name=self._treatment, outcome_name=self._outcome + graph=self._graph._graph, + action_nodes=self._treatment, + outcome_nodes=self._outcome, + observed_nodes=list(self._graph.get_all_nodes(include_unobserved=False)), ) self.identifier = identifier diff --git a/dowhy/causal_prediction/algorithms/cacm.py b/dowhy/causal_prediction/algorithms/cacm.py index d985bd3158..f08c26e1b7 100644 --- a/dowhy/causal_prediction/algorithms/cacm.py +++ b/dowhy/causal_prediction/algorithms/cacm.py @@ -1,5 +1,4 @@ import torch -from torch import nn from torch.nn import functional as F from dowhy.causal_prediction.algorithms.base_algorithm import PredictionAlgorithm diff --git a/dowhy/causal_prediction/algorithms/erm.py b/dowhy/causal_prediction/algorithms/erm.py index aa394f0ba5..78188068ed 100644 --- a/dowhy/causal_prediction/algorithms/erm.py +++ b/dowhy/causal_prediction/algorithms/erm.py @@ -1,5 +1,4 @@ import torch -from torch import nn from torch.nn import functional as F from dowhy.causal_prediction.algorithms.base_algorithm import PredictionAlgorithm diff --git a/dowhy/causal_prediction/dataloaders/misc.py b/dowhy/causal_prediction/dataloaders/misc.py index caf08b5046..62679d8d74 100644 --- a/dowhy/causal_prediction/dataloaders/misc.py +++ b/dowhy/causal_prediction/dataloaders/misc.py @@ -4,7 +4,7 @@ """ import hashlib -from collections import Counter, OrderedDict +from collections import Counter import numpy as np import torch diff --git a/dowhy/causal_prediction/datasets/mnist.py b/dowhy/causal_prediction/datasets/mnist.py index 3fe99b8346..2ee67b2ba3 100644 --- a/dowhy/causal_prediction/datasets/mnist.py +++ b/dowhy/causal_prediction/datasets/mnist.py @@ -3,9 +3,9 @@ import torch import torchvision from PIL import Image -from torch.utils.data import Subset, TensorDataset +from torch.utils.data import TensorDataset from torchvision import transforms -from torchvision.datasets import MNIST, ImageFolder +from torchvision.datasets import MNIST from torchvision.transforms.functional import rotate from dowhy.causal_prediction.datasets.base_dataset import MultipleDomainDataset diff --git a/dowhy/causal_refuters/bootstrap_refuter.py b/dowhy/causal_refuters/bootstrap_refuter.py index d678053b05..eb23bfeba2 100644 --- a/dowhy/causal_refuters/bootstrap_refuter.py +++ b/dowhy/causal_refuters/bootstrap_refuter.py @@ -8,7 +8,7 @@ from sklearn.utils import resample from tqdm.auto import tqdm -from dowhy.causal_estimator import CausalEstimate, CausalEstimator +from dowhy.causal_estimator import CausalEstimate from dowhy.causal_estimators.econml import Econml from dowhy.causal_identifier.identified_estimand import IdentifiedEstimand from dowhy.causal_refuter import CausalRefutation, CausalRefuter, choose_variables, test_significance diff --git a/dowhy/causal_refuters/data_subset_refuter.py b/dowhy/causal_refuters/data_subset_refuter.py index fa15da5c87..759c767546 100755 --- a/dowhy/causal_refuters/data_subset_refuter.py +++ b/dowhy/causal_refuters/data_subset_refuter.py @@ -6,7 +6,7 @@ from joblib import Parallel, delayed from tqdm.auto import tqdm -from dowhy.causal_estimator import CausalEstimate, CausalEstimator +from dowhy.causal_estimator import CausalEstimate from dowhy.causal_estimators.econml import Econml from dowhy.causal_identifier import IdentifiedEstimand from dowhy.causal_refuter import CausalRefutation, CausalRefuter, test_significance diff --git a/dowhy/causal_refuters/dummy_outcome_refuter.py b/dowhy/causal_refuters/dummy_outcome_refuter.py index a4a529b2ed..70c6b23fbd 100644 --- a/dowhy/causal_refuters/dummy_outcome_refuter.py +++ b/dowhy/causal_refuters/dummy_outcome_refuter.py @@ -12,7 +12,7 @@ from sklearn.svm import SVR from tqdm.auto import tqdm -from dowhy.causal_estimator import CausalEstimate, CausalEstimator +from dowhy.causal_estimator import CausalEstimate from dowhy.causal_estimators.econml import Econml from dowhy.causal_identifier.identified_estimand import IdentifiedEstimand from dowhy.causal_refuter import CausalRefutation, CausalRefuter, choose_variables, test_significance diff --git a/dowhy/causal_refuters/evalue_sensitivity_analyzer.py b/dowhy/causal_refuters/evalue_sensitivity_analyzer.py index 9b39b092cd..110303d53a 100644 --- a/dowhy/causal_refuters/evalue_sensitivity_analyzer.py +++ b/dowhy/causal_refuters/evalue_sensitivity_analyzer.py @@ -6,7 +6,7 @@ import pandas as pd import statsmodels.api as sm -from dowhy.causal_estimator import CausalEstimate, CausalEstimator +from dowhy.causal_estimator import CausalEstimate from dowhy.causal_estimators.econml import Econml from dowhy.causal_estimators.generalized_linear_model_estimator import GeneralizedLinearModelEstimator from dowhy.causal_estimators.linear_regression_estimator import LinearRegressionEstimator diff --git a/dowhy/causal_refuters/linear_sensitivity_analyzer.py b/dowhy/causal_refuters/linear_sensitivity_analyzer.py index d5c95365de..b465eec630 100644 --- a/dowhy/causal_refuters/linear_sensitivity_analyzer.py +++ b/dowhy/causal_refuters/linear_sensitivity_analyzer.py @@ -1,5 +1,4 @@ import logging -import sys import matplotlib.pyplot as plt import numpy as np diff --git a/dowhy/causal_refuters/random_common_cause.py b/dowhy/causal_refuters/random_common_cause.py index dc8b88c715..f506772645 100755 --- a/dowhy/causal_refuters/random_common_cause.py +++ b/dowhy/causal_refuters/random_common_cause.py @@ -7,7 +7,7 @@ from joblib import Parallel, delayed from tqdm.auto import tqdm -from dowhy.causal_estimator import CausalEstimate, CausalEstimator +from dowhy.causal_estimator import CausalEstimate from dowhy.causal_estimators.econml import Econml from dowhy.causal_identifier.identified_estimand import IdentifiedEstimand from dowhy.causal_refuter import CausalRefutation, CausalRefuter, test_significance diff --git a/dowhy/causal_refuters/refute_estimate.py b/dowhy/causal_refuters/refute_estimate.py index fc9ac71b30..9db8427227 100644 --- a/dowhy/causal_refuters/refute_estimate.py +++ b/dowhy/causal_refuters/refute_estimate.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, List, Optional, Union +from typing import Callable, List, Optional, Union import pandas as pd diff --git a/dowhy/do_sampler.py b/dowhy/do_sampler.py index b82a4fc668..1de49f1f52 100644 --- a/dowhy/do_sampler.py +++ b/dowhy/do_sampler.py @@ -1,8 +1,11 @@ import logging +from typing import List +import networkx as nx import numpy as np import pandas as pd +from dowhy import EstimandType, identify_effect_auto from dowhy.utils.api import parse_state @@ -10,7 +13,17 @@ class DoSampler: """Base class for a sampler from the interventional distribution.""" def __init__( - self, data, params=None, variable_types=None, num_cores=1, causal_model=None, keep_original_treatment=False + self, + graph: nx.DiGraph, + action_nodes: List[str], + outcome_nodes: List[str], + observed_nodes: List[str], + data, + params=None, + variable_types=None, + num_cores=1, + keep_original_treatment=False, + estimand_type=EstimandType.NONPARAMETRIC_ATE, ): """ Initializes a do sampler with data and names of relevant variables. @@ -50,11 +63,12 @@ def __init__( """ self._data = data.copy() - self._causal_model = causal_model - self._target_estimand = self._causal_model.identify_effect() + self._target_estimand = identify_effect_auto( + graph, action_nodes, outcome_nodes, observed_nodes, estimand_type=estimand_type + ) self._target_estimand.set_identifier_method("backdoor") - self._treatment_names = parse_state(self._causal_model._treatment) - self._outcome_names = parse_state(self._causal_model._outcome) + self._treatment_names = parse_state(action_nodes) + self._outcome_names = parse_state(outcome_nodes) self._estimate = None self._variable_types = variable_types self.num_cores = num_cores @@ -71,6 +85,7 @@ def __init__( if not self._variable_types: self._infer_variable_types() self.dep_type = [self._variable_types[var] for var in self._outcome_names] + self.indep_type = [ self._variable_types[var] for var in self._treatment_names + self._target_estimand.get_backdoor_variables() ] diff --git a/dowhy/do_samplers/kernel_density_sampler.py b/dowhy/do_samplers/kernel_density_sampler.py index 09793184b9..ad2f2eef48 100644 --- a/dowhy/do_samplers/kernel_density_sampler.py +++ b/dowhy/do_samplers/kernel_density_sampler.py @@ -1,6 +1,6 @@ import numpy as np -from scipy.interpolate import LinearNDInterpolator, interp1d -from statsmodels.nonparametric.kernel_density import EstimatorSettings, KDEMultivariate, KDEMultivariateConditional +from scipy.interpolate import interp1d +from statsmodels.nonparametric.kernel_density import EstimatorSettings, KDEMultivariateConditional from dowhy.do_sampler import DoSampler diff --git a/dowhy/do_samplers/mcmc_sampler.py b/dowhy/do_samplers/mcmc_sampler.py index bd7c5e8057..1dfb77b4e6 100644 --- a/dowhy/do_samplers/mcmc_sampler.py +++ b/dowhy/do_samplers/mcmc_sampler.py @@ -1,40 +1,49 @@ +from typing import List + import networkx as nx import numpy as np import pymc3 as pm +from dowhy import EstimandType from dowhy.do_sampler import DoSampler class McmcSampler(DoSampler): def __init__( self, + graph: nx.DiGraph, + action_nodes: List[str], + outcome_nodes: List[str], + observed_nodes: List[str], data, - *args, params=None, variable_types=None, num_cores=1, keep_original_treatment=False, - causal_model=None, - **kwargs, + estimand_type=EstimandType.NONPARAMETRIC_ATE, ): """ g, df, data_types """ super().__init__( - data, + graph=graph, + action_nodes=action_nodes, + outcome_nodes=outcome_nodes, + observed_nodes=observed_nodes, + data=data, params=params, variable_types=variable_types, - causal_model=causal_model, num_cores=num_cores, keep_original_treatment=keep_original_treatment, + estimand_type=estimand_type, ) self.logger.info("Using McmcSampler for do sampling.") self.point_sampler = False self.sampler = self._construct_sampler() - self.g = causal_model._graph.get_unconfounded_observed_subgraph() + self.g = graph.subgraph(observed_nodes) g_fit = nx.DiGraph(self.g) _, self.fit_trace = self.fit_causal_model(g_fit, self._data, self._variable_types) diff --git a/dowhy/do_samplers/multivariate_weighting_sampler.py b/dowhy/do_samplers/multivariate_weighting_sampler.py index 65ca429fc0..56a41cef0f 100644 --- a/dowhy/do_samplers/multivariate_weighting_sampler.py +++ b/dowhy/do_samplers/multivariate_weighting_sampler.py @@ -1,3 +1,8 @@ +from typing import List + +import networkx as nx + +from dowhy import EstimandType from dowhy.do_sampler import DoSampler from dowhy.utils.propensity_score import state_propensity_score @@ -5,26 +10,32 @@ class MultivariateWeightingSampler(DoSampler): def __init__( self, + graph: nx.DiGraph, + action_nodes: List[str], + outcome_nodes: List[str], + observed_nodes: List[str], data, - *args, params=None, variable_types=None, num_cores=1, keep_original_treatment=False, - causal_model=None, - **kwargs, + estimand_type=EstimandType.NONPARAMETRIC_ATE, ): """ g, df, data_types """ super().__init__( - data, + graph=graph, + action_nodes=action_nodes, + outcome_nodes=outcome_nodes, + observed_nodes=observed_nodes, + data=data, params=params, variable_types=variable_types, num_cores=num_cores, keep_original_treatment=keep_original_treatment, - causal_model=causal_model, + estimand_type=estimand_type, ) self.logger.info("Using MultivariateWeightingSampler for do sampling.") diff --git a/dowhy/do_samplers/weighting_sampler.py b/dowhy/do_samplers/weighting_sampler.py index 8e24267063..626efa10c5 100644 --- a/dowhy/do_samplers/weighting_sampler.py +++ b/dowhy/do_samplers/weighting_sampler.py @@ -1,32 +1,41 @@ -import numpy as np +from typing import List +import networkx as nx + +from dowhy import EstimandType from dowhy.do_sampler import DoSampler -from dowhy.utils.propensity_score import propensity_of_treatment_score, state_propensity_score +from dowhy.utils.propensity_score import state_propensity_score class WeightingSampler(DoSampler): def __init__( self, + graph: nx.DiGraph, + action_nodes: List[str], + outcome_nodes: List[str], + observed_nodes: List[str], data, - *args, params=None, variable_types=None, num_cores=1, keep_original_treatment=False, - causal_model=None, - **kwargs, + estimand_type=EstimandType.NONPARAMETRIC_ATE, ): """ g, df, data_types """ super().__init__( - data, + graph=graph, + action_nodes=action_nodes, + outcome_nodes=outcome_nodes, + observed_nodes=observed_nodes, + data=data, params=params, variable_types=variable_types, num_cores=num_cores, keep_original_treatment=keep_original_treatment, - causal_model=causal_model, + estimand_type=estimand_type, ) self.logger.info("Using WeightingSampler for do sampling.") diff --git a/dowhy/gcm/causal_models.py b/dowhy/gcm/causal_models.py index 2e0d366544..0e93130025 100644 --- a/dowhy/gcm/causal_models.py +++ b/dowhy/gcm/causal_models.py @@ -103,21 +103,6 @@ def causal_mechanism(self, node: Any) -> Union[StochasticModel, InvertibleFuncti return super().causal_mechanism(node) -def validate_causal_dag(causal_graph: DirectedGraph) -> None: - validate_acyclic(causal_graph) - validate_causal_graph(causal_graph) - - -def validate_causal_graph(causal_graph: DirectedGraph) -> None: - for node in causal_graph.nodes: - validate_node(causal_graph, node) - - -def validate_node(causal_graph: DirectedGraph, node: Any) -> None: - validate_causal_model_assignment(causal_graph, node) - validate_local_structure(causal_graph, node) - - def validate_causal_model_assignment(causal_graph: DirectedGraph, target_node: Any) -> None: validate_node_has_causal_model(causal_graph, target_node) @@ -136,6 +121,28 @@ def validate_causal_model_assignment(causal_graph: DirectedGraph, target_node: A ) +def validate_node_has_causal_model(causal_graph: HasNodes, node: Any) -> None: + validate_node_in_graph(causal_graph, node) + + if CAUSAL_MECHANISM not in causal_graph.nodes[node]: + raise ValueError("Node %s has no assigned causal mechanism!" % node) + + +def validate_causal_dag(causal_graph: DirectedGraph) -> None: + validate_acyclic(causal_graph) + validate_causal_graph(causal_graph) + + +def validate_causal_graph(causal_graph: DirectedGraph) -> None: + for node in causal_graph.nodes: + validate_node(causal_graph, node) + + +def validate_node(causal_graph: DirectedGraph, node: Any) -> None: + validate_causal_model_assignment(causal_graph, node) + validate_local_structure(causal_graph, node) + + def validate_local_structure(causal_graph: DirectedGraph, node: Any) -> None: if PARENTS_DURING_FIT not in causal_graph.nodes[node] or causal_graph.nodes[node][ PARENTS_DURING_FIT @@ -147,13 +154,6 @@ def validate_local_structure(causal_graph: DirectedGraph, node: Any) -> None: ) -def validate_node_has_causal_model(causal_graph: HasNodes, node: Any) -> None: - validate_node_in_graph(causal_graph, node) - - if CAUSAL_MECHANISM not in causal_graph.nodes[node]: - raise ValueError("Node %s has no assigned causal mechanism!" % node) - - def clone_causal_models(source: HasNodes, destination: HasNodes): for node in destination.nodes: if CAUSAL_MECHANISM in source.nodes[node]: diff --git a/dowhy/graph.py b/dowhy/graph.py index 7f505ea683..f739ffc069 100644 --- a/dowhy/graph.py +++ b/dowhy/graph.py @@ -1,15 +1,18 @@ -"""This module defines the fundamental interfaces and functions related to causal graphs.. - -Classes and functions in this module should be considered experimental, meaning there might be breaking API changes in -the future. -""" - +"""This module defines the fundamental interfaces and functions related to causal graphs.""" +import itertools +import logging +import re from abc import abstractmethod from typing import Any, List, Protocol import networkx as nx from networkx.algorithms.dag import has_cycle +from dowhy.utils.api import parse_state +from dowhy.utils.graph_operations import daggity_to_dot + +_logger = logging.getLogger(__name__) + class HasNodes(Protocol): """This protocol defines a trait for classes having nodes.""" @@ -72,3 +75,355 @@ def validate_acyclic(causal_graph: DirectedGraph) -> None: def validate_node_in_graph(causal_graph: HasNodes, node: Any) -> None: if node not in causal_graph.nodes: raise ValueError("Node %s can not be found in the given graph!" % node) + + +def check_valid_backdoor_set( + graph: nx.DiGraph, + nodes1, + nodes2, + nodes3, + backdoor_paths=None, + new_graph: nx.DiGraph = None, + dseparation_algo="default", +): + """Assume that the first parameter (nodes1) is the treatment, + the second is the outcome, and the third is the candidate backdoor set + """ + # also return the number of backdoor paths blocked by observed nodes + if dseparation_algo == "default": + if new_graph is None: + # Assume that nodes1 is the treatment + new_graph = do_surgery(graph, nodes1, remove_outgoing_edges=True) + dseparated = nx.algorithms.d_separated(new_graph, set(nodes1), set(nodes2), set(nodes3)) + elif dseparation_algo == "naive": + # ignores new_graph parameter, always uses self._graph + if backdoor_paths is None: + backdoor_paths = get_backdoor_paths(graph, nodes1, nodes2) + dseparated = all([is_blocked(graph, path, nodes3) for path in backdoor_paths]) + else: + raise ValueError(f"{dseparation_algo} method for d-separation not supported.") + return {"is_dseparated": dseparated} + + +def do_surgery( + graph: nx.DiGraph, + node_names, + remove_outgoing_edges=False, + remove_incoming_edges=False, + target_node_names=None, + remove_only_direct_edges_to_target=False, +): + """Method to create a new graph based on the concept of do-surgery. + + :param node_names: focal nodes for the surgery + :param remove_outgoing_edges: whether to remove outgoing edges from the focal nodes + :param remove_incoming_edges: whether to remove incoming edges to the focal nodes + :param target_node_names: target nodes (optional) for the surgery, only used when remove_only_direct_edges_to_target is True + :param remove_only_direct_edges_to_target: whether to remove only the direct edges from focal nodes to the target nodes + + :returns: a new networkx graph after the specified removal of edges + """ + + node_names = parse_state(node_names) + new_graph = graph.copy() + for node_name in node_names: + if remove_outgoing_edges: + if remove_only_direct_edges_to_target: + new_graph.remove_edges_from([(node_name, v) for v in target_node_names]) + else: + children = new_graph.successors(node_name) + edges_bunch = [(node_name, child) for child in children] + new_graph.remove_edges_from(edges_bunch) + if remove_incoming_edges: + # removal of only direct edges wrt a target is not implemented for incoming edges + parents = new_graph.predecessors(node_name) + edges_bunch = [(parent, node_name) for parent in parents] + new_graph.remove_edges_from(edges_bunch) + return new_graph + + +def get_backdoor_paths(graph: nx.DiGraph, nodes1, nodes2): + paths = [] + undirected_graph = graph.to_undirected() + nodes12 = set(nodes1).union(nodes2) + for node1 in nodes1: + for node2 in nodes2: + backdoor_paths = [ + pth + for pth in nx.all_simple_paths(undirected_graph, source=node1, target=node2) + if graph.has_edge(pth[1], pth[0]) + ] + # remove paths that have nodes1\node1 or nodes2\node2 as intermediate nodes + filtered_backdoor_paths = [pth for pth in backdoor_paths if len(nodes12.intersection(pth[1:-1])) == 0] + paths.extend(filtered_backdoor_paths) + _logger.debug("Backdoor paths: " + str(paths)) + return paths + + +def is_blocked(graph: nx.DiGraph, path, conditioned_nodes): + """Uses d-separation criteria to decide if conditioned_nodes block given path.""" + + blocked_by_conditioning = False + has_unconditioned_collider = False + + for i in range(len(path) - 2): + if graph.has_edge(path[i], path[i + 1]) and graph.has_edge(path[i + 2], path[i + 1]): # collider + collider_descendants = nx.descendants(graph, path[i + 1]) + if path[i + 1] not in conditioned_nodes and all( + cdesc not in conditioned_nodes for cdesc in collider_descendants + ): + has_unconditioned_collider = True + else: # chain or fork + if path[i + 1] in conditioned_nodes: + blocked_by_conditioning = True + break + if blocked_by_conditioning: + return True + elif has_unconditioned_collider: + return True + else: + return False + + +def get_descendants(graph: nx.DiGraph, nodes): + descendants = set() + for node_name in nodes: + descendants = descendants.union(set(nx.descendants(graph, node_name))) + return descendants + + +def check_dseparation(graph: nx.DiGraph, nodes1, nodes2, nodes3, new_graph=None, dseparation_algo="default"): + if dseparation_algo == "default": + if new_graph is None: + new_graph = graph + dseparated = nx.algorithms.d_separated(new_graph, set(nodes1), set(nodes2), set(nodes3)) + else: + raise ValueError(f"{dseparation_algo} method for d-separation not supported.") + return dseparated + + +def get_all_nodes(graph: nx.DiGraph, observed_nodes: List[Any], include_unobserved_nodes: bool) -> List[Any]: + observed_nodes = set(observed_nodes) + return [node for node in graph.nodes if include_unobserved_nodes or node in observed_nodes] + + +def get_instruments(graph: nx.DiGraph, treatment_nodes, outcome_nodes): + treatment_nodes = parse_state(treatment_nodes) + outcome_nodes = parse_state(outcome_nodes) + parents_treatment = set() + for node in treatment_nodes: + parents_treatment = parents_treatment.union(set(graph.predecessors(node))) + g_no_parents_treatment = do_surgery(graph, treatment_nodes, remove_incoming_edges=True) + ancestors_outcome = set() + for node in outcome_nodes: + ancestors_outcome = ancestors_outcome.union(nx.ancestors(g_no_parents_treatment, node)) + # [TODO: double check these work with multivariate implementation:] + # Exclusion + candidate_instruments = parents_treatment.difference(ancestors_outcome) + _logger.debug("Candidate instruments after satisfying exclusion: %s", candidate_instruments) + # As-if-random setup + children_causes_outcome = [nx.descendants(g_no_parents_treatment, v) for v in ancestors_outcome] + children_causes_outcome = set([item for sublist in children_causes_outcome for item in sublist]) + + # As-if-random + instruments = candidate_instruments.difference(children_causes_outcome) + _logger.debug("Candidate instruments after satisfying exclusion and as-if-random: %s", instruments) + return list(instruments) + + +def check_valid_frontdoor_set( + graph: nx.DiGraph, + nodes1, + nodes2, + candidate_nodes, + frontdoor_paths=None, + new_graph: nx.DiGraph = None, + dseparation_algo="default", +): + """Check if valid the frontdoor variables for set of treatments, nodes1 to set of outcomes, nodes2.""" + # Condition 1: node 1 ---> node 2 is intercepted by candidate_nodes + if dseparation_algo == "default": + if new_graph is None: + new_graph = graph + dseparated = nx.algorithms.d_separated(new_graph, set(nodes1), set(nodes2), set(candidate_nodes)) + elif dseparation_algo == "naive": + if frontdoor_paths is None: + frontdoor_paths = get_all_directed_paths(graph, nodes1, nodes2) + + dseparated = all([is_blocked(graph, path, candidate_nodes) for path in frontdoor_paths]) + else: + raise ValueError(f"{dseparation_algo} method for d-separation not supported.") + return dseparated + + +def get_all_directed_paths(graph: nx.DiGraph, nodes1, nodes2): + """Get all directed paths between sets of nodes. + + Currently only supports singleton sets. + """ + if len(nodes1) > 1 or len(nodes2) > 1: + raise ValueError( + "The list of action and outcome nodes can only contain one element, i.e., needs to be univariate!" + ) + return [p for p in nx.all_simple_paths(graph, source=nodes1[0], target=nodes2[0])] + + +def has_directed_path(graph: nx.DiGraph, nodes1, nodes2): + """Checks if there is any directed path between two sets of nodes. + + Currently only supports singleton sets. + """ + # dpaths = self.get_all_directed_paths(nodes1, nodes2) + # return len(dpaths) > 0 + return nx.has_path(graph, nodes1[0], nodes2[0]) + + +def check_valid_mediation_set(graph: nx.DiGraph, nodes1, nodes2, candidate_nodes, mediation_paths=None): + """Check if candidate nodes are valid mediators for set of treatments, nodes1 to set of outcomes, nodes2.""" + if mediation_paths is None: + mediation_paths = get_all_directed_paths(graph, nodes1, nodes2) + + is_mediator = any([is_blocked(graph, path, candidate_nodes) for path in mediation_paths]) + return is_mediator + + +def get_adjacency_matrix(graph: nx.DiGraph, *args, **kwargs): + """ + Get adjacency matrix from the networkx graph + + """ + return nx.convert_matrix.to_numpy_array(graph, *args, **kwargs) + + +def build_graph( + action_nodes: List[str], + outcome_nodes: List[str], + common_cause_nodes: List[str] = None, + instrument_nodes=None, + effect_modifier_nodes=None, + mediator_nodes=None, +): + """Creates nodes and edges based on variable names and their semantics. + + Currently only considers the graphical representation of "direct" effect modifiers. Thus, all effect modifiers are assumed to be "direct" unless otherwise expressed using a graph. Based on the taxonomy of effect modifiers by VanderWheele and Robins: "Four types of effect modification: A classification based on directed acyclic graphs. Epidemiology. 2007." + """ + graph = nx.DiGraph() + + action_nodes = parse_state(action_nodes) + outcome_nodes = parse_state(outcome_nodes) + common_cause_nodes = parse_state(common_cause_nodes) + instrument_nodes = parse_state(instrument_nodes) + effect_modifier_nodes = parse_state(effect_modifier_nodes) + + for treatment in action_nodes: + graph.add_node(treatment) + for outcome in outcome_nodes: + graph.add_node(outcome) + for treatment, outcome in itertools.product(action_nodes, outcome_nodes): + graph.add_edge(treatment, outcome) + + # Adding common causes + if common_cause_nodes: + for node_name in common_cause_nodes: + for treatment, outcome in itertools.product(action_nodes, outcome_nodes): + graph.add_node(node_name) + graph.add_edge(node_name, treatment) + graph.add_edge(node_name, outcome) + + # Adding instruments + if instrument_nodes: + if type(instrument_nodes[0]) != tuple: + if len(action_nodes) > 1: + _logger.info("Assuming Instrument points to all treatments! Use tuples for more granularity.") + for instrument, treatment in itertools.product(instrument_nodes, action_nodes): + graph.add_node(instrument) + graph.add_edge(instrument, treatment) + else: + for instrument, treatment in itertools.product(instrument_nodes): + graph.add_node(instrument) + graph.add_edge(instrument, treatment) + + # Adding effect modifiers + if effect_modifier_nodes: + for node_name in effect_modifier_nodes: + if node_name not in common_cause_nodes: + for outcome in outcome_nodes: + graph.add_node(node_name) + # Assuming the simple form of effect modifier + # that directly causes the outcome. + graph.add_edge(node_name, outcome) + # self._graph.add_edge(node_name, outcome, style = "dotted", headport="s", tailport="n") + # self._graph.add_edge(outcome, node_name, style = "dotted", headport="n", tailport="s") # TODO make the ports more general so that they apply not just to top-bottom node configurations + if mediator_nodes: + for node_name in mediator_nodes: + for treatment, outcome in itertools.product(action_nodes, outcome_nodes): + graph.add_node(node_name) + graph.add_edge(treatment, node_name) + graph.add_edge(node_name, outcome) + + return graph + + +def build_graph_from_str(graph_str: str) -> nx.DiGraph: + """ + User-friendly function that returns a networkx graph based on the graph string. + + Formats supported: dot, gml, daggity + + The `graph_str` parameter can refer to the path of a text file containing the encoded graph or contain the actual encoded graph as a string. + + :param graph_str: a string containing the filepath or the encoded graph + :type graph_str: str + + :returns: a networkx directed graph object + """ + # some preprocessing steps + if re.match(r".*\.txt", graph_str): + text_file = open(graph_str, "r") + graph_str = text_file.read() + text_file.close() + if re.match(r"^dag", graph_str): # Convert daggity output to dot format + graph_str = daggity_to_dot(graph_str) + if isinstance(graph_str, str): + graph_str = graph_str.replace("\n", " ") + + # parsing the correct graph based on input graph format + if re.match(r".*\.dot", graph_str): + # load dot file + try: + import pygraphviz as pgv + + return nx.DiGraph(nx.drawing.nx_agraph.read_dot(graph_str)) + except Exception as e: + _logger.error("Pygraphviz cannot be loaded. " + str(e) + "\nTrying pydot...") + try: + import pydot + + return nx.DiGraph(nx.drawing.nx_pydot.read_dot(graph_str)) + except Exception as e: + _logger.error("Error: Pydot cannot be loaded. " + str(e)) + raise e + elif re.match(r".*\.gml", graph_str): + return nx.DiGraph(nx.read_gml(graph_str)) + elif re.match(r".*graph\s*\{.*\}\s*", graph_str): + try: + import pygraphviz as pgv + + graph = pgv.AGraph(graph_str, strict=True, directed=True) + return nx.drawing.nx_agraph.from_agraph(graph) + except Exception as e: + _logger.error("Error: Pygraphviz cannot be loaded. " + str(e) + "\nTrying pydot ...") + try: + import pydot + + P_list = pydot.graph_from_dot_data(graph_str) + return nx.drawing.nx_pydot.from_pydot(P_list[0]) + except Exception as e: + _logger.error("Error: Pydot cannot be loaded. " + str(e)) + raise e + elif re.match(".*graph\s*\[.*\]\s*", graph_str): + return nx.DiGraph(nx.parse_gml(graph_str)) + else: + _logger.error("Error: Please provide graph (as string or text file) in dot or gml format.") + _logger.error("Error: Incorrect graph format") + raise ValueError diff --git a/dowhy/graph_learners/__init__.py b/dowhy/graph_learners/__init__.py index 5b2aa51f43..e06ffb1640 100644 --- a/dowhy/graph_learners/__init__.py +++ b/dowhy/graph_learners/__init__.py @@ -1,4 +1,3 @@ -import string from importlib import import_module from dowhy.graph_learner import GraphLearner diff --git a/dowhy/graph_learners/cdt.py b/dowhy/graph_learners/cdt.py index 33ed452aaf..8c96a42f7f 100644 --- a/dowhy/graph_learners/cdt.py +++ b/dowhy/graph_learners/cdt.py @@ -1,6 +1,3 @@ -import networkx as nx -import numpy as np - from dowhy.graph_learners import GraphLearner from dowhy.utils.graph_operations import * diff --git a/dowhy/graph_learners/ges.py b/dowhy/graph_learners/ges.py index f6666e20e1..043b91ccfe 100644 --- a/dowhy/graph_learners/ges.py +++ b/dowhy/graph_learners/ges.py @@ -1,8 +1,5 @@ from importlib import import_module -import networkx as nx -import numpy as np - from dowhy.graph_learners import GraphLearner from dowhy.utils.graph_operations import * diff --git a/dowhy/graph_learners/lingam.py b/dowhy/graph_learners/lingam.py index 5833938844..a6c4cb8713 100644 --- a/dowhy/graph_learners/lingam.py +++ b/dowhy/graph_learners/lingam.py @@ -1,6 +1,3 @@ -import networkx as nx -import numpy as np - from dowhy.graph_learners import GraphLearner from dowhy.utils.graph_operations import * diff --git a/dowhy/interpreters/propensity_balance_interpreter.py b/dowhy/interpreters/propensity_balance_interpreter.py index 6b21150e88..2c33791836 100644 --- a/dowhy/interpreters/propensity_balance_interpreter.py +++ b/dowhy/interpreters/propensity_balance_interpreter.py @@ -2,7 +2,6 @@ import pandas as pd from dowhy.causal_estimator import CausalEstimate -from dowhy.causal_estimators.propensity_score_estimator import PropensityScoreEstimator from dowhy.causal_estimators.propensity_score_stratification_estimator import PropensityScoreStratificationEstimator from dowhy.interpreters.visual_interpreter import VisualInterpreter diff --git a/dowhy/utils/cit.py b/dowhy/utils/cit.py index 1b762053bf..5c60b3b997 100644 --- a/dowhy/utils/cit.py +++ b/dowhy/utils/cit.py @@ -1,8 +1,7 @@ from collections import defaultdict -from math import exp, log +from math import log import numpy as np -import pandas as pd from scipy.stats import norm, t diff --git a/dowhy/utils/dgp.py b/dowhy/utils/dgp.py index 645345ae8d..0c7eb4c3c6 100644 --- a/dowhy/utils/dgp.py +++ b/dowhy/utils/dgp.py @@ -1,5 +1,4 @@ import numpy as np -import pandas as pd class DataGeneratingProcess: diff --git a/dowhy/utils/graph_operations.py b/dowhy/utils/graph_operations.py index b3118648d0..577f1d1265 100644 --- a/dowhy/utils/graph_operations.py +++ b/dowhy/utils/graph_operations.py @@ -1,5 +1,4 @@ import re -from collections import deque from queue import LifoQueue import networkx as nx diff --git a/dowhy/utils/propensity_score.py b/dowhy/utils/propensity_score.py index 98af528334..df49d58bbc 100644 --- a/dowhy/utils/propensity_score.py +++ b/dowhy/utils/propensity_score.py @@ -1,15 +1,10 @@ -import logging - import numpy as np import pandas as pd from pandas import get_dummies -from sklearn.ensemble import GradientBoostingClassifier from sklearn.linear_model import LogisticRegression from sklearn.preprocessing import LabelEncoder from statsmodels.nonparametric.kernel_density import EstimatorSettings, KDEMultivariateConditional -import dowhy.utils.api as api - def propensity_of_treatment_score(data, covariates, treatment, model="logistic", variable_types=None): if model == "logistic": diff --git a/tests/causal_estimators/base.py b/tests/causal_estimators/base.py index c620377c71..526e062062 100755 --- a/tests/causal_estimators/base.py +++ b/tests/causal_estimators/base.py @@ -1,7 +1,8 @@ import itertools import dowhy.datasets -from dowhy import CausalModel +from dowhy import EstimandType, identify_effect_auto +from dowhy.graph import build_graph_from_str class TestEstimator(object): @@ -52,15 +53,13 @@ def average_treatment_effect_test( else: raise ValueError("Dataset type not supported.") - model = CausalModel( - data=data["df"], - treatment=data["treatment_name"], - outcome=data["outcome_name"], - graph=data["gml_graph"], - proceed_when_unidentifiable=True, - test_significance=test_significance, + target_estimand = identify_effect_auto( + build_graph_from_str(data["gml_graph"]), + observed_nodes=list(data["df"].columns), + action_nodes=data["treatment_name"], + outcome_nodes=data["outcome_name"], + estimand_type=EstimandType.NONPARAMETRIC_ATE, ) - target_estimand = model.identify_effect() target_estimand.set_identifier_method(self._identifier_method) estimator_ate = self._Estimator( identified_estimand=target_estimand, @@ -158,15 +157,13 @@ def average_treatment_effect_testsuite( self.average_treatment_effect_test(**cfg) def custom_data_average_treatment_effect_test(self, data): - model = CausalModel( - data=data["df"], - treatment=data["treatment_name"], - outcome=data["outcome_name"], - graph=data["gml_graph"], - proceed_when_unidentifiable=True, - test_significance=None, + target_estimand = identify_effect_auto( + build_graph_from_str(data["gml_graph"]), + observed_nodes=list(data["df"].columns), + action_nodes=data["treatment_name"], + outcome_nodes=data["outcome_name"], + estimand_type=EstimandType.NONPARAMETRIC_ATE, ) - target_estimand = model.identify_effect() estimator_ate = self._Estimator( identified_estimand=target_estimand, test_significance=None, diff --git a/tests/causal_identifiers/base.py b/tests/causal_identifiers/base.py index cdbd8eaefb..3cdbb110ca 100644 --- a/tests/causal_identifiers/base.py +++ b/tests/causal_identifiers/base.py @@ -1,6 +1,6 @@ import pytest -from dowhy.causal_graph import CausalGraph +from dowhy.graph import build_graph_from_str from .example_graphs import TEST_GRAPH_SOLUTIONS @@ -15,9 +15,10 @@ def __init__( maximal_adjustment_sets, direct_maximal_adjustment_sets=None, ): - self.graph = CausalGraph("X", "Y", graph_str, observed_node_names=observed_variables) - self.graph_str = graph_str - self.observed_variables = observed_variables + self.graph = build_graph_from_str(graph_str) + self.action_nodes = ["X"] + self.outcome_nodes = ["Y"] + self.observed_nodes = observed_variables self.biased_sets = biased_sets self.minimal_adjustment_sets = minimal_adjustment_sets self.maximal_adjustment_sets = maximal_adjustment_sets diff --git a/tests/causal_identifiers/test_backdoor_identifier.py b/tests/causal_identifiers/test_backdoor_identifier.py index ff33616809..dcc4c884f5 100644 --- a/tests/causal_identifiers/test_backdoor_identifier.py +++ b/tests/causal_identifiers/test_backdoor_identifier.py @@ -1,6 +1,5 @@ import pytest -from dowhy.causal_graph import CausalGraph from dowhy.causal_identifier import AutoIdentifier, BackdoorAdjustment from dowhy.causal_identifier.identify_effect import EstimandType @@ -16,7 +15,13 @@ def test_identify_backdoor_no_biased_sets(self, example_graph_solution: Identifi backdoor_adjustment=BackdoorAdjustment.BACKDOOR_EXHAUSTIVE, ) - backdoor_results = identifier.identify_backdoor(graph, "X", "Y", include_unobserved=False) + backdoor_results = identifier.identify_backdoor( + graph, + observed_nodes=example_graph_solution.observed_nodes, + action_nodes=["X"], + outcome_nodes=["Y"], + include_unobserved=False, + ) backdoor_sets = [ set(backdoor_result_dict["backdoor_set"]) for backdoor_result_dict in backdoor_results @@ -31,13 +36,19 @@ def test_identify_backdoor_unobserved_not_in_backdoor_set( self, example_graph_solution: IdentificationTestGraphSolution ): graph = example_graph_solution.graph - observed_variables = example_graph_solution.observed_variables + observed_variables = example_graph_solution.observed_nodes identifier = AutoIdentifier( estimand_type=EstimandType.NONPARAMETRIC_ATE, backdoor_adjustment=BackdoorAdjustment.BACKDOOR_EXHAUSTIVE, ) - backdoor_results = identifier.identify_backdoor(graph, "X", "Y", include_unobserved=False) + backdoor_results = identifier.identify_backdoor( + graph, + observed_nodes=example_graph_solution.observed_nodes, + action_nodes=["X"], + outcome_nodes=["Y"], + include_unobserved=False, + ) backdoor_sets = [ set(backdoor_result_dict["backdoor_set"]) for backdoor_result_dict in backdoor_results @@ -52,12 +63,16 @@ def test_identify_backdoor_minimal_adjustment(self, example_graph_solution: Iden graph = example_graph_solution.graph expected_sets = example_graph_solution.minimal_adjustment_sets identifier = AutoIdentifier( - estimand_type=EstimandType.NONPARAMETRIC_ATE, - backdoor_adjustment=BackdoorAdjustment.BACKDOOR_MIN, - proceed_when_unidentifiable=False, + estimand_type=EstimandType.NONPARAMETRIC_ATE, backdoor_adjustment=BackdoorAdjustment.BACKDOOR_MIN ) - backdoor_results = identifier.identify_backdoor(graph, "X", "Y", include_unobserved=False) + backdoor_results = identifier.identify_backdoor( + graph, + observed_nodes=example_graph_solution.observed_nodes, + action_nodes=["X"], + outcome_nodes=["Y"], + include_unobserved=False, + ) backdoor_sets = [set(backdoor_result_dict["backdoor_set"]) for backdoor_result_dict in backdoor_results] assert ( @@ -72,13 +87,17 @@ def test_identify_backdoor_maximal_adjustment(self, example_graph_solution: Iden identifier = AutoIdentifier( estimand_type=EstimandType.NONPARAMETRIC_ATE, backdoor_adjustment=BackdoorAdjustment.BACKDOOR_MAX, - proceed_when_unidentifiable=False, ) - backdoor_results = identifier.identify_backdoor(graph, "X", "Y", include_unobserved=False) + backdoor_results = identifier.identify_backdoor( + graph, + observed_nodes=example_graph_solution.observed_nodes, + action_nodes=["X"], + outcome_nodes=["Y"], + include_unobserved=False, + ) backdoor_sets = [set(backdoor_result_dict["backdoor_set"]) for backdoor_result_dict in backdoor_results] - print(backdoor_sets, expected_sets, example_graph_solution.graph_str) assert ( (len(backdoor_sets) == 0) and (len(expected_sets) == 0) ) or all( # No adjustments exist and that's expected. @@ -91,13 +110,17 @@ def test_identify_backdoor_maximal_direct_effect(self, example_graph_solution: I identifier = AutoIdentifier( estimand_type=EstimandType.NONPARAMETRIC_CDE, backdoor_adjustment=BackdoorAdjustment.BACKDOOR_MAX, - proceed_when_unidentifiable=False, ) - backdoor_results = identifier.identify_backdoor(graph, "X", "Y", direct_effect=True) + backdoor_results = identifier.identify_backdoor( + graph, + observed_nodes=example_graph_solution.observed_nodes, + action_nodes=["X"], + outcome_nodes=["Y"], + direct_effect=True, + ) backdoor_sets = [set(backdoor_result_dict["backdoor_set"]) for backdoor_result_dict in backdoor_results] - print(backdoor_sets, expected_sets, example_graph_solution.graph_str) assert ( (len(backdoor_sets) == 0) and (len(expected_sets) == 0) ) or all( # No adjustments exist and that's expected. diff --git a/tests/causal_identifiers/test_efficient_backdoor_identifier.py b/tests/causal_identifiers/test_efficient_backdoor_identifier.py index 131d4e2d85..b5cf986139 100644 --- a/tests/causal_identifiers/test_efficient_backdoor_identifier.py +++ b/tests/causal_identifiers/test_efficient_backdoor_identifier.py @@ -2,20 +2,14 @@ import pytest -from dowhy.causal_graph import CausalGraph from dowhy.causal_identifier import AutoIdentifier, BackdoorAdjustment, EstimandType from dowhy.causal_identifier.auto_identifier import EFFICIENT_METHODS +from dowhy.graph import build_graph_from_str from tests.causal_identifiers.example_graphs_efficient import TEST_EFFICIENT_BD_SOLUTIONS def test_identify_efficient_backdoor_algorithms(): for example in TEST_EFFICIENT_BD_SOLUTIONS.values(): - G = CausalGraph( - graph=example["graph_str"], - treatment_name="X", - outcome_name="Y", - observed_node_names=example["observed_node_names"], - ) for method_name in EFFICIENT_METHODS: ident_eff = AutoIdentifier( estimand_type=EstimandType.NONPARAMETRIC_ATE, @@ -26,16 +20,18 @@ def test_identify_efficient_backdoor_algorithms(): if example[method_name_results] is None: with pytest.raises(ValueError): ident_eff.identify_effect( - G, - G.treatment_name, - G.outcome_name, + build_graph_from_str(example["graph_str"]), + observed_nodes=example["observed_node_names"], + action_nodes=["X"], + outcome_nodes=["Y"], conditional_node_names=example["conditional_node_names"], ) else: results_eff = ident_eff.identify_effect( - G, - G.treatment_name, - G.outcome_name, + build_graph_from_str(example["graph_str"]), + observed_nodes=example["observed_node_names"], + action_nodes=["X"], + outcome_nodes=["Y"], conditional_node_names=example["conditional_node_names"], ) assert set(results_eff.get_backdoor_variables()) == example[method_name_results] @@ -43,12 +39,6 @@ def test_identify_efficient_backdoor_algorithms(): def test_fail_negative_costs_efficient_backdoor_algorithms(): example = TEST_EFFICIENT_BD_SOLUTIONS["sr22_fig2_example_graph"] - G = CausalGraph( - graph=example["graph_str"], - treatment_name="X", - outcome_name="Y", - observed_node_names=example["observed_node_names"], - ) mod_costs = copy.deepcopy(example["costs"]) mod_costs[0][1]["cost"] = 0 ident_eff = AutoIdentifier( @@ -59,21 +49,16 @@ def test_fail_negative_costs_efficient_backdoor_algorithms(): with pytest.raises(Exception): ident_eff.identify_effect( - G, - G.treatment_name, - G.outcome_name, + build_graph_from_str(example["graph_str"]), + observed_nodes=example["observed_node_names"], + action_nodes=["X"], + outcome_nodes=["Y"], conditional_node_names=example["conditional_node_names"], ) def test_fail_unobserved_cond_vars_efficient_backdoor_algorithms(): example = TEST_EFFICIENT_BD_SOLUTIONS["sr22_fig2_example_graph"] - G = CausalGraph( - graph=example["graph_str"], - treatment_name="X", - outcome_name="Y", - observed_node_names=example["observed_node_names"], - ) ident_eff = AutoIdentifier( estimand_type=EstimandType.NONPARAMETRIC_ATE, backdoor_adjustment=BackdoorAdjustment.BACKDOOR_MINCOST_EFFICIENT, @@ -83,21 +68,16 @@ def test_fail_unobserved_cond_vars_efficient_backdoor_algorithms(): mod_cond_names.append("U") with pytest.raises(Exception): ident_eff.identify_effect( - G, - G.treatment_name, - G.outcome_name, + build_graph_from_str(example["graph_str"]), + observed_nodes=example["observed_node_names"], + action_nodes=["X"], + outcome_nodes=["Y"], conditional_node_names=mod_cond_names, ) def test_fail_multivar_treat_efficient_backdoor_algorithms(): example = TEST_EFFICIENT_BD_SOLUTIONS["sr22_fig2_example_graph"] - G = CausalGraph( - graph=example["graph_str"], - treatment_name=["X", "K"], - outcome_name="Y", - observed_node_names=example["observed_node_names"], - ) ident_eff = AutoIdentifier( estimand_type=EstimandType.NONPARAMETRIC_ATE, backdoor_adjustment=BackdoorAdjustment.BACKDOOR_MINCOST_EFFICIENT, @@ -105,21 +85,16 @@ def test_fail_multivar_treat_efficient_backdoor_algorithms(): ) with pytest.raises(Exception): ident_eff.identify_effect( - G, - G.treatment_name, - G.outcome_name, + build_graph_from_str(example["graph_str"]), + observed_nodes=example["observed_node_names"], + action_nodes=["X", "K"], + outcome_nodes=["Y"], conditional_node_names=example["conditional_node_names"], ) def test_fail_multivar_outcome_efficient_backdoor_algorithms(): example = TEST_EFFICIENT_BD_SOLUTIONS["sr22_fig2_example_graph"] - G = CausalGraph( - graph=example["graph_str"], - treatment_name="X", - outcome_name=["Y", "R"], - observed_node_names=example["observed_node_names"], - ) ident_eff = AutoIdentifier( estimand_type=EstimandType.NONPARAMETRIC_ATE, backdoor_adjustment=BackdoorAdjustment.BACKDOOR_MINCOST_EFFICIENT, @@ -127,8 +102,9 @@ def test_fail_multivar_outcome_efficient_backdoor_algorithms(): ) with pytest.raises(Exception): ident_eff.identify_effect( - G, - G.treatment_name, - G.outcome_name, + build_graph_from_str(example["graph_str"]), + observed_nodes=example["observed_node_names"], + action_nodes=["X"], + outcome_nodes=["Y", "R"], conditional_node_names=example["conditional_node_names"], ) diff --git a/tests/causal_identifiers/test_id_identifier.py b/tests/causal_identifiers/test_id_identifier.py index 335604970e..52be382cbd 100644 --- a/tests/causal_identifiers/test_id_identifier.py +++ b/tests/causal_identifiers/test_id_identifier.py @@ -1,22 +1,14 @@ -import numpy as np -import pandas as pd import pytest -from numpy.core.fromnumeric import var -from dowhy import CausalModel +from dowhy import identify_effect_id +from dowhy.graph import build_graph_from_str class TestIDIdentifier(object): def test_1(self): - treatment = "T" - outcome = "Y" - causal_graph = "digraph{T->Y;}" - columns = list(treatment) + list(outcome) - df = pd.DataFrame(columns=columns) - - # Calculate causal effect twice: once for unit (t=1, c=0), once for specific increase (t=100, c=50) - causal_model = CausalModel(df, treatment, outcome, graph=causal_graph) - identified_estimand = causal_model.identify_effect(method_name="id-algorithm") + identified_estimand = identify_effect_id( + build_graph_from_str("digraph{T->Y;}"), action_nodes=["T"], outcome_nodes=["Y"] + ) # Only P(Y|T) should be present for test to succeed. identified_str = identified_estimand.__str__() @@ -27,30 +19,16 @@ def test_2(self): """ Test undirected edge between treatment and outcome. """ - treatment = "T" - outcome = "Y" - causal_graph = "digraph{T->Y; Y->T;}" - columns = list(treatment) + list(outcome) - df = pd.DataFrame(columns=columns) - - # Calculate causal effect twice: once for unit (t=1, c=0), once for specific increase (t=100, c=50) - causal_model = CausalModel(df, treatment, outcome, graph=causal_graph) - # Since undirected graph, identify effect must throw an error. with pytest.raises(Exception): - identified_estimand = causal_model.identify_effect(method_name="id-algorithm") + identified_estimand = identify_effect_id( + build_graph_from_str("digraph{T->Y; Y->T;}"), action_nodes=["T"], outcome_nodes=["Y"] + ) def test_3(self): - treatment = "T" - outcome = "Y" - variables = ["X1"] - causal_graph = "digraph{T->X1;X1->Y;}" - columns = list(treatment) + list(outcome) + list(variables) - df = pd.DataFrame(columns=columns) - - # Calculate causal effect twice: once for unit (t=1, c=0), once for specific increase (t=100, c=50) - causal_model = CausalModel(df, treatment, outcome, graph=causal_graph) - identified_estimand = causal_model.identify_effect(method_name="id-algorithm") + identified_estimand = identify_effect_id( + build_graph_from_str("digraph{T->X1;X1->Y;}"), action_nodes=["T"], outcome_nodes=["Y"] + ) # Compare with ground truth identified_str = identified_estimand.__str__() @@ -58,16 +36,9 @@ def test_3(self): assert identified_str == gt_str def test_4(self): - treatment = "T" - outcome = "Y" - variables = ["X1"] - causal_graph = "digraph{T->Y;T->X1;X1->Y;}" - columns = list(treatment) + list(outcome) + list(variables) - df = pd.DataFrame(columns=columns) - - # Calculate causal effect twice: once for unit (t=1, c=0), once for specific increase (t=100, c=50) - causal_model = CausalModel(df, treatment, outcome, graph=causal_graph) - identified_estimand = causal_model.identify_effect(method_name="id-algorithm") + identified_estimand = identify_effect_id( + build_graph_from_str("digraph{T->Y;T->X1;X1->Y;}"), action_nodes=["T"], outcome_nodes=["Y"] + ) # Compare with ground truth identified_str = identified_estimand.__str__() @@ -75,16 +46,9 @@ def test_4(self): assert identified_str == gt_str def test_5(self): - treatment = "T" - outcome = "Y" - variables = ["X1", "X2"] - causal_graph = "digraph{T->Y;X1->T;X1->Y;X2->T;}" - columns = list(treatment) + list(outcome) + list(variables) - df = pd.DataFrame(columns=columns) - - # Calculate causal effect twice: once for unit (t=1, c=0), once for specific increase (t=100, c=50) - causal_model = CausalModel(df, treatment, outcome, graph=causal_graph) - identified_estimand = causal_model.identify_effect(method_name="id-algorithm") + identified_estimand = identify_effect_id( + build_graph_from_str("digraph{T->Y;X1->T;X1->Y;X2->T;}"), action_nodes=["T"], outcome_nodes=["Y"] + ) # Compare with ground truth set_a = set(identified_estimand._product[0]._product[0]._product[0]["outcome_vars"]._set) @@ -98,16 +62,9 @@ def test_5(self): assert len(set_d) == 0 def test_6(self): - treatment = "T" - outcome = "Y" - variables = ["X1"] - causal_graph = "digraph{T;X1->Y;}" - columns = list(treatment) + list(outcome) + list(variables) - df = pd.DataFrame(columns=columns) - - # Calculate causal effect twice: once for unit (t=1, c=0), once for specific increase (t=100, c=50) - causal_model = CausalModel(df, treatment, outcome, graph=causal_graph) - identified_estimand = causal_model.identify_effect(method_name="id-algorithm") + identified_estimand = identify_effect_id( + build_graph_from_str("digraph{T;X1->Y;}"), action_nodes=["T"], outcome_nodes=["Y"] + ) # Compare with ground truth identified_str = identified_estimand.__str__() diff --git a/tests/do_sampler/test_pandas_do_api.py b/tests/do_sampler/test_pandas_do_api.py index 00fa101e71..260c502925 100755 --- a/tests/do_sampler/test_pandas_do_api.py +++ b/tests/do_sampler/test_pandas_do_api.py @@ -1,6 +1,5 @@ import numpy as np import pandas as pd -import pytest from pytest import mark from sklearn.linear_model import LinearRegression @@ -39,7 +38,6 @@ def test_pandas_api_discrete_cause_continuous_confounder(self, N, error_toleranc outcome=outcome, method=method, common_causes=common_causes, - proceed_when_unidentifiable=True, ) ate = (causal_df[causal_df.v == 1].mean() - causal_df[causal_df.v == 0].mean())["y"] @@ -81,7 +79,6 @@ def test_pandas_api_discrete_cause_discrete_confounder(self, N, error_tolerance) outcome=outcome, method=method, common_causes=common_causes, - proceed_when_unidentifiable=True, ) ate = (causal_df[causal_df.v == 1].mean() - causal_df[causal_df.v == 0].mean())["y"] @@ -124,7 +121,6 @@ def test_pandas_api_continuous_cause_discrete_confounder(self, N, error_toleranc outcome=outcome, method=method, common_causes=common_causes, - proceed_when_unidentifiable=True, ) ate = LinearRegression().fit(causal_df[["v"]], causal_df["y"]).coef_[0] @@ -167,7 +163,6 @@ def test_pandas_api_continuous_cause_continuous_confounder(self, N, error_tolera outcome=outcome, method=method, common_causes=common_causes, - proceed_when_unidentifiable=True, ) ate = LinearRegression().fit(causal_df[["v"]], causal_df["y"]).coef_[0] @@ -198,9 +193,9 @@ def test_pandas_api_with_full_specification_of_type(self, N, variable_types): beta=5, num_common_causes=1, num_instruments=0, num_samples=1000, treatment_is_binary=True ) - data["df"].causal.do( - x="v0", variable_types=variable_types, outcome="y", proceed_when_unidentifiable=True, common_causes=["W0"] - ).groupby("v0").mean() + data["df"].causal.do(x="v0", variable_types=variable_types, outcome="y", common_causes=["W0"]).groupby( + "v0" + ).mean() assert True @mark.parametrize( @@ -214,9 +209,9 @@ def test_pandas_api_with_partial_specification_of_type(self, N, variable_types): beta=5, num_common_causes=1, num_instruments=0, num_samples=1000, treatment_is_binary=True ) - data["df"].causal.do( - x="v0", variable_types=variable_types, outcome="y", proceed_when_unidentifiable=True, common_causes=["W0"] - ).groupby("v0").mean() + data["df"].causal.do(x="v0", variable_types=variable_types, outcome="y", common_causes=["W0"]).groupby( + "v0" + ).mean() assert True @mark.parametrize( @@ -230,9 +225,9 @@ def test_pandas_api_with_no_specification_of_type(self, N, variable_types): beta=5, num_common_causes=1, num_instruments=0, num_samples=1000, treatment_is_binary=True ) - data["df"].causal.do( - x="v0", variable_types=variable_types, outcome="y", proceed_when_unidentifiable=True, common_causes=["W0"] - ).groupby("v0").mean() + data["df"].causal.do(x="v0", variable_types=variable_types, outcome="y", common_causes=["W0"]).groupby( + "v0" + ).mean() assert True @mark.parametrize( @@ -247,7 +242,6 @@ def test_pandas_api_with_dummy_data(self, N, variable_types): x=["x"], outcome="y", common_causes=["a", "b"], - proceed_when_unidentifiable=True, variable_types=dict(x="c", y="c", a="c", b="c"), ) print(dd) diff --git a/tests/sample_dag.txt b/tests/sample_dag.txt index 71f9547de8..400e91328a 100644 --- a/tests/sample_dag.txt +++ b/tests/sample_dag.txt @@ -1,5 +1,6 @@ dag { "Unobserved Confounders" [pos="0.491,-1.056"] +W0 X0 [pos="-2.109,0.057"] X1 [adjusted, pos="-0.453,-1.562"] X2 [pos="-2.268,-1.210"] @@ -8,6 +9,8 @@ v0 [pos="-1.525,-1.293"] y [outcome, pos="-1.164,-0.116"] "Unobserved Confounders" -> v0 "Unobserved Confounders" -> y +W0 -> v0 +W0 -> y X0 -> v0 X0 -> y X1 -> v0 @@ -16,4 +19,4 @@ X2 -> v0 X2 -> y Z0 -> v0 v0 -> y -} \ No newline at end of file +} diff --git a/tests/test_causal_graph.py b/tests/test_causal_graph.py new file mode 100644 index 0000000000..ccc41709ac --- /dev/null +++ b/tests/test_causal_graph.py @@ -0,0 +1,121 @@ +import networkx as nx +import pandas as pd +import pytest +from flaky import flaky +from pytest import mark + +import dowhy +import dowhy.datasets +from dowhy import CausalModel +from dowhy.graph import * +from dowhy.utils.graph_operations import daggity_to_dot + + +class TestCausalGraph(object): + @pytest.fixture(autouse=True) + def _init_graph(self): + self.daggity_file = "tests/sample_dag.txt" + data = dowhy.datasets.linear_dataset( + beta=10, + num_common_causes=1, + num_instruments=1, + # num_frontdoor_variables=1, + num_effect_modifiers=3, + num_samples=100, + num_treatments=1, + treatment_is_binary=True, + ) + model = CausalModel( + data=data["df"], + treatment=data["treatment_name"], + outcome=data["outcome_name"], + graph=self.daggity_file, + proceed_when_unidentifiable=True, + test_significance=None, + missing_nodes_as_confounders=False, + ) + self.graph_obj = model._graph + + # creating nx graph instance + with open(self.daggity_file, "r") as text_file: + graph_str = text_file.read() + graph_str = daggity_to_dot(graph_str) + # to be used later for a test. Does not include the replace operation + self.graph_str = graph_str + graph_str = graph_str.replace("\n", " ") + + import pygraphviz as pgv + + nx_graph = pgv.AGraph(graph_str, strict=True, directed=True) + nx_graph = nx.drawing.nx_agraph.from_agraph(nx_graph) + self.nx_graph = nx_graph + self.action_node = data["treatment_name"] + self.outcome_node = data["outcome_name"] + self.observed_nodes = list(nx_graph.nodes) + self.observed_nodes.remove("Unobserved Confounders") + + def test_check_valid_backdoor_set(self): + res1 = self.graph_obj.check_valid_backdoor_set(self.action_node, self.outcome_node, ["X1", "X2"]) + res2 = check_valid_backdoor_set(self.nx_graph, self.action_node, self.outcome_node, ["X1", "X2"]) + assert res1 == res2 + + def test_do_surgery(self): + res1 = self.graph_obj.do_surgery(self.action_node) + res2 = do_surgery(self.nx_graph, self.action_node) + assert list(res1.nodes) == list(res2.nodes) + assert res1.edges == res2.edges + + def test_get_backdoor_paths(self): + res1 = self.graph_obj.get_backdoor_paths(self.action_node, self.outcome_node) + res2 = get_backdoor_paths(self.nx_graph, self.action_node, self.outcome_node) + assert res1 == res2 + + def test_check_dseparation(self): + res1 = self.graph_obj.check_dseparation(self.action_node, self.outcome_node, ["X1", "X2"]) + res2 = check_dseparation(self.nx_graph, self.action_node, self.outcome_node, ["X1", "X2"]) + assert res1 == res2 + + def test_get_instruments(self): + res1 = self.graph_obj.get_instruments(self.action_node, self.outcome_node) + res2 = get_instruments(self.nx_graph, self.action_node, self.outcome_node) + assert res1 == res2 + + def test_get_all_nodes(self): + for flag in [True, False]: + print(list(self.graph_obj._graph.nodes)) + print(list(self.nx_graph.nodes)) + res1 = self.graph_obj.get_all_nodes(include_unobserved=flag) + res2 = get_all_nodes(self.nx_graph, self.observed_nodes, include_unobserved_nodes=flag) + assert set(res1) == set(res2) + + def test_valid_frontdoor_set(self): + res1 = self.graph_obj.check_valid_frontdoor_set(self.action_node, self.outcome_node, ["X0"]) + res2 = check_valid_frontdoor_set(self.nx_graph, self.action_node, self.outcome_node, ["X0"]) + assert res1 == res2 + + def test_valid_mediation_set(self): + res1 = self.graph_obj.check_valid_mediation_set(self.action_node, self.outcome_node, ["X0"]) + res2 = check_valid_mediation_set(self.nx_graph, self.action_node, self.outcome_node, ["X0"]) + assert res1 == res2 + + def test_build_graph(self): + data = dowhy.datasets.linear_dataset(beta=10, num_common_causes=1, num_instruments=1, num_samples=100) + res1 = CausalModel( + data=data["df"], + treatment=data["treatment_name"], + outcome=data["outcome_name"], + common_causes=["W0"], + instruments=["Z0"], + missing_nodes_as_confounders=False, + )._graph._graph + res2 = build_graph( + action_nodes=data["treatment_name"], + outcome_nodes=data["outcome_name"], + common_cause_nodes=["W0"], + instrument_nodes=["Z0"], + ) + assert res1.edges == res2.edges + + def test_build_graph_from_str(self): + build_graph_from_str(self.daggity_file) + build_graph_from_str(self.graph_str) diff --git a/tests/test_causal_model.py b/tests/test_causal_model.py index 85bc29ce25..1fc1a96b2a 100644 --- a/tests/test_causal_model.py +++ b/tests/test_causal_model.py @@ -1,3 +1,4 @@ +import networkx as nx import pandas as pd import pytest from flaky import flaky @@ -7,6 +8,7 @@ import dowhy import dowhy.datasets from dowhy import CausalModel +from dowhy.utils.graph_operations import daggity_to_dot class TestCausalModel(object): @@ -311,6 +313,59 @@ def test_graph_input4(self, beta, num_instruments, num_samples, num_treatments): all_nodes = model._graph.get_all_nodes(include_unobserved=False) assert "Unobserved Confounders" not in all_nodes + @mark.parametrize( + ["beta", "num_instruments", "num_samples", "num_treatments"], + [ + (10, 1, 100, 1), + ], + ) + def test_graph_input_nx(self, beta, num_instruments, num_samples, num_treatments): + num_common_causes = 5 + data = dowhy.datasets.linear_dataset( + beta=beta, + num_common_causes=num_common_causes, + num_instruments=num_instruments, + num_samples=num_samples, + num_treatments=num_treatments, + treatment_is_binary=True, + ) + nx_graph = nx.DiGraph(nx.parse_gml(data["gml_graph"])) + model = CausalModel( + data=data["df"], + treatment=data["treatment_name"], + outcome=data["outcome_name"], + graph=nx_graph, + proceed_when_unidentifiable=True, + test_significance=None, + ) + # removing two common causes + daggity_file = "tests/sample_dag.txt" + with open(daggity_file, "r") as text_file: + graph_str = text_file.read() + graph_str = daggity_to_dot(graph_str) + graph_str = graph_str.replace("\n", " ") + import pygraphviz as pgv + + nx_graph2 = pgv.AGraph(graph_str, strict=True, directed=True) + nx_graph2 = nx.drawing.nx_agraph.from_agraph(nx_graph2) + model = CausalModel( + data=data["df"], + treatment=data["treatment_name"], + outcome=data["outcome_name"], + graph=nx_graph2, + proceed_when_unidentifiable=True, + test_significance=None, + missing_nodes_as_confounders=True, + ) + common_causes = model.get_common_causes() + assert all(node_name in common_causes for node_name in ["X1", "X2"]) + all_nodes = model._graph.get_all_nodes(include_unobserved=True) + assert all( + node_name in all_nodes for node_name in ["Unobserved Confounders", "X0", "X1", "X2", "Z0", "v0", "y"] + ) + all_nodes = model._graph.get_all_nodes(include_unobserved=False) + assert "Unobserved Confounders" not in all_nodes + @mark.parametrize( ["num_variables", "num_samples"], [