Skip to content

Commit 055cda5

Browse files
committed
fixed bug in efficient backdoor notebook
Signed-off-by: Amit Sharma <amit_sharma@live.com>
1 parent 807650f commit 055cda5

File tree

1 file changed

+47
-34
lines changed

1 file changed

+47
-34
lines changed

docs/source/example_notebooks/dowhy_efficient_backdoor_example.ipynb

Lines changed: 47 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,9 @@
6262
"outputs": [],
6363
"source": [
6464
"from dowhy.causal_graph import CausalGraph\n",
65-
"from dowhy.causal_identifier import AutoIdentifier, BackdoorAdjustment, EstimandType"
65+
"from dowhy.causal_identifier import AutoIdentifier, BackdoorAdjustment, EstimandType\n",
66+
"from dowhy.graph import build_graph_from_str\n",
67+
"from dowhy.utils.plotting import plot"
6668
]
6769
},
6870
{
@@ -135,9 +137,7 @@
135137
"]\n",
136138
"treatment_name = \"warm-up\"\n",
137139
"outcome_name = \"injury\"\n",
138-
"G = CausalGraph(\n",
139-
" graph=graph_str, treatment_name=treatment_name, outcome_name=outcome_name, observed_node_names=observed_node_names\n",
140-
")"
140+
"G = build_graph_from_str(graph_str)"
141141
]
142142
},
143143
{
@@ -153,7 +153,7 @@
153153
"metadata": {},
154154
"outputs": [],
155155
"source": [
156-
"G.view_graph()"
156+
"plot(G)"
157157
]
158158
},
159159
{
@@ -184,7 +184,11 @@
184184
")\n",
185185
"print(\n",
186186
" ident_eff.identify_effect(\n",
187-
" graph=G, treatment_name=treatment_name, outcome_name=outcome_name, conditional_node_names=conditional_node_names\n",
187+
" graph=G, \n",
188+
" action_nodes=treatment_name, \n",
189+
" outcome_nodes=outcome_name,\n",
190+
" observed_nodes=observed_node_names,\n",
191+
" conditional_node_names=conditional_node_names\n",
188192
" )\n",
189193
")"
190194
]
@@ -215,7 +219,11 @@
215219
")\n",
216220
"print(\n",
217221
" ident_minimal_eff.identify_effect(\n",
218-
" graph=G, treatment_name=treatment_name, outcome_name=outcome_name, conditional_node_names=conditional_node_names\n",
222+
" graph=G, \n",
223+
" action_nodes=treatment_name, \n",
224+
" outcome_nodes=outcome_name, \n",
225+
" observed_nodes=observed_node_names,\n",
226+
" conditional_node_names=conditional_node_names\n",
219227
" )\n",
220228
")"
221229
]
@@ -239,7 +247,11 @@
239247
")\n",
240248
"print(\n",
241249
" ident_mincost_eff.identify_effect(\n",
242-
" graph=G, treatment_name=treatment_name, outcome_name=outcome_name, conditional_node_names=conditional_node_names\n",
250+
" graph=G, \n",
251+
" action_nodes=treatment_name, \n",
252+
" outcome_nodes=outcome_name,\n",
253+
" observed_nodes=observed_node_names,\n",
254+
" conditional_node_names=conditional_node_names\n",
243255
" )\n",
244256
")"
245257
]
@@ -294,9 +306,7 @@
294306
"observed_node_names = [\"X\", \"Y\", \"Z1\", \"Z2\"]\n",
295307
"treatment_name = \"X\"\n",
296308
"outcome_name = \"Y\"\n",
297-
"G = CausalGraph(\n",
298-
" graph=graph_str, treatment_name=treatment_name, outcome_name=outcome_name, observed_node_names=observed_node_names\n",
299-
")"
309+
"G = build_graph_from_str(graph_str)"
300310
]
301311
},
302312
{
@@ -317,7 +327,10 @@
317327
" backdoor_adjustment=BackdoorAdjustment.BACKDOOR_EFFICIENT,\n",
318328
")\n",
319329
"try:\n",
320-
" results_eff = ident_eff.identify_effect(graph=G, treatment_name=treatment_name, outcome_name=outcome_name)\n",
330+
" results_eff = ident_eff.identify_effect(graph=G, \n",
331+
" action_nodes=treatment_name, \n",
332+
" outcome_nodes=outcome_name,\n",
333+
" observed_nodes=observed_node_names)\n",
321334
"except ValueError as e:\n",
322335
" print(e)"
323336
]
@@ -335,8 +348,9 @@
335348
"print(\n",
336349
" ident_minimal_eff.identify_effect(\n",
337350
" graph=G,\n",
338-
" treatment_name=treatment_name,\n",
339-
" outcome_name=outcome_name,\n",
351+
" action_nodes=treatment_name,\n",
352+
" outcome_nodes=outcome_name,\n",
353+
" observed_nodes=observed_node_names\n",
340354
" )\n",
341355
")"
342356
]
@@ -354,8 +368,9 @@
354368
"print(\n",
355369
" ident_mincost_eff.identify_effect(\n",
356370
" graph=G,\n",
357-
" treatment_name=treatment_name,\n",
358-
" outcome_name=outcome_name,\n",
371+
" action_nodes=treatment_name,\n",
372+
" outcome_nodes=outcome_name,\n",
373+
" observed_nodes=observed_node_names\n",
359374
" )\n",
360375
")"
361376
]
@@ -391,9 +406,7 @@
391406
"observed_node_names = [\"X\", \"Y\"]\n",
392407
"treatment_name = \"X\"\n",
393408
"outcome_name = \"Y\"\n",
394-
"G = CausalGraph(\n",
395-
" graph=graph_str, treatment_name=treatment_name, outcome_name=outcome_name, observed_node_names=observed_node_names\n",
396-
")"
409+
"G = build_graph_from_str(graph_str)"
397410
]
398411
},
399412
{
@@ -409,8 +422,9 @@
409422
"try:\n",
410423
" results_eff = ident_eff.identify_effect(\n",
411424
" graph=G,\n",
412-
" treatment_name=treatment_name,\n",
413-
" outcome_name=outcome_name,\n",
425+
" action_nodes=treatment_name,\n",
426+
" outcome_nodes=outcome_name,\n",
427+
" observed_nodes=observed_node_names\n",
414428
" )\n",
415429
"except ValueError as e:\n",
416430
" print(e)"
@@ -475,9 +489,7 @@
475489
" (\"R\", {\"cost\": 2}),\n",
476490
" (\"T\", {\"cost\": 1}),\n",
477491
"]\n",
478-
"G = CausalGraph(\n",
479-
" graph=graph_str, treatment_name=treatment_name, outcome_name=outcome_name, observed_node_names=observed_node_names\n",
480-
")"
492+
"G = build_graph_from_str(graph_str)"
481493
]
482494
},
483495
{
@@ -504,7 +516,11 @@
504516
")\n",
505517
"print(\n",
506518
" ident_mincost_eff.identify_effect(\n",
507-
" graph=G, treatment_name=treatment_name, outcome_name=outcome_name, conditional_node_names=conditional_node_names\n",
519+
" graph=G, \n",
520+
" action_nodes=treatment_name, \n",
521+
" outcome_nodes=outcome_name, \n",
522+
" observed_nodes=observed_node_names,\n",
523+
" conditional_node_names=conditional_node_names\n",
508524
" )\n",
509525
")"
510526
]
@@ -528,22 +544,19 @@
528544
")\n",
529545
"print(\n",
530546
" ident_minimal_eff.identify_effect(\n",
531-
" graph=G, treatment_name=treatment_name, outcome_name=outcome_name, conditional_node_names=conditional_node_names\n",
547+
" graph=G, \n",
548+
" action_nodes=treatment_name,\n",
549+
" outcome_nodes=outcome_name, \n",
550+
" observed_nodes=observed_node_names,\n",
551+
" conditional_node_names=conditional_node_names\n",
532552
" )\n",
533553
")"
534554
]
535-
},
536-
{
537-
"cell_type": "code",
538-
"execution_count": null,
539-
"metadata": {},
540-
"outputs": [],
541-
"source": []
542555
}
543556
],
544557
"metadata": {
545558
"kernelspec": {
546-
"display_name": "Python 3.8.10 ('dowhy-_zBapv7Q-py3.8')",
559+
"display_name": "Python 3 (ipykernel)",
547560
"language": "python",
548561
"name": "python3"
549562
},

0 commit comments

Comments
 (0)