diff --git a/CHANGELOG.md b/CHANGELOG.md index 1b98bc5d40..9e6038a832 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,7 @@ - Added support for tracking usages of `__array_ufunc__`. - Added numpy compatibility support for `np.float_power`, `np.mod`, `np.remainder`, `np.greater`, `np.greater_equal`, `np.less`, `np.less_equal`, `np.not_equal`, and `np.equal`. - Added support for `DataFrameGroupBy.bfill`, `SeriesGroupBy.bfill`, `DataFrameGroupBy.ffill`, and `SeriesGroupBy.ffill`. +- Added support for `on` parameter with `Resampler`. #### Improvements diff --git a/docs/source/modin/supported/dataframe_supported.rst b/docs/source/modin/supported/dataframe_supported.rst index 5689aec7d3..3c41dc0639 100644 --- a/docs/source/modin/supported/dataframe_supported.rst +++ b/docs/source/modin/supported/dataframe_supported.rst @@ -349,7 +349,7 @@ Methods | | | ``limit`` | | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ | ``resample`` | P | ``axis``, ``label``, | Only DatetimeIndex is supported and its ``freq`` | -| | | ``convention``, ``kind``, ``on`` | will be lost. ``rule`` frequencies 's', 'min', | +| | | ``convention``, ``kind``, | will be lost. ``rule`` frequencies 's', 'min', | | | | , ``level``, ``origin``, | 'h', and 'D' are supported. ``rule`` frequencies | | | | , ``offset``, ``group_keys`` | 'W', 'ME', and 'YE' are supported with | | | | | `closed = "left"` | diff --git a/docs/source/modin/supported/series_supported.rst b/docs/source/modin/supported/series_supported.rst index b74bd40432..d11d3303d6 100644 --- a/docs/source/modin/supported/series_supported.rst +++ b/docs/source/modin/supported/series_supported.rst @@ -342,7 +342,7 @@ Methods | ``replace`` | P | ``method``, ``limit`` | | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ | ``resample`` | P | ``axis``, ``label``, | Only DatetimeIndex is supported and its ``freq`` | -| | | ``convention``, ``kind``, ``on`` | will be lost. ``rule`` frequencies 's', 'min', | +| | | ``convention``, ``kind``, | will be lost. ``rule`` frequencies 's', 'min', | | | | , ``level``, ``origin``, | 'h', and 'D' are supported. ``rule`` frequencies | | | | , ``offset``, ``group_keys`` | 'W', 'ME', and 'YE' are supported with | | | | | `closed = "left"` | diff --git a/src/snowflake/snowpark/modin/plugin/_internal/resample_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/resample_utils.py index e9a29a032c..ab4bf81cbe 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/resample_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/resample_utils.py @@ -249,10 +249,6 @@ def validate_resample_supported_by_snowflake( if kind is not None: # pragma: no cover _argument_not_implemented("kind", kind) - on = resample_kwargs.get("on") - if on is not None: # pragma: no cover - _argument_not_implemented("on", on) - level = resample_kwargs.get("level") if level is not None: # pragma: no cover _argument_not_implemented("level", level) @@ -304,7 +300,7 @@ def get_snowflake_quoted_identifier_for_resample_index_col(frame: InternalFrame) sf_type = frame.get_snowflake_type(index_col) if not isinstance(sf_type, (TimestampType, DateType)): - raise TypeError("Only valid with DatetimeIndex.") + raise TypeError("Only valid with DatetimeIndex or TimedeltaIndex") return index_col @@ -347,7 +343,11 @@ def time_slice( def perform_resample_binning_on_frame( - frame: InternalFrame, start_date: str, bin_size: str + frame: InternalFrame, + datetime_index_col_identifier: str, + start_date: str, + slice_width: int, + slice_unit: str, ) -> InternalFrame: """ Returns a new dataframe where each item of the index column @@ -359,21 +359,25 @@ def perform_resample_binning_on_frame( The internal frame with a single DatetimeIndex column to perform resample binning on. + datetime_index_col_identifier : str + The datetime-like column snowflake quoted identifier to use for resampling. + start_date : str The earliest date in the Datetime index column of `frame`. - bin_size : str - The offset string or object representing target conversion. + slice_width : int + Width of the slice (i.e. how many units of time are contained in the slice). + + slice_unit : str + Time unit for the slice length. Returns ------- frame : InternalFrame A new internal frame where items in the index column are - placed in a bin based on `bin_length` and `bin_unit` + placed in a bin based on `slice_width` and `slice_unit` """ - - slice_width, slice_unit = rule_to_snowflake_width_and_slice_unit(bin_size) # Consider the following example: # frame: # data_col @@ -387,9 +391,7 @@ def perform_resample_binning_on_frame( # 2023-08-15 7 # 2023-08-16 8 # 2023-08-17 9 - # start_date = 2023-08-07, bin_size = 3D (3 days) - - datetime_index_col = get_snowflake_quoted_identifier_for_resample_index_col(frame) + # start_date = 2023-08-07, rule = 3D (3 days) # Time slices in Snowflake are aligned to snowflake_timeslice_alignment_date, # so we must normalize input datetimes. @@ -399,7 +401,11 @@ def perform_resample_binning_on_frame( # Subtract the normalization amount in seconds from the input datetime. normalized_dates = to_timestamp_ntz( - datediff("second", to_timestamp_ntz(lit(normalization_amt)), datetime_index_col) + datediff( + "second", + to_timestamp_ntz(lit(normalization_amt)), + datetime_index_col_identifier, + ) ) # frame: # data_col @@ -460,7 +466,7 @@ def perform_resample_binning_on_frame( # 2023-08-16 9 return frame.update_snowflake_quoted_identifiers_with_expressions( - {datetime_index_col: unnormalized_dates_set_to_bins} + {datetime_index_col_identifier: unnormalized_dates_set_to_bins} ).frame diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py index f48f328217..ba0c792cea 100644 --- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py +++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py @@ -12322,7 +12322,17 @@ def resample( validate_resample_supported_by_snowflake(resample_kwargs) - frame = self._modin_frame + axis = resample_kwargs.get("axis", 0) + rule = resample_kwargs.get("rule") + on = resample_kwargs.get("on") + + # Supplying 'on' to Resampler replaces the existing index of the DataFrame with the 'on' column + if on is not None: + if on not in self._modin_frame.data_column_pandas_labels: + raise KeyError(f"{on}") + frame = self.set_index(keys=[on])._modin_frame + else: + frame = self._modin_frame if resample_method in ("var", np.var) and any( isinstance(t, TimedeltaType) @@ -12334,8 +12344,6 @@ def resample( get_snowflake_quoted_identifier_for_resample_index_col(frame) ) - rule = resample_kwargs.get("rule") - slice_width, slice_unit = rule_to_snowflake_width_and_slice_unit(rule) min_max_index_column_quoted_identifier = ( @@ -12405,13 +12413,19 @@ def resample( ) return SnowflakeQueryCompiler(output_frame).set_index_names(index_name) elif resample_method in IMPLEMENTED_AGG_METHODS: - frame = perform_resample_binning_on_frame(frame, start_date, rule) + resampled_frame = perform_resample_binning_on_frame( + frame=frame, + datetime_index_col_identifier=snowflake_index_column_identifier, + start_date=start_date, + slice_width=slice_width, + slice_unit=slice_unit, + ) if resample_method == "indices": # Convert groupby_indices output of dict[Hashable, np.ndarray] to # collections.defaultdict - result_dict = SnowflakeQueryCompiler(frame).groupby_indices( - by=self._modin_frame.index_column_pandas_labels, - axis=resample_kwargs.get("axis", 0), + result_dict = SnowflakeQueryCompiler(resampled_frame).groupby_indices( + by=frame.index_column_pandas_labels, + axis=axis, groupby_kwargs=dict(), values_as_np_array=False, ) @@ -12420,33 +12434,34 @@ def resample( # Call groupby_size directly on the dataframe or series with the index reset # to ensure we perform count aggregation on row positions which cannot be null qc = ( - SnowflakeQueryCompiler(frame) + SnowflakeQueryCompiler(resampled_frame) .reset_index() .groupby_size( - by="index", - axis=resample_kwargs.get("axis", 0), + by=on if on is not None else "index", + axis=axis, groupby_kwargs=dict(), agg_args=resample_method_args, agg_kwargs=resample_method_kwargs, ) - .set_index_names([None]) + .set_index_names(frame.index_column_pandas_labels) ) elif resample_method in ("first", "last"): # Call groupby_first or groupby_last directly qc = getattr( - SnowflakeQueryCompiler(frame), f"groupby_{resample_method}" + SnowflakeQueryCompiler(resampled_frame), + f"groupby_{resample_method}", )( - by=self._modin_frame.index_column_pandas_labels, - axis=resample_kwargs.get("axis", 0), + by=frame.index_column_pandas_labels, + axis=axis, groupby_kwargs=dict(), agg_args=resample_method_args, agg_kwargs=resample_method_kwargs, ) else: - qc = SnowflakeQueryCompiler(frame).groupby_agg( - by=self._modin_frame.index_column_pandas_labels, + qc = SnowflakeQueryCompiler(resampled_frame).groupby_agg( + by=frame.index_column_pandas_labels, agg_func=resample_method, - axis=resample_kwargs.get("axis", 0), + axis=axis, groupby_kwargs=dict(), agg_args=resample_method_args, agg_kwargs=resample_method_kwargs, @@ -12454,7 +12469,7 @@ def resample( is_series_groupby=is_series, ) - frame = fill_missing_resample_bins_for_frame( + resampled_frame_all_bins = fill_missing_resample_bins_for_frame( qc._modin_frame, rule, start_date, end_date ) if resample_method in ("sum", "count", "size", "nunique"): @@ -12463,7 +12478,9 @@ def resample( # For sum(), we need to fill NaN values as Timedelta(0) # for timedelta columns and as 0 for other columns. values_arg = {} - for pandas_label in frame.data_column_pandas_labels: + for ( + pandas_label + ) in resampled_frame_all_bins.data_column_pandas_labels: label_dtypes: native_pd.Series = self.dtypes[[pandas_label]] # query compiler's fillna() takes a dictionary mapping # pandas labels to values. When we have two columns @@ -12485,7 +12502,7 @@ def resample( values_arg = list(values_arg.values())[0] else: values_arg = 0 - return SnowflakeQueryCompiler(frame).fillna( + return SnowflakeQueryCompiler(resampled_frame_all_bins).fillna( value=values_arg, self_is_series=is_series ) else: @@ -12493,7 +12510,7 @@ def resample( f"Resample Method {resample_method} has not been implemented." ) - return SnowflakeQueryCompiler(frame) + return SnowflakeQueryCompiler(resampled_frame_all_bins) def value_counts_index( self, diff --git a/tests/integ/modin/resample/test_resample_on.py b/tests/integ/modin/resample/test_resample_on.py new file mode 100644 index 0000000000..40ee75d8a9 --- /dev/null +++ b/tests/integ/modin/resample/test_resample_on.py @@ -0,0 +1,96 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +import modin.pandas as pd +import numpy as np +import pandas as native_pd +import pytest + +import snowflake.snowpark.modin.plugin # noqa: F401 +from snowflake.snowpark.modin.plugin._internal.resample_utils import ( + IMPLEMENTED_AGG_METHODS, + IMPLEMENTED_DATEOFFSET_STRINGS, +) +from tests.integ.modin.utils import create_test_dfs, eval_snowpark_pandas_result +from tests.integ.utils.sql_counter import sql_count_checker + +agg_func = pytest.mark.parametrize( + "agg_func", list(filter(lambda x: x not in ["indices"], IMPLEMENTED_AGG_METHODS)) +) +freq = pytest.mark.parametrize("freq", IMPLEMENTED_DATEOFFSET_STRINGS) + + +@freq +@agg_func +# One extra query to convert index to native pandas for dataframe constructor +@sql_count_checker(query_count=3, join_count=1) +def test_resample_on(freq, agg_func): + rule = f"2{freq}" + # Note that supplying 'on' to Resampler replaces the existing index of the DataFrame with the 'on' column + eval_snowpark_pandas_result( + *create_test_dfs( + { + "A": np.random.randn(15), + "B": native_pd.date_range("2020-01-01", periods=15, freq=f"1{freq}"), + }, + index=native_pd.date_range("2020-10-01", periods=15, freq=f"1{freq}"), + ), + lambda df: getattr(df.resample(rule=rule, on="B", closed="left"), agg_func)(), + check_freq=False, + ) + + +# One extra query to convert index to native pandas for dataframe constructor +@sql_count_checker(query_count=3, join_count=1) +def test_resample_hashable_on(): + eval_snowpark_pandas_result( + *create_test_dfs( + { + "A": np.random.randn(15), + 1: native_pd.date_range("2020-01-01", periods=15, freq="1s"), + }, + index=native_pd.date_range("2020-10-01", periods=15, freq="1s"), + ), + lambda df: df.resample(rule="2s", on=1, closed="left").min(), + check_freq=False, + ) + + +@sql_count_checker(query_count=0) +def test_resample_non_datetime_on(): + native_df = native_pd.DataFrame( + data={ + "A": np.random.randn(15), + "B": native_pd.date_range("2020-01-01", periods=15, freq="1s"), + }, + index=native_pd.date_range("2020-10-01", periods=15, freq="1s"), + ) + snow_df = pd.DataFrame(native_df) + with pytest.raises( + TypeError, + match="Only valid with DatetimeIndex, TimedeltaIndex or PeriodIndex, but got an instance of 'Index'", + ): + native_df.resample(rule="2s", on="A").min() + with pytest.raises( + TypeError, match="Only valid with DatetimeIndex or TimedeltaIndex" + ): + snow_df.resample(rule="2s", on="A").min().to_pandas() + + +@sql_count_checker(query_count=1) +# One query to get the Modin frame data column pandas labels +def test_resample_invalid_on(): + eval_snowpark_pandas_result( + *create_test_dfs( + { + "A": np.random.randn(15), + "B": native_pd.date_range("2020-01-01", periods=15, freq="1s"), + }, + index=native_pd.date_range("2020-10-01", periods=15, freq="1s"), + ), + lambda df: df.resample(rule="2s", on="invalid", closed="left").min(), + expect_exception=True, + expect_exception_type=KeyError, + expect_exception_match="invalid", + )