From 6be4c37080a1872fff18997f5e8a26b60da6f455 Mon Sep 17 00:00:00 2001 From: rahulbshrestha Date: Tue, 2 Jul 2024 17:15:18 +0200 Subject: [PATCH] Refactor for dowhy_confounder_example --- .../dowhy_confounder_example.ipynb | 86 +++++++++++++++++-- dowhy/causal_refuters/refute_estimate.py | 2 + 2 files changed, 79 insertions(+), 9 deletions(-) diff --git a/docs/source/example_notebooks/dowhy_confounder_example.ipynb b/docs/source/example_notebooks/dowhy_confounder_example.ipynb index c4c90594fb..58d921bdc8 100644 --- a/docs/source/example_notebooks/dowhy_confounder_example.ipynb +++ b/docs/source/example_notebooks/dowhy_confounder_example.ipynb @@ -123,6 +123,15 @@ "Identify the causal effect using properties of the causal graph." ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "df" + ] + }, { "cell_type": "code", "execution_count": null, @@ -130,7 +139,30 @@ "outputs": [], "source": [ "identified_estimand = model.identify_effect(proceed_when_unidentifiable=True)\n", - "print(identified_estimand)" + "print(identified_estimand)\n", + " \n", + "\n", + "from dowhy.causal_graph import CausalGraph\n", + "\n", + "graph = CausalGraph(\n", + " data_dict[\"treatment_name\"],\n", + " data_dict[\"outcome_name\"],\n", + " common_cause_names=data_dict[\"common_causes_names\"],\n", + " observed_node_names=df.columns.tolist(),\n", + ")\n", + "\n", + "nx_graph = graph._graph" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from dowhy.causal_identifier import identify_effect\n", + "\n", + "identified_estimand = identify_effect(nx_graph, action_nodes=data_dict[\"treatment_name\"], outcome_nodes=data_dict[\"outcome_name\"], observed_nodes=list(graph.get_all_nodes(include_unobserved=False)))" ] }, { @@ -197,7 +229,21 @@ "metadata": {}, "outputs": [], "source": [ - "res_random=model.refute_estimate(identified_estimand, estimate, method_name=\"random_common_cause\")\n", + "from dowhy.causal_refuters import refute_estimate\n", + "from typing import List\n", + "\n", + "refuter_class = dowhy.causal_refuters.get_class_object(\"random_common_cause\")\n", + "\n", + "refuters = list()\n", + "refuters.append(refuter_class)\n", + "\n", + "res_random = refute_estimate(data=df, \n", + " treatment_name=data_dict[\"treatment_name\"], \n", + " target_estimand=identified_estimand,\n", + " outcome_name=data_dict[\"outcome_name\"], \n", + " identified_estimand=identified_estimand, \n", + " estimate=estimate,\n", + " refuters=refuters)\n", "print(res_random)" ] }, @@ -214,9 +260,20 @@ "metadata": {}, "outputs": [], "source": [ - "res_placebo=model.refute_estimate(identified_estimand, estimate,\n", - " method_name=\"placebo_treatment_refuter\", placebo_type=\"permute\")\n", - "print(res_placebo)" + "refuter_class = dowhy.causal_refuters.get_class_object(\"placebo_treatment_refuter\")\n", + "\n", + "refuters = list()\n", + "refuters.append(refuter_class)\n", + "\n", + "res_placebo = refute_estimate(data=df, \n", + " treatment_name=data_dict[\"treatment_name\"], \n", + " target_estimand=identified_estimand,\n", + " outcome_name=data_dict[\"outcome_name\"], \n", + " identified_estimand=identified_estimand, \n", + " estimate=estimate,\n", + " refuters=refuters,\n", + " placebo_type=\"permute\")\n", + "print(res_placebo)\n" ] }, { @@ -232,9 +289,20 @@ "metadata": {}, "outputs": [], "source": [ - "res_subset=model.refute_estimate(identified_estimand, estimate,\n", - " method_name=\"data_subset_refuter\", subset_fraction=0.9)\n", - "print(res_subset)\n" + "refuter_class = dowhy.causal_refuters.get_class_object(\"data_subset_refuter\")\n", + "\n", + "refuters = list()\n", + "refuters.append(refuter_class)\n", + "\n", + "res_subset = refute_estimate(data=df, \n", + " treatment_name=data_dict[\"treatment_name\"], \n", + " target_estimand=identified_estimand,\n", + " outcome_name=data_dict[\"outcome_name\"], \n", + " identified_estimand=identified_estimand, \n", + " estimate=estimate,\n", + " refuters=refuters,\n", + " subset_fraction=0.9)\n", + "print(res_subset)" ] }, { @@ -261,7 +329,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.5" + "version": "3.11.8" }, "toc": { "base_numbering": 1, diff --git a/dowhy/causal_refuters/refute_estimate.py b/dowhy/causal_refuters/refute_estimate.py index 9db8427227..508a917896 100644 --- a/dowhy/causal_refuters/refute_estimate.py +++ b/dowhy/causal_refuters/refute_estimate.py @@ -55,7 +55,9 @@ def refute_estimate( results = [] for refuter in refuters: + print("REFUTER: ", refuter) refute = refuter(**refuter_kwargs) + print("REFUTE: ", refute) if isinstance(refute, list): results.extend(refute) else: