Skip to content

Commit

Permalink
[SNOW-1748174]: Add support for size in groupby.agg
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-rdurrani committed Oct 17, 2024
1 parent 50a9dcf commit 9a4f455
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 43 deletions.
16 changes: 16 additions & 0 deletions src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,16 @@ def _columns_count(*cols: SnowparkColumn) -> Callable:
return sum(builtin("nvl2")(col, pandas_lit(1), pandas_lit(0)) for col in cols)


def _columns_count_keep_nulls(*cols: SnowparkColumn) -> Callable:
"""
Counts the number of values (including NULL) in each row.
"""
# IMPORTANT: count and sum use python builtin sum to invoke __add__ on each column rather than Snowpark
# sum_, since Snowpark sum_ gets the sum of all rows within a single column.
# NVL2(col, x, y) returns x if col is NULL, and y otherwise.
return sum(pandas_lit(1) for col in cols)


def _columns_coalescing_sum(*cols: SnowparkColumn) -> Callable:
"""
Sums all non-NaN elements in each row. If all elements are NaN, returns 0.
Expand Down Expand Up @@ -447,6 +457,12 @@ def _create_pandas_to_snowpark_pandas_aggregation_map(
axis_1_aggregation_skipna=_columns_count,
preserves_snowpark_pandas_types=False,
),
"size": _SnowparkPandasAggregation(
# We must count the total number of rows regardless of if they're null.
axis_0_aggregation=lambda cols: builtin("count_if")(pandas_lit(True)),
axis_1_aggregation_keepna=_columns_count_keep_nulls,
preserves_snowpark_pandas_types=False,
),
**_create_pandas_to_snowpark_pandas_aggregation_map(
("mean", np.mean),
_SnowparkPandasAggregation(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -434,25 +434,7 @@ def any(self, skipna: bool = True):
)

def bfill(self, limit=None):
is_series_groupby = self.ndim == 1

# TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions
query_compiler = self._query_compiler.groupby_fillna(
self._by,
self._axis,
self._kwargs,
value=None,
method="bfill",
fill_axis=None,
inplace=False,
limit=limit,
downcast=None,
)
return (
pd.Series(query_compiler=query_compiler)
if is_series_groupby
else pd.DataFrame(query_compiler=query_compiler)
)
ErrorMessage.method_not_implemented_error(name="bfill", class_="GroupBy")

def corr(self, **kwargs):
# TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions
Expand Down Expand Up @@ -525,25 +507,7 @@ def diff(self):
ErrorMessage.method_not_implemented_error(name="diff", class_="GroupBy")

def ffill(self, limit=None):
is_series_groupby = self.ndim == 1

# TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions
query_compiler = self._query_compiler.groupby_fillna(
self._by,
self._axis,
self._kwargs,
value=None,
method="ffill",
fill_axis=None,
inplace=False,
limit=limit,
downcast=None,
)
return (
pd.Series(query_compiler=query_compiler)
if is_series_groupby
else pd.DataFrame(query_compiler=query_compiler)
)
ErrorMessage.method_not_implemented_error(name="ffill", class_="GroupBy")

def fillna(
self,
Expand Down Expand Up @@ -1092,10 +1056,7 @@ def __getitem__(self, key):
if is_list_like(key):
make_dataframe = True
else:
if self._as_index:
make_dataframe = False
else:
make_dataframe = True
make_dataframe = False
key = [key]

column_index = self._df.columns
Expand Down Expand Up @@ -1427,7 +1388,10 @@ def unique(self):

def size(self):
# TODO: Remove this once SNOW-1478924 is fixed
return super().size().rename(self._df.columns[-1])
if self._as_index:
return super().size().rename(self._df.columns[-1])
else:
return pd.DataFrame(super().size()).T

def value_counts(
self,
Expand Down
17 changes: 17 additions & 0 deletions tests/integ/modin/groupby/test_groupby_named_agg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import re

import modin.pandas as pd
import numpy as np
import pandas as native_pd
import pytest

Expand Down Expand Up @@ -114,6 +115,22 @@ def test_named_agg_passed_in_via_star_kwargs(basic_df_data):
)


@sql_count_checker(query_count=1)
def test_named_agg_count_vs_size():
data = [[1, 2, 3], [1, 5, np.nan], [7, np.nan, 9]]
native_df = native_pd.DataFrame(
data, columns=["a", "b", "c"], index=["owl", "toucan", "eagle"]
)
snow_df = pd.DataFrame(native_df)
eval_snowpark_pandas_result(
snow_df,
native_df,
lambda df: df.groupby("a").agg(
l=("b", "size"), j=("c", "size"), m=("c", "count"), n=("b", "count")
),
)


@sql_count_checker(query_count=0)
def test_named_agg_with_invalid_function_raises_not_implemented(
basic_df_data,
Expand Down
107 changes: 107 additions & 0 deletions tests/integ/modin/groupby/test_groupby_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,92 @@ def test_groupby_size(by, as_index):
)


@pytest.mark.parametrize(
"by",
[
"col1_grp",
"col2_int64",
"col3_int_identical",
"col4_int32",
"col6_mixed",
"col7_bool",
"col8_bool_missing",
"col9_int_missing",
"col10_mixed_missing",
["col1_grp", "col2_int64"],
["col6_mixed", "col7_bool", "col3_int_identical"],
],
)
@pytest.mark.parametrize("as_index", [True, False])
def test_groupby_agg_size(by, as_index):
snowpark_pandas_df = pd.DataFrame(
{
"col1_grp": ["g1", "g2", "g0", "g0", "g2", "g3", "g0", "g2", "g3"],
"col2_int64": np.arange(9, dtype="int64") // 3,
"col3_int_identical": [2] * 9,
"col4_int32": np.arange(9, dtype="int32") // 4,
"col5_int16": np.arange(9, dtype="int16") // 3,
"col6_mixed": np.concatenate(
[
np.arange(3, dtype="int64") // 3,
np.arange(3, dtype="int32") // 3,
np.arange(3, dtype="int16") // 3,
]
),
"col7_bool": [True] * 5 + [False] * 4,
"col8_bool_missing": [
True,
None,
False,
False,
None,
None,
True,
False,
None,
],
"col9_int_missing": [5, 6, np.nan, 2, 1, np.nan, 5, np.nan, np.nan],
"col10_mixed_missing": np.concatenate(
[
np.arange(2, dtype="int64") // 3,
[np.nan],
np.arange(2, dtype="int32") // 3,
[np.nan],
np.arange(2, dtype="int16") // 3,
[np.nan],
]
),
}
)

pandas_df = snowpark_pandas_df.to_pandas()
with SqlCounter(query_count=1):
eval_snowpark_pandas_result(
snowpark_pandas_df,
pandas_df,
lambda df: df.groupby(by, as_index=as_index).agg({"col5_int16": "size"}),
)

# DataFrame with __getitem__
# if as_index:
# In this case, calling get item returns a DataFrame.
with SqlCounter(query_count=1):
eval_snowpark_pandas_result(
snowpark_pandas_df,
pandas_df,
lambda df: df.groupby(by, as_index=as_index)["col5_int16"].agg(
new_result="size"
),
)
# else:
# with SqlCounter(query_count=1):
# eval_snowpark_pandas_result(
# snowpark_pandas_df,
# pandas_df,
# lambda df: df.groupby(by, as_index=as_index)["col5_int16"].agg(new_result=('col5_int16', 'size')),
# )


@sql_count_checker(query_count=0)
def test_error_checking():
s = pd.Series(list("abc") * 4)
Expand All @@ -111,3 +197,24 @@ def test_timedelta(by):
native_df,
lambda df: df.groupby(by).size(),
)


@pytest.mark.parametrize("by", ["A", "B"])
@sql_count_checker(query_count=1)
def test_timedelta_agg(by):
native_df = native_pd.DataFrame(
{
"A": native_pd.to_timedelta(
["1 days 06:05:01.00003", "16us", "nan", "16us"]
),
"B": [8, 8, 12, 10],
"C": ["the", "name", "is", "bond"],
}
)
snow_df = pd.DataFrame(native_df)

eval_snowpark_pandas_result(
snow_df,
native_df,
lambda df: df.groupby(by).agg(d=pd.NamedAgg("A" if by != "A" else "C", "size")),
)

0 comments on commit 9a4f455

Please sign in to comment.