Skip to content

Commit

Permalink
black and isort timeseries
Browse files Browse the repository at this point in the history
  • Loading branch information
srivhash committed Jul 29, 2024
1 parent f518c09 commit 8b621be
Showing 1 changed file with 18 additions and 9 deletions.
27 changes: 18 additions & 9 deletions dowhy/timeseries/temporal_shift.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import List, Optional, Tuple

import networkx as nx
import pandas as pd
from typing import List, Tuple, Optional

def find_lagged_parents(graph:nx.DiGraph, node:str) -> Tuple[List[str], List[int]]:

def find_lagged_parents(graph: nx.DiGraph, node: str) -> Tuple[List[str], List[int]]:
"""
Given a graph and a node, this function returns the parent nodes of the node and the time lags associated with the edges between the parent nodes and the node.
Expand All @@ -17,12 +19,15 @@ def find_lagged_parents(graph:nx.DiGraph, node:str) -> Tuple[List[str], List[int
time_lags = []
for n in graph.predecessors(node):
edge_data = graph.get_edge_data(n, node)
if 'time_lag' in edge_data:
if "time_lag" in edge_data:
parent_nodes.append(n)
time_lags.append(edge_data['time_lag'])
time_lags.append(edge_data["time_lag"])
return parent_nodes, time_lags

def shift_columns_by_lag(df: pd.DataFrame, columns: List[str], lag: List[int], filter: bool, child_node: Optional[str] = None) -> pd.DataFrame:

def shift_columns_by_lag(
df: pd.DataFrame, columns: List[str], lag: List[int], filter: bool, child_node: Optional[str] = None
) -> pd.DataFrame:
"""
Given a dataframe, a list of columns, and a list of time lags, this function shifts the columns in the dataframe by the corresponding time lags, creating a new unique column for each shifted version.
Optionally, it can filter the dataframe to keep only the columns of the child node, the parent nodes, and their shifted versions.
Expand All @@ -42,17 +47,21 @@ def shift_columns_by_lag(df: pd.DataFrame, columns: List[str], lag: List[int], f
"""
if len(columns) != len(lag):
raise ValueError("The size of 'columns' and 'lag' lists must be the same.")

new_df = df.copy()
for column, max_lag in zip(columns, lag):
max_lag = int(max_lag)
for shift in range(1, max_lag + 1):
new_column_name = f"{column}_lag{shift}"
new_df[new_column_name] = new_df[column].shift(shift, axis=0, fill_value=0)

if filter and child_node is not None:
relevant_columns = [child_node] + columns + [f"{col}_lag{shift}" for col in columns for shift in range(1, lag[columns.index(col)] + 1)]
relevant_columns = (
[child_node]
+ columns
+ [f"{col}_lag{shift}" for col in columns for shift in range(1, lag[columns.index(col)] + 1)]
)
relevant_columns = list(dict.fromkeys(relevant_columns)) # Ensure unique and maintain order
new_df = new_df[relevant_columns]

return new_df

0 comments on commit 8b621be

Please sign in to comment.