Skip to content

Commit

Permalink
SNOW-1445416, SNOW-1445419: Implement DataFrame/Series.attrs (#2386)
Browse files Browse the repository at this point in the history
<!---
Please answer these questions before creating your pull request. Thanks!
--->

1. Which Jira issue is this PR addressing? Make sure that there is an
accompanying issue to your PR.

   <!---
   In this section, please add a Snowflake Jira issue number.
   
Note that if a corresponding GitHub issue exists, you should still
include
   the Snowflake Jira issue number. For example, for GitHub issue
#1400, you should
   add "SNOW-1335071" here.
    --->

   Fixes SNOW-1445416 and SNOW-1445419

2. Fill out the following pre-review checklist:

- [x] I am adding a new automated test(s) to verify correctness of my
new code
- [ ] If this test skips Local Testing mode, I'm requesting review from
@snowflakedb/local-testing
   - [ ] I am adding new logging messages
   - [ ] I am adding a new telemetry message
   - [ ] I am adding new credentials
   - [ ] I am adding a new dependency
- [ ] If this is a new feature/behavior, I'm adding the Local Testing
parity changes.
   - [x] I acknowledge that I have ensured my changes to be thread-safe

3. Please describe how your code solves the related issue.

Implements `DataFrame`/`Series.attrs` by adding a new query compiler
variable `_attrs` that is read out by frontend objects. A new annotation
on the query compiler, `_propagate_attrs_on_methods`, will either copy
`_attrs` from `self` to the return value, or reset `_attrs` on the
return value.

I initially intended to implement this solely at the frontend layer with
the override system (similar to how telemetry is added to all methods),
but this created difficulties when preserving `attrs` across in-place
operations like `df.columns = [...]`, and could create ambiguity if the
frame had a column named `"_attrs"`. Implementing propagation at the
query compiler level is simpler.

This PR also adds a new `test_attrs=True` parameter to
`eval_snowpark_pandas_result`. `eval_snowpark_pandas_result` will set a
dummy value of `attrs` on its inputs, and ensure that if the result is a
DF/Series, the `attrs` field on the result matches that of pandas. Since
pandas isn't always consistent about whether it propagates attrs or
resets it (for some methods, the behavior depends on the input, and for
some methods, it is inconsistent between Series/DF), setting
`test_attrs=False` skips this check. When I encountered such
inconsistent methods, I elected to have Snowpark pandas always propagate
`attrs`, since it seems unlikely that users would rely on the `attrs` of
a result being empty if they did not explicitly set it.
  • Loading branch information
sfc-gh-joshi authored Oct 22, 2024
1 parent 84434f1 commit 0b56f4b
Show file tree
Hide file tree
Showing 33 changed files with 420 additions and 41 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
- Added support for `on` parameter with `Resampler`.
- Added support for timedelta inputs in `value_counts()`.
- Added support for applying Snowpark Python function `snowflake_cortex_summarize`.
- Added support for `DataFrame`/`Series.attrs`

#### Improvements

Expand Down
2 changes: 1 addition & 1 deletion docs/source/modin/supported/dataframe_supported.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Attributes
+-----------------------------+---------------------------------+----------------------------------------------------+
| ``at`` | P | ``N`` for set with MultiIndex |
+-----------------------------+---------------------------------+----------------------------------------------------+
| ``attrs`` | N | |
| ``attrs`` | Y | |
+-----------------------------+---------------------------------+----------------------------------------------------+
| ``axes`` | Y | |
+-----------------------------+---------------------------------+----------------------------------------------------+
Expand Down
4 changes: 1 addition & 3 deletions docs/source/modin/supported/series_supported.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@ Attributes
+-----------------------------+---------------------------------+----------------------------------------------------+
| ``at`` | P | ``N`` for set with MultiIndex |
+-----------------------------+---------------------------------+----------------------------------------------------+
| ``attrs`` | N | Reading ``attrs`` always returns an empty dict, |
| | | and attempting to modify or set ``attrs`` will |
| | | fail. |
| ``attrs`` | Y | |
+-----------------------------+---------------------------------+----------------------------------------------------+
| ``axes`` | Y | |
+-----------------------------+---------------------------------+----------------------------------------------------+
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#
import calendar
import collections
import copy
import functools
import inspect
import itertools
Expand All @@ -14,7 +15,7 @@
from collections.abc import Hashable, Iterable, Mapping, Sequence
from datetime import timedelta, tzinfo
from functools import reduce
from typing import Any, Callable, List, Literal, Optional, Union, get_args
from typing import Any, Callable, List, Literal, Optional, TypeVar, Union, get_args

import modin.pandas as pd
import numpy as np
Expand Down Expand Up @@ -411,7 +412,95 @@
"ops for Rolling for this dtype timedelta64[ns] are not implemented"
)

