Skip to content

Commit 2a8e49a

Browse files
bloebpamit-sharma
andauthored
Proposal: Finalize functional API refactor - deprecate causal graph (#943)
* Deprecate CausalGraph The effect estimation API is now based on an functional API that expects a networkx graph as input. - The graph should now be defined via a networkx graph. Most identification methods now expect an additional "observed_nodes" parameter accordingly. - CausalModel and CausalGraph still exist and should be compatible with the old API. --------- Signed-off-by: Patrick Bloebaum <bloebp@amazon.com> Signed-off-by: Amit Sharma <amit_sharma@live.com> Co-authored-by: Amit Sharma <amit_sharma@live.com>
1 parent 4fd0a92 commit 2a8e49a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+1179
-604
lines changed

docs/source/example_notebooks/do_sampler_demo.ipynb

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
"\n",
4646
"## Integration\n",
4747
"\n",
48-
"The do-sampler is built on top of the identification abstraction used throughout do-why. It uses a `dowhy.CausalModel` to perform identification, and builds any models it needs automatically using this identification.\n",
48+
"The do-sampler is built on top of the identification abstraction used throughout do-why. It automatically performs an identification, and builds any models it needs automatically using this identification.\n",
4949
"\n",
5050
"## Specifying Interventions\n",
5151
"\n",
@@ -128,7 +128,8 @@
128128
"model = CausalModel(df, \n",
129129
" causes,\n",
130130
" outcomes,\n",
131-
" common_causes=common_causes)"
131+
" common_causes=common_causes)\n",
132+
"nx_graph = model._graph._graph"
132133
]
133134
},
134135
{
@@ -162,8 +163,11 @@
162163
"source": [
163164
"from dowhy.do_samplers.weighting_sampler import WeightingSampler\n",
164165
"\n",
165-
"sampler = WeightingSampler(df,\n",
166-
" causal_model=model,\n",
166+
"sampler = WeightingSampler(graph=nx_graph,\n",
167+
" action_nodes=causes,\n",
168+
" outcome_nodes=outcomes,\n",
169+
" observed_nodes=df.columns.tolist(),\n",
170+
" data=df,\n",
167171
" keep_original_treatment=True,\n",
168172
" variable_types={'D': 'b', 'Z': 'c', 'Y': 'c'}\n",
169173
" )\n",
@@ -207,7 +211,7 @@
207211
],
208212
"metadata": {
209213
"kernelspec": {
210-
"display_name": "Python 3",
214+
"display_name": "Python 3 (ipykernel)",
211215
"language": "python",
212216
"name": "python3"
213217
},
@@ -221,7 +225,7 @@
221225
"name": "python",
222226
"nbconvert_exporter": "python",
223227
"pygments_lexer": "ipython3",
224-
"version": "3.8.5"
228+
"version": "3.8.10"
225229
},
226230
"toc": {
227231
"base_numbering": 1,

docs/source/example_notebooks/dowhy_causal_api.ipynb

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
"source": [
1717
"import dowhy.datasets\n",
1818
"import dowhy.api\n",
19+
"from dowhy.graph import build_graph_from_str\n",
1920
"\n",
2021
"import numpy as np\n",
2122
"import pandas as pd\n",
@@ -36,7 +37,7 @@
3637
" treatment_is_binary=True)\n",
3738
"df = data['df']\n",
3839
"df['y'] = df['y'] + np.random.normal(size=len(df)) # Adding noise to data. Without noise, the variance in Y|X, Z is zero, and mcmc fails.\n",
39-
"#data['dot_graph'] = 'digraph { v ->y;X0-> v;X0-> y;}'\n",
40+
"nx_graph = build_graph_from_str(data[\"dot_graph\"])\n",
4041
"\n",
4142
"treatment= data[\"treatment_name\"][0]\n",
4243
"outcome = data[\"outcome_name\"][0]\n",
@@ -47,15 +48,17 @@
4748
{
4849
"cell_type": "code",
4950
"execution_count": null,
50-
"metadata": {},
51+
"metadata": {
52+
"scrolled": true
53+
},
5154
"outputs": [],
5255
"source": [
5356
"# data['df'] is just a regular pandas.DataFrame\n",
5457
"df.causal.do(x=treatment,\n",
55-
" variable_types={treatment: 'b', outcome: 'c', common_cause: 'c'},\n",
56-
" outcome=outcome,\n",
57-
" common_causes=[common_cause],\n",
58-
" proceed_when_unidentifiable=True).groupby(treatment).mean().plot(y=outcome, kind='bar')"
58+
" variable_types={treatment: 'b', outcome: 'c', common_cause: 'c'},\n",
59+
" outcome=outcome,\n",
60+
" common_causes=[common_cause],\n",
61+
" ).groupby(treatment).mean().plot(y=outcome, kind='bar')"
5962
]
6063
},
6164
{
@@ -68,8 +71,8 @@
6871
" variable_types={treatment:'b', outcome: 'c', common_cause: 'c'}, \n",
6972
" outcome=outcome,\n",
7073
" method='weighting', \n",
71-
" common_causes=[common_cause],\n",
72-
" proceed_when_unidentifiable=True).groupby(treatment).mean().plot(y=outcome, kind='bar')"
74+
" common_causes=[common_cause]\n",
75+
" ).groupby(treatment).mean().plot(y=outcome, kind='bar')"
7376
]
7477
},
7578
{
@@ -81,14 +84,14 @@
8184
"cdf_1 = df.causal.do(x={treatment: 1}, \n",
8285
" variable_types={treatment: 'b', outcome: 'c', common_cause: 'c'}, \n",
8386
" outcome=outcome, \n",
84-
" dot_graph=data['dot_graph'],\n",
85-
" proceed_when_unidentifiable=True)\n",
87+
" graph=nx_graph\n",
88+
" )\n",
8689
"\n",
8790
"cdf_0 = df.causal.do(x={treatment: 0}, \n",
8891
" variable_types={treatment: 'b', outcome: 'c', common_cause: 'c'}, \n",
8992
" outcome=outcome, \n",
90-
" dot_graph=data['dot_graph'],\n",
91-
" proceed_when_unidentifiable=True)\n"
93+
" graph=nx_graph\n",
94+
" )\n"
9295
]
9396
},
9497
{
@@ -158,7 +161,7 @@
158161
],
159162
"metadata": {
160163
"kernelspec": {
161-
"display_name": "Python 3",
164+
"display_name": "Python 3 (ipykernel)",
162165
"language": "python",
163166
"name": "python3"
164167
},
@@ -172,7 +175,7 @@
172175
"name": "python",
173176
"nbconvert_exporter": "python",
174177
"pygments_lexer": "ipython3",
175-
"version": "3.8.5"
178+
"version": "3.8.10"
176179
},
177180
"toc": {
178181
"base_numbering": 1,

docs/source/example_notebooks/dowhy_efficient_backdoor_example.ipynb

Lines changed: 47 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,9 @@
6262
"outputs": [],
6363
"source": [
6464
"from dowhy.causal_graph import CausalGraph\n",
65-
"from dowhy.causal_identifier import AutoIdentifier, BackdoorAdjustment, EstimandType"
65+
"from dowhy.causal_identifier import AutoIdentifier, BackdoorAdjustment, EstimandType\n",
66+
"from dowhy.graph import build_graph_from_str\n",
67+
"from dowhy.utils.plotting import plot"
6668
]
6769
},
6870
{
@@ -135,9 +137,7 @@
135137
"]\n",
136138
"treatment_name = \"warm-up\"\n",
137139
"outcome_name = \"injury\"\n",
138-
"G = CausalGraph(\n",
139-
" graph=graph_str, treatment_name=treatment_name, outcome_name=outcome_name, observed_node_names=observed_node_names\n",
140-
")"
140+
"G = build_graph_from_str(graph_str)"
141141
]
142142
},
143143
{
@@ -153,7 +153,7 @@
153153
"metadata": {},
154154
"outputs": [],
155155
"source": [
156-
"G.view_graph()"
156+
"plot(G)"
157157
]
158158
},
159159
{
@@ -184,7 +184,11 @@
184184
")\n",
185185
"print(\n",
186186
" ident_eff.identify_effect(\n",
187-
" graph=G, treatment_name=treatment_name, outcome_name=outcome_name, conditional_node_names=conditional_node_names\n",
187+
" graph=G, \n",
188+
" action_nodes=treatment_name, \n",
189+
" outcome_nodes=outcome_name,\n",
190+
" observed_nodes=observed_node_names,\n",
191+
" conditional_node_names=conditional_node_names\n",
188192
" )\n",
189193
")"
190194
]
@@ -215,7 +219,11 @@
215219
")\n",
216220
"print(\n",
217221
" ident_minimal_eff.identify_effect(\n",
218-
" graph=G, treatment_name=treatment_name, outcome_name=outcome_name, conditional_node_names=conditional_node_names\n",
222+
" graph=G, \n",
223+
" action_nodes=treatment_name, \n",
224+
" outcome_nodes=outcome_name, \n",
225+
" observed_nodes=observed_node_names,\n",
226+
" conditional_node_names=conditional_node_names\n",
219227
" )\n",
220228
")"
221229
]
@@ -239,7 +247,11 @@
239247
")\n",
240248
"print(\n",
241249
" ident_mincost_eff.identify_effect(\n",
242-
" graph=G, treatment_name=treatment_name, outcome_name=outcome_name, conditional_node_names=conditional_node_names\n",
250+
" graph=G, \n",
251+
" action_nodes=treatment_name, \n",
252+
" outcome_nodes=outcome_name,\n",
253+
" observed_nodes=observed_node_names,\n",
254+
" conditional_node_names=conditional_node_names\n",
243255
" )\n",
244256
")"
245257
]
@@ -294,9 +306,7 @@
294306
"observed_node_names = [\"X\", \"Y\", \"Z1\", \"Z2\"]\n",
295307
"treatment_name = \"X\"\n",
296308
"outcome_name = \"Y\"\n",
297-
"G = CausalGraph(\n",
298-
" graph=graph_str, treatment_name=treatment_name, outcome_name=outcome_name, observed_node_names=observed_node_names\n",
299-
")"
309+
"G = build_graph_from_str(graph_str)"
300310
]
301311
},
302312
{
@@ -317,7 +327,10 @@
317327
" backdoor_adjustment=BackdoorAdjustment.BACKDOOR_EFFICIENT,\n",
318328
")\n",
319329
"try:\n",
320-
" results_eff = ident_eff.identify_effect(graph=G, treatment_name=treatment_name, outcome_name=outcome_name)\n",
330+
" results_eff = ident_eff.identify_effect(graph=G, \n",
331+
" action_nodes=treatment_name, \n",
332+
" outcome_nodes=outcome_name,\n",
333+
" observed_nodes=observed_node_names)\n",
321334
"except ValueError as e:\n",
322335
" print(e)"
323336
]
@@ -335,8 +348,9 @@
335348
"print(\n",
336349
" ident_minimal_eff.identify_effect(\n",
337350
" graph=G,\n",
338-
" treatment_name=treatment_name,\n",
339-
" outcome_name=outcome_name,\n",
351+
" action_nodes=treatment_name,\n",
352+
" outcome_nodes=outcome_name,\n",
353+
" observed_nodes=observed_node_names\n",
340354
" )\n",
341355
")"
342356
]
@@ -354,8 +368,9 @@
354368
"print(\n",
355369
" ident_mincost_eff.identify_effect(\n",
356370
" graph=G,\n",
357-
" treatment_name=treatment_name,\n",
358-
" outcome_name=outcome_name,\n",
371+
" action_nodes=treatment_name,\n",
372+
" outcome_nodes=outcome_name,\n",
373+
" observed_nodes=observed_node_names\n",
359374
" )\n",
360375
")"
361376
]
@@ -391,9 +406,7 @@
391406
"observed_node_names = [\"X\", \"Y\"]\n",
392407
"treatment_name = \"X\"\n",
393408
"outcome_name = \"Y\"\n",
394-
"G = CausalGraph(\n",
395-
" graph=graph_str, treatment_name=treatment_name, outcome_name=outcome_name, observed_node_names=observed_node_names\n",
396-
")"
409+
"G = build_graph_from_str(graph_str)"
397410
]
398411
},
399412
{
@@ -409,8 +422,9 @@
409422
"try:\n",
410423
" results_eff = ident_eff.identify_effect(\n",
411424
" graph=G,\n",
412-
" treatment_name=treatment_name,\n",
413-
" outcome_name=outcome_name,\n",
425+
" action_nodes=treatment_name,\n",
426+
" outcome_nodes=outcome_name,\n",
427+
" observed_nodes=observed_node_names\n",
414428
" )\n",
415429
"except ValueError as e:\n",
416430
" print(e)"
@@ -475,9 +489,7 @@
475489
" (\"R\", {\"cost\": 2}),\n",
476490
" (\"T\", {\"cost\": 1}),\n",
477491
"]\n",
478-
"G = CausalGraph(\n",
479-
" graph=graph_str, treatment_name=treatment_name, outcome_name=outcome_name, observed_node_names=observed_node_names\n",
480-
")"
492+
"G = build_graph_from_str(graph_str)"
481493
]
482494
},
483495
{
@@ -504,7 +516,11 @@
504516
")\n",
505517
"print(\n",
506518
" ident_mincost_eff.identify_effect(\n",
507-
" graph=G, treatment_name=treatment_name, outcome_name=outcome_name, conditional_node_names=conditional_node_names\n",
519+
" graph=G, \n",
520+
" action_nodes=treatment_name, \n",
521+
" outcome_nodes=outcome_name, \n",
522+
" observed_nodes=observed_node_names,\n",
523+
" conditional_node_names=conditional_node_names\n",
508524
" )\n",
509525
")"
510526
]
@@ -528,22 +544,19 @@
528544
")\n",
529545
"print(\n",
530546
" ident_minimal_eff.identify_effect(\n",
531-
" graph=G, treatment_name=treatment_name, outcome_name=outcome_name, conditional_node_names=conditional_node_names\n",
547+
" graph=G, \n",
548+
" action_nodes=treatment_name,\n",
549+
" outcome_nodes=outcome_name, \n",
550+
" observed_nodes=observed_node_names,\n",
551+
" conditional_node_names=conditional_node_names\n",
532552
" )\n",
533553
")"
534554
]
535-
},
536-
{
537-
"cell_type": "code",
538-
"execution_count": null,
539-
"metadata": {},
540-
"outputs": [],
541-
"source": []
542555
}
543556
],
544557
"metadata": {
545558
"kernelspec": {
546-
"display_name": "Python 3.8.10 ('dowhy-_zBapv7Q-py3.8')",
559+
"display_name": "Python 3 (ipykernel)",
547560
"language": "python",
548561
"name": "python3"
549562
},

0 commit comments

Comments
 (0)