Skip to content

Commit

Permalink
updates to graphreduce
Browse files Browse the repository at this point in the history
  • Loading branch information
wesmadrigal committed Oct 30, 2024
1 parent e841e7d commit f98105f
Show file tree
Hide file tree
Showing 3 changed files with 294 additions and 40 deletions.
100 changes: 100 additions & 0 deletions graphreduce/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
#!/usr/bin/env python

import pytz
from datetime import datetime
import pandas as pd
import dask.dataframe as dd
from pyspark.sql import functions as F
import pyspark
from torch_frame import stype


stype_map = {
'numerical': [
'min',
'max',
'median',
'mean',
'sum',
],
'categorical': [
'nunique',
'count',
'mode',
],
'text_embedded': [
'length'
],
'text_tokenized': [
'length'
],
'multicategorical': [
'length'
],
'sequence_numerical': [
'sum',
'min',
'max',
'median',
],
'timestamp': [
'min',
'max',
'delta'
],
'image_embedded': [],
'embedding': []
}


def clean_datetime_pandas(df: pd.DataFrame, col: str) -> pd.DataFrame:
df[col] = pd.to_datetime(df[col], errors="coerce", utc=True)

# Count the number of rows before removing invalid dates
total_before = len(df)

# Remove rows where timestamp is NaT (indicating parsing failure)
df = df.dropna(subset=[col])

# Count the number of rows after removing invalid dates
total_after = len(df)

# Calculate the percentage of rows removed
percentage_removed = ((total_before - total_after) / total_before) * 100

# Print the percentage of comments removed
print(
f"Percentage of rows removed due to invalid dates: "
f"{percentage_removed:.2f}%"
)
return df


def clean_datetime_dask(df: dd.DataFrame, col: str) -> dd.DataFrame:
df[col] = dd.to_datetime(df[col])
total_before = len(df)
df = df.dropna(subset=[col])
total_after = len(df)
percentage_removed = ((total_before - total_after) / total_before) * 100
return df


def clean_datetime_spark(df, col: str) -> pyspark.sql.DataFrame:
pass