# List of query compiler methods where attrs on the result should always be empty.
_RESET_ATTRS_METHODS = [
"compare",
"merge",
"value_counts",
"dataframe_to_datetime",
"series_to_datetime",
"to_numeric",
"dt_isocalendar",
"groupby_all",
"groupby_any",
"groupby_cumcount",
"groupby_cummax",
"groupby_cummin",
"groupby_cumsum",
"groupby_nunique",
"groupby_rank",
"groupby_size",
# expanding and rolling methods also do not propagate; we check them by prefix matching
# agg, crosstab, and concat depend on their inputs, and are handled separately
]


T = TypeVar("T", bound=Callable[..., Any])


def _propagate_attrs_on_methods(cls): # type: ignore
"""
Decorator that modifies all methods on the class to copy `_attrs` from `self`
to the output of the method, if the output is another query compiler.
"""

def propagate_attrs_decorator(method: T) -> T:
@functools.wraps(method)
def wrap(self, *args, **kwargs): # type: ignore
result = method(self, *args, **kwargs)
if isinstance(result, SnowflakeQueryCompiler) and len(self._attrs):
result._attrs = copy.deepcopy(self._attrs)
return result

return typing.cast(T, wrap)

def reset_attrs_decorator(method: T) -> T:
@functools.wraps(method)
def wrap(self, *args, **kwargs): # type: ignore
result = method(self, *args, **kwargs)
if isinstance(result, SnowflakeQueryCompiler) and len(self._attrs):
result._attrs = {}
return result

return typing.cast(T, wrap)

for attr_name, attr_value in cls.__dict__.items():
# concat is handled explicitly because it checks all of its arguments
# agg is handled explicitly because it sometimes resets and sometimes propagates
if attr_name.startswith("_") or attr_name in ["concat", "agg"]:
continue
if attr_name in _RESET_ATTRS_METHODS or any(
attr_name.startswith(prefix) for prefix in ["expanding", "rolling"]
):
setattr(cls, attr_name, reset_attrs_decorator(attr_value))
elif isinstance(attr_value, property):
setattr(
cls,
attr_name,
property(
propagate_attrs_decorator(
attr_value.fget
if attr_value.fget is not None
else attr_value.__get__
),
propagate_attrs_decorator(
attr_value.fset
if attr_value.fset is not None
else attr_value.__set__
),
propagate_attrs_decorator(
attr_value.fdel
if attr_value.fdel is not None
else attr_value.__delete__
),
),
)
elif inspect.isfunction(attr_value):
setattr(cls, attr_name, propagate_attrs_decorator(attr_value))
return cls


@_propagate_attrs_on_methods
class SnowflakeQueryCompiler(BaseQueryCompiler):
"""based on: https://modin.readthedocs.io/en/0.11.0/flow/modin/backends/base/query_compiler.html
this class is best explained by looking at https://github.com/modin-project/modin/blob/a8be482e644519f2823668210cec5cf1564deb7e/modin/experimental/core/storage_formats/hdk/query_compiler.py
Expand All @@ -429,6 +518,7 @@ def __init__(self, frame: InternalFrame) -> None:
# self.snowpark_pandas_api_calls a list of lazy Snowpark pandas telemetry api calls
# Copying and modifying self.snowpark_pandas_api_calls is taken care of in telemetry decorators
self.snowpark_pandas_api_calls: list = []
self._attrs: dict[Any, Any] = {}

def _raise_not_implemented_error_for_timedelta(
self, frame: InternalFrame = None
Expand Down Expand Up @@ -854,7 +944,10 @@ def to_pandas(
The QueryCompiler converted to pandas.

"""
return self._modin_frame.to_pandas(statement_params, **kwargs)
result = self._modin_frame.to_pandas(statement_params, **kwargs)
if self._attrs:
result.attrs = self._attrs
return result

def finalize(self) -> None:
pass
Expand Down Expand Up @@ -6065,6 +6158,7 @@ def agg(
)

