Skip to content

Commit

Permalink
[SNOW-902943]: Add support for pd.NamedAgg in DataFrame and Series.agg (
Browse files Browse the repository at this point in the history
#1652)

Co-authored-by: Jonathan Shi <149419494+sfc-gh-joshi@users.noreply.github.com>
  • Loading branch information
sfc-gh-rdurrani and sfc-gh-joshi authored May 29, 2024
1 parent dbb0713 commit 6a1bd53
Show file tree
Hide file tree
Showing 13 changed files with 806 additions and 112 deletions.
20 changes: 20 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,23 @@
# Release History

## 1.19.0 (TBD)

### Snowpark Python API Updates

#### Improvements

### Snowpark pandas API Updates

#### New Features

#### Bug Fixes

- Fixed a bug that causes output of GroupBy.aggregate's columns to be ordered incorrectly.

#### Improvements

- Added support for named aggregations in `DataFrame.aggregate` and `Series.aggregate` with `axis=0`.

## 1.18.0 (2024-05-28)

### Snowpark Python API Updates
Expand All @@ -15,6 +33,8 @@

- Added `DataFrame.cache_result` and `Series.cache_result` methods for users to persist DataFrames and Series to a temporary table lasting the duration of the session to improve latency of subsequent operations.

#### Bug Fixes

#### Improvements

- Added partial support for `DataFrame.pivot_table` with no `index` parameter, as well as for `margins` parameter.
Expand Down
45 changes: 38 additions & 7 deletions src/snowflake/snowpark/modin/pandas/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
)
from pandas.core.dtypes.inference import is_integer
from pandas.core.indexes.api import ensure_index
from pandas.errors import SpecificationError
from pandas.util._validators import (
validate_ascending,
validate_bool_kwarg,
Expand All @@ -76,6 +77,7 @@
from snowflake.snowpark.modin import pandas as pd
from snowflake.snowpark.modin.pandas.utils import (
_doc_binary_op,
extract_validate_and_try_convert_named_aggs_from_kwargs,
get_as_shape_compatible_dataframe_or_series,
is_scalar,
raise_if_native_pandas_objects,
Expand Down Expand Up @@ -703,16 +705,45 @@ def aggregate(
# native pandas raise error with message "no result", here we raise a more readable error.
raise ValueError("No column to aggregate on.")

func = validate_and_try_convert_agg_func_arg_func_to_str(
agg_func=func,
obj=self,
allow_duplication=False,
axis=axis,
)
# If we are using named kwargs, then we do not clear the kwargs (need them in the QC for processing
# order, as well as formatting error messages.)
uses_named_kwargs = False
# If aggregate is called on a Series, named aggregations can be passed in via a dictionary
# to func.
if func is None or (is_dict_like(func) and not self._is_dataframe):
if axis == 1:
raise ValueError(
"`func` must not be `None` when `axis=1`. Named aggregations are not supported with `axis=1`."
)
if func is not None:
# If named aggregations are passed in via a dictionary to func, then we
# ignore the kwargs.
if any(is_dict_like(value) for value in func.values()):
# We can only get to this codepath if self is a Series, and func is a dictionary.
# In this case, if any of the values of func are themselves dictionaries, we must raise
# a Specification Error, as that is what pandas does.
raise SpecificationError("nested renamer is not supported")
kwargs = func
func = extract_validate_and_try_convert_named_aggs_from_kwargs(
self, allow_duplication=False, axis=axis, **kwargs
)
uses_named_kwargs = True
else:
func = validate_and_try_convert_agg_func_arg_func_to_str(
agg_func=func,
obj=self,
allow_duplication=False,
axis=axis,
)

# This is to stay consistent with pandas result format, when the func is single
# aggregation function in format of callable or str, reduce the result dimension to
# convert dataframe to series, or convert series to scalar.
# Note: When named aggregations are used, the result is not reduced, even if there
# is only a single function.
# needs_reduce_dimension cannot be True if we are using named aggregations, since
# the values for func in that case are either NamedTuples (AggFuncWithLabels) or
# lists of NamedTuples, both of which are list like.
need_reduce_dimension = (
(callable(func) or isinstance(func, str))
# A Series should be returned when a single scalar string/function aggregation function, or a
Expand Down Expand Up @@ -767,7 +798,7 @@ def aggregate(
# dtype: int8
# >>> pd.DataFrame([[np.nan], [0]]).count(skipna=True, axis=0)
# TypeError: got an unexpected keyword argument 'skipna'
if is_dict_like(func):
if is_dict_like(func) and not uses_named_kwargs:
kwargs.clear()

result = self.__constructor__(
Expand Down
36 changes: 32 additions & 4 deletions src/snowflake/snowpark/modin/pandas/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,8 +585,37 @@ def aggregate(
"axis other than 0 is not supported"
) # pragma: no cover
if func is None:
# When func is None, we assume that the aggregation functions have been passed in via named aggregations,
# which can be of the form named_agg=('col_name', 'agg_func') or named_agg=pd.NamedAgg('col_name', 'agg_func').
# We need to parse out the following three things:
# 1. The new label to apply to the result of the aggregation.
# 2. The column to apply the aggregation over.
# 3. The aggregation to apply.
# This function checks that:
# 1. The kwargs contain named aggregations.
# 2. The kwargs do not contain anything besides named aggregations. (for pandas compatibility - see function for more details.)
# If both of these things are true, it then extracts the named aggregations from the kwargs, and returns a dictionary that contains
# a mapping from the column pandas labels to apply the aggregation over (2 above) to a tuple containing the aggregation to apply
# and the new label to assign it (1 and 3 above). Take for example, the following call:
# df.groupby(...).agg(new_col1=('A', 'min'), new_col2=('B', 'max'), new_col3=('A', 'max'))
# After this function returns, func will look like this:
# {
# "A": [AggFuncWithLabel(func="min", pandas_label="new_col1"), AggFuncWithLabel(func="max", pandas_label="new_col3")],
# "B": AggFuncWithLabel(func="max", pandas_label="new_col2")
# }
# This remapping causes an issue with ordering though - the dictionary above will be processed in the following order:
# 1. apply "min" to "A" and name it "new_col1"
# 2. apply "max" to "A" and name it "new_col3"
# 3. apply "max" to "B" and name it "new_col2"
# In other words - the order is slightly shifted so that named aggregations on the same column are contiguous in the ordering
# although the ordering of the kwargs is used to determine the ordering of named aggregations on the same columns. Since
# the reordering for groupby agg is a reordering of columns, its relatively cheap to do after the aggregation is over,
# rather than attempting to preserve the order of the named aggregations internally.
func = extract_validate_and_try_convert_named_aggs_from_kwargs(
obj=self, allow_duplication=True, axis=self._axis, **kwargs
obj=self,
allow_duplication=True,
axis=self._axis,
**kwargs,
)
else:
func = validate_and_try_convert_agg_func_arg_func_to_str(
Expand Down Expand Up @@ -615,6 +644,7 @@ def aggregate(
how="axis_wise",
is_result_dataframe=is_result_dataframe,
)

return result

agg = aggregate
Expand Down Expand Up @@ -1170,9 +1200,7 @@ def aggregate(
):
# TODO: SNOW-1063350: Modin upgrade - modin.pandas.groupby.SeriesGroupBy functions
if is_dict_like(func):
raise SpecificationError(
"Value for func argument in dict format is not allowed for SeriesGroupBy."
)
raise SpecificationError("nested renamer is not supported")

return super().aggregate(
func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs
Expand Down
7 changes: 0 additions & 7 deletions src/snowflake/snowpark/modin/pandas/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
from pandas.core.common import apply_if_callable, is_bool_indexer
from pandas.core.dtypes.common import is_bool_dtype, is_dict_like, is_list_like
from pandas.core.series import _coerce_method
from pandas.errors import SpecificationError
from pandas.util._validators import validate_bool_kwarg

from snowflake.snowpark.modin.pandas.accessor import CachedAccessor, SparseAccessor
Expand Down Expand Up @@ -750,12 +749,6 @@ def drop(
def aggregate(
self, func: AggFuncType = None, axis: Axis = 0, *args: Any, **kwargs: Any
):
# TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions
if is_dict_like(func):
raise SpecificationError(
"Value for func argument in dict format is not allowed for Series aggregate."
)

return super().aggregate(func, axis, *args, **kwargs)

agg = aggregate
Expand Down
55 changes: 40 additions & 15 deletions src/snowflake/snowpark/modin/pandas/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,21 +553,47 @@ def extract_validate_and_try_convert_named_aggs_from_kwargs(
A dictionary mapping columns to a tuple containing the aggregation to perform, as well
as the pandas label to give the aggregated column.
"""
from snowflake.snowpark.modin.pandas import Series
from snowflake.snowpark.modin.pandas.groupby import SeriesGroupBy

is_series_like = isinstance(obj, (Series, SeriesGroupBy))
named_aggs = {}
accepted_keys = []
columns = obj._query_compiler.columns
for key, value in kwargs.items():
if isinstance(value, pd.NamedAgg) or (
isinstance(value, tuple) and len(value) == 2
):
if is_series_like:
# pandas does not allow pd.NamedAgg or 2-tuples for named aggregations
# when the base object is a Series, but has different errors depending
# on whether we are doing a Series.agg or Series.groupby.agg.
if isinstance(obj, Series):
raise SpecificationError("nested renamer is not supported")
else:
value_type_str = (
"NamedAgg" if isinstance(value, pd.NamedAgg) else "tuple"
)
raise TypeError(
f"func is expected but received {value_type_str} in **kwargs."
)
if axis == 0:
# If axis == 1, we would need a query to materialize the index to check its existence
# so we defer the error checking to later.
if value[0] not in columns:
raise KeyError(f"Column(s) ['{value[0]}'] do not exist")

# This function converts our named aggregations dictionary from a mapping of
# new_label -> tuple[column_name, agg_func] to a mapping of
# column_name -> tuple[agg_func, new_label] in order to process
# the aggregation functions internally. One issue with this is that the order
# of the named aggregations can change - say we have the following aggregations:
# {new_col: ('A', min), new_col1: ('B', max), new_col2: ('A', max)}
# The output of this function will look like this:
# {A: [AggFuncWithLabel(func=min, label=new_col), AggFuncWithLabel(func=max, label=new_col2)]
# B: AggFuncWithLabel(func=max, label=new_col1)}
# And so our final dataframe will have the wrong order. We handle the reordering of the generated
# labels at the QC layer.
if value[0] in named_aggs:
if not isinstance(named_aggs[value[0]], list):
named_aggs[value[0]] = [named_aggs[value[0]]]
Expand All @@ -577,8 +603,11 @@ def extract_validate_and_try_convert_named_aggs_from_kwargs(
else:
named_aggs[value[0]] = AggFuncWithLabel(func=value[1], pandas_label=key)
accepted_keys += [key]
elif isinstance(obj, SeriesGroupBy):
col_name = obj._df._query_compiler.columns[0]
elif is_series_like:
if isinstance(obj, SeriesGroupBy):
col_name = obj._df._query_compiler.columns[0]
else:
col_name = obj._query_compiler.columns[0]
if col_name not in named_aggs:
named_aggs[col_name] = AggFuncWithLabel(func=value, pandas_label=key)
else:
Expand All @@ -587,14 +616,15 @@ def extract_validate_and_try_convert_named_aggs_from_kwargs(
named_aggs[col_name] += [AggFuncWithLabel(func=value, pandas_label=key)]
accepted_keys += [key]

if len(named_aggs.keys()) == 0:
ErrorMessage.not_implemented(
"Must provide value for 'func' argument, func=None is currently not supported with Snowpark pandas"
)

if any(key not in accepted_keys for key in kwargs.keys()):
# For compatibility with pandas errors. Otherwise, we would just ignore
# those kwargs.
if len(named_aggs.keys()) == 0 or any(
key not in accepted_keys for key in kwargs.keys()
):
# First check makes sure that some functions have been passed. If nothing has been passed,
# we raise the TypeError.
# The second check is for compatibility with pandas errors. Say the user does something like this:
# df.agg(x=pd.NamedAgg('A', 'min'), random_extra_kwarg=14). pandas errors out, since func is None
# and not every kwarg is a named aggregation. Without this check explicitly, we would just ignore
# the extraneous kwargs, so we include this check for parity with pandas.
raise TypeError("Must provide 'func' or tuples of '(column, aggfunc).")

validated_named_aggs = {}
Expand Down Expand Up @@ -659,11 +689,6 @@ def validate_and_try_convert_agg_func_arg_func_to_str(
If nested dict configuration is used when agg_func is dict like or functions with duplicated names.
"""
if agg_func is None:
ErrorMessage.not_implemented(
"Must provide value for 'func' argument, func=None is currently not supported with Snowpark pandas"
)

if callable(agg_func):
result_agg_func = try_convert_builtin_func_to_str(agg_func, obj)
elif is_dict_like(agg_func):
Expand Down
36 changes: 36 additions & 0 deletions src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1099,3 +1099,39 @@ def generate_column_agg_info(
new_data_column_index_names += [None]

return column_agg_ops, new_data_column_index_names


def using_named_aggregations_for_func(func: Any) -> bool:
"""
Helper method to check if func is formatted in a way that indicates that we are using named aggregations.
If the user specifies named aggregations, we parse them into the func variable as a dictionary mapping
Hashable pandas labels to either a single AggFuncWithLabel or a list of AggFuncWithLabel NamedTuples. To know if
a SnowflakeQueryCompiler aggregation method (agg(), groupby_agg()) was called with named aggregations, we can check
if the `func` argument passed in obeys this formatting.
This function checks the following:
1. `func` is dict-like.
2. Every value in `func` is either:
a) an AggFuncWithLabel object
b) a list of AggFuncWithLabel objects.
If both conditions are met, that means that this func is the result of our internal processing of an aggregation
API with named aggregations specified by the user.
"""
return is_dict_like(func) and all(
isinstance(value, AggFuncWithLabel)
or (
isinstance(value, list)
and all(isinstance(v, AggFuncWithLabel) for v in value)
)
for value in func.values()
)


def format_kwargs_for_error_message(kwargs: dict[Any, Any]) -> str:
"""
Helper method to format a kwargs dictionary for an error message.
Returns a string containing the keys + values of kwargs formatted like so:
"key1=value1, key2=value2, ..."
"""
return ", ".join([f"{key}={value}" for key, value in kwargs.items()])
Loading

0 comments on commit 6a1bd53

Please sign in to comment.