Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revise GCM example notebooks #1283

Merged
merged 1 commit into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified docs/source/_static/sales_attribution.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -1,16 +1,5 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "ef2e40dc",
"metadata": {},
"outputs": [],
"source": [
"# !pip install dowhy\n",
"# !pip install scipy"
]
},
{
"cell_type": "markdown",
"id": "777695bc",
Expand Down Expand Up @@ -653,7 +642,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
"version": "3.9.18"
}
},
"nbformat": 4,
Expand Down

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion docs/source/example_notebooks/gcm_falsify_dag.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
"from dowhy.gcm.util import plot\n",
"from dowhy.gcm.util.general import set_random_seed\n",
"from dowhy.gcm.ml import SklearnRegressionModel\n",
"from dowhy.gcm.util.general import set_random_seed\n",
"set_random_seed(0)\n",
"\n",
"# Set random seed\n",
"set_random_seed(1332)"
Expand Down Expand Up @@ -328,7 +330,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
"version": "3.9.18"
}
},
"nbformat": 4,
Expand Down
1 change: 1 addition & 0 deletions docs/source/example_notebooks/gcm_icc.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
"\n",
"from dowhy import gcm\n",
"from dowhy.utils.plotting import plot, bar_plot\n",
"gcm.util.general.set_random_seed(0)\n",
"\n",
"# Load a modified version of the Auto MPG data: Quinlan,R.. (1993). Auto MPG. UCI Machine Learning Repository. https://doi.org/10.24432/C5859H.\n",
"auto_mpg_data = pd.read_csv(\"datasets/auto_mpg.csv\", index_col=0)\n",
Expand Down
9 changes: 5 additions & 4 deletions docs/source/example_notebooks/gcm_online_shop.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -197,12 +197,13 @@
"outputs": [],
"source": [
"from dowhy import gcm\n",
"gcm.util.general.set_random_seed(0)\n",
"\n",
"# Create the structural causal model object\n",
"scm = gcm.StructuralCausalModel(causal_graph)\n",
"\n",
"# Automatically assign generative models to each node based on the given data\n",
"auto_assignment_summary = gcm.auto.assign_causal_mechanisms(scm, data_2021, override_models=True, quality=gcm.auto.AssignmentQuality.GOOD)"
"auto_assignment_summary = gcm.auto.assign_causal_mechanisms(scm, data_2021)"
]
},
{
Expand Down Expand Up @@ -274,7 +275,7 @@
"metadata": {},
"outputs": [],
"source": [
"print(gcm.evaluate_causal_model(scm, data_2021, compare_mechanism_baselines=True, evaluate_invertibility_assumptions=False))"
"print(gcm.evaluate_causal_model(scm, data_2021, compare_mechanism_baselines=True, evaluate_invertibility_assumptions=False, evaluate_causal_structure=False))"
]
},
{
Expand All @@ -290,7 +291,7 @@
"id": "a2e5acb0-f40e-4478-bee8-db3ff58bea21",
"metadata": {},
"source": [
"> The selection of baseline models or the p-value for graph falsification can be configured as well. For more details, take a look at the corresponding evaluate_causal_model documentation."
"> The selection of baseline models can be configured as well. For more details, take a look at the corresponding evaluate_causal_model documentation."
]
},
{
Expand Down Expand Up @@ -718,7 +719,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
"version": "3.9.18"
}
},
"nbformat": 4,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
"import networkx as nx\n",
"from dowhy import gcm\n",
"from dowhy.utils import plot, bar_plot\n",
"gcm.util.general.set_random_seed(0)\n",
"\n",
"causal_graph = nx.DiGraph([('www', 'Website'),\n",
" ('Auth Service', 'www'),\n",
Expand Down Expand Up @@ -309,7 +310,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"The attributions indicate that `Caching Service` is the main driver of high latency in `Website` which is expected as we perturb the causal mechanism of `Caching Service` to generate an outlier latency in `Website` (see Appendix below). Attributions to `Customer DB` and `Product Service` can be explained by misspecification of causal models. First, some of the parent-child relationships in the causal graph are non-linear (by looking at the scatter matrix). Second, the parent child-relationship between `Caching Service` and `Product DB` seems to indicate two mechanisms. This could be due to an unobserved binary variable (e.g., Cache hit/miss) that has a multiplicative effect on `Caching Service`. An additive noise cannot capture the multiplicative effect of this unobserved variable."
"The attributions indicate that `Caching Service` is the main driver of high latency in `Website` which is expected as we perturb the causal mechanism of `Caching Service` to generate an outlier latency in `Website` (see Appendix below). Interestingly, `Customer DB` has a negative contribution, indicating that it was particularly fast, decreasing the outlier in the `Website`. Note that some of the attributions are also due to model misspecifications. For instance, the parent child-relationship between `Caching Service` and `Product DB` seems to indicate two mechanisms. This could be due to an unobserved binary variable (e.g., Cache hit/miss) that has a multiplicative effect on `Caching Service`. An additive noise cannot capture the multiplicative effect of this unobserved variable."
]
},
{
Expand Down Expand Up @@ -596,7 +597,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
"version": "3.9.18"
}
},
"nbformat": 4,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@
"import networkx as nx\n",
"import dowhy.gcm as gcm\n",
"from dowhy.utils import plot\n",
"gcm.util.general.set_random_seed(0)\n",
"\n",
"causal_graph = nx.DiGraph([('demand', 'submitted'),\n",
" ('constraint', 'submitted'),\n",
Expand Down Expand Up @@ -397,7 +398,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
"version": "3.9.18"
}
},
"nbformat": 4,
Expand Down
112 changes: 56 additions & 56 deletions docs/source/example_notebooks/sales_attribution_intervention.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
"from scipy import stats\n",
"from statsmodels.stats.multitest import multipletests\n",
"\n",
"gcm.util.general.set_random_seed(0)\n",
"%matplotlib inline"
]
},
Expand All @@ -75,7 +76,7 @@
},
"outputs": [],
"source": [
"df = pd.read_csv('datasets/sale_attribution.csv', index_col=0)\n",
"df = pd.read_csv('datasets/sales_attribution.csv', index_col=0)\n",
"df.head()"
]
},
Expand Down Expand Up @@ -387,7 +388,7 @@
"tags": []
},
"source": [
"Then we remove insignificant edges and their associated nodes, resulting a refined causal graph. "
"Then we remove insignificant edges and their associated nodes, resulting a refined causal graph."
]
},
{
Expand All @@ -402,6 +403,9 @@
"for insignificant_parent in test_causal_minimality(causal_graph, 'sale', df):\n",
" causal_graph.remove_edge(insignificant_parent, 'sale')\n",
"\n",
"for insignificant_parent in test_causal_minimality(causal_graph, 'dpv', df):\n",
" causal_graph.remove_edge(insignificant_parent, 'dpv')\n",
"\n",
"cols_to_remove=[]\n",
"cols_to_remove.extend([node for node in causal_graph.nodes if causal_graph.in_degree(node) + causal_graph.out_degree(node) == 0])"
]
Expand All @@ -415,10 +419,18 @@
},
"outputs": [],
"source": [
"causal_graph.remove_nodes_from(cols_to_remove)\n",
"causal_graph.remove_nodes_from(set(cols_to_remove))\n",
"plot(causal_graph)"
]
},
{
"cell_type": "markdown",
"id": "b165f1c3-d63b-4e3e-93f7-786da73ad60f",
"metadata": {},
"source": [
"Interestingly, the 'other_shopping_event' variable has no significant impact on either 'dpv' or 'sale'."
]
},
{
"cell_type": "markdown",
"id": "630fed0a",
Expand All @@ -432,7 +444,7 @@
"id": "4d7468ec",
"metadata": {},
"source": [
"Next, we need to assign functional causal models (FCMs) to each node, which describe the data generation process from x to y with an error term. The auto assignment method compares different prediction models for each node and takes the one with the smallest error. The `quality` parameter controls the set of model types that are tested, where `BETTER` indicates some of the most common regression and classification models, such as trees, support vector regression etc. You can also use 'Good' which fits fewer models to speed up, or 'Best' that is computationally heavy. After assigning the models, we can fit them to the data:"
"Next, we need to assign functional causal models (FCMs) to each node, which describe the data generation process from x to y with an error term. The auto assignment method compares different prediction models for each node and takes the one with the smallest error. The `quality` parameter controls the set of model types that are tested, where `BETTER` indicates some of the most common regression and classification models, such as trees, support vector regression etc. You can also use `GOOD` which fits fewer models to speed up, or `BEST` that is computationally heavy (and requres AutoGluon to be installed). After assigning the models, we can fit them to the data:"
]
},
{
Expand Down Expand Up @@ -496,18 +508,16 @@
},
"outputs": [],
"source": [
"def calculate_difference_estimation(causal_model, df_old, df_new, target_column, difference_estimation_func, num_samples=2000, confidence_level=0.90, num_bootstrap_resamples=10):\n",
"def calculate_difference_estimation(causal_model, df_old, df_new, target_column, difference_estimation_func, num_samples=2000, confidence_level=0.90, num_bootstrap_resamples=4):\n",
"\n",
" difference_contribs, uncertainty_contribs = gcm.confidence_intervals(\n",
" lambda : gcm.distribution_change(causal_model, \n",
" df_old, \n",
" df_new, \n",
" target_column, \n",
" num_samples=num_samples,\n",
" independence_test=lambda x, y: gcm.kernel_based(x.astype(float), y.astype(float)),\n",
" conditional_independence_test=lambda x, y, z: gcm.kernel_based(x.astype(float), y.astype(float), z.astype(float)),\n",
" difference_estimation_func=difference_estimation_func,\n",
" shapley_config=gcm.shapley.ShapleyConfig(approximation_method=gcm.shapley.ShapleyApproximationMethods.PERMUTATION, num_permutations=20)),\n",
" shapley_config=gcm.shapley.ShapleyConfig(approximation_method=gcm.shapley.ShapleyApproximationMethods.PERMUTATION, num_permutations=50)),\n",
" confidence_level=confidence_level,\n",
" num_bootstrap_resamples=num_bootstrap_resamples\n",
" )\n",
Expand Down Expand Up @@ -537,6 +547,16 @@
"gcm.util.bar_plot(median_diff_contribs, median_diff_uncertainty, 'Contribution', figure_size=(10,5))"
]
},
{
"cell_type": "markdown",
"id": "6d9d2ba9-e78d-406d-beed-f2895aa05d65",
"metadata": {},
"source": [
"Here, we see that 'sp_spend' has the larges contribution to the change in the mean of 'sale', while the 'discount' and 'special_shopping_event' has little to none contribution. This aligns with the way how the data was generated.\n",
"\n",
"Taking a look at the tabular overview to look for significance based on the confidence intervals:"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -564,8 +584,6 @@
"id": "d2ab2655-0d80-4112-8e65-1a27a3622b37",
"metadata": {},
"source": [
"Here, we see that 'sp_spend' has the larges contrigution to the change in the mean of 'sale', while the 'discount', 'other_shopping_event' and 'special_shopping_event' has little to none contribution. This aligns with the way how the data was generated.\n",
"\n",
"Next, we remove all variables that have a 0 as part of their confidence interval or are negative, i.e., do not have a clear significant positive contribution:"
]
},
Expand Down Expand Up @@ -626,7 +644,7 @@
"id": "5908d58e-a928-4a85-9a72-2ec2a53fa3d0",
"metadata": {},
"source": [
"Section 4 above helps us to understand drivers for past growth. Now, looking forward to business planning, we conduct interventions to understand incremental contributions to KPIs. Intuitively, the spend types resulting in higher returns should be doubled down on. "
"Section 4 above helps us to understand drivers for past growth. Now, looking forward to business planning, we conduct interventions to understand incremental contributions to KPIs. Intuitively, the spend types resulting in higher returns should be doubled down on. Here, we explicitly remove 'sale', 'dpv' and 'special_shopping_event' as a possible intervention targets."
]
},
{
Expand All @@ -638,7 +656,7 @@
},
"outputs": [],
"source": [
"def intervention_influence(causal_model, target, data=None, step_size=1, non_interveneable_nodes=None, confidence_level=0.90, cap_value=None, prints=False):\n",
"def intervention_influence(causal_model, target, step_size=1, non_interveneable_nodes=None, confidence_level=0.95, prints=False, threshold_insignificant=0.0001):\n",
" progress_bar_was_on = gcm.config.show_progress_bars\n",
"\n",
" if progress_bar_was_on:\n",
Expand All @@ -665,58 +683,33 @@
" interventions_alternative = {node: intervention}\n",
" interventions_reference = {node: non_intervention}\n",
"\n",
" # Compute causal effects\n",
" if data is not None:\n",
" effect = gcm.confidence_intervals(\n",
" gcm.fit_and_compute(gcm.average_causal_effect,\n",
" causal_model,\n",
" auto_assign_quality=gcm.auto.AssignmentQuality.GOOD,\n",
" bootstrap_training_data=data,\n",
" target_node=target,\n",
" interventions_alternative=interventions_alternative,\n",
" interventions_reference=interventions_reference,\n",
" observed_data=data),\n",
" n_jobs=1,\n",
" num_bootstrap_resamples=10,\n",
" confidence_level=confidence_level)\n",
" else:\n",
" effect = gcm.confidence_intervals(\n",
" partial(gcm.average_causal_effect,\n",
" causal_model=causal_model,\n",
" target_node=target,\n",
" interventions_alternative=interventions_alternative,\n",
" interventions_reference=interventions_reference,\n",
" observed_data=data,\n",
" num_samples_to_draw=10000),\n",
" n_jobs=-1,\n",
" num_bootstrap_resamples=20,\n",
" confidence_level=confidence_level)\n",
"\n",
" effect = gcm.confidence_intervals(\n",
" partial(gcm.average_causal_effect,\n",
" causal_model=causal_model,\n",
" target_node=target,\n",
" interventions_alternative=interventions_alternative,\n",
" interventions_reference=interventions_reference,\n",
" num_samples_to_draw=10000),\n",
" n_jobs=-1,\n",
" num_bootstrap_resamples=40,\n",
" confidence_level=confidence_level)\n",
"\n",
" causal_effects[node] = effect[0][0]\n",
" causal_effects_confidence_interval[node] = effect[1].squeeze()\n",
" \n",
" # Apply capping constraint\n",
" if cap_value is not None and node.endswith('_spend'):\n",
" if causal_effects[node] > cap_value:\n",
" causal_effects[node] = cap_value\n",
" capped_effects.append(node)\n",
" elif causal_effects[node] < -cap_value:\n",
" causal_effects[node] = -cap_value\n",
" capped_effects.append(node)\n",
"\n",
" # Apply non-negativity constraint\n",
"\n",
" # Apply non-negativity constraint - Here, spend cannot be negative. However, small negative values can happen in the analysis due to misspecifications.\n",
" if node.endswith('_spend') and causal_effects[node] < 0:\n",
" causal_effects[node] = 0\n",
" causal_effects_confidence_interval[node] = [np.nan, np.nan]\n",
"\n",
" if progress_bar_was_on:\n",
" gcm.config.enable_progress_bars()\n",
"\n",
" print(causal_effects)\n",
" if prints:\n",
" for node in sorted(causal_effects, key=causal_effects.get, reverse=True):\n",
" if causal_effects[node] == 0:\n",
" print(f\"{'Increasing' if step_size > 0 else 'Decreasing'} {node} by {step_size} has no effect on {target}.\")\n",
" if abs(causal_effects[node]) < threshold_insignificant:\n",
" print(f\"{'Increasing' if step_size > 0 else 'Decreasing'} {node} by {step_size} has no significant effect on {target}.\")\n",
" else:\n",
" print(f\"{'Increasing' if step_size > 0 else 'Decreasing'} {node} by {step_size} {'increases' if causal_effects[node] > 0 else 'decreases'} {target} \"\n",
" f\"by around {causal_effects[node]} with a confidence interval ({confidence_level * 100}%) of {causal_effects_confidence_interval[node]}.\")\n",
Expand All @@ -743,7 +736,7 @@
},
"outputs": [],
"source": [
"interv_result = intervention_influence(causal_model=causal_model, target='sale', non_interveneable_nodes=['dpv', 'sale'], confidence_level=0.85, cap_value=20, prints=True)\n",
"interv_result = intervention_influence(causal_model=causal_model, target='sale', non_interveneable_nodes=['dpv', 'sale', 'special_shopping_event'], prints=True)\n",
"interv_result"
]
},
Expand All @@ -752,7 +745,7 @@
"id": "8bce37de-ca68-42ce-aac1-6a86d7834320",
"metadata": {},
"source": [
"We similarly filter to positively significant interventions, i.e., the spending with statistically significant positive returns. Note that with capping, 'causal effect' may not fall between lower and upper CI. The interpretation is, for each dollar spent on one type of ad, we receive X amount in return, with X indicated in the 'Causal Effect' column. "
"We similarly filter to positively significant interventions, i.e., the spending with statistically significant positive returns. The interpretation is, for each dollar spent on one type of ad, we receive X amount in return, with X indicated in the 'Causal Effect' column. "
]
},
{
Expand All @@ -764,8 +757,15 @@
},
"outputs": [],
"source": [
"interv_result_v2= filter_significant_rows(interv_result, 'positive', 'Upper CI', 'Lower CI')\n",
"interv_result_v2"
"filter_significant_rows(interv_result, 'positive', 'Upper CI', 'Lower CI')"
]
},
{
"cell_type": "markdown",
"id": "64f11dbf-cf72-42fa-a945-4450042c8588",
"metadata": {},
"source": [
"This tells us that there is a clear benefit in doubling down on 'sp_spend' and 'dsp_spend'. Note that the quantitative numbers here can be off due to model misspecifications, but they nevertheless provide some helpful insights."
]
},
{
Expand Down