Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Library functions for temporal causal functionality #1218

Merged
merged 67 commits into from
Aug 3, 2024
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
8d6cdd5
Library functions for temporal causal functionality
srivhash Jun 28, 2024
87a90b4
shifting plotter function
srivhash Jul 2, 2024
ec73c1a
modified helper functions
srivhash Jul 2, 2024
6d22554
tutorial notebook
srivhash Jul 2, 2024
17f5822
graph creation into utils
srivhash Jul 4, 2024
4d880b5
parents to columns
srivhash Jul 4, 2024
1330658
helper
srivhash Jul 4, 2024
4daf89f
updated deprecated plotter
srivhash Jul 4, 2024
74f8985
removed plotter
srivhash Jul 5, 2024
8010145
printing graph: best practices
srivhash Jul 5, 2024
9a09e6d
updated imports
srivhash Jul 5, 2024
3fa8786
renamed module
srivhash Jul 5, 2024
e0a4ae3
added docstrings
srivhash Jul 9, 2024
979bc32
moved datasets
srivhash Jul 12, 2024
c495ddb
updated tutorial notebook
srivhash Jul 12, 2024
74ffbda
improved notebook quality
srivhash Jul 12, 2024
49bad76
sphinx documentation
srivhash Jul 15, 2024
fbfd7d2
removed pretty print
srivhash Jul 15, 2024
348391e
dot format
srivhash Jul 15, 2024
f6df32a
better node names
srivhash Jul 16, 2024
2bf8c3d
edge validation
srivhash Jul 22, 2024
eaba762
updated shifting columns with 0,1,..,max_lag
srivhash Jul 25, 2024
01c275f
support for dot format
srivhash Jul 25, 2024
599419a
tigramite support
srivhash Jul 25, 2024
848d58c
updated filter to be a hidden function
srivhash Jul 26, 2024
aa8221c
pdated notebook, corrections in data type
srivhash Jul 26, 2024
0690fe5
better readability of parents and time lags
srivhash Jul 26, 2024
41d3246
graph creation test
srivhash Jul 26, 2024
7a3557f
removed wrong comment
srivhash Jul 26, 2024
5e03ddc
moved test files
srivhash Jul 26, 2024
8ef07d0
moved tutorial notebook & updated the path
srivhash Jul 26, 2024
a2785e6
removed repeated imports
srivhash Jul 27, 2024
8422ac6
time_lag parameter explained
srivhash Jul 27, 2024
e2813c8
better function name
srivhash Jul 27, 2024
19d53e2
updated imports
srivhash Jul 27, 2024
38f0883
tests for temporal shift
srivhash Jul 27, 2024
4cd5ede
removed filter_columns function
srivhash Jul 27, 2024
f518c09
black and isort utils
srivhash Jul 29, 2024
8b621be
black and isort timeseries
srivhash Jul 29, 2024
2f097ef
updated notebook text
amit-sharma Jul 29, 2024
559e1fb
integer range fix
srivhash Jul 30, 2024
a52999d
correction in timestamp : notebook text
srivhash Jul 30, 2024
e3b8e22
time lagged causal estimation
srivhash Jul 30, 2024
0317026
removed cell outputs
srivhash Jul 30, 2024
bf731bb
find ancestors
srivhash Jul 30, 2024
e488713
include ancestors in notebook
srivhash Jul 30, 2024
0fc85cc
formatting changes
srivhash Jul 30, 2024
f87f74d
comments : notebook
srivhash Jul 30, 2024
cdbb9af
multiple time lags : csv graph'
srivhash Jul 31, 2024
139eeaa
multiple time lags
srivhash Jul 31, 2024
1357399
unrolled graph using bfs
srivhash Jul 31, 2024
9cefd87
cleanup of functions
srivhash Jul 31, 2024
ed5962e
removed find parents and ancestors
srivhash Aug 1, 2024
af92dce
updated graph creation functions
srivhash Aug 1, 2024
e4dbc3b
multiple time lags : dot file
srivhash Aug 1, 2024
74a88b7
import update
srivhash Aug 1, 2024
d0bfc64
tests for causal graph creation
srivhash Aug 1, 2024
afb9914
tests for adding lagged edges
srivhash Aug 1, 2024
a99558a
tests for shifting columns
srivhash Aug 1, 2024
64cb310
formatting
srivhash Aug 1, 2024
30b986b
formatting
srivhash Aug 1, 2024
8c80c02
inter node links
srivhash Aug 2, 2024
04adfb6
clear outputs & formatting
srivhash Aug 2, 2024
cd54fab
updated tests for lagged shift
srivhash Aug 2, 2024
5c75284
formatting for tests
srivhash Aug 2, 2024
c94357a
formatting
srivhash Aug 2, 2024
81f1b34
tigramite dependency added
srivhash Aug 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions docs/source/example_notebooks/datasets/temporal_dataset.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
V1,V2,V3,V4,V5,V6,V7
1,2,3,4,5,6,7
2,3,4,5,6,7,8
3,4,5,6,7,8,9
4,5,6,7,8,9,10
0,1,5,7,8,9,7
3,5,4,1,2,6,5
6,7,1,2,4,5,9
12,3,5,7,3,8,9
3,2,1,6,3,8,9
4,6,3,5,8,9,1
6 changes: 6 additions & 0 deletions docs/source/example_notebooks/datasets/temporal_graph.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
node1,node2,time_lag
V1,V2,3
V2,V3,4
V5,V6,1
V4,V7,4
V7,V6,3
7 changes: 7 additions & 0 deletions docs/source/example_notebooks/datasets/temporal_graph.dot
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
digraph G {
V1 -> V2 [label=3];
V2 -> V3 [label=4];
V5 -> V6 [label=1];
V4 -> V7 [label=4];
V7 -> V6 [label=3];
}
189 changes: 189 additions & 0 deletions dowhy/timeseries/causal_effect_tutorial.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
{
srivhash marked this conversation as resolved.
Show resolved Hide resolved
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Temporal Causal Inference\n",
"\n",
"In this notebook, we will look at an example of causal effect inference for Temporal Dependencies within a Graphical Causal Model. We demonstrate the use of temporal shift to normalise temporal datasets and establish causal effects."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Estimating the lag between the parent and action node is quite challenging in temporal causal inference. In this tutorial, we will assume we have the ground truth causal graph as an input."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import networkx as nx\n",
"import pandas as pd\n",
"from dowhy.utils.timeseries import create_graph_from_csv,create_graph_from_user\n",
"from dowhy.utils.plotting import plot, pretty_print_graph"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The user can create a csv file with the edges in the temporal graph. The columns in the csv are node_1, node_2, time_lag which represents an directed edge node_1 -> node_2 with the time lag of time_lag. Let us consider the following graph as the input\n",
"\n",
"| node1 | node2 | time_lag |\n",
"|--------|--------|----------|\n",
"| V1 | V2 | 3 |\n",
"| V2 | V3 | 4 |\n",
"| V5 | V6 | 1 |\n",
"| V4 | V7 | 4 |\n",
"| V7 | V6 | 3 |"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Input a csv file with the edges in the graph with the columns: node_1,node_2,time_lag\n",
"file_path = \"../../docs/source/example_notebooks/datasets/temporal_graph.csv\"\n",
srivhash marked this conversation as resolved.
Show resolved Hide resolved
"\n",
"# Create the graph from the CSV file\n",
"graph = create_graph_from_csv(file_path)\n",
"plot(graph)"
srivhash marked this conversation as resolved.
Show resolved Hide resolved
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Dataset Shifting and Filtering\n",
"\n",
"In temporal causal inference, accurately estimating causal effects often requires accounting for time lags between nodes in a graph. For instance, if node_1 influences node_2 with a time lag of 5 timestamps, we represent this dependency as node_1(t-5) -> node_2(t). To maintain this causal relationship in our analysis, we need to shift the columns by the given time lag.\n",
"\n",
"When considering node_2 as the target node, the data for node_1 should be shifted down by 5 timestamps. This adjustment ensures that the new edge node_1(t) -> node_2(t) accurately represents the same lagged dependency as before. Shifting the data in this manner aligns the time series data correctly, allowing us to properly analyze and estimate the causal effects."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from dowhy.timeseries.temporal_shift import find_lagged_parent_nodes,shift_columns,_filter_columns"
srivhash marked this conversation as resolved.
Show resolved Hide resolved
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# read the dataframe in a csv format from the user\n",
"dataset_path=\"../../docs/source/example_notebooks/datasets/temporal_dataset.csv\"\n",
"dataframe=pd.read_csv(dataset_path)\n",
"\n",
"# the node for which effect estimation has to be done, node:6\n",
"target_node = 'V6'\n",
"\n",
"# find the action nodes of the given target node with respective lag times\n",
"parents = find_lagged_parent_nodes(graph, target_node)\n",
srivhash marked this conversation as resolved.
Show resolved Hide resolved
"parents"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"time_shifted_df = shift_columns(dataframe,parents[0],parents[1])\n",
"time_shifted_df.head()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"filtered_ts_df = _filter_columns(time_shifted_df,target_node,parents[0])\n",
"filtered_ts_df.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Cause Estimation using Dowhy\n",
"\n",
"Once you have the new dataframe, causal effect estimation can be performed on the target node with respect to the action nodes."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# perform causal effect estimation on this new dataset\n",
"import dowhy\n",
"from dowhy import CausalModel\n",
"\n",
"model = CausalModel(\n",
" data=filtered_ts_df,\n",
" treatment='V5',\n",
" outcome='V6',\n",
" proceed_when_unidentifiable=True # Proceed even if the causal graph is not fully identifiable\n",
")\n",
"\n",
"identified_estimand = model.identify_effect()\n",
"\n",
"estimate = model.estimate_effect(identified_estimand,\n",
" method_name=\"backdoor.linear_regression\",\n",
" test_significance=True)\n",
"\n",
"print(estimate)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from dowhy.utils.timeseries import create_graph_from_dot_format\n",
"\n",
"file_path = \"../../docs/source/example_notebooks/datasets/temporal_graph.dot\"\n",
"\n",
"graph = create_graph_from_dot_format(file_path)\n",
"plot(graph)\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
72 changes: 72 additions & 0 deletions dowhy/timeseries/temporal_shift.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import networkx as nx
import pandas as pd
from typing import List, Tuple

def find_lagged_parent_nodes(graph:nx.DiGraph, node:str) -> Tuple[List[str], List[int]]:
srivhash marked this conversation as resolved.
Show resolved Hide resolved
"""
Given a graph and a node, this function returns the parent nodes of the node and the time lags associated with the edges between the parent nodes and the node.

:param graph: The graph object.
:type graph: networkx.Graph
:param node: The node for which we want to find the parent nodes.
:type node: string
:return: A tuple containing a list of parent nodes of the node and a list of time lags associated with the edges between the parent nodes and the node.
:rtype: tuple (list, list)
"""
parent_nodes = []
time_lags = []
for n in graph.predecessors(node):
edge_data = graph.get_edge_data(n, node)
if 'time_lag' in edge_data:
parent_nodes.append(n)
time_lags.append(edge_data['time_lag'])
return parent_nodes, time_lags

def shift_columns(df: pd.DataFrame, columns: List[str], lag: List[int]) -> pd.DataFrame:
srivhash marked this conversation as resolved.
Show resolved Hide resolved
"""
Given a dataframe, a list of columns, and a list of time lags, this function shifts the columns in the dataframe by the corresponding time lags, creating a new unique column for each shifted version.

:param df: The dataframe to shift.
:type df: pandas.DataFrame
:param columns: A list of columns to shift.
:type columns: list
:param lags: A list of time lags to shift the columns by.
:type lags: list
:return: The dataframe with the columns shifted by the corresponding time lags.
:rtype: pandas.DataFrame
"""
if len(columns) != len(lag):
raise ValueError("The size of 'columns' and 'lag' lists must be the same.")

new_df = df.copy()
for column, max_lag in zip(columns, lag):
max_lag = int(max_lag)
for shift in range(1, max_lag + 1):
new_column_name = f"{column}_lag{shift}"
new_df[new_column_name] = new_df[column].shift(shift, axis=0, fill_value=0)

return new_df

def _filter_columns(df: pd.DataFrame, child_node: int, parent_nodes: List[int]) -> pd.DataFrame:
"""
Given a dataframe, a target node, and a list of action/parent nodes, this function filters the dataframe to keep only the columns of the target node, the parent nodes, and their shifted versions.
srivhash marked this conversation as resolved.
Show resolved Hide resolved

:param df: The dataframe to filter.
:type df: pandas.DataFrame
:param child_node: The child node.
:type child_node: int
:param parent_nodes: A list of parent nodes.
:type parent_nodes: list
:return: The dataframe with only the columns of the child node, parent nodes, and their shifted versions.
:rtype: pandas.DataFrame
"""
columns_to_keep = [str(child_node)]
for node in parent_nodes:
columns_to_keep.append(str(node))
# Include all shifted versions of the parent node
shifted_columns = [col for col in df.columns if col.startswith(f"{node}_lag")]
columns_to_keep.extend(shifted_columns)

# Filter the dataframe to keep only the relevant columns
filtered_df = df[columns_to_keep]
return filtered_df
13 changes: 13 additions & 0 deletions dowhy/utils/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,3 +199,16 @@ def _plot_as_pyplot_figure(pygraphviz_graph: Any, figure_size: Optional[Tuple[in

if figure_size is not None:
plt.rcParams["figure.figsize"] = org_fig_size

def pretty_print_graph(graph: nx.DiGraph) -> None:
"""
Pretty print the graph edges with time lags.

:param graph: The networkx graph.
:type graph: networkx.Graph
:return: None
:rtype: None
"""
print("\nGraph edges with time lags:")
for edge in graph.edges(data=True):
print(f"{edge[0]} -> {edge[1]} with time-lagged dependency {edge[2]['time_lag']}")
Loading