-
Notifications
You must be signed in to change notification settings - Fork 110
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[SNOW-1478924]: Ensure
__getitem__
on DataFrameGroupBy returns Seri…
…esGroupBy when appropriate even if `as_index=False` (#2475) <!--- Please answer these questions before creating your pull request. Thanks! ---> 1. Which Jira issue is this PR addressing? Make sure that there is an accompanying issue to your PR. <!--- In this section, please add a Snowflake Jira issue number. Note that if a corresponding GitHub issue exists, you should still include the Snowflake Jira issue number. For example, for GitHub issue #1400, you should add "SNOW-1335071" here. ---> Fixes SNOW-1478924 2. Fill out the following pre-review checklist: - [ ] I am adding a new automated test(s) to verify correctness of my new code - [ ] If this test skips Local Testing mode, I'm requesting review from @snowflakedb/local-testing - [ ] I am adding new logging messages - [ ] I am adding a new telemetry message - [ ] I am adding new credentials - [ ] I am adding a new dependency - [ ] If this is a new feature/behavior, I'm adding the Local Testing parity changes. - [ ] I acknowledge that I have ensured my changes to be thread-safe. Follow the link for more information: [Thread-safe Developer Guidelines](https://docs.google.com/document/d/162d_i4zZ2AfcGRXojj0jByt8EUq-DrSHPPnTa4QvwbA/edit#bookmark=id.e82u4nekq80k) 3. Please describe how your code solves the related issue. Ensure that getitem on a DataFrameGroupBy returns a SeriesGroupBy if only one column is indexed, even if `as_index=False`.
- Loading branch information
1 parent
22f8398
commit afa7d7d
Showing
3 changed files
with
73 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
# | ||
# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. | ||
# | ||
import modin.pandas as pd | ||
import numpy as np | ||
import pandas as native_pd | ||
import pytest | ||
from pandas.core.groupby.generic import ( | ||
DataFrameGroupBy as native_df_groupby, | ||
SeriesGroupBy as native_ser_groupby, | ||
) | ||
|
||
import snowflake.snowpark.modin.plugin # noqa: F401 | ||
from snowflake.snowpark.modin.plugin.extensions.groupby_overrides import ( | ||
DataFrameGroupBy as snow_df_groupby, | ||
SeriesGroupBy as snow_ser_groupby, | ||
) | ||
from tests.integ.utils.sql_counter import sql_count_checker | ||
|
||
data_dictionary = { | ||
"col1_grp": ["g1", "g2", "g0", "g0", "g2", "g3", "g0", "g2", "g3"], | ||
"col2_int64": np.arange(9, dtype="int64") // 3, | ||
"col3_int_identical": [2] * 9, | ||
"col4_int32": np.arange(9, dtype="int32") // 4, | ||
"col5_int16": np.arange(9, dtype="int16") // 3, | ||
} | ||
|
||
|
||
def check_groupby_types_same(native_groupby, snow_groupby): | ||
if isinstance(native_groupby, native_df_groupby): | ||
assert isinstance(snow_groupby, snow_df_groupby) | ||
elif isinstance(native_groupby, native_ser_groupby): | ||
assert isinstance(snow_groupby, snow_ser_groupby) | ||
else: | ||
raise ValueError( | ||
f"Unknown GroupBy type for native pandas: {type(native_groupby)}. Snowpark pandas GroupBy type: {type(snow_groupby)}" | ||
) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"by", | ||
[ | ||
"col1_grp", | ||
["col1_grp", "col2_int64"], | ||
["col1_grp", "col2_int64", "col3_int_identical"], | ||
], | ||
) | ||
@pytest.mark.parametrize("as_index", [True, False]) | ||
@pytest.mark.parametrize("indexer", ["col5_int16", ["col5_int16", "col4_int32"]]) | ||
@sql_count_checker(query_count=0) | ||
def test_groupby_getitem(by, as_index, indexer): | ||
snow_df = pd.DataFrame(data_dictionary) | ||
native_df = native_pd.DataFrame(data_dictionary) | ||
check_groupby_types_same( | ||
native_df.groupby(by=by, as_index=as_index)[indexer], | ||
snow_df.groupby(by, as_index=as_index)[indexer], | ||
) |