Skip to content

Commit

Permalink
Merge pull request #13 from wesmadrigal/read-encoding
Browse files Browse the repository at this point in the history
updates
  • Loading branch information
wesmadrigal authored Nov 9, 2024
2 parents 212dec2 + f6616d6 commit 4d79790
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 12 deletions.
47 changes: 39 additions & 8 deletions graphreduce/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ def pandas_auto_features (
continue
elif type_func_map.get(_type):
for func in type_func_map[_type]:
if (_type == 'numerical' or 'timestamp') and dict(self.df.dtypes)[col].__str__() == 'object':
if (_type == 'numerical' or 'timestamp') and dict(self.df.dtypes)[col].__str__() == 'object' and func in ['min','max','median','mean']:
logger.info(f'skipped aggregation on {col} because semantic numerical but physical object')
continue
col_new = f"{col}_{func}"
Expand Down Expand Up @@ -535,20 +535,48 @@ def sql_auto_features (
data type inference libraries.
"""
agg_funcs = []
for col, _type in dict(table_df_sample.dtypes).items():
_type = str(_type)
if type_func_map.get(_type):
if not self._stypes:
self._stypes = infer_df_stype(table_df_sample)
for col, stype in self._stypes.items():
_type = str(stype)
if self._is_identifier(col) and col != reduce_key:
# We only perform counts for identifiers.
func = "count"
col_new = f"{col}_{func}"
agg_funcs.append(
sqlop(
optype=SQLOpType.aggfunc,
opval=f"{func}" + f"({col}) as {col_new}"
)
)
elif self._is_identifier(col) and col == reduce_key:
continue
elif type_func_map.get(_type):
for func in type_func_map[_type]:
# There should be a better top-level mapping
# but for now this will do. SQL engines typically
# don't have 'median' and 'mean'. 'mean' is typically
# just called 'avg'.
if (_type == 'numerical' or 'timestamp') and dict(table_df_sample)[col].__str__() == 'object' and func in ['min','max','mean', 'median']:
logger.info(f'skipped aggregation on {col} because semantic numerical but physical object')
continue
if func in self.FUNCTION_MAPPING:
func = self.FUNCTION_MAPPING.get(func)

if not func:
continue
col_new = f"{col}_{func}"
agg_funcs.append(
sqlop(
optype=SQLOpType.aggfunc,
opval=f"{func}" + f"({col}) as {col_new}"
)
)
# Need the aggregation and time-based filtering.
if not len(agg_funcs):
logger.info(f'No aggregations for {self}')
return self.df
agg = sqlop(optype=SQLOpType.agg, opval=f"{self.colabbr(reduce_key)}")

# Need the aggregation and time-based filtering.
tfilt = self.prep_for_features() if self.prep_for_features() else []

return tfilt + agg_funcs + [agg]
Expand Down Expand Up @@ -966,8 +994,12 @@ class SQLNode(GraphReduceNode):
AWS Athena, which requires additional params.
Subclasses should simply extend the `SQLNode` interface:
"""
FUNCTION_MAPPING = {
'mean': 'avg',
'median': None,
'nunique': None,
}
def __init__ (
self,
*args,
Expand Down Expand Up @@ -1091,7 +1123,6 @@ def create_temp_view (
Create a view with the results of
the query.
"""

try:
sql = f"""
CREATE VIEW {view_name} AS
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ structlog>=23.1.0
pytest>=8.0.2
woodwork==0.29.0
pydantic==1.10.5
pytorch_frame
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

setuptools.setup(
name="graphreduce",
version = "1.6.8",
version = "1.6.9",
url="https://github.com/wesmadrigal/graphreduce",
packages = setuptools.find_packages(exclude=[ "docs", "examples" ]),
install_requires = [
Expand Down
12 changes: 9 additions & 3 deletions tests/test_graph_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,14 +151,16 @@ def test_multi_node():
fpath=os.path.join(data_path, 'cust.csv'),
fmt='csv',
prefix='cust',
date_key=None
date_key=None,
pk='id',
)

order_node = DynamicNode(
fpath=os.path.join(data_path, 'orders.csv'),
fmt='csv',
prefix='ord',
date_key='ts'
date_key='ts',
pk='id',
)

gr = GraphReduce(
Expand Down Expand Up @@ -308,26 +310,30 @@ def test_sql_graph_transform():
def test_sql_graph_auto_fe():
conn = _setup_sqlite()
cust = SQLNode(fpath='cust',
pk='id',
prefix='cust',
client=conn,
compute_layer=ComputeLayerEnum.sqlite,
columns=['id','name'])

notif = SQLNode(fpath='notifications',
prefix='not',
pk='id',
client=conn,
compute_layer=ComputeLayerEnum.sqlite,
columns=['id','customer_id','ts'],
date_key='ts')

ni = SQLNode(fpath='notification_interactions',
prefix='ni',
pk='id',
client=conn,
compute_layer=ComputeLayerEnum.sqlite,
columns=['id','notification_id','interaction_type_id','ts'],
date_key='ts')

order = SQLNode(fpath='orders',
pk='id',
prefix='ord',
client=conn,
compute_layer=ComputeLayerEnum.sqlite,
Expand All @@ -351,7 +357,6 @@ def test_sql_graph_auto_fe():
compute_layer=ComputeLayerEnum.sqlite,
use_temp_tables=True,
lazy_execution=False,

# Auto feature engineering params.
auto_features=True,
auto_feature_hops_back=3,
Expand Down Expand Up @@ -385,6 +390,7 @@ def test_sql_graph_auto_fe():
relation_key='customer_id',
reduce=True
)
gr.plot_graph('cust_graph.html')
gr.do_transformations_sql()
d = pd.read_sql_query(f"select * from {gr.parent_node._cur_data_ref}", conn)
ic(d)
Expand Down

0 comments on commit 4d79790

Please sign in to comment.