Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Library functions for temporal causal functionality #1218

Merged
merged 67 commits into from
Aug 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
8d6cdd5
Library functions for temporal causal functionality
srivhash Jun 28, 2024
87a90b4
shifting plotter function
srivhash Jul 2, 2024
ec73c1a
modified helper functions
srivhash Jul 2, 2024
6d22554
tutorial notebook
srivhash Jul 2, 2024
17f5822
graph creation into utils
srivhash Jul 4, 2024
4d880b5
parents to columns
srivhash Jul 4, 2024
1330658
helper
srivhash Jul 4, 2024
4daf89f
updated deprecated plotter
srivhash Jul 4, 2024
74f8985
removed plotter
srivhash Jul 5, 2024
8010145
printing graph: best practices
srivhash Jul 5, 2024
9a09e6d
updated imports
srivhash Jul 5, 2024
3fa8786
renamed module
srivhash Jul 5, 2024
e0a4ae3
added docstrings
srivhash Jul 9, 2024
979bc32
moved datasets
srivhash Jul 12, 2024
c495ddb
updated tutorial notebook
srivhash Jul 12, 2024
74ffbda
improved notebook quality
srivhash Jul 12, 2024
49bad76
sphinx documentation
srivhash Jul 15, 2024
fbfd7d2
removed pretty print
srivhash Jul 15, 2024
348391e
dot format
srivhash Jul 15, 2024
f6df32a
better node names
srivhash Jul 16, 2024
2bf8c3d
edge validation
srivhash Jul 22, 2024
eaba762
updated shifting columns with 0,1,..,max_lag
srivhash Jul 25, 2024
01c275f
support for dot format
srivhash Jul 25, 2024
599419a
tigramite support
srivhash Jul 25, 2024
848d58c
updated filter to be a hidden function
srivhash Jul 26, 2024
aa8221c
pdated notebook, corrections in data type
srivhash Jul 26, 2024
0690fe5
better readability of parents and time lags
srivhash Jul 26, 2024
41d3246
graph creation test
srivhash Jul 26, 2024
7a3557f
removed wrong comment
srivhash Jul 26, 2024
5e03ddc
moved test files
srivhash Jul 26, 2024
8ef07d0
moved tutorial notebook & updated the path
srivhash Jul 26, 2024
a2785e6
removed repeated imports
srivhash Jul 27, 2024
8422ac6
time_lag parameter explained
srivhash Jul 27, 2024
e2813c8
better function name
srivhash Jul 27, 2024
19d53e2
updated imports
srivhash Jul 27, 2024
38f0883
tests for temporal shift
srivhash Jul 27, 2024
4cd5ede
removed filter_columns function
srivhash Jul 27, 2024
f518c09
black and isort utils
srivhash Jul 29, 2024
8b621be
black and isort timeseries
srivhash Jul 29, 2024
2f097ef
updated notebook text
amit-sharma Jul 29, 2024
559e1fb
integer range fix
srivhash Jul 30, 2024
a52999d
correction in timestamp : notebook text
srivhash Jul 30, 2024
e3b8e22
time lagged causal estimation
srivhash Jul 30, 2024
0317026
removed cell outputs
srivhash Jul 30, 2024
bf731bb
find ancestors
srivhash Jul 30, 2024
e488713
include ancestors in notebook
srivhash Jul 30, 2024
0fc85cc
formatting changes
srivhash Jul 30, 2024
f87f74d
comments : notebook
srivhash Jul 30, 2024
cdbb9af
multiple time lags : csv graph'
srivhash Jul 31, 2024
139eeaa
multiple time lags
srivhash Jul 31, 2024
1357399
unrolled graph using bfs
srivhash Jul 31, 2024
9cefd87
cleanup of functions
srivhash Jul 31, 2024
ed5962e
removed find parents and ancestors
srivhash Aug 1, 2024
af92dce
updated graph creation functions
srivhash Aug 1, 2024
e4dbc3b
multiple time lags : dot file
srivhash Aug 1, 2024
74a88b7
import update
srivhash Aug 1, 2024
d0bfc64
tests for causal graph creation
srivhash Aug 1, 2024
afb9914
tests for adding lagged edges
srivhash Aug 1, 2024
a99558a
tests for shifting columns
srivhash Aug 1, 2024
64cb310
formatting
srivhash Aug 1, 2024
30b986b
formatting
srivhash Aug 1, 2024
8c80c02
inter node links
srivhash Aug 2, 2024
04adfb6
clear outputs & formatting
srivhash Aug 2, 2024
cd54fab
updated tests for lagged shift
srivhash Aug 2, 2024
5c75284
formatting for tests
srivhash Aug 2, 2024
c94357a
formatting
srivhash Aug 2, 2024
81f1b34
tigramite dependency added
srivhash Aug 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions docs/source/example_notebooks/datasets/temporal_dataset.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
V1,V2,V3,V4,V5,V6,V7
1,2,3,4,5,6,7
2,3,4,5,6,7,8
3,4,5,6,7,8,9
4,5,6,7,8,9,10
0,1,5,7,8,9,7
3,5,4,1,2,6,5
6,7,1,2,4,5,9
12,3,5,7,3,8,9
3,2,1,6,3,8,9
4,6,3,5,8,9,1
3,5,9,6,2,1,3
5,2,6,8,11,3,4
2,2,4,1,1,4,6
5,6,4,3,4,6,2
8 changes: 8 additions & 0 deletions docs/source/example_notebooks/datasets/temporal_graph.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
node1,node2,time_lag
V1,V2,3
V2,V3,4
V5,V6,1
V4,V7,4
V4,V5,2
V7,V6,3
V7,V6,5
8 changes: 8 additions & 0 deletions docs/source/example_notebooks/datasets/temporal_graph.dot
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
digraph G {
V1 -> V2 [label="(3)"];
V2 -> V3 [label="(4)"];
V5 -> V6 [label="(1)"];
V4 -> V7 [label="(4)"];
V4 -> V5 [label="(2)"];
V7 -> V6 [label="(3, 5)"];
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,306 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Effect inference with timeseries data\n",
"\n",
"In this notebook, we will look at an example of causal effect inference from timeseries data. We will use DoWhy's functionality to add temporal dependencies to a causal graph and estimate causal effect based on the augmented graph. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import networkx as nx\n",
"import pandas as pd\n",
"from dowhy.utils.timeseries import create_graph_from_csv,create_graph_from_user\n",
"from dowhy.utils.plotting import plot, pretty_print_graph"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Loading timeseries data and causal graph"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dataset_path=\"../datasets/temporal_dataset.csv\"\n",
"\n",
"dataframe=pd.read_csv(dataset_path)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In temporal causal inference, accurately estimating causal effects often requires accounting for time lags between nodes in a graph. For instance, if $node_1$ influences $node_2$ with a time lag of 5 timestamps, we represent this dependency as $node_1^{t-5}$ -> $node_2^{t}$.\n",
"\n",
"We can provide the causal graph as a networkx DAG or as a dot file. The edge attributes should mention the exact `time_lag` that is associated with each edge (if any)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from dowhy.utils.timeseries import create_graph_from_dot_format\n",
"\n",
"file_path = \"../datasets/temporal_graph.dot\"\n",
"\n",
"graph = create_graph_from_dot_format(file_path)\n",
"plot(graph)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can also create a csv file with the edges in the temporal graph. The columns in the csv are node1, node2, time_lag which represents an directed edge node1 -> node2 with the time lag of time_lag. Let us consider the following graph as the input:\n",
"\n",
"| node1 | node2 | time_lag |\n",
"|--------|--------|----------|\n",
"| V1 | V2 | 3 |\n",
"| V2 | V3 | 4 |\n",
"| V5 | V6 | 1 |\n",
"| V4 | V7 | 4 |\n",
"| V4 | V5 | 2 |\n",
"| V7 | V6 | 3 |\n",
"| V7 | V6 | 5 |"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Input a csv file with the edges in the graph with the columns: node_1,node_2,time_lag\n",
"file_path = \"../datasets/temporal_graph.csv\"\n",
"\n",
"# Create the graph from the CSV file\n",
"graph = create_graph_from_csv(file_path)\n",
"plot(graph)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Dataset Shifting and Filtering\n",
"\n",
"To prepare the dataset for temporal causal inference, we need to shift the columns by the given time lag.\n",
"\n",
"For example, in the causal graph above, $node_1^{t-5}$ -> $node_2^{t}$ with a lag of 5. When considering $node_2$ as the target node, the data for $node_1$ should be shifted down by 5 timestamps. This adjustment ensures that the edge $node_1$ -> $node_2$ accurately represents the lagged dependency. Shifting the data in this manner creates additional columns and allows downstream estimators to acccess the correct values in the same row of a dataframe. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from dowhy.timeseries.temporal_shift import shift_columns_by_lag_using_unrolled_graph, add_lagged_edges"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# the outcome node for which effect estimation has to be done, node:6\n",
"target_node = 'V6'\n",
"unrolled_graph = add_lagged_edges(graph, target_node)\n",
"plot(unrolled_graph)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"time_shifted_df = shift_columns_by_lag_using_unrolled_graph(dataframe, unrolled_graph)\n",
"time_shifted_df.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Causal Effect Estimation\n",
"\n",
"Once you have the new dataframe, causal effect estimation can be performed on the target node with respect to the action nodes."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"target_node = 'V6_0'\n",
"# include all the treatments\n",
"treatment_columns = list(time_shifted_df.columns)\n",
"treatment_columns.remove(target_node)\n",
"treatment_columns"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# perform causal effect estimation on this new dataset\n",
"import dowhy\n",
"from dowhy import CausalModel\n",
"\n",
"model = CausalModel(\n",
" data=time_shifted_df,\n",
" treatment='V5_-1',\n",
" outcome=target_node,\n",
" graph = unrolled_graph\n",
")\n",
"\n",
"identified_estimand = model.identify_effect()\n",
"\n",
"estimate = model.estimate_effect(identified_estimand,\n",
" method_name=\"backdoor.linear_regression\",\n",
" test_significance=True)\n",
"\n",
"\n",
"print(estimate)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Importing temporal causal graph from Tigramite library\n",
"\n",
"Tigramite is a popular temporal causal discovery library. In this section, we highlight how the causal graph can be obtained by applying PCMCI+ algorithm from tigramite and imported into DoWhy."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!pip install tigramite"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import tigramite\n",
"import tigramite.data_processing as pp\n",
"import matplotlib.pyplot as plt\n",
"import pandas as pd\n",
"\n",
"dataframe = dataframe.astype(float)\n",
"var_names = dataframe.columns\n",
"# convert the dataframe values to float\n",
"dataframe = pp.DataFrame(dataframe.values, var_names=var_names)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from tigramite import plotting as tp\n",
"tp.plot_timeseries(dataframe, figsize=(15, 5)); plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from tigramite.pcmci import PCMCI\n",
"from tigramite.independence_tests.parcorr import ParCorr\n",
"import numpy as np\n",
"parcorr = ParCorr(significance='analytic')\n",
"pcmci = PCMCI(\n",
" dataframe=dataframe, \n",
" cond_ind_test=parcorr,\n",
" verbosity=1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"correlations = pcmci.run_bivci(tau_max=3, val_only=True)['val_matrix']\n",
"matrix_lags = np.argmax(np.abs(correlations), axis=2)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"tau_max = 3\n",
"pc_alpha = None\n",
"pcmci.verbosity = 2\n",
"\n",
"results = pcmci.run_pcmciplus(tau_min=0, tau_max=tau_max, pc_alpha=pc_alpha)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from dowhy.utils.timeseries import create_graph_from_networkx_array\n",
"\n",
"graph = create_graph_from_networkx_array(results['graph'], var_names)\n",
"\n",
"plot(graph)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Loading
Loading