From 9a4f455f72c2749c31e32fe503df16e5fe32b8e5 Mon Sep 17 00:00:00 2001 From: Rehan Durrani Date: Thu, 17 Oct 2024 15:57:08 -0700 Subject: [PATCH] [SNOW-1748174]: Add support for `size` in `groupby.agg` --- .../plugin/_internal/aggregation_utils.py | 16 +++ .../plugin/extensions/groupby_overrides.py | 50 ++------ .../modin/groupby/test_groupby_named_agg.py | 17 +++ .../integ/modin/groupby/test_groupby_size.py | 107 ++++++++++++++++++ 4 files changed, 147 insertions(+), 43 deletions(-) diff --git a/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py index 0005df924d..23eba32cc1 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py @@ -396,6 +396,16 @@ def _columns_count(*cols: SnowparkColumn) -> Callable: return sum(builtin("nvl2")(col, pandas_lit(1), pandas_lit(0)) for col in cols) +def _columns_count_keep_nulls(*cols: SnowparkColumn) -> Callable: + """ + Counts the number of values (including NULL) in each row. + """ + # IMPORTANT: count and sum use python builtin sum to invoke __add__ on each column rather than Snowpark + # sum_, since Snowpark sum_ gets the sum of all rows within a single column. + # NVL2(col, x, y) returns x if col is NULL, and y otherwise. + return sum(pandas_lit(1) for col in cols) + + def _columns_coalescing_sum(*cols: SnowparkColumn) -> Callable: """ Sums all non-NaN elements in each row. If all elements are NaN, returns 0. @@ -447,6 +457,12 @@ def _create_pandas_to_snowpark_pandas_aggregation_map( axis_1_aggregation_skipna=_columns_count, preserves_snowpark_pandas_types=False, ), + "size": _SnowparkPandasAggregation( + # We must count the total number of rows regardless of if they're null. + axis_0_aggregation=lambda cols: builtin("count_if")(pandas_lit(True)), + axis_1_aggregation_keepna=_columns_count_keep_nulls, + preserves_snowpark_pandas_types=False, + ), **_create_pandas_to_snowpark_pandas_aggregation_map( ("mean", np.mean), _SnowparkPandasAggregation( diff --git a/src/snowflake/snowpark/modin/plugin/extensions/groupby_overrides.py b/src/snowflake/snowpark/modin/plugin/extensions/groupby_overrides.py index 6d3728265d..53351dcc92 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/groupby_overrides.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/groupby_overrides.py @@ -434,25 +434,7 @@ def any(self, skipna: bool = True): ) def bfill(self, limit=None): - is_series_groupby = self.ndim == 1 - - # TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions - query_compiler = self._query_compiler.groupby_fillna( - self._by, - self._axis, - self._kwargs, - value=None, - method="bfill", - fill_axis=None, - inplace=False, - limit=limit, - downcast=None, - ) - return ( - pd.Series(query_compiler=query_compiler) - if is_series_groupby - else pd.DataFrame(query_compiler=query_compiler) - ) + ErrorMessage.method_not_implemented_error(name="bfill", class_="GroupBy") def corr(self, **kwargs): # TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions @@ -525,25 +507,7 @@ def diff(self): ErrorMessage.method_not_implemented_error(name="diff", class_="GroupBy") def ffill(self, limit=None): - is_series_groupby = self.ndim == 1 - - # TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions - query_compiler = self._query_compiler.groupby_fillna( - self._by, - self._axis, - self._kwargs, - value=None, - method="ffill", - fill_axis=None, - inplace=False, - limit=limit, - downcast=None, - ) - return ( - pd.Series(query_compiler=query_compiler) - if is_series_groupby - else pd.DataFrame(query_compiler=query_compiler) - ) + ErrorMessage.method_not_implemented_error(name="ffill", class_="GroupBy") def fillna( self, @@ -1092,10 +1056,7 @@ def __getitem__(self, key): if is_list_like(key): make_dataframe = True else: - if self._as_index: - make_dataframe = False - else: - make_dataframe = True + make_dataframe = False key = [key] column_index = self._df.columns @@ -1427,7 +1388,10 @@ def unique(self): def size(self): # TODO: Remove this once SNOW-1478924 is fixed - return super().size().rename(self._df.columns[-1]) + if self._as_index: + return super().size().rename(self._df.columns[-1]) + else: + return pd.DataFrame(super().size()).T def value_counts( self, diff --git a/tests/integ/modin/groupby/test_groupby_named_agg.py b/tests/integ/modin/groupby/test_groupby_named_agg.py index 53e3354bf6..7220fea0f1 100644 --- a/tests/integ/modin/groupby/test_groupby_named_agg.py +++ b/tests/integ/modin/groupby/test_groupby_named_agg.py @@ -4,6 +4,7 @@ import re import modin.pandas as pd +import numpy as np import pandas as native_pd import pytest @@ -114,6 +115,22 @@ def test_named_agg_passed_in_via_star_kwargs(basic_df_data): ) +@sql_count_checker(query_count=1) +def test_named_agg_count_vs_size(): + data = [[1, 2, 3], [1, 5, np.nan], [7, np.nan, 9]] + native_df = native_pd.DataFrame( + data, columns=["a", "b", "c"], index=["owl", "toucan", "eagle"] + ) + snow_df = pd.DataFrame(native_df) + eval_snowpark_pandas_result( + snow_df, + native_df, + lambda df: df.groupby("a").agg( + l=("b", "size"), j=("c", "size"), m=("c", "count"), n=("b", "count") + ), + ) + + @sql_count_checker(query_count=0) def test_named_agg_with_invalid_function_raises_not_implemented( basic_df_data, diff --git a/tests/integ/modin/groupby/test_groupby_size.py b/tests/integ/modin/groupby/test_groupby_size.py index da881754e6..0f1c203061 100644 --- a/tests/integ/modin/groupby/test_groupby_size.py +++ b/tests/integ/modin/groupby/test_groupby_size.py @@ -85,6 +85,92 @@ def test_groupby_size(by, as_index): ) +@pytest.mark.parametrize( + "by", + [ + "col1_grp", + "col2_int64", + "col3_int_identical", + "col4_int32", + "col6_mixed", + "col7_bool", + "col8_bool_missing", + "col9_int_missing", + "col10_mixed_missing", + ["col1_grp", "col2_int64"], + ["col6_mixed", "col7_bool", "col3_int_identical"], + ], +) +@pytest.mark.parametrize("as_index", [True, False]) +def test_groupby_agg_size(by, as_index): + snowpark_pandas_df = pd.DataFrame( + { + "col1_grp": ["g1", "g2", "g0", "g0", "g2", "g3", "g0", "g2", "g3"], + "col2_int64": np.arange(9, dtype="int64") // 3, + "col3_int_identical": [2] * 9, + "col4_int32": np.arange(9, dtype="int32") // 4, + "col5_int16": np.arange(9, dtype="int16") // 3, + "col6_mixed": np.concatenate( + [ + np.arange(3, dtype="int64") // 3, + np.arange(3, dtype="int32") // 3, + np.arange(3, dtype="int16") // 3, + ] + ), + "col7_bool": [True] * 5 + [False] * 4, + "col8_bool_missing": [ + True, + None, + False, + False, + None, + None, + True, + False, + None, + ], + "col9_int_missing": [5, 6, np.nan, 2, 1, np.nan, 5, np.nan, np.nan], + "col10_mixed_missing": np.concatenate( + [ + np.arange(2, dtype="int64") // 3, + [np.nan], + np.arange(2, dtype="int32") // 3, + [np.nan], + np.arange(2, dtype="int16") // 3, + [np.nan], + ] + ), + } + ) + + pandas_df = snowpark_pandas_df.to_pandas() + with SqlCounter(query_count=1): + eval_snowpark_pandas_result( + snowpark_pandas_df, + pandas_df, + lambda df: df.groupby(by, as_index=as_index).agg({"col5_int16": "size"}), + ) + + # DataFrame with __getitem__ + # if as_index: + # In this case, calling get item returns a DataFrame. + with SqlCounter(query_count=1): + eval_snowpark_pandas_result( + snowpark_pandas_df, + pandas_df, + lambda df: df.groupby(by, as_index=as_index)["col5_int16"].agg( + new_result="size" + ), + ) + # else: + # with SqlCounter(query_count=1): + # eval_snowpark_pandas_result( + # snowpark_pandas_df, + # pandas_df, + # lambda df: df.groupby(by, as_index=as_index)["col5_int16"].agg(new_result=('col5_int16', 'size')), + # ) + + @sql_count_checker(query_count=0) def test_error_checking(): s = pd.Series(list("abc") * 4) @@ -111,3 +197,24 @@ def test_timedelta(by): native_df, lambda df: df.groupby(by).size(), ) + + +@pytest.mark.parametrize("by", ["A", "B"]) +@sql_count_checker(query_count=1) +def test_timedelta_agg(by): + native_df = native_pd.DataFrame( + { + "A": native_pd.to_timedelta( + ["1 days 06:05:01.00003", "16us", "nan", "16us"] + ), + "B": [8, 8, 12, 10], + "C": ["the", "name", "is", "bond"], + } + ) + snow_df = pd.DataFrame(native_df) + + eval_snowpark_pandas_result( + snow_df, + native_df, + lambda df: df.groupby(by).agg(d=pd.NamedAgg("A" if by != "A" else "C", "size")), + )