query_compiler = self
initial_attrs = self._attrs
if numeric_only:
# drop off the non-numeric data columns if the data column if numeric_only is configured to be True
query_compiler = drop_non_numeric_data_columns(
Expand Down Expand Up @@ -6481,6 +6575,11 @@ def generate_single_agg_column_func_map(
result = result.transpose_single_row()
# Set the single column's name to MODIN_UNNAMED_SERIES_LABEL
result = result.set_columns([MODIN_UNNAMED_SERIES_LABEL])
# native pandas clears attrs if the aggregation was a list, but propagates it otherwise
if is_list_like(func):
result._attrs = {}
else:
result._attrs = copy.deepcopy(initial_attrs)
return result

def insert(
Expand Down Expand Up @@ -7336,6 +7435,10 @@ def concat(
raise ValueError(
f"Indexes have overlapping values. Few of them are: {overlap}. Please run df1.index.intersection(df2.index) to see complete list"
)
# If each input's `attrs` was identical and not empty, then copy it to the output.
# Otherwise, leave `attrs` empty.
if len(self._attrs) > 0 and all(self._attrs == o._attrs for o in other):
qc._attrs = copy.deepcopy(self._attrs)
return qc

def cumsum(
Expand Down
22 changes: 20 additions & 2 deletions src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"""
from __future__ import annotations

import copy
import pickle as pkl
import warnings
from collections.abc import Sequence
Expand Down Expand Up @@ -105,12 +106,16 @@ def decorator(base_method: Any):
series_method = getattr(pd.Series, method_name, None)
if isinstance(series_method, property):
series_method = series_method.fget
if series_method is None or series_method is parent_method:
if (
series_method is None
or series_method is parent_method
or parent_method is None
):
register_series_accessor(method_name)(base_method)
df_method = getattr(pd.DataFrame, method_name, None)
if isinstance(df_method, property):
df_method = df_method.fget
if df_method is None or df_method is parent_method:
if df_method is None or df_method is parent_method or parent_method is None:
register_dataframe_accessor(method_name)(base_method)
# Replace base method
setattr(BasePandasDataset, method_name, base_method)
Expand Down Expand Up @@ -864,6 +869,19 @@ def var(
)


def _set_attrs(self, value: dict) -> None: # noqa: RT01, D200
# Use a field on the query compiler instead of self to avoid any possible ambiguity with
# a column named "_attrs"
self._query_compiler._attrs = copy.deepcopy(value)


def _get_attrs(self) -> dict: # noqa: RT01, D200
return self._query_compiler._attrs


register_base_override("attrs")(property(_get_attrs, _set_attrs))


# Modin does not provide `MultiIndex` support and will default to pandas when `level` is specified,
# and allows binary ops against native pandas objects that Snowpark pandas prohibits.
@register_base_override("_binary_op")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -381,13 +381,6 @@ def __delitem__(self, key):
pass # pragma: no cover


@register_dataframe_accessor("attrs")
@dataframe_not_implemented()
@property
def attrs(self): # noqa: RT01, D200
pass # pragma: no cover


@register_dataframe_accessor("style")
@dataframe_not_implemented()
@property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -784,6 +784,7 @@ def _get_names_wrapper(list_of_objs, names, prefix):

table = table.rename_axis(index=rownames_mapper, axis=0)
table = table.rename_axis(columns=colnames_mapper, axis=1)
table.attrs = {} # native pandas crosstab does not propagate attrs form the input

return table

Expand Down
9 changes: 8 additions & 1 deletion tests/integ/modin/frame/test_aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,14 @@ def test_corr_negative(numeric_native_df, method):
@sql_count_checker(query_count=1)
def test_string_sum(data, numeric_only_kwargs):
eval_snowpark_pandas_result(
*create_test_dfs(data), lambda df: df.sum(**numeric_only_kwargs)
*create_test_dfs(data),
lambda df: df.sum(**numeric_only_kwargs),
# pandas doesn't propagate attrs if the frame is empty after type filtering,
# which happens if numeric_only=True and all columns are strings, but Snowpark pandas does.
test_attrs=not (
numeric_only_kwargs.get("numeric_only", False)
and isinstance(data["col1"][0], str)
),
)


Expand Down
2 changes: 2 additions & 0 deletions tests/integ/modin/frame/test_apply_axis_0.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,8 @@ def test_groupby_apply_constant_output():
snow_df,
native_df,
lambda df: df.groupby(by=["fg"], axis=0).apply(lambda x: [1, 2]),
# Some calls to the native pandas function propagate attrs while some do not, depending on the values of its arguments.
test_attrs=False,
)


Expand Down
Loading

0 comments on commit 0b56f4b

Please sign in to comment.