From 4005f201687fffba47d03d6beb24b68c4fc34c2e Mon Sep 17 00:00:00 2001 From: Rehan Durrani Date: Thu, 17 Oct 2024 17:24:18 -0700 Subject: [PATCH 1/3] [SNOW-1478924]: Ensure `__getitem__` on DataFrameGroupBy returns SeriesGroupBy when appropriate even if `as_index=False` --- CHANGELOG.md | 1 + .../modin/plugin/extensions/groupby_overrides.py | 13 +++++++------ 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a5e03e6324..5abb20149e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -41,6 +41,7 @@ - Fixed a bug where `replace()` would sometimes propagate `Timedelta` types incorrectly through `replace()`. Instead raise `NotImplementedError` for `replace()` on `Timedelta`. - Fixed a bug where `DataFrame` and `Series` `round()` would raise `AssertionError` for `Timedelta` columns. Instead raise `NotImplementedError` for `round()` on `Timedelta`. - Fixed a bug where `reindex` fails when the new index is a Series with non-overlapping types from the original index. +- Fixed a bug where calling `__getitem__` on a DataFrameGroupBy object always returned a DataFrameGroupBy object if `as_index=False`. ## 1.23.0 (2024-10-09) diff --git a/src/snowflake/snowpark/modin/plugin/extensions/groupby_overrides.py b/src/snowflake/snowpark/modin/plugin/extensions/groupby_overrides.py index 6d3728265d..9ca2d6faa0 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/groupby_overrides.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/groupby_overrides.py @@ -1092,10 +1092,7 @@ def __getitem__(self, key): if is_list_like(key): make_dataframe = True else: - if self._as_index: - make_dataframe = False - else: - make_dataframe = True + make_dataframe = False key = [key] column_index = self._df.columns @@ -1267,7 +1264,7 @@ def _wrap_aggregation( numeric_only = False if is_result_dataframe is None: - is_result_dataframe = not is_series_groupby + is_result_dataframe = not is_series_groupby or not self._as_index result_type = pd.DataFrame if is_result_dataframe else pd.Series result = result_type( query_compiler=qc_method( @@ -1427,7 +1424,11 @@ def unique(self): def size(self): # TODO: Remove this once SNOW-1478924 is fixed - return super().size().rename(self._df.columns[-1]) + result = super().size() + if isinstance(result, Series): + return result.rename(self._df.columns[-1]) + else: + return result def value_counts( self, From 75b9cb3ca689c42e9d48f261c4f6f3e9ad8565a5 Mon Sep 17 00:00:00 2001 From: Rehan Durrani Date: Thu, 17 Oct 2024 17:28:45 -0700 Subject: [PATCH 2/3] Fix comments --- .../modin/plugin/extensions/groupby_overrides.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/snowflake/snowpark/modin/plugin/extensions/groupby_overrides.py b/src/snowflake/snowpark/modin/plugin/extensions/groupby_overrides.py index 9ca2d6faa0..6558aa1cb3 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/groupby_overrides.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/groupby_overrides.py @@ -1082,12 +1082,13 @@ def __getitem__(self, key): "idx_name": self._idx_name, } # The rules of type deduction for the resulted object is the following: - # 1. If `key` is a list-like or `as_index is False`, then the resulted object is a DataFrameGroupBy + # 1. If `key` is a list-like, then the resulted object is a DataFrameGroupBy # 2. Otherwise, the resulted object is SeriesGroupBy # 3. Result type does not depend on the `by` origin # Examples: # - drop: any, as_index: any, __getitem__(key: list_like) -> DataFrameGroupBy - # - drop: any, as_index: False, __getitem__(key: any) -> DataFrameGroupBy + # - drop: any, as_index: False, __getitem__(key: list_like) -> DataFrameGroupBy + # - drop: any, as_index: False, __getitem__(key: label) -> SeriesGroupBy # - drop: any, as_index: True, __getitem__(key: label) -> SeriesGroupBy if is_list_like(key): make_dataframe = True @@ -1264,6 +1265,11 @@ def _wrap_aggregation( numeric_only = False if is_result_dataframe is None: + # If the GroupBy object is a SeriesGroupBy, we generally return a Series + # after an aggregation - unless `as_index` is False, in which case we + # return a DataFrame with N columns, where the first N-1 columns are + # the grouping columns (by), and the Nth column is the aggregation + # result. is_result_dataframe = not is_series_groupby or not self._as_index result_type = pd.DataFrame if is_result_dataframe else pd.Series result = result_type( From 4395f750bec9d47cb011a3910f4fd0273ae98f4f Mon Sep 17 00:00:00 2001 From: Rehan Durrani Date: Fri, 18 Oct 2024 12:04:26 -0700 Subject: [PATCH 3/3] Add tests --- .../modin/groupby/test_groupby_getitem.py | 57 +++++++++++++++++++ 1 file changed, 57 insertions(+) create mode 100644 tests/integ/modin/groupby/test_groupby_getitem.py diff --git a/tests/integ/modin/groupby/test_groupby_getitem.py b/tests/integ/modin/groupby/test_groupby_getitem.py new file mode 100644 index 0000000000..e07cb238a4 --- /dev/null +++ b/tests/integ/modin/groupby/test_groupby_getitem.py @@ -0,0 +1,57 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# +import modin.pandas as pd +import numpy as np +import pandas as native_pd +import pytest +from pandas.core.groupby.generic import ( + DataFrameGroupBy as native_df_groupby, + SeriesGroupBy as native_ser_groupby, +) + +import snowflake.snowpark.modin.plugin # noqa: F401 +from snowflake.snowpark.modin.plugin.extensions.groupby_overrides import ( + DataFrameGroupBy as snow_df_groupby, + SeriesGroupBy as snow_ser_groupby, +) +from tests.integ.utils.sql_counter import sql_count_checker + +data_dictionary = { + "col1_grp": ["g1", "g2", "g0", "g0", "g2", "g3", "g0", "g2", "g3"], + "col2_int64": np.arange(9, dtype="int64") // 3, + "col3_int_identical": [2] * 9, + "col4_int32": np.arange(9, dtype="int32") // 4, + "col5_int16": np.arange(9, dtype="int16") // 3, +} + + +def check_groupby_types_same(native_groupby, snow_groupby): + if isinstance(native_groupby, native_df_groupby): + assert isinstance(snow_groupby, snow_df_groupby) + elif isinstance(native_groupby, native_ser_groupby): + assert isinstance(snow_groupby, snow_ser_groupby) + else: + raise ValueError( + f"Unknown GroupBy type for native pandas: {type(native_groupby)}. Snowpark pandas GroupBy type: {type(snow_groupby)}" + ) + + +@pytest.mark.parametrize( + "by", + [ + "col1_grp", + ["col1_grp", "col2_int64"], + ["col1_grp", "col2_int64", "col3_int_identical"], + ], +) +@pytest.mark.parametrize("as_index", [True, False]) +@pytest.mark.parametrize("indexer", ["col5_int16", ["col5_int16", "col4_int32"]]) +@sql_count_checker(query_count=0) +def test_groupby_getitem(by, as_index, indexer): + snow_df = pd.DataFrame(data_dictionary) + native_df = native_pd.DataFrame(data_dictionary) + check_groupby_types_same( + native_df.groupby(by=by, as_index=as_index)[indexer], + snow_df.groupby(by, as_index=as_index)[indexer], + )