def convert_to_utc(dt):
"""Converts a datetime object to UTC.
Args:
dt: The datetime object to convert.
Returns:
The datetime object converted to UTC.
"""
if dt.tzinfo is None: # Naive datetime
# Assuming the original timezone is the local system time
local_tz = pytz.timezone('US/Pacific') # Replace with the actual timezone if known
dt = local_tz.localize(dt)
return dt.astimezone(pytz.UTC)
73 changes: 54 additions & 19 deletions graphreduce/graph_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,28 @@ def __init__(
auto_feature_hops_back: int = 2,
auto_feature_hops_front: int = 1,
feature_typefunc_map : typing.Dict[str, typing.List[str]] = {
'int64' : ['count'],
'int64' : ['median', 'mean', 'sum', 'min', 'max'],
'str' : ['min', 'max', 'count'],
#'object' : ['first', 'count'],
'object': ['count'],
'float64' : ['min', 'max', 'sum'],
'float64' : ['median', 'min', 'max', 'sum', 'mean'],
'float32': ['median','min','max','sum','mean'],
#'bool' : ['first'],
#'datetime64' : ['first', 'min', 'max'],
'datetime64': ['min', 'max'],
'datetime64[ns]': ['min', 'max'],
},
feature_stype_map: typing.Dict[str, typing.List[str]] = {
'numerical': ['median', 'mean', 'sum', 'min', 'max'],
'categorical': ['count', 'nunique'],
'embedding': ['first'],
'image_embedded': ['first'],
'multicategorical': ['mode'],
'sequence_numerical': ['min', 'max'],
'timestamp': ['min','max']
},
# Label parameters.
label_node: typing.Optional[GraphReduceNode] = None,
label_node: typing.Optional[typing.Union[GraphReduceNode, typing.List[GraphReduceNode]]] = None,
label_operation: typing.Optional[typing.Union[callable, str]] = None,
# Field on the node.
label_field: typing.Optional[str] = None,
Expand Down Expand Up @@ -113,6 +123,7 @@ def __init__(
self.auto_feature_hops_back = auto_feature_hops_back
self.auto_feature_hops_front = auto_feature_hops_front
self.feature_typefunc_map = feature_typefunc_map
self.feature_stype_map = feature_stype_map

# SQL dialect parameters.
self._lazy_execution = lazy_execution
Expand Down Expand Up @@ -216,6 +227,8 @@ def hydrate_graph_data (
Hydrate the nodes in the graph with their data
"""
for node in self.nodes():
if self.debug:
logger.debug(f'hydrating {node} data')
node.do_data()


Expand All @@ -227,11 +240,14 @@ def add_entity_edge (
relation_key : str,
# need to enforce this better
relation_type : str = 'parent_child',
reduce : bool = True
reduce : bool = True,
reduce_after_join: bool = False,
):
"""
Add an entity relation
"""
if reduce and reduce_after_join:
raise Exception(f'only one can be true: `reduce` or `reduce_after_join`')
if not self.has_edge(parent_node, relation_node):
self.add_edge(
parent_node,
Expand All @@ -240,7 +256,8 @@ def add_entity_edge (
'parent_key' : parent_key,
'relation_key' : relation_key,
'relation_type' : relation_type,
'reduce' : reduce
'reduce' : reduce,
'reduce_after_join': reduce_after_join
}
)

Expand Down Expand Up @@ -610,7 +627,8 @@ def do_transformations_sql(self):

sql_ops = relation_node.auto_features(
reduce_key=edge_data['relation_key'],
type_func_map=self.feature_typefunc_map,
#type_func_map=self.feature_typefunc_map,
type_func_map=self.feature_stype_map,
compute_layer=self.compute_layer
)
logger.info(f"{sql_ops}")
Expand All @@ -619,7 +637,8 @@ def do_transformations_sql(self):
relation_node.build_query(
relation_node.auto_features(
reduce_key=edge_data['relation_key'],
type_func_map=self.feature_typefunc_map,
#type_func_map=self.feature_typefunc_map,
type_func_map=self.feature_stype_map,
compute_layer=self.compute_layer
)
),
Expand All @@ -634,8 +653,6 @@ def do_transformations_sql(self):
reduce_sql = relation_node.build_query(reduce_ops)
logger.info(f"reduce SQL: {reduce_sql}")
reduce_ref = relation_node.create_ref(reduce_sql, relation_node.do_reduce)


else:
# in this case we will join the entire relation's dataframe
logger.info(f"doing nothing with relation node {relation_node}")
Expand All @@ -649,13 +666,12 @@ def do_transformations_sql(self):
)

# Target variables.
if self.label_node and self.label_node == relation_node:
if self.label_node and (self.label_node == relation_node or relation_node.label_field is not None):
logger.info(f"Had label node {self.label_node}")

# Get the reference right before `do_reduce`
# so the records are not aggregated yet.
data_ref = relation_node.get_ref_name(relation_node.do_filters, lookup=True)

data_ref = relation_node.get_ref_name(relation_node.do_filters, lookup=True)

#TODO: don't need to reduce if it's 1:1 cardinality.
if self.auto_features:
Expand Down Expand Up @@ -741,7 +757,7 @@ def do_transformations(self):
# This may be an incorrect assumption
# so this implementation is currently brittle.
if self.debug:
logger.debug(f'Performing an auto_features front join from {from_node} to {to_node}')
logger.debug(f'Performing FRONT auto_features front join from {from_node} to {to_node}')
joined_df = self.join_any(
to_node,
from_node
Expand All @@ -761,10 +777,11 @@ def do_transformations(self):
join_df = relation_node.do_reduce(edge_data['relation_key'])
# only relevant when reducing
if self.auto_features:
logger.info(f"performing auto_features on node {relation_node}")
logger.info(f"performing auto_features on node {relation_node} with reduce key {edge_data['relation_key']}")
child_df = relation_node.auto_features(
reduce_key=edge_data['relation_key'],
type_func_map=self.feature_typefunc_map,
#type_func_map=self.feature_typefunc_map,
type_func_map=self.feature_stype_map,
compute_layer=self.compute_layer
)

Expand All @@ -779,6 +796,8 @@ def do_transformations(self):
)
else:
join_df = child_df
if self.debug:
logger.debug(f'assigned join_df to be {child_df.columns}')
elif self.compute_layer == ComputeLayerEnum.spark:
if isinstance(join_df, pyspark.sql.dataframe.DataFrame):
join_df = join_df.join(
Expand All @@ -788,6 +807,8 @@ def do_transformations(self):
)
else:
join_df = child_df
if self.debug:
logger.debug(f'assigned join_df to be {child_df.columns}')

else:
# in this case we will join the entire relation's dataframe
Expand All @@ -805,14 +826,24 @@ def do_transformations(self):
parent_node.df = joined_df

# Target variables.
if self.label_node and self.label_node == relation_node:
if self.label_node and (self.label_node == relation_node or relation_node.label_field is not None):
logger.info(f"Had label node {self.label_node}")
# Automatic label generation.
if isinstance(relation_node, DynamicNode):
label_df = relation_node.default_label(
if self.label_node == relation_node:
label_df = relation_node.default_label(
op=self.label_operation,
field=self.label_field,
reduce_key=edge_data['relation_key']
)
)
elif relation_node.label_field is not None:
label_df = relation_node.default_label(
op=relation_node.label_operation if relation_node.label_operation else 'count',
field=relation_node.label_field,
reduce_key=edge_data['relation_key']
)
# There should be an implementation of `do_labels`
# when the instance is a `GraphReduceNode`.
elif isinstance(relation_node, GraphReduceNode):
label_df = relation_node.do_labels(edge_data['relation_key'])

Expand All @@ -829,4 +860,8 @@ def do_transformations(self):
parent_node.do_post_join_annotate()
# post-join filters (if any)
if hasattr(parent_node, 'do_post_join_filters'):
parent_node.do_post_join_filters()
parent_node.do_post_join_filters()

# post-join aggregation
if edge_data['reduce_after_join']:
parent_node.do_post_join_reduce(edge_data['relation_key'], type_func_map=self.feature_stype_map)
Loading

0 comments on commit f98105f

Please sign in to comment.