|
62 | 62 | "outputs": [],
|
63 | 63 | "source": [
|
64 | 64 | "from dowhy.causal_graph import CausalGraph\n",
|
65 |
| - "from dowhy.causal_identifier import AutoIdentifier, BackdoorAdjustment, EstimandType" |
| 65 | + "from dowhy.causal_identifier import AutoIdentifier, BackdoorAdjustment, EstimandType\n", |
| 66 | + "from dowhy.graph import build_graph_from_str\n", |
| 67 | + "from dowhy.utils.plotting import plot" |
66 | 68 | ]
|
67 | 69 | },
|
68 | 70 | {
|
|
135 | 137 | "]\n",
|
136 | 138 | "treatment_name = \"warm-up\"\n",
|
137 | 139 | "outcome_name = \"injury\"\n",
|
138 |
| - "G = CausalGraph(\n", |
139 |
| - " graph=graph_str, treatment_name=treatment_name, outcome_name=outcome_name, observed_node_names=observed_node_names\n", |
140 |
| - ")" |
| 140 | + "G = build_graph_from_str(graph_str)" |
141 | 141 | ]
|
142 | 142 | },
|
143 | 143 | {
|
|
153 | 153 | "metadata": {},
|
154 | 154 | "outputs": [],
|
155 | 155 | "source": [
|
156 |
| - "G.view_graph()" |
| 156 | + "plot(G)" |
157 | 157 | ]
|
158 | 158 | },
|
159 | 159 | {
|
|
184 | 184 | ")\n",
|
185 | 185 | "print(\n",
|
186 | 186 | " ident_eff.identify_effect(\n",
|
187 |
| - " graph=G, treatment_name=treatment_name, outcome_name=outcome_name, conditional_node_names=conditional_node_names\n", |
| 187 | + " graph=G, \n", |
| 188 | + " action_nodes=treatment_name, \n", |
| 189 | + " outcome_nodes=outcome_name,\n", |
| 190 | + " observed_nodes=observed_node_names,\n", |
| 191 | + " conditional_node_names=conditional_node_names\n", |
188 | 192 | " )\n",
|
189 | 193 | ")"
|
190 | 194 | ]
|
|
215 | 219 | ")\n",
|
216 | 220 | "print(\n",
|
217 | 221 | " ident_minimal_eff.identify_effect(\n",
|
218 |
| - " graph=G, treatment_name=treatment_name, outcome_name=outcome_name, conditional_node_names=conditional_node_names\n", |
| 222 | + " graph=G, \n", |
| 223 | + " action_nodes=treatment_name, \n", |
| 224 | + " outcome_nodes=outcome_name, \n", |
| 225 | + " observed_nodes=observed_node_names,\n", |
| 226 | + " conditional_node_names=conditional_node_names\n", |
219 | 227 | " )\n",
|
220 | 228 | ")"
|
221 | 229 | ]
|
|
239 | 247 | ")\n",
|
240 | 248 | "print(\n",
|
241 | 249 | " ident_mincost_eff.identify_effect(\n",
|
242 |
| - " graph=G, treatment_name=treatment_name, outcome_name=outcome_name, conditional_node_names=conditional_node_names\n", |
| 250 | + " graph=G, \n", |
| 251 | + " action_nodes=treatment_name, \n", |
| 252 | + " outcome_nodes=outcome_name,\n", |
| 253 | + " observed_nodes=observed_node_names,\n", |
| 254 | + " conditional_node_names=conditional_node_names\n", |
243 | 255 | " )\n",
|
244 | 256 | ")"
|
245 | 257 | ]
|
|
294 | 306 | "observed_node_names = [\"X\", \"Y\", \"Z1\", \"Z2\"]\n",
|
295 | 307 | "treatment_name = \"X\"\n",
|
296 | 308 | "outcome_name = \"Y\"\n",
|
297 |
| - "G = CausalGraph(\n", |
298 |
| - " graph=graph_str, treatment_name=treatment_name, outcome_name=outcome_name, observed_node_names=observed_node_names\n", |
299 |
| - ")" |
| 309 | + "G = build_graph_from_str(graph_str)" |
300 | 310 | ]
|
301 | 311 | },
|
302 | 312 | {
|
|
317 | 327 | " backdoor_adjustment=BackdoorAdjustment.BACKDOOR_EFFICIENT,\n",
|
318 | 328 | ")\n",
|
319 | 329 | "try:\n",
|
320 |
| - " results_eff = ident_eff.identify_effect(graph=G, treatment_name=treatment_name, outcome_name=outcome_name)\n", |
| 330 | + " results_eff = ident_eff.identify_effect(graph=G, \n", |
| 331 | + " action_nodes=treatment_name, \n", |
| 332 | + " outcome_nodes=outcome_name,\n", |
| 333 | + " observed_nodes=observed_node_names)\n", |
321 | 334 | "except ValueError as e:\n",
|
322 | 335 | " print(e)"
|
323 | 336 | ]
|
|
335 | 348 | "print(\n",
|
336 | 349 | " ident_minimal_eff.identify_effect(\n",
|
337 | 350 | " graph=G,\n",
|
338 |
| - " treatment_name=treatment_name,\n", |
339 |
| - " outcome_name=outcome_name,\n", |
| 351 | + " action_nodes=treatment_name,\n", |
| 352 | + " outcome_nodes=outcome_name,\n", |
| 353 | + " observed_nodes=observed_node_names\n", |
340 | 354 | " )\n",
|
341 | 355 | ")"
|
342 | 356 | ]
|
|
354 | 368 | "print(\n",
|
355 | 369 | " ident_mincost_eff.identify_effect(\n",
|
356 | 370 | " graph=G,\n",
|
357 |
| - " treatment_name=treatment_name,\n", |
358 |
| - " outcome_name=outcome_name,\n", |
| 371 | + " action_nodes=treatment_name,\n", |
| 372 | + " outcome_nodes=outcome_name,\n", |
| 373 | + " observed_nodes=observed_node_names\n", |
359 | 374 | " )\n",
|
360 | 375 | ")"
|
361 | 376 | ]
|
|
391 | 406 | "observed_node_names = [\"X\", \"Y\"]\n",
|
392 | 407 | "treatment_name = \"X\"\n",
|
393 | 408 | "outcome_name = \"Y\"\n",
|
394 |
| - "G = CausalGraph(\n", |
395 |
| - " graph=graph_str, treatment_name=treatment_name, outcome_name=outcome_name, observed_node_names=observed_node_names\n", |
396 |
| - ")" |
| 409 | + "G = build_graph_from_str(graph_str)" |
397 | 410 | ]
|
398 | 411 | },
|
399 | 412 | {
|
|
409 | 422 | "try:\n",
|
410 | 423 | " results_eff = ident_eff.identify_effect(\n",
|
411 | 424 | " graph=G,\n",
|
412 |
| - " treatment_name=treatment_name,\n", |
413 |
| - " outcome_name=outcome_name,\n", |
| 425 | + " action_nodes=treatment_name,\n", |
| 426 | + " outcome_nodes=outcome_name,\n", |
| 427 | + " observed_nodes=observed_node_names\n", |
414 | 428 | " )\n",
|
415 | 429 | "except ValueError as e:\n",
|
416 | 430 | " print(e)"
|
|
475 | 489 | " (\"R\", {\"cost\": 2}),\n",
|
476 | 490 | " (\"T\", {\"cost\": 1}),\n",
|
477 | 491 | "]\n",
|
478 |
| - "G = CausalGraph(\n", |
479 |
| - " graph=graph_str, treatment_name=treatment_name, outcome_name=outcome_name, observed_node_names=observed_node_names\n", |
480 |
| - ")" |
| 492 | + "G = build_graph_from_str(graph_str)" |
481 | 493 | ]
|
482 | 494 | },
|
483 | 495 | {
|
|
504 | 516 | ")\n",
|
505 | 517 | "print(\n",
|
506 | 518 | " ident_mincost_eff.identify_effect(\n",
|
507 |
| - " graph=G, treatment_name=treatment_name, outcome_name=outcome_name, conditional_node_names=conditional_node_names\n", |
| 519 | + " graph=G, \n", |
| 520 | + " action_nodes=treatment_name, \n", |
| 521 | + " outcome_nodes=outcome_name, \n", |
| 522 | + " observed_nodes=observed_node_names,\n", |
| 523 | + " conditional_node_names=conditional_node_names\n", |
508 | 524 | " )\n",
|
509 | 525 | ")"
|
510 | 526 | ]
|
|
528 | 544 | ")\n",
|
529 | 545 | "print(\n",
|
530 | 546 | " ident_minimal_eff.identify_effect(\n",
|
531 |
| - " graph=G, treatment_name=treatment_name, outcome_name=outcome_name, conditional_node_names=conditional_node_names\n", |
| 547 | + " graph=G, \n", |
| 548 | + " action_nodes=treatment_name,\n", |
| 549 | + " outcome_nodes=outcome_name, \n", |
| 550 | + " observed_nodes=observed_node_names,\n", |
| 551 | + " conditional_node_names=conditional_node_names\n", |
532 | 552 | " )\n",
|
533 | 553 | ")"
|
534 | 554 | ]
|
535 |
| - }, |
536 |
| - { |
537 |
| - "cell_type": "code", |
538 |
| - "execution_count": null, |
539 |
| - "metadata": {}, |
540 |
| - "outputs": [], |
541 |
| - "source": [] |
542 | 555 | }
|
543 | 556 | ],
|
544 | 557 | "metadata": {
|
545 | 558 | "kernelspec": {
|
546 |
| - "display_name": "Python 3.8.10 ('dowhy-_zBapv7Q-py3.8')", |
| 559 | + "display_name": "Python 3 (ipykernel)", |
547 | 560 | "language": "python",
|
548 | 561 | "name": "python3"
|
549 | 562 | },
|
|
0 commit comments