Skip to content

Commit

Permalink
[SNOW-1478924]: Ensure __getitem__ on DataFrameGroupBy returns Seri…
Browse files Browse the repository at this point in the history
…esGroupBy when appropriate even if `as_index=False` (#2475)

<!---
Please answer these questions before creating your pull request. Thanks!
--->

1. Which Jira issue is this PR addressing? Make sure that there is an
accompanying issue to your PR.

   <!---
   In this section, please add a Snowflake Jira issue number.

Note that if a corresponding GitHub issue exists, you should still
include
   the Snowflake Jira issue number. For example, for GitHub issue
#1400, you should
   add "SNOW-1335071" here.
    --->

   Fixes SNOW-1478924

2. Fill out the following pre-review checklist:

- [ ] I am adding a new automated test(s) to verify correctness of my
new code
- [ ] If this test skips Local Testing mode, I'm requesting review from
@snowflakedb/local-testing
   - [ ] I am adding new logging messages
   - [ ] I am adding a new telemetry message
   - [ ] I am adding new credentials
   - [ ] I am adding a new dependency
- [ ] If this is a new feature/behavior, I'm adding the Local Testing
parity changes.
- [ ] I acknowledge that I have ensured my changes to be thread-safe.
Follow the link for more information: [Thread-safe Developer
Guidelines](https://docs.google.com/document/d/162d_i4zZ2AfcGRXojj0jByt8EUq-DrSHPPnTa4QvwbA/edit#bookmark=id.e82u4nekq80k)

3. Please describe how your code solves the related issue.

Ensure that getitem on a DataFrameGroupBy returns a SeriesGroupBy if
only one column is indexed, even if `as_index=False`.
  • Loading branch information
sfc-gh-rdurrani authored Oct 18, 2024
1 parent 22f8398 commit afa7d7d
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 8 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
- Fixed a bug where `replace()` would sometimes propagate `Timedelta` types incorrectly through `replace()`. Instead raise `NotImplementedError` for `replace()` on `Timedelta`.
- Fixed a bug where `DataFrame` and `Series` `round()` would raise `AssertionError` for `Timedelta` columns. Instead raise `NotImplementedError` for `round()` on `Timedelta`.
- Fixed a bug where `reindex` fails when the new index is a Series with non-overlapping types from the original index.
- Fixed a bug where calling `__getitem__` on a DataFrameGroupBy object always returned a DataFrameGroupBy object if `as_index=False`.

### Snowpark Local Testing Updates

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1082,20 +1082,18 @@ def __getitem__(self, key):
"idx_name": self._idx_name,
}
# The rules of type deduction for the resulted object is the following:
# 1. If `key` is a list-like or `as_index is False`, then the resulted object is a DataFrameGroupBy
# 1. If `key` is a list-like, then the resulted object is a DataFrameGroupBy
# 2. Otherwise, the resulted object is SeriesGroupBy
# 3. Result type does not depend on the `by` origin
# Examples:
# - drop: any, as_index: any, __getitem__(key: list_like) -> DataFrameGroupBy
# - drop: any, as_index: False, __getitem__(key: any) -> DataFrameGroupBy
# - drop: any, as_index: False, __getitem__(key: list_like) -> DataFrameGroupBy
# - drop: any, as_index: False, __getitem__(key: label) -> SeriesGroupBy
# - drop: any, as_index: True, __getitem__(key: label) -> SeriesGroupBy
if is_list_like(key):
make_dataframe = True
else:
if self._as_index:
make_dataframe = False
else:
make_dataframe = True
make_dataframe = False
key = [key]

column_index = self._df.columns
Expand Down Expand Up @@ -1267,7 +1265,12 @@ def _wrap_aggregation(
numeric_only = False

if is_result_dataframe is None:
is_result_dataframe = not is_series_groupby
# If the GroupBy object is a SeriesGroupBy, we generally return a Series
# after an aggregation - unless `as_index` is False, in which case we
# return a DataFrame with N columns, where the first N-1 columns are
# the grouping columns (by), and the Nth column is the aggregation
# result.
is_result_dataframe = not is_series_groupby or not self._as_index
result_type = pd.DataFrame if is_result_dataframe else pd.Series
result = result_type(
query_compiler=qc_method(
Expand Down Expand Up @@ -1427,7 +1430,11 @@ def unique(self):

def size(self):
# TODO: Remove this once SNOW-1478924 is fixed
return super().size().rename(self._df.columns[-1])
result = super().size()
if isinstance(result, Series):
return result.rename(self._df.columns[-1])
else:
return result

def value_counts(
self,
Expand Down
57 changes: 57 additions & 0 deletions tests/integ/modin/groupby/test_groupby_getitem.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#
# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved.
#
import modin.pandas as pd
import numpy as np
import pandas as native_pd
import pytest
from pandas.core.groupby.generic import (
DataFrameGroupBy as native_df_groupby,
SeriesGroupBy as native_ser_groupby,
)

import snowflake.snowpark.modin.plugin # noqa: F401
from snowflake.snowpark.modin.plugin.extensions.groupby_overrides import (
DataFrameGroupBy as snow_df_groupby,
SeriesGroupBy as snow_ser_groupby,
)
from tests.integ.utils.sql_counter import sql_count_checker

data_dictionary = {
"col1_grp": ["g1", "g2", "g0", "g0", "g2", "g3", "g0", "g2", "g3"],
"col2_int64": np.arange(9, dtype="int64") // 3,
"col3_int_identical": [2] * 9,
"col4_int32": np.arange(9, dtype="int32") // 4,
"col5_int16": np.arange(9, dtype="int16") // 3,
}


def check_groupby_types_same(native_groupby, snow_groupby):
if isinstance(native_groupby, native_df_groupby):
assert isinstance(snow_groupby, snow_df_groupby)
elif isinstance(native_groupby, native_ser_groupby):
assert isinstance(snow_groupby, snow_ser_groupby)
else:
raise ValueError(
f"Unknown GroupBy type for native pandas: {type(native_groupby)}. Snowpark pandas GroupBy type: {type(snow_groupby)}"
)


@pytest.mark.parametrize(
"by",
[
"col1_grp",
["col1_grp", "col2_int64"],
["col1_grp", "col2_int64", "col3_int_identical"],
],
)
@pytest.mark.parametrize("as_index", [True, False])
@pytest.mark.parametrize("indexer", ["col5_int16", ["col5_int16", "col4_int32"]])
@sql_count_checker(query_count=0)
def test_groupby_getitem(by, as_index, indexer):
snow_df = pd.DataFrame(data_dictionary)
native_df = native_pd.DataFrame(data_dictionary)
check_groupby_types_same(
native_df.groupby(by=by, as_index=as_index)[indexer],
snow_df.groupby(by, as_index=as_index)[indexer],
)

0 comments on commit afa7d7d

Please sign in to comment.