Skip to content

Commit

Permalink
SNOW-1747461: Support on parameter with Resampler (#2466)
Browse files Browse the repository at this point in the history
SNOW-1747461

[x] I acknowledge that I have ensured my changes to be thread-safe

This PR adds support for the `on` parameter with `Resampler` and does
some small cleanup removing repeated calls to utility functions for
Resample methods.

---------

Signed-off-by: Naren Krishna <naren.krishna@snowflake.com>
  • Loading branch information
sfc-gh-nkrishna authored Oct 17, 2024
1 parent 50a9dcf commit 47cadd2
Show file tree
Hide file tree
Showing 6 changed files with 159 additions and 39 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
- Added support for tracking usages of `__array_ufunc__`.
- Added numpy compatibility support for `np.float_power`, `np.mod`, `np.remainder`, `np.greater`, `np.greater_equal`, `np.less`, `np.less_equal`, `np.not_equal`, and `np.equal`.
- Added support for `DataFrameGroupBy.bfill`, `SeriesGroupBy.bfill`, `DataFrameGroupBy.ffill`, and `SeriesGroupBy.ffill`.
- Added support for `on` parameter with `Resampler`.

#### Improvements

Expand Down
2 changes: 1 addition & 1 deletion docs/source/modin/supported/dataframe_supported.rst
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ Methods
| | | ``limit`` | |
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
| ``resample`` | P | ``axis``, ``label``, | Only DatetimeIndex is supported and its ``freq`` |
| | | ``convention``, ``kind``, ``on`` | will be lost. ``rule`` frequencies 's', 'min', |
| | | ``convention``, ``kind``, | will be lost. ``rule`` frequencies 's', 'min', |
| | | , ``level``, ``origin``, | 'h', and 'D' are supported. ``rule`` frequencies |
| | | , ``offset``, ``group_keys`` | 'W', 'ME', and 'YE' are supported with |
| | | | `closed = "left"` |
Expand Down
2 changes: 1 addition & 1 deletion docs/source/modin/supported/series_supported.rst
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ Methods
| ``replace`` | P | ``method``, ``limit`` | |
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
| ``resample`` | P | ``axis``, ``label``, | Only DatetimeIndex is supported and its ``freq`` |
| | | ``convention``, ``kind``, ``on`` | will be lost. ``rule`` frequencies 's', 'min', |
| | | ``convention``, ``kind``, | will be lost. ``rule`` frequencies 's', 'min', |
| | | , ``level``, ``origin``, | 'h', and 'D' are supported. ``rule`` frequencies |
| | | , ``offset``, ``group_keys`` | 'W', 'ME', and 'YE' are supported with |
| | | | `closed = "left"` |
Expand Down
38 changes: 22 additions & 16 deletions src/snowflake/snowpark/modin/plugin/_internal/resample_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,10 +249,6 @@ def validate_resample_supported_by_snowflake(
if kind is not None: # pragma: no cover
_argument_not_implemented("kind", kind)

on = resample_kwargs.get("on")
if on is not None: # pragma: no cover
_argument_not_implemented("on", on)

level = resample_kwargs.get("level")
if level is not None: # pragma: no cover
_argument_not_implemented("level", level)
Expand Down Expand Up @@ -304,7 +300,7 @@ def get_snowflake_quoted_identifier_for_resample_index_col(frame: InternalFrame)
sf_type = frame.get_snowflake_type(index_col)

if not isinstance(sf_type, (TimestampType, DateType)):
raise TypeError("Only valid with DatetimeIndex.")
raise TypeError("Only valid with DatetimeIndex or TimedeltaIndex")

return index_col

Expand Down Expand Up @@ -347,7 +343,11 @@ def time_slice(


def perform_resample_binning_on_frame(
frame: InternalFrame, start_date: str, bin_size: str
frame: InternalFrame,
datetime_index_col_identifier: str,
start_date: str,
slice_width: int,
slice_unit: str,
) -> InternalFrame:
"""
Returns a new dataframe where each item of the index column
Expand All @@ -359,21 +359,25 @@ def perform_resample_binning_on_frame(
The internal frame with a single DatetimeIndex column
to perform resample binning on.
datetime_index_col_identifier : str
The datetime-like column snowflake quoted identifier to use for resampling.
start_date : str
The earliest date in the Datetime index column of
`frame`.
bin_size : str
The offset string or object representing target conversion.
slice_width : int
Width of the slice (i.e. how many units of time are contained in the slice).
slice_unit : str
Time unit for the slice length.
Returns
-------
frame : InternalFrame
A new internal frame where items in the index column are
placed in a bin based on `bin_length` and `bin_unit`
placed in a bin based on `slice_width` and `slice_unit`
"""

slice_width, slice_unit = rule_to_snowflake_width_and_slice_unit(bin_size)
# Consider the following example:
# frame:
# data_col
Expand All @@ -387,9 +391,7 @@ def perform_resample_binning_on_frame(
# 2023-08-15 7
# 2023-08-16 8
# 2023-08-17 9
# start_date = 2023-08-07, bin_size = 3D (3 days)

datetime_index_col = get_snowflake_quoted_identifier_for_resample_index_col(frame)
# start_date = 2023-08-07, rule = 3D (3 days)

# Time slices in Snowflake are aligned to snowflake_timeslice_alignment_date,
# so we must normalize input datetimes.
Expand All @@ -399,7 +401,11 @@ def perform_resample_binning_on_frame(

# Subtract the normalization amount in seconds from the input datetime.
normalized_dates = to_timestamp_ntz(
datediff("second", to_timestamp_ntz(lit(normalization_amt)), datetime_index_col)
datediff(
"second",
to_timestamp_ntz(lit(normalization_amt)),
datetime_index_col_identifier,
)
)
# frame:
# data_col
Expand Down Expand Up @@ -460,7 +466,7 @@ def perform_resample_binning_on_frame(
# 2023-08-16 9

return frame.update_snowflake_quoted_identifiers_with_expressions(
{datetime_index_col: unnormalized_dates_set_to_bins}
{datetime_index_col_identifier: unnormalized_dates_set_to_bins}
).frame


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12322,7 +12322,17 @@ def resample(

validate_resample_supported_by_snowflake(resample_kwargs)

frame = self._modin_frame
axis = resample_kwargs.get("axis", 0)
rule = resample_kwargs.get("rule")
on = resample_kwargs.get("on")

# Supplying 'on' to Resampler replaces the existing index of the DataFrame with the 'on' column
if on is not None:
if on not in self._modin_frame.data_column_pandas_labels:
raise KeyError(f"{on}")
frame = self.set_index(keys=[on])._modin_frame
else:
frame = self._modin_frame

if resample_method in ("var", np.var) and any(
isinstance(t, TimedeltaType)
Expand All @@ -12334,8 +12344,6 @@ def resample(
get_snowflake_quoted_identifier_for_resample_index_col(frame)
)

rule = resample_kwargs.get("rule")

slice_width, slice_unit = rule_to_snowflake_width_and_slice_unit(rule)

min_max_index_column_quoted_identifier = (
Expand Down Expand Up @@ -12405,13 +12413,19 @@ def resample(
)
return SnowflakeQueryCompiler(output_frame).set_index_names(index_name)
elif resample_method in IMPLEMENTED_AGG_METHODS:
frame = perform_resample_binning_on_frame(frame, start_date, rule)
resampled_frame = perform_resample_binning_on_frame(
frame=frame,
datetime_index_col_identifier=snowflake_index_column_identifier,
start_date=start_date,
slice_width=slice_width,
slice_unit=slice_unit,
)
if resample_method == "indices":
# Convert groupby_indices output of dict[Hashable, np.ndarray] to
# collections.defaultdict
result_dict = SnowflakeQueryCompiler(frame).groupby_indices(
by=self._modin_frame.index_column_pandas_labels,
axis=resample_kwargs.get("axis", 0),
result_dict = SnowflakeQueryCompiler(resampled_frame).groupby_indices(
by=frame.index_column_pandas_labels,
axis=axis,
groupby_kwargs=dict(),
values_as_np_array=False,
)
Expand All @@ -12420,41 +12434,42 @@ def resample(
# Call groupby_size directly on the dataframe or series with the index reset
# to ensure we perform count aggregation on row positions which cannot be null
qc = (
SnowflakeQueryCompiler(frame)
SnowflakeQueryCompiler(resampled_frame)
.reset_index()
.groupby_size(
by="index",
axis=resample_kwargs.get("axis", 0),
by=on if on is not None else "index",
axis=axis,
groupby_kwargs=dict(),
agg_args=resample_method_args,
agg_kwargs=resample_method_kwargs,
)
.set_index_names([None])
.set_index_names(frame.index_column_pandas_labels)
)
elif resample_method in ("first", "last"):
# Call groupby_first or groupby_last directly
qc = getattr(
SnowflakeQueryCompiler(frame), f"groupby_{resample_method}"
SnowflakeQueryCompiler(resampled_frame),
f"groupby_{resample_method}",
)(
by=self._modin_frame.index_column_pandas_labels,
axis=resample_kwargs.get("axis", 0),
by=frame.index_column_pandas_labels,
axis=axis,
groupby_kwargs=dict(),
agg_args=resample_method_args,
agg_kwargs=resample_method_kwargs,
)
else:
qc = SnowflakeQueryCompiler(frame).groupby_agg(
by=self._modin_frame.index_column_pandas_labels,
qc = SnowflakeQueryCompiler(resampled_frame).groupby_agg(
by=frame.index_column_pandas_labels,
agg_func=resample_method,
axis=resample_kwargs.get("axis", 0),
axis=axis,
groupby_kwargs=dict(),
agg_args=resample_method_args,
agg_kwargs=resample_method_kwargs,
numeric_only=resample_method_kwargs.get("numeric_only", False),
is_series_groupby=is_series,
)

frame = fill_missing_resample_bins_for_frame(
resampled_frame_all_bins = fill_missing_resample_bins_for_frame(
qc._modin_frame, rule, start_date, end_date
)
if resample_method in ("sum", "count", "size", "nunique"):
Expand All @@ -12463,7 +12478,9 @@ def resample(
# For sum(), we need to fill NaN values as Timedelta(0)
# for timedelta columns and as 0 for other columns.
values_arg = {}
for pandas_label in frame.data_column_pandas_labels:
for (
pandas_label
) in resampled_frame_all_bins.data_column_pandas_labels:
label_dtypes: native_pd.Series = self.dtypes[[pandas_label]]
# query compiler's fillna() takes a dictionary mapping
# pandas labels to values. When we have two columns
Expand All @@ -12485,15 +12502,15 @@ def resample(
values_arg = list(values_arg.values())[0]
else:
values_arg = 0
return SnowflakeQueryCompiler(frame).fillna(
return SnowflakeQueryCompiler(resampled_frame_all_bins).fillna(
value=values_arg, self_is_series=is_series
)
else:
ErrorMessage.not_implemented(
f"Resample Method {resample_method} has not been implemented."
)

return SnowflakeQueryCompiler(frame)
return SnowflakeQueryCompiler(resampled_frame_all_bins)

def value_counts_index(
self,
Expand Down
96 changes: 96 additions & 0 deletions tests/integ/modin/resample/test_resample_on.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
#
# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved.
#

import modin.pandas as pd
import numpy as np
import pandas as native_pd
import pytest

import snowflake.snowpark.modin.plugin # noqa: F401
from snowflake.snowpark.modin.plugin._internal.resample_utils import (
IMPLEMENTED_AGG_METHODS,
IMPLEMENTED_DATEOFFSET_STRINGS,
)
from tests.integ.modin.utils import create_test_dfs, eval_snowpark_pandas_result
from tests.integ.utils.sql_counter import sql_count_checker

agg_func = pytest.mark.parametrize(
"agg_func", list(filter(lambda x: x not in ["indices"], IMPLEMENTED_AGG_METHODS))
)
freq = pytest.mark.parametrize("freq", IMPLEMENTED_DATEOFFSET_STRINGS)


@freq
@agg_func
# One extra query to convert index to native pandas for dataframe constructor
@sql_count_checker(query_count=3, join_count=1)
def test_resample_on(freq, agg_func):
rule = f"2{freq}"
# Note that supplying 'on' to Resampler replaces the existing index of the DataFrame with the 'on' column
eval_snowpark_pandas_result(
*create_test_dfs(
{
"A": np.random.randn(15),
"B": native_pd.date_range("2020-01-01", periods=15, freq=f"1{freq}"),
},
index=native_pd.date_range("2020-10-01", periods=15, freq=f"1{freq}"),
),
lambda df: getattr(df.resample(rule=rule, on="B", closed="left"), agg_func)(),
check_freq=False,
)


# One extra query to convert index to native pandas for dataframe constructor
@sql_count_checker(query_count=3, join_count=1)
def test_resample_hashable_on():
eval_snowpark_pandas_result(
*create_test_dfs(
{
"A": np.random.randn(15),
1: native_pd.date_range("2020-01-01", periods=15, freq="1s"),
},
index=native_pd.date_range("2020-10-01", periods=15, freq="1s"),
),
lambda df: df.resample(rule="2s", on=1, closed="left").min(),
check_freq=False,
)


@sql_count_checker(query_count=0)
def test_resample_non_datetime_on():
native_df = native_pd.DataFrame(
data={
"A": np.random.randn(15),
"B": native_pd.date_range("2020-01-01", periods=15, freq="1s"),
},
index=native_pd.date_range("2020-10-01", periods=15, freq="1s"),
)
snow_df = pd.DataFrame(native_df)
with pytest.raises(
TypeError,
match="Only valid with DatetimeIndex, TimedeltaIndex or PeriodIndex, but got an instance of 'Index'",
):
native_df.resample(rule="2s", on="A").min()
with pytest.raises(
TypeError, match="Only valid with DatetimeIndex or TimedeltaIndex"
):
snow_df.resample(rule="2s", on="A").min().to_pandas()


@sql_count_checker(query_count=1)
# One query to get the Modin frame data column pandas labels
def test_resample_invalid_on():
eval_snowpark_pandas_result(
*create_test_dfs(
{
"A": np.random.randn(15),
"B": native_pd.date_range("2020-01-01", periods=15, freq="1s"),
},
index=native_pd.date_range("2020-10-01", periods=15, freq="1s"),
),
lambda df: df.resample(rule="2s", on="invalid", closed="left").min(),
expect_exception=True,
expect_exception_type=KeyError,
expect_exception_match="invalid",
)

0 comments on commit 47cadd2

Please sign in to comment.