Skip to content

Commit

Permalink
updates to forward traversal
Browse files Browse the repository at this point in the history
  • Loading branch information
wesmadrigal committed Oct 18, 2024
1 parent 3f61d97 commit e841e7d
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 13 deletions.
36 changes: 29 additions & 7 deletions graphreduce/graph_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from structlog import get_logger
import pyspark
import pyvis
import woodwork as ww

# internal
from graphreduce.node import GraphReduceNode, DynamicNode, SQLNode
Expand All @@ -25,7 +24,6 @@
logger = get_logger('GraphReduce')



class GraphReduce(nx.DiGraph):
def __init__(
self,
Expand Down Expand Up @@ -62,6 +60,8 @@ def __init__(

# Only for SQL engines.
lazy_execution: bool = False,
# Debug
debug: bool = False,
*args,
**kwargs
):
Expand All @@ -83,6 +83,10 @@ def __init__(
auto_feature_hops_back: optional for automatically computing features
auto_feature_hops_front: optional for automatically computing features
feature_typefunc_map : optional mapping from type to a list of functions (e.g., {'int' : ['min', 'max', 'sum'], 'str' : ['first']})
label_node: optionl GraphReduceNode for the label
label_operation: optional str or callable operation to call to compute the label
label_field: optional str field to compute the label
debug: bool whether to run debug logging
"""
super(GraphReduce, self).__init__(*args, **kwargs)

Expand Down Expand Up @@ -116,6 +120,8 @@ def __init__(
# if using Spark
self._sqlctx = spark_sqlctx
self._storage_client = storage_client

self.debug = debug

if self.compute_layer == ComputeLayerEnum.spark and self._sqlctx is None:
raise Exception(f"Must provide a `spark_sqlctx` kwarg if using {self.compute_layer.value} as compute layer")
Expand Down Expand Up @@ -451,15 +457,24 @@ def traverse_up (
"""
parents = [(start, n, 1) for n in self.predecessors(start)]
to_traverse = [(n, 1) for n in self.predecessors(start)]
while len(to_traverse):
cur_level = 1
while len(to_traverse) and cur_level <= self.auto_feature_hops_front:
cur_node, cur_level = to_traverse[0]
del to_traverse[0]

for node in self.predecessors(cur_node):
parents.append((cur_node, node, cur_level+1))
to_traverse.append((node, cur_level+1))

return parents
if cur_level+1 <= self.auto_feature_hops_front:
parents.append((cur_node, node, cur_level+1))
to_traverse.append((node, cur_level+1))
# Returns higher levels first so that
# when we iterate through these edges
# we will traverse from top to bottom
# where the bottom is our `start`.
parents_ordered = list(reversed(parents))
if self.debug:
for ix in range(len(parents_ordered)):
logger.debug(f"index {ix} is level {parents_ordered[ix][-1]}")
return parents_ordered


def get_children (
Expand Down Expand Up @@ -720,6 +735,13 @@ def do_transformations(self):
if self.auto_features:
for to_node, from_node, level in self.traverse_up(start=self.parent_node):
if self.auto_feature_hops_front and level <= self.auto_feature_hops_front:
# It is assumed that front-facing relations
# are not one to many and therefore we
# won't have duplication on the join.
# This may be an incorrect assumption
# so this implementation is currently brittle.
if self.debug:
logger.debug(f'Performing an auto_features front join from {from_node} to {to_node}')
joined_df = self.join_any(
to_node,
from_node
Expand Down
21 changes: 15 additions & 6 deletions graphreduce/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import pyspark
from structlog import get_logger
from dateutil.parser import parse as date_parse
import woodwork as ww

# internal
from graphreduce.enum import ComputeLayerEnum, PeriodUnit, SQLOpType
Expand Down Expand Up @@ -85,6 +84,9 @@ def __init__ (
checkpoints: list = [],
# Only for SQL dialects at the moment.
lazy_execution: bool = False,
# Read encoding.
delimiter: str = None,
encoding: str = None,
):
"""
Constructor
Expand Down Expand Up @@ -112,6 +114,10 @@ def __init__ (
self.spark_sqlctx = spark_sqlctx
self.columns = columns

# Read options
self.delimiter = delimiter if delimiter else ','
self.encoding = encoding

# Lazy execution for the SQL nodes.
self._lazy_execution = lazy_execution
self._storage_client = storage_client
Expand Down Expand Up @@ -168,11 +174,14 @@ def do_data (

if self.compute_layer.value == 'pandas':
if not hasattr(self, 'df') or (hasattr(self,'df') and not isinstance(self.df, pd.DataFrame)):
self.df = getattr(pd, f"read_{self.fmt}")(self.fpath)
if self.encoding and self.delimiter:
self.df = getattr(pd, f"read_{self.fmt}")(self.fpath, encoding=self.encoding, delimiter=self.delimiter)
else:
self.df = getattr(pd, f"read_{self.fmt}")(self.fpath)

# Initialize woodwork.
self.df.ww.init()
self._logical_types = self.df.ww.logical_types
#self.df.ww.init()
#self._logical_types = self.df.ww.logical_types

# Rename columns with prefixes.
if len(self.columns):
Expand All @@ -185,8 +194,8 @@ def do_data (
self.df = getattr(dd, f"read_{self.fmt}")(self.fpath)

# Initialize woodwork.
self.df.ww.init()
self._logical_types = self.df.ww.logical_types
#self.df.ww.init()
#self._logical_types = self.df.ww.logical_types

# Rename columns with prefixes.
if len(self.columns):
Expand Down

0 comments on commit e841e7d

Please sign in to comment.