Skip to content

Commit

Permalink
SNOW-1727473: Test and fix timedelta handling in tests/integ/modin/fr…
Browse files Browse the repository at this point in the history
…ame: part 2/n. (#2435)

Fixes SNOW-1727473

Test and/or fix timedelta handling for the methods tested in
tests/integ/modin/frame/, alphabetically from test_pct_change.py through
test_set_index.py.

Make the following behavior changes:
- Fixed a bug where `DataFrame` and `Series` `pct_change()` would raise
`TypeError` when input contained timedelta columns.
- Fixed a bug where `replace()` would sometimes propagate `Timedelta`
types incorrectly through `replace()`. Instead raise
`NotImplementedError` for `replace()` on `Timedelta`.
- Fixed a bug where `DataFrame` and `Series` `round()` would raise
`AssertionError` for `Timedelta` columns. Instead raise
`NotImplementedError` for `round()` on `Timedelta`.

Apart from those changes, just test that we can handle timedelta inputs
correctly.

---------

Signed-off-by: sfc-gh-mvashishtha <mahesh.vashishtha@snowflake.com>
Co-authored-by: Naren Krishna <naren.krishna@snowflake.com>
  • Loading branch information
sfc-gh-mvashishtha and sfc-gh-nkrishna authored Oct 14, 2024
1 parent 9d60271 commit fa7c406
Show file tree
Hide file tree
Showing 16 changed files with 381 additions and 24 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@
- Fixed a bug where `Resampler` methods on timedelta columns would produce integer results.
- Fixed a bug where `pd.to_numeric()` would leave `Timedelta` inputs as `Timedelta` instead of converting them to integers.
- Fixed `loc` set when setting a single row, or multiple rows, of a DataFrame with a Series value.
- Fixed a bug where `DataFrame` and `Series` `pct_change()` would raise `TypeError` when input contained timedelta columns.
- Fixed a bug where `replace()` would sometimes propagate `Timedelta` types incorrectly through `replace()`. Instead raise `NotImplementedError` for `replace()` on `Timedelta`.
- Fixed a bug where `DataFrame` and `Series` `round()` would raise `AssertionError` for `Timedelta` columns. Instead raise `NotImplementedError` for `round()` on `Timedelta`.

### Snowpark Local Testing Updates

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14512,6 +14512,11 @@ def replace(
Returns:
SnowflakeQueryCompiler with all `to_replace` values replaced by `value`.
"""
# Propagating client-side types through replace() is complicated.
# Timedelta columns may change types after replace(), and non-timedelta
# columns may contain timedelta columns after replace().
self._raise_not_implemented_error_for_timedelta()

if method is not lib.no_default:
ErrorMessage.not_implemented(
"Snowpark pandas replace API does not support 'method' parameter"
Expand Down Expand Up @@ -14621,6 +14626,23 @@ def replace(
elif value != lib.no_default:
raise TypeError(f"Unsupported value type: {type(value)}")

def _scalar_belongs_to_timedelta_classes(s: Any) -> bool:
return any(
issubclass(type(s), timedelta_class)
for timedelta_class in TimedeltaType.types_to_convert_with_from_pandas
)

# Raise if the new values in `value` include timedelta.
if any(
(
isinstance(v, list)
and any(_scalar_belongs_to_timedelta_classes(vv) for vv in v)
)
or _scalar_belongs_to_timedelta_classes(v)
for v in value_map.values()
):
ErrorMessage.not_implemented_for_timedelta("replace")

replaced_column_exprs = {}
for identifier, to_replace in replace_map.items():
if identifier not in value_map:
Expand Down Expand Up @@ -15111,7 +15133,6 @@ def create_lazy_type_functions(

return SnowflakeQueryCompiler(new_frame)

@snowpark_pandas_type_immutable_check
def round(
self, decimals: Union[int, Mapping, "pd.Series"] = 0, **kwargs: Any
) -> "SnowflakeQueryCompiler":
Expand All @@ -15130,6 +15151,14 @@ def round(
BaseQueryCompiler
QueryCompiler with rounded values.
"""
# DataFrame.round() and Series.round() ignore non-numeric columns like
# timedelta. We raise a Snowflake error for non-numeric, non-timedelta
# columns like strings, but we have to detect timedelta separately
# because its underlying representation is an integer. Without this
# check, we'd round the integer representation of the timedelta instead
# of leaving the timedelta unchanged.
self._raise_not_implemented_error_for_timedelta()

if isinstance(decimals, pd.Series):
raise ErrorMessage.not_implemented(
"round with decimals of type Series is not yet supported"
Expand Down
21 changes: 17 additions & 4 deletions src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@
)
from snowflake.snowpark.modin.utils import validate_int_kwarg

_TIMEDELTA_PCT_CHANGE_AXIS_1_MIXED_TYPE_ERROR_MESSAGE = (
"pct_change(axis=1) is invalid when one column is Timedelta another column is not."
)


def register_base_override(method_name: str):
"""
Expand Down Expand Up @@ -1889,16 +1893,25 @@ def pct_change(
if limit is lib.no_default:
limit = None

if "axis" in kwargs:
kwargs["axis"] = self._get_axis_number(kwargs["axis"])
kwargs["axis"] = self._get_axis_number(kwargs.get("axis", 0))

# Attempting to match pandas error behavior here
if not isinstance(periods, int):
raise TypeError(f"periods must be an int. got {type(periods)} instead")

column_is_timedelta_type = [
self._query_compiler.is_timedelta64_dtype(i, is_index=False)
for i in range(len(self._query_compiler.columns))
]

if kwargs["axis"] == 1:
if any(column_is_timedelta_type) and not all(column_is_timedelta_type):
# pct_change() between timedelta and a non-timedelta type is invalid.
raise TypeError(_TIMEDELTA_PCT_CHANGE_AXIS_1_MIXED_TYPE_ERROR_MESSAGE)

# Attempting to match pandas error behavior here
for dtype in self._get_dtypes():
if not is_numeric_dtype(dtype):
for i, dtype in enumerate(self._get_dtypes()):
if not is_numeric_dtype(dtype) and not column_is_timedelta_type[i]:
raise TypeError(
f"cannot perform pct_change on non-numeric column with dtype {dtype}"
)
Expand Down
73 changes: 73 additions & 0 deletions tests/integ/modin/frame/test_pct_change.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,17 @@
#

import contextlib
import re

import modin.pandas as pd
import pandas as native_pd
import pytest
from pandas._libs.lib import no_default

import snowflake.snowpark.modin.plugin # noqa: F401
from snowflake.snowpark.modin.plugin.extensions.base_overrides import (
_TIMEDELTA_PCT_CHANGE_AXIS_1_MIXED_TYPE_ERROR_MESSAGE,
)
from tests.integ.modin.utils import create_test_dfs, eval_snowpark_pandas_result
from tests.integ.utils.sql_counter import sql_count_checker

Expand Down Expand Up @@ -71,6 +75,75 @@ def test_pct_change_simple(native_data, periods, fill_method, axis):
)


@sql_count_checker(query_count=1)
@pytest.mark.parametrize(
"periods",
[-1, 0, 1, 2],
)
@pytest.mark.parametrize(
"fill_method", [None, no_default, "ffill", "pad", "backfill", "bfill"]
)
def test_axis_0_with_timedelta(periods, fill_method):
eval_snowpark_pandas_result(
*create_test_dfs(
{
"timedelta": [
pd.Timedelta(100),
pd.Timedelta(25),
None,
pd.Timedelta(75),
],
"float": [4.0405, 4.0963, 4.3149, None],
}
),
lambda df: df.pct_change(periods=periods, fill_method=fill_method),
)


@sql_count_checker(query_count=1)
@pytest.mark.parametrize(
"periods",
[-1, 0, 1, 2],
)
@pytest.mark.parametrize(
"fill_method", [None, no_default, "ffill", "pad", "backfill", "bfill"]
)
def test_axis_1_with_timedelta_columns(periods, fill_method):
eval_snowpark_pandas_result(
*create_test_dfs(
{
"col0": [pd.Timedelta(100), pd.Timedelta(25)],
"col1": [pd.Timedelta(50), None],
"col3": [pd.Timedelta(75), pd.Timedelta(75)],
"col4": [None, pd.Timedelta(150)],
}
),
lambda df: df.pct_change(axis=1, periods=periods, fill_method=fill_method),
)


@sql_count_checker(query_count=0)
@pytest.mark.parametrize(
"data",
[
{"timedelta": [pd.Timedelta(1)], "string": ["value"]},
{"timedelta": [pd.Timedelta(1)], "int": [0]},
],
)
def test_axis_1_with_timedelta_and_non_timedelta_column_invalid(data):
eval_snowpark_pandas_result(
*create_test_dfs(data),
lambda df: df.pct_change(axis=1),
expect_exception=True,
# pandas exception depends on the type of the non-timedelta column, so
# we don't try to match the pandas exception.
assert_exception_equal=False,
expect_exception_match=re.escape(
_TIMEDELTA_PCT_CHANGE_AXIS_1_MIXED_TYPE_ERROR_MESSAGE
),
)


@pytest.mark.parametrize("params", [{"limit": 2}, {"freq": "ME"}])
@sql_count_checker(query_count=0)
def test_pct_change_unsupported_args(params):
Expand Down
6 changes: 6 additions & 0 deletions tests/integ/modin/frame/test_quantile.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@
"b": [-5, -2, -1, 0],
"c": [89, np.nan, -540, 0.1],
"d": [0, 0, 0, 0],
"timedelta": [
pd.NaT,
pd.Timedelta(10),
pd.Timedelta(-5),
pd.Timedelta(7),
],
}

TEST_QUANTILES = [
Expand Down
6 changes: 6 additions & 0 deletions tests/integ/modin/frame/test_rank.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
import pandas as native_pd
import pytest
from pytest import param

import snowflake.snowpark.modin.plugin # noqa: F401
from tests.integ.modin.utils import (
Expand All @@ -15,6 +16,11 @@

TEST_RANK_DATA = [
({"a": [1, 2, 2, 2, 3, 3, 3]}, None),
param(
{"timedelta": native_pd.to_timedelta([1, 2, 2, 2, 3, 3, 3])},
None,
id="timedelta",
),
(
{
"a": [4, -2, 4, 8, 3],
Expand Down
24 changes: 24 additions & 0 deletions tests/integ/modin/frame/test_reindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,30 @@ def test_reindex_multiindex_negative(axis):
snow_df.T.reindex(columns=[1, 2, 3])


@sql_count_checker(query_count=0)
@pytest.mark.xfail(strict=True, raises=NotImplementedError)
def test_reindex_timedelta_axis_0_negative():
native_df = native_pd.DataFrame(
np.arange(9).reshape((3, 3)), index=list("ABC")
).astype("timedelta64[ns]")
snow_df = pd.DataFrame(native_df)
eval_snowpark_pandas_result(
snow_df, native_df, lambda df: df.reindex(axis=0, labels=list("CAB"))
)


@sql_count_checker(query_count=0)
@pytest.mark.xfail(strict=True, raises=NotImplementedError)
def test_reindex_timedelta_axis_1_negative():
native_df = native_pd.DataFrame(
np.arange(9).reshape((3, 3)), columns=list("ABC")
).astype("timedelta64[ns]")
snow_df = pd.DataFrame(native_df)
eval_snowpark_pandas_result(
snow_df, native_df, lambda df: df.reindex(axis=1, labels=list("CAB"))
)


@sql_count_checker(query_count=1, join_count=1)
def test_reindex_with_index_name():
native_df = native_pd.DataFrame(
Expand Down
9 changes: 9 additions & 0 deletions tests/integ/modin/frame/test_rename.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from tests.integ.modin.utils import (
assert_frame_equal,
assert_index_equal,
create_test_dfs,
eval_snowpark_pandas_result,
)
from tests.integ.utils.sql_counter import SqlCounter, sql_count_checker
Expand Down Expand Up @@ -490,3 +491,11 @@ def test_rename_copy_warning(self, float_frame, caplog):

snow_float_frame.rename(columns={"C": "foo"}, copy=True)
assert msg in caplog.text

@pytest.mark.parametrize("axis, join_count", [(0, 1), (1, 0)])
def test_rename_timedelta_values(self, axis, join_count):
with SqlCounter(query_count=1, join_count=join_count):
eval_snowpark_pandas_result(
*create_test_dfs([pd.Timedelta(1)]),
lambda df: df.rename(mapper={0: "a"}, axis=axis)
)
57 changes: 56 additions & 1 deletion tests/integ/modin/frame/test_replace.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
#
# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved.
#
import datetime

import modin.pandas as pd
import numpy as np
import pandas as native_pd
import pytest
from pandas._libs.lib import no_default
from pytest import param

import snowflake.snowpark.modin.plugin # noqa: F401
from tests.integ.modin.utils import eval_snowpark_pandas_result
from tests.integ.modin.utils import create_test_dfs, eval_snowpark_pandas_result
from tests.integ.utils.sql_counter import SqlCounter, sql_count_checker


Expand Down Expand Up @@ -144,6 +147,58 @@ def test_replace_limit_negative(snow_df):
snow_df.replace("abc", "ABC", limit=10)


@pytest.mark.parametrize(
"pandas_df",
[
native_pd.DataFrame([pd.Timedelta(1)]),
native_pd.DataFrame([1], index=[pd.Timedelta(2)]),
],
)
@pytest.mark.xfail(strict=True, raises=NotImplementedError)
def test_replace_frame_with_timedelta_index_or_column_negative(pandas_df):
eval_snowpark_pandas_result(
*create_test_dfs(
pandas_df,
),
lambda df: df.replace({1: 3})
)


@pytest.mark.xfail(strict=True, raises=NotImplementedError)
@pytest.mark.parametrize(
"kwargs",
[
param(
{"to_replace": {1: native_pd.Timedelta(1)}},
id="to_replace_dict_with_pandas_timedelta",
),
param(
{"to_replace": {1: np.timedelta64(1)}},
id="to_replace_dict_with_numpy_timedelta",
),
param(
{"to_replace": {1: datetime.timedelta(days=1)}},
id="to_replace_dict_with_datetime_timedelta",
),
param(
{"to_replace": 1, "value": native_pd.Timedelta(1)},
id="value_timedelta_scalar",
),
param(
{"to_replace": [1, 2], "value": [native_pd.Timedelta(1), 3]},
id="value_timedelta_list",
),
],
)
def test_replace_integer_value_with_timedelta_negative(kwargs):
eval_snowpark_pandas_result(
*create_test_dfs(
[1],
),
lambda df: df.replace(**kwargs)
)


@sql_count_checker(query_count=0)
def test_replace_no_value_negative(snow_df):
# pandas will not raise error instead uses 'pad' method to replace values.
Expand Down
10 changes: 10 additions & 0 deletions tests/integ/modin/frame/test_repr.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,20 @@
{
"Animal": ["Falcon", "Falcon", "Parrot", "Parrot"],
"Max Speed": [380.0, 370.0, 24.0, 26.0],
"Timedelta": [
pd.Timedelta(1, unit="ns"),
pd.Timedelta(microseconds=1),
pd.Timedelta(milliseconds=-1),
pd.Timedelta(days=9999, hours=10, minutes=30, seconds=10),
],
}
),
1,
),
(
native_pd.DataFrame([1, 2], index=[pd.Timedelta(1), pd.Timedelta(-1)]),
1,
),
(
IRIS_DF,
6,
Expand Down
Loading

0 comments on commit fa7c406

Please sign in to comment.