Skip to content

Commit

Permalink
Refactor for dowhy_confounder_example
Browse files Browse the repository at this point in the history
  • Loading branch information
rahulbshrestha committed Jul 2, 2024
1 parent 7ebe273 commit 6be4c37
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 9 deletions.
86 changes: 77 additions & 9 deletions docs/source/example_notebooks/dowhy_confounder_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -123,14 +123,46 @@
"Identify the causal effect using properties of the causal graph."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"df"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"identified_estimand = model.identify_effect(proceed_when_unidentifiable=True)\n",
"print(identified_estimand)"
"print(identified_estimand)\n",
" \n",
"\n",
"from dowhy.causal_graph import CausalGraph\n",
"\n",
"graph = CausalGraph(\n",
" data_dict[\"treatment_name\"],\n",
" data_dict[\"outcome_name\"],\n",
" common_cause_names=data_dict[\"common_causes_names\"],\n",
" observed_node_names=df.columns.tolist(),\n",
")\n",
"\n",
"nx_graph = graph._graph"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from dowhy.causal_identifier import identify_effect\n",
"\n",
"identified_estimand = identify_effect(nx_graph, action_nodes=data_dict[\"treatment_name\"], outcome_nodes=data_dict[\"outcome_name\"], observed_nodes=list(graph.get_all_nodes(include_unobserved=False)))"
]
},
{
Expand Down Expand Up @@ -197,7 +229,21 @@
"metadata": {},
"outputs": [],
"source": [
"res_random=model.refute_estimate(identified_estimand, estimate, method_name=\"random_common_cause\")\n",
"from dowhy.causal_refuters import refute_estimate\n",
"from typing import List\n",
"\n",
"refuter_class = dowhy.causal_refuters.get_class_object(\"random_common_cause\")\n",
"\n",
"refuters = list()\n",
"refuters.append(refuter_class)\n",
"\n",
"res_random = refute_estimate(data=df, \n",
" treatment_name=data_dict[\"treatment_name\"], \n",
" target_estimand=identified_estimand,\n",
" outcome_name=data_dict[\"outcome_name\"], \n",
" identified_estimand=identified_estimand, \n",
" estimate=estimate,\n",
" refuters=refuters)\n",
"print(res_random)"
]
},
Expand All @@ -214,9 +260,20 @@
"metadata": {},
"outputs": [],
"source": [
"res_placebo=model.refute_estimate(identified_estimand, estimate,\n",
" method_name=\"placebo_treatment_refuter\", placebo_type=\"permute\")\n",
"print(res_placebo)"
"refuter_class = dowhy.causal_refuters.get_class_object(\"placebo_treatment_refuter\")\n",
"\n",
"refuters = list()\n",
"refuters.append(refuter_class)\n",
"\n",
"res_placebo = refute_estimate(data=df, \n",
" treatment_name=data_dict[\"treatment_name\"], \n",
" target_estimand=identified_estimand,\n",
" outcome_name=data_dict[\"outcome_name\"], \n",
" identified_estimand=identified_estimand, \n",
" estimate=estimate,\n",
" refuters=refuters,\n",
" placebo_type=\"permute\")\n",
"print(res_placebo)\n"
]
},
{
Expand All @@ -232,9 +289,20 @@
"metadata": {},
"outputs": [],
"source": [
"res_subset=model.refute_estimate(identified_estimand, estimate,\n",
" method_name=\"data_subset_refuter\", subset_fraction=0.9)\n",
"print(res_subset)\n"
"refuter_class = dowhy.causal_refuters.get_class_object(\"data_subset_refuter\")\n",
"\n",
"refuters = list()\n",
"refuters.append(refuter_class)\n",
"\n",
"res_subset = refute_estimate(data=df, \n",
" treatment_name=data_dict[\"treatment_name\"], \n",
" target_estimand=identified_estimand,\n",
" outcome_name=data_dict[\"outcome_name\"], \n",
" identified_estimand=identified_estimand, \n",
" estimate=estimate,\n",
" refuters=refuters,\n",
" subset_fraction=0.9)\n",
"print(res_subset)"
]
},
{
Expand All @@ -261,7 +329,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
"version": "3.11.8"
},
"toc": {
"base_numbering": 1,
Expand Down
2 changes: 2 additions & 0 deletions dowhy/causal_refuters/refute_estimate.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ def refute_estimate(

results = []
for refuter in refuters:
print("REFUTER: ", refuter)
refute = refuter(**refuter_kwargs)
print("REFUTE: ", refute)
if isinstance(refute, list):
results.extend(refute)
else:
Expand Down

0 comments on commit 6be4c37

Please sign in to comment.