From b26358b569bc8f9001dd263183c7aeeb32e720b8 Mon Sep 17 00:00:00 2001 From: Marco Edward Gorelli Date: Mon, 20 Jan 2025 10:29:41 +0000 Subject: [PATCH] feat: track whether expressions change length but don't aggregate, and only allow length-changing expressions if they're followed by aggregations in the lazy API (#1828) --- narwhals/_dask/expr.py | 56 ++---- narwhals/_expression_parsing.py | 41 ++++ narwhals/dataframe.py | 11 ++ narwhals/exceptions.py | 8 + narwhals/expr.py | 236 +++++++++++++++++++++-- narwhals/expr_cat.py | 2 + narwhals/expr_dt.py | 42 ++++ narwhals/expr_list.py | 2 + narwhals/expr_name.py | 12 ++ narwhals/expr_str.py | 26 +++ narwhals/functions.py | 75 ++++++- narwhals/selectors.py | 40 +++- narwhals/stable/v1/__init__.py | 39 +++- tests/expr_and_series/drop_nulls_test.py | 19 ++ tests/expr_and_series/unique_test.py | 39 +++- tests/frame/filter_test.py | 15 +- utils/check_api_reference.py | 24 ++- 17 files changed, 581 insertions(+), 106 deletions(-) diff --git a/narwhals/_dask/expr.py b/narwhals/_dask/expr.py index 09a3adf8fb..1319f1001f 100644 --- a/narwhals/_dask/expr.py +++ b/narwhals/_dask/expr.py @@ -4,7 +4,6 @@ from typing import Any from typing import Callable from typing import Literal -from typing import NoReturn from typing import Sequence from narwhals._dask.expr_dt import DaskExprDateTimeNamespace @@ -448,34 +447,20 @@ def round(self, decimals: int) -> Self: returns_scalar=self._returns_scalar, ) - def ewm_mean( - self: Self, - *, - com: float | None = None, - span: float | None = None, - half_life: float | None = None, - alpha: float | None = None, - adjust: bool = True, - min_periods: int = 1, - ignore_nulls: bool = False, - ) -> NoReturn: - msg = "`Expr.ewm_mean` is not supported for the Dask backend" - raise NotImplementedError(msg) - - def unique(self) -> NoReturn: - # We can't (yet?) allow methods which modify the index - msg = "`Expr.unique` is not supported for the Dask backend. Please use `LazyFrame.unique` instead." - raise NotImplementedError(msg) - - def drop_nulls(self) -> NoReturn: - # We can't (yet?) allow methods which modify the index - msg = "`Expr.drop_nulls` is not supported for the Dask backend. Please use `LazyFrame.drop_nulls` instead." - raise NotImplementedError(msg) + def unique(self, *, maintain_order: bool) -> Self: + # TODO(marco): maintain_order has no effect and will be deprecated + return self._from_call( + lambda _input: _input.unique(), + "unique", + returns_scalar=self._returns_scalar, + ) - def head(self) -> NoReturn: - # We can't (yet?) allow methods which modify the index - msg = "`Expr.head` is not supported for the Dask backend. Please use `LazyFrame.head` instead." - raise NotImplementedError(msg) + def drop_nulls(self) -> Self: + return self._from_call( + lambda _input: _input.dropna(), + "drop_nulls", + returns_scalar=self._returns_scalar, + ) def replace_strict( self, old: Sequence[Any], new: Sequence[Any], *, return_dtype: DType | None @@ -483,11 +468,6 @@ def replace_strict( msg = "`replace_strict` is not yet supported for Dask expressions" raise NotImplementedError(msg) - def sort(self, *, descending: bool = False, nulls_last: bool = False) -> NoReturn: - # We can't (yet?) allow methods which modify the index - msg = "`Expr.sort` is not supported for the Dask backend. Please use `LazyFrame.sort` instead." - raise NotImplementedError(msg) - def abs(self) -> Self: return self._from_call( lambda _input: _input.abs(), "abs", returns_scalar=self._returns_scalar @@ -678,16 +658,6 @@ def null_count(self: Self) -> Self: returns_scalar=True, ) - def tail(self: Self) -> NoReturn: - # We can't (yet?) allow methods which modify the index - msg = "`Expr.tail` is not supported for the Dask backend. Please use `LazyFrame.tail` instead." - raise NotImplementedError(msg) - - def gather_every(self: Self, n: int, offset: int = 0) -> NoReturn: - # We can't (yet?) allow methods which modify the index - msg = "`Expr.gather_every` is not supported for the Dask backend. Please use `LazyFrame.gather_every` instead." - raise NotImplementedError(msg) - def over(self: Self, keys: list[str]) -> Self: def func(df: DaskLazyFrame) -> list[Any]: if self._output_names is None: diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index dc52da002b..a11db68eb8 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -14,6 +14,7 @@ from narwhals.dependencies import is_numpy_array from narwhals.exceptions import InvalidIntoExprError +from narwhals.exceptions import LengthChangingExprError from narwhals.utils import Implementation if TYPE_CHECKING: @@ -342,3 +343,43 @@ def operation_is_order_dependent(*args: IntoExpr | Any) -> bool: # it means that it was a scalar (e.g. nw.col('a') + 1) or a column name, # neither of which is order-dependent, so we default to `False`. return any(getattr(x, "_is_order_dependent", False) for x in args) + + +def operation_changes_length(*args: IntoExpr | Any) -> bool: + """Track whether operation changes length. + + n-ary operations between expressions which change length are not + allowed. This is because the output might be non-relational. For + example: + df = pl.LazyFrame({'a': [1,2,None], 'b': [4,None,6]}) + df.select(pl.col('a', 'b').drop_nulls()) + Polars does allow this, but in the result we end up with the + tuple (2, 6) which wasn't part of the original data. + + Rules are: + - in an n-ary operation, if any one of them changes length, then + it must be the only expression present + - in a comparison between a changes-length expression and a + scalar, the output changes length + """ + from narwhals.expr import Expr + + n_exprs = len([x for x in args if isinstance(x, Expr)]) + changes_length = any(isinstance(x, Expr) and x._changes_length for x in args) + if n_exprs > 1 and changes_length: + msg = ( + "Found multiple expressions at least one of which changes length.\n" + "Any length-changing expression can only be used in isolation, unless\n" + "it is followed by an aggregation." + ) + raise LengthChangingExprError(msg) + return changes_length + + +def operation_aggregates(*args: IntoExpr | Any) -> bool: + # If an arg is an Expr, we look at `_aggregates`. If it isn't, + # it means that it was a scalar (e.g. nw.col('a').sum() + 1), + # which is already length-1, so we default to `True`. If any + # expression does not aggregate, then broadcasting will take + # place and the result will not be an aggregate. + return all(getattr(x, "_aggregates", True) for x in args) diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index b055fed12c..1633a3c226 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -16,6 +16,7 @@ from narwhals.dependencies import get_polars from narwhals.dependencies import is_numpy_array +from narwhals.exceptions import LengthChangingExprError from narwhals.exceptions import OrderDependentExprError from narwhals.schema import Schema from narwhals.translate import to_native @@ -3648,6 +3649,16 @@ def _extract_compliant(self, arg: Any) -> Any: " they will be supported." ) raise OrderDependentExprError(msg) + if arg._changes_length: + msg = ( + "Length-changing expressions are not supported for use in LazyFrame, unless\n" + "followed by an aggregation.\n\n" + "Hints:\n" + "- Instead of `lf.select(nw.col('a').head())`, use `lf.select('a').head()\n" + "- Instead of `lf.select(nw.col('a').drop_nulls()).select(nw.sum('a'))`,\n" + " use `lf.select(nw.col('a').drop_nulls().sum())\n" + ) + raise LengthChangingExprError(msg) return arg._to_compliant_expr(self.__narwhals_namespace__()) if get_polars() is not None and "polars" in str(type(arg)): # pragma: no cover msg = ( diff --git a/narwhals/exceptions.py b/narwhals/exceptions.py index 9b05e3ba88..6a553fa448 100644 --- a/narwhals/exceptions.py +++ b/narwhals/exceptions.py @@ -91,6 +91,14 @@ def __init__(self, message: str) -> None: super().__init__(self.message) +class LengthChangingExprError(ValueError): + """Exception raised when trying to use an expression which changes length with LazyFrames.""" + + def __init__(self, message: str) -> None: + self.message = message + super().__init__(self.message) + + class UnsupportedDTypeError(ValueError): """Exception raised when trying to convert to a DType which is not supported by the given backend.""" diff --git a/narwhals/expr.py b/narwhals/expr.py index 4d17552afa..ccb33ead25 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -9,6 +9,8 @@ from typing import Sequence from narwhals._expression_parsing import extract_compliant +from narwhals._expression_parsing import operation_aggregates +from narwhals._expression_parsing import operation_changes_length from narwhals._expression_parsing import operation_is_order_dependent from narwhals.dtypes import _validate_dtype from narwhals.expr_cat import ExprCatNamespace @@ -34,16 +36,31 @@ def __init__( self, to_compliant_expr: Callable[[Any], Any], is_order_dependent: bool, # noqa: FBT001 + changes_length: bool, # noqa: FBT001 + aggregates: bool, # noqa: FBT001 ) -> None: # callable from CompliantNamespace to CompliantExpr self._to_compliant_expr = to_compliant_expr self._is_order_dependent = is_order_dependent + self._changes_length = changes_length + self._aggregates = aggregates + + def __repr__(self) -> str: + return ( + "Narwhals Expr\n" + f"is_order_dependent: {self._is_order_dependent}\n" + f"changes_length: {self._changes_length}\n" + f"aggregates: {self._aggregates}" + ) def _taxicab_norm(self) -> Self: # This is just used to test out the stable api feature in a realistic-ish way. # It's not intended to be used. return self.__class__( - lambda plx: self._to_compliant_expr(plx).abs().sum(), self._is_order_dependent + lambda plx: self._to_compliant_expr(plx).abs().sum(), + self._is_order_dependent, + self._changes_length, + self._aggregates, ) # --- convert --- @@ -103,6 +120,8 @@ def alias(self, name: str) -> Self: return self.__class__( lambda plx: self._to_compliant_expr(plx).alias(name), is_order_dependent=self._is_order_dependent, + changes_length=self._changes_length, + aggregates=self._aggregates, ) def pipe(self, function: Callable[[Any], Self], *args: Any, **kwargs: Any) -> Self: @@ -225,6 +244,8 @@ def cast(self: Self, dtype: DType | type[DType]) -> Self: return self.__class__( lambda plx: self._to_compliant_expr(plx).cast(dtype), is_order_dependent=self._is_order_dependent, + changes_length=self._changes_length, + aggregates=self._aggregates, ) # --- binary --- @@ -234,6 +255,8 @@ def __eq__(self, other: object) -> Self: # type: ignore[override] extract_compliant(plx, other) ), is_order_dependent=operation_is_order_dependent(self, other), + changes_length=operation_changes_length(self, other), + aggregates=operation_aggregates(self, other), ) def __ne__(self, other: object) -> Self: # type: ignore[override] @@ -242,6 +265,8 @@ def __ne__(self, other: object) -> Self: # type: ignore[override] extract_compliant(plx, other) ), is_order_dependent=operation_is_order_dependent(self, other), + changes_length=operation_changes_length(self, other), + aggregates=operation_aggregates(self, other), ) def __and__(self, other: Any) -> Self: @@ -250,6 +275,8 @@ def __and__(self, other: Any) -> Self: extract_compliant(plx, other) ), is_order_dependent=operation_is_order_dependent(self, other), + changes_length=operation_changes_length(self, other), + aggregates=operation_aggregates(self, other), ) def __rand__(self, other: Any) -> Self: @@ -261,6 +288,8 @@ def func(plx: CompliantNamespace[Any]) -> CompliantExpr[Any]: return self.__class__( func, is_order_dependent=operation_is_order_dependent(self, other), + changes_length=operation_changes_length(self, other), + aggregates=operation_aggregates(self, other), ) def __or__(self, other: Any) -> Self: @@ -269,6 +298,8 @@ def __or__(self, other: Any) -> Self: extract_compliant(plx, other) ), is_order_dependent=operation_is_order_dependent(self, other), + changes_length=operation_changes_length(self, other), + aggregates=operation_aggregates(self, other), ) def __ror__(self, other: Any) -> Self: @@ -280,6 +311,8 @@ def func(plx: CompliantNamespace[Any]) -> CompliantExpr[Any]: return self.__class__( func, is_order_dependent=operation_is_order_dependent(self, other), + changes_length=operation_changes_length(self, other), + aggregates=operation_aggregates(self, other), ) def __add__(self, other: Any) -> Self: @@ -288,6 +321,8 @@ def __add__(self, other: Any) -> Self: extract_compliant(plx, other) ), is_order_dependent=operation_is_order_dependent(self, other), + changes_length=operation_changes_length(self, other), + aggregates=operation_aggregates(self, other), ) def __radd__(self, other: Any) -> Self: @@ -299,6 +334,8 @@ def func(plx: CompliantNamespace[Any]) -> CompliantExpr[Any]: return self.__class__( func, is_order_dependent=operation_is_order_dependent(self, other), + changes_length=operation_changes_length(self, other), + aggregates=operation_aggregates(self, other), ) def __sub__(self, other: Any) -> Self: @@ -307,6 +344,8 @@ def __sub__(self, other: Any) -> Self: extract_compliant(plx, other) ), is_order_dependent=operation_is_order_dependent(self, other), + changes_length=operation_changes_length(self, other), + aggregates=operation_aggregates(self, other), ) def __rsub__(self, other: Any) -> Self: @@ -318,6 +357,8 @@ def func(plx: CompliantNamespace[Any]) -> CompliantExpr[Any]: return self.__class__( func, is_order_dependent=operation_is_order_dependent(self, other), + changes_length=operation_changes_length(self, other), + aggregates=operation_aggregates(self, other), ) def __truediv__(self, other: Any) -> Self: @@ -326,6 +367,8 @@ def __truediv__(self, other: Any) -> Self: extract_compliant(plx, other) ), is_order_dependent=operation_is_order_dependent(self, other), + changes_length=operation_changes_length(self, other), + aggregates=operation_aggregates(self, other), ) def __rtruediv__(self, other: Any) -> Self: @@ -337,6 +380,8 @@ def func(plx: CompliantNamespace[Any]) -> CompliantExpr[Any]: return self.__class__( func, is_order_dependent=operation_is_order_dependent(self, other), + changes_length=operation_changes_length(self, other), + aggregates=operation_aggregates(self, other), ) def __mul__(self, other: Any) -> Self: @@ -345,6 +390,8 @@ def __mul__(self, other: Any) -> Self: extract_compliant(plx, other) ), is_order_dependent=operation_is_order_dependent(self, other), + changes_length=operation_changes_length(self, other), + aggregates=operation_aggregates(self, other), ) def __rmul__(self, other: Any) -> Self: @@ -356,6 +403,8 @@ def func(plx: CompliantNamespace[Any]) -> CompliantExpr[Any]: return self.__class__( func, is_order_dependent=operation_is_order_dependent(self, other), + changes_length=operation_changes_length(self, other), + aggregates=operation_aggregates(self, other), ) def __le__(self, other: Any) -> Self: @@ -364,6 +413,8 @@ def __le__(self, other: Any) -> Self: extract_compliant(plx, other) ), is_order_dependent=operation_is_order_dependent(self, other), + changes_length=operation_changes_length(self, other), + aggregates=operation_aggregates(self, other), ) def __lt__(self, other: Any) -> Self: @@ -372,6 +423,8 @@ def __lt__(self, other: Any) -> Self: extract_compliant(plx, other) ), is_order_dependent=operation_is_order_dependent(self, other), + changes_length=operation_changes_length(self, other), + aggregates=operation_aggregates(self, other), ) def __gt__(self, other: Any) -> Self: @@ -380,6 +433,8 @@ def __gt__(self, other: Any) -> Self: extract_compliant(plx, other) ), is_order_dependent=operation_is_order_dependent(self, other), + changes_length=operation_changes_length(self, other), + aggregates=operation_aggregates(self, other), ) def __ge__(self, other: Any) -> Self: @@ -388,6 +443,8 @@ def __ge__(self, other: Any) -> Self: extract_compliant(plx, other) ), is_order_dependent=operation_is_order_dependent(self, other), + changes_length=operation_changes_length(self, other), + aggregates=operation_aggregates(self, other), ) def __pow__(self, other: Any) -> Self: @@ -396,6 +453,8 @@ def __pow__(self, other: Any) -> Self: extract_compliant(plx, other) ), is_order_dependent=operation_is_order_dependent(self, other), + changes_length=operation_changes_length(self, other), + aggregates=operation_aggregates(self, other), ) def __rpow__(self, other: Any) -> Self: @@ -407,6 +466,8 @@ def func(plx: CompliantNamespace[Any]) -> CompliantExpr[Any]: return self.__class__( func, is_order_dependent=operation_is_order_dependent(self, other), + changes_length=operation_changes_length(self, other), + aggregates=operation_aggregates(self, other), ) def __floordiv__(self, other: Any) -> Self: @@ -415,6 +476,8 @@ def __floordiv__(self, other: Any) -> Self: extract_compliant(plx, other) ), is_order_dependent=operation_is_order_dependent(self, other), + changes_length=operation_changes_length(self, other), + aggregates=operation_aggregates(self, other), ) def __rfloordiv__(self, other: Any) -> Self: @@ -426,6 +489,8 @@ def func(plx: CompliantNamespace[Any]) -> CompliantExpr[Any]: return self.__class__( func, is_order_dependent=operation_is_order_dependent(self, other), + changes_length=operation_changes_length(self, other), + aggregates=operation_aggregates(self, other), ) def __mod__(self, other: Any) -> Self: @@ -434,6 +499,8 @@ def __mod__(self, other: Any) -> Self: extract_compliant(plx, other) ), is_order_dependent=operation_is_order_dependent(self, other), + changes_length=operation_changes_length(self, other), + aggregates=operation_aggregates(self, other), ) def __rmod__(self, other: Any) -> Self: @@ -445,6 +512,8 @@ def func(plx: CompliantNamespace[Any]) -> CompliantExpr[Any]: return self.__class__( func, is_order_dependent=operation_is_order_dependent(self, other), + changes_length=operation_changes_length(self, other), + aggregates=operation_aggregates(self, other), ) # --- unary --- @@ -452,6 +521,8 @@ def __invert__(self) -> Self: return self.__class__( lambda plx: self._to_compliant_expr(plx).__invert__(), is_order_dependent=self._is_order_dependent, + changes_length=self._changes_length, + aggregates=self._aggregates, ) def any(self) -> Self: @@ -506,6 +577,8 @@ def any(self) -> Self: return self.__class__( lambda plx: self._to_compliant_expr(plx).any(), is_order_dependent=self._is_order_dependent, + changes_length=False, + aggregates=True, ) def all(self) -> Self: @@ -560,6 +633,8 @@ def all(self) -> Self: return self.__class__( lambda plx: self._to_compliant_expr(plx).all(), is_order_dependent=self._is_order_dependent, + changes_length=False, + aggregates=True, ) def ewm_mean( @@ -663,6 +738,8 @@ def ewm_mean( ignore_nulls=ignore_nulls, ), is_order_dependent=self._is_order_dependent, + changes_length=self._changes_length, + aggregates=self._aggregates, ) def mean(self) -> Self: @@ -717,6 +794,8 @@ def mean(self) -> Self: return self.__class__( lambda plx: self._to_compliant_expr(plx).mean(), is_order_dependent=self._is_order_dependent, + changes_length=False, + aggregates=True, ) def median(self) -> Self: @@ -774,6 +853,8 @@ def median(self) -> Self: return self.__class__( lambda plx: self._to_compliant_expr(plx).median(), is_order_dependent=self._is_order_dependent, + changes_length=False, + aggregates=True, ) def std(self, *, ddof: int = 1) -> Self: @@ -831,6 +912,8 @@ def std(self, *, ddof: int = 1) -> Self: return self.__class__( lambda plx: self._to_compliant_expr(plx).std(ddof=ddof), is_order_dependent=self._is_order_dependent, + changes_length=False, + aggregates=True, ) def var(self, *, ddof: int = 1) -> Self: @@ -889,6 +972,8 @@ def var(self, *, ddof: int = 1) -> Self: return self.__class__( lambda plx: self._to_compliant_expr(plx).var(ddof=ddof), is_order_dependent=self._is_order_dependent, + changes_length=False, + aggregates=True, ) def map_batches( @@ -964,7 +1049,10 @@ def map_batches( lambda plx: self._to_compliant_expr(plx).map_batches( function=function, return_dtype=return_dtype ), - is_order_dependent=True, # safest assumption + # safest assumptions + is_order_dependent=True, + changes_length=True, + aggregates=False, ) def skew(self: Self) -> Self: @@ -1019,6 +1107,8 @@ def skew(self: Self) -> Self: return self.__class__( lambda plx: self._to_compliant_expr(plx).skew(), is_order_dependent=self._is_order_dependent, + changes_length=False, + aggregates=True, ) def sum(self) -> Expr: @@ -1071,6 +1161,8 @@ def sum(self) -> Expr: return self.__class__( lambda plx: self._to_compliant_expr(plx).sum(), is_order_dependent=self._is_order_dependent, + changes_length=False, + aggregates=True, ) def min(self) -> Self: @@ -1125,6 +1217,8 @@ def min(self) -> Self: return self.__class__( lambda plx: self._to_compliant_expr(plx).min(), is_order_dependent=self._is_order_dependent, + changes_length=False, + aggregates=True, ) def max(self) -> Self: @@ -1179,6 +1273,8 @@ def max(self) -> Self: return self.__class__( lambda plx: self._to_compliant_expr(plx).max(), is_order_dependent=self._is_order_dependent, + changes_length=False, + aggregates=True, ) def arg_min(self) -> Self: @@ -1233,7 +1329,10 @@ def arg_min(self) -> Self: b_arg_min: [[1]] """ return self.__class__( - lambda plx: self._to_compliant_expr(plx).arg_min(), is_order_dependent=True + lambda plx: self._to_compliant_expr(plx).arg_min(), + is_order_dependent=True, + changes_length=False, + aggregates=True, ) def arg_max(self) -> Self: @@ -1288,7 +1387,10 @@ def arg_max(self) -> Self: b_arg_max: [[0]] """ return self.__class__( - lambda plx: self._to_compliant_expr(plx).arg_max(), is_order_dependent=True + lambda plx: self._to_compliant_expr(plx).arg_max(), + is_order_dependent=True, + changes_length=False, + aggregates=True, ) def count(self) -> Self: @@ -1341,7 +1443,10 @@ def count(self) -> Self: b: [[2]] """ return self.__class__( - lambda plx: self._to_compliant_expr(plx).count(), self._is_order_dependent + lambda plx: self._to_compliant_expr(plx).count(), + self._is_order_dependent, + changes_length=False, + aggregates=True, ) def n_unique(self) -> Self: @@ -1392,7 +1497,10 @@ def n_unique(self) -> Self: b: [[3]] """ return self.__class__( - lambda plx: self._to_compliant_expr(plx).n_unique(), self._is_order_dependent + lambda plx: self._to_compliant_expr(plx).n_unique(), + self._is_order_dependent, + changes_length=False, + aggregates=True, ) def unique(self, *, maintain_order: bool = False) -> Self: @@ -1458,6 +1566,8 @@ def unique(self, *, maintain_order: bool = False) -> Self: maintain_order=maintain_order ), self._is_order_dependent, + changes_length=True, + aggregates=self._aggregates, ) def abs(self) -> Self: @@ -1512,7 +1622,10 @@ def abs(self) -> Self: b: [[3,4]] """ return self.__class__( - lambda plx: self._to_compliant_expr(plx).abs(), self._is_order_dependent + lambda plx: self._to_compliant_expr(plx).abs(), + self._is_order_dependent, + changes_length=self._changes_length, + aggregates=self._aggregates, ) def cum_sum(self: Self, *, reverse: bool = False) -> Self: @@ -1576,6 +1689,8 @@ def cum_sum(self: Self, *, reverse: bool = False) -> Self: return self.__class__( lambda plx: self._to_compliant_expr(plx).cum_sum(reverse=reverse), is_order_dependent=True, + changes_length=self._changes_length, + aggregates=self._aggregates, ) def diff(self) -> Self: @@ -1643,7 +1758,10 @@ def diff(self) -> Self: a_diff: [[null,0,2,2,0]] """ return self.__class__( - lambda plx: self._to_compliant_expr(plx).diff(), is_order_dependent=True + lambda plx: self._to_compliant_expr(plx).diff(), + is_order_dependent=True, + changes_length=self._changes_length, + aggregates=self._aggregates, ) def shift(self, n: int) -> Self: @@ -1714,7 +1832,10 @@ def shift(self, n: int) -> Self: a_shift: [[null,1,1,3,5]] """ return self.__class__( - lambda plx: self._to_compliant_expr(plx).shift(n), is_order_dependent=True + lambda plx: self._to_compliant_expr(plx).shift(n), + is_order_dependent=True, + changes_length=self._changes_length, + aggregates=self._aggregates, ) def replace_strict( @@ -1808,6 +1929,8 @@ def replace_strict( old, new, return_dtype=return_dtype ), self._is_order_dependent, + changes_length=self._changes_length, + aggregates=self._aggregates, ) def sort(self, *, descending: bool = False, nulls_last: bool = False) -> Self: @@ -1839,6 +1962,8 @@ def sort(self, *, descending: bool = False, nulls_last: bool = False) -> Self: descending=descending, nulls_last=nulls_last ), is_order_dependent=True, + changes_length=self._changes_length, + aggregates=self._aggregates, ) # --- transform --- @@ -1916,6 +2041,8 @@ def is_between( is_order_dependent=operation_is_order_dependent( self, lower_bound, upper_bound ), + changes_length=self._changes_length, + aggregates=self._aggregates, ) def is_in(self, other: Any) -> Self: @@ -1982,6 +2109,8 @@ def is_in(self, other: Any) -> Self: extract_compliant(plx, other) ), self._is_order_dependent, + changes_length=self._changes_length, + aggregates=self._aggregates, ) else: msg = "Narwhals `is_in` doesn't accept expressions as an argument, as opposed to Polars. You should provide an iterable instead." @@ -2052,6 +2181,8 @@ def filter(self, *predicates: Any) -> Self: *[extract_compliant(plx, pred) for pred in flat_predicates], ), is_order_dependent=operation_is_order_dependent(*flat_predicates), + changes_length=True, + aggregates=self._aggregates, ) def is_null(self) -> Self: @@ -2131,7 +2262,10 @@ def is_null(self) -> Self: b_is_null: [[false,false,true,false,false]] """ return self.__class__( - lambda plx: self._to_compliant_expr(plx).is_null(), self._is_order_dependent + lambda plx: self._to_compliant_expr(plx).is_null(), + self._is_order_dependent, + changes_length=self._changes_length, + aggregates=self._aggregates, ) def is_nan(self) -> Self: @@ -2198,7 +2332,10 @@ def is_nan(self) -> Self: divided_is_nan: [[true,null,false]] """ return self.__class__( - lambda plx: self._to_compliant_expr(plx).is_nan(), self._is_order_dependent + lambda plx: self._to_compliant_expr(plx).is_nan(), + self._is_order_dependent, + changes_length=self._changes_length, + aggregates=self._aggregates, ) def arg_true(self) -> Self: @@ -2214,7 +2351,10 @@ def arg_true(self) -> Self: ) issue_deprecation_warning(msg, _version="1.23.0") return self.__class__( - lambda plx: self._to_compliant_expr(plx).arg_true(), is_order_dependent=True + lambda plx: self._to_compliant_expr(plx).arg_true(), + is_order_dependent=True, + changes_length=True, + aggregates=self._aggregates, ) def fill_null( @@ -2359,6 +2499,8 @@ def fill_null( value=value, strategy=strategy, limit=limit ), self._is_order_dependent, + changes_length=self._changes_length, + aggregates=self._aggregates, ) # --- partial reduction --- @@ -2422,6 +2564,8 @@ def drop_nulls(self) -> Self: return self.__class__( lambda plx: self._to_compliant_expr(plx).drop_nulls(), self._is_order_dependent, + changes_length=True, + aggregates=self._aggregates, ) def sample( @@ -2463,6 +2607,8 @@ def sample( n, fraction=fraction, with_replacement=with_replacement, seed=seed ), self._is_order_dependent, + changes_length=True, + aggregates=self._aggregates, ) def over(self, *keys: str | Iterable[str]) -> Self: @@ -2555,6 +2701,8 @@ def over(self, *keys: str | Iterable[str]) -> Self: return self.__class__( lambda plx: self._to_compliant_expr(plx).over(flatten(keys)), self._is_order_dependent, + changes_length=self._changes_length, + aggregates=self._aggregates, ) def is_duplicated(self) -> Self: @@ -2615,6 +2763,8 @@ def is_duplicated(self) -> Self: return self.__class__( lambda plx: self._to_compliant_expr(plx).is_duplicated(), self._is_order_dependent, + changes_length=self._changes_length, + aggregates=self._aggregates, ) def is_unique(self) -> Self: @@ -2673,7 +2823,10 @@ def is_unique(self) -> Self: b: [[false,false,true,true]] """ return self.__class__( - lambda plx: self._to_compliant_expr(plx).is_unique(), self._is_order_dependent + lambda plx: self._to_compliant_expr(plx).is_unique(), + self._is_order_dependent, + changes_length=self._changes_length, + aggregates=self._aggregates, ) def null_count(self) -> Self: @@ -2733,6 +2886,8 @@ def null_count(self) -> Self: return self.__class__( lambda plx: self._to_compliant_expr(plx).null_count(), self._is_order_dependent, + changes_length=self._changes_length, + aggregates=self._aggregates, ) def is_first_distinct(self) -> Self: @@ -2793,6 +2948,8 @@ def is_first_distinct(self) -> Self: return self.__class__( lambda plx: self._to_compliant_expr(plx).is_first_distinct(), is_order_dependent=True, + changes_length=self._changes_length, + aggregates=self._aggregates, ) def is_last_distinct(self) -> Self: @@ -2853,6 +3010,8 @@ def is_last_distinct(self) -> Self: return self.__class__( lambda plx: self._to_compliant_expr(plx).is_last_distinct(), is_order_dependent=True, + changes_length=self._changes_length, + aggregates=self._aggregates, ) def quantile( @@ -2924,6 +3083,8 @@ def quantile( return self.__class__( lambda plx: self._to_compliant_expr(plx).quantile(quantile, interpolation), self._is_order_dependent, + changes_length=False, + aggregates=True, ) def head(self, n: int = 10) -> Self: @@ -2950,7 +3111,10 @@ def head(self, n: int = 10) -> Self: ) issue_deprecation_warning(msg, _version="1.22.0") return self.__class__( - lambda plx: self._to_compliant_expr(plx).head(n), is_order_dependent=True + lambda plx: self._to_compliant_expr(plx).head(n), + is_order_dependent=True, + changes_length=True, + aggregates=self._aggregates, ) def tail(self, n: int = 10) -> Self: @@ -2977,7 +3141,10 @@ def tail(self, n: int = 10) -> Self: ) issue_deprecation_warning(msg, _version="1.22.0") return self.__class__( - lambda plx: self._to_compliant_expr(plx).tail(n), is_order_dependent=True + lambda plx: self._to_compliant_expr(plx).tail(n), + is_order_dependent=True, + changes_length=True, + aggregates=self._aggregates, ) def round(self, decimals: int = 0) -> Self: @@ -3046,6 +3213,8 @@ def round(self, decimals: int = 0) -> Self: return self.__class__( lambda plx: self._to_compliant_expr(plx).round(decimals), self._is_order_dependent, + changes_length=self._changes_length, + aggregates=self._aggregates, ) def len(self) -> Self: @@ -3104,7 +3273,10 @@ def len(self) -> Self: a2: [[1]] """ return self.__class__( - lambda plx: self._to_compliant_expr(plx).len(), self._is_order_dependent + lambda plx: self._to_compliant_expr(plx).len(), + self._is_order_dependent, + changes_length=False, + aggregates=True, ) def gather_every(self: Self, n: int, offset: int = 0) -> Self: @@ -3134,6 +3306,8 @@ def gather_every(self: Self, n: int, offset: int = 0) -> Self: return self.__class__( lambda plx: self._to_compliant_expr(plx).gather_every(n=n, offset=offset), is_order_dependent=True, + changes_length=True, + aggregates=self._aggregates, ) # need to allow numeric typing @@ -3284,6 +3458,8 @@ def clip( is_order_dependent=operation_is_order_dependent( self, lower_bound, upper_bound ), + changes_length=self._changes_length, + aggregates=self._aggregates, ) def mode(self: Self) -> Self: @@ -3339,7 +3515,10 @@ def mode(self: Self) -> Self: a: [[1]] """ return self.__class__( - lambda plx: self._to_compliant_expr(plx).mode(), self._is_order_dependent + lambda plx: self._to_compliant_expr(plx).mode(), + self._is_order_dependent, + changes_length=True, + aggregates=self._aggregates, ) def is_finite(self: Self) -> Self: @@ -3401,7 +3580,10 @@ def is_finite(self: Self) -> Self: a: [[false,false,true,null]] """ return self.__class__( - lambda plx: self._to_compliant_expr(plx).is_finite(), self._is_order_dependent + lambda plx: self._to_compliant_expr(plx).is_finite(), + self._is_order_dependent, + changes_length=self._changes_length, + aggregates=self._aggregates, ) def cum_count(self: Self, *, reverse: bool = False) -> Self: @@ -3470,6 +3652,8 @@ def cum_count(self: Self, *, reverse: bool = False) -> Self: return self.__class__( lambda plx: self._to_compliant_expr(plx).cum_count(reverse=reverse), is_order_dependent=True, + changes_length=self._changes_length, + aggregates=self._aggregates, ) def cum_min(self: Self, *, reverse: bool = False) -> Self: @@ -3538,6 +3722,8 @@ def cum_min(self: Self, *, reverse: bool = False) -> Self: return self.__class__( lambda plx: self._to_compliant_expr(plx).cum_min(reverse=reverse), is_order_dependent=True, + changes_length=self._changes_length, + aggregates=self._aggregates, ) def cum_max(self: Self, *, reverse: bool = False) -> Self: @@ -3606,6 +3792,8 @@ def cum_max(self: Self, *, reverse: bool = False) -> Self: return self.__class__( lambda plx: self._to_compliant_expr(plx).cum_max(reverse=reverse), is_order_dependent=True, + changes_length=self._changes_length, + aggregates=self._aggregates, ) def cum_prod(self: Self, *, reverse: bool = False) -> Self: @@ -3674,6 +3862,8 @@ def cum_prod(self: Self, *, reverse: bool = False) -> Self: return self.__class__( lambda plx: self._to_compliant_expr(plx).cum_prod(reverse=reverse), is_order_dependent=True, + changes_length=self._changes_length, + aggregates=self._aggregates, ) def rolling_sum( @@ -3769,6 +3959,8 @@ def rolling_sum( center=center, ), is_order_dependent=True, + changes_length=self._changes_length, + aggregates=self._aggregates, ) def rolling_mean( @@ -3864,6 +4056,8 @@ def rolling_mean( center=center, ), is_order_dependent=True, + changes_length=self._changes_length, + aggregates=self._aggregates, ) def rolling_var( @@ -3959,6 +4153,8 @@ def rolling_var( window_size=window_size, min_periods=min_periods, center=center, ddof=ddof ), is_order_dependent=True, + changes_length=self._changes_length, + aggregates=self._aggregates, ) def rolling_std( @@ -4057,6 +4253,8 @@ def rolling_std( ddof=ddof, ), is_order_dependent=True, + changes_length=self._changes_length, + aggregates=self._aggregates, ) def rank( @@ -4155,6 +4353,8 @@ def rank( method=method, descending=descending ), is_order_dependent=True, + changes_length=self._changes_length, + aggregates=self._aggregates, ) @property diff --git a/narwhals/expr_cat.py b/narwhals/expr_cat.py index baf467df3f..16dbb39298 100644 --- a/narwhals/expr_cat.py +++ b/narwhals/expr_cat.py @@ -64,4 +64,6 @@ def get_categories(self: Self) -> ExprT: return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).cat.get_categories(), self._expr._is_order_dependent, + changes_length=True, + aggregates=self._expr._aggregates, ) diff --git a/narwhals/expr_dt.py b/narwhals/expr_dt.py index 6b981315d8..6ea1fbbdd2 100644 --- a/narwhals/expr_dt.py +++ b/narwhals/expr_dt.py @@ -73,6 +73,8 @@ def date(self: Self) -> ExprT: return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).dt.date(), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def year(self: Self) -> ExprT: @@ -142,6 +144,8 @@ def year(self: Self) -> ExprT: return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).dt.year(), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def month(self: Self) -> ExprT: @@ -211,6 +215,8 @@ def month(self: Self) -> ExprT: return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).dt.month(), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def day(self: Self) -> ExprT: @@ -280,6 +286,8 @@ def day(self: Self) -> ExprT: return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).dt.day(), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def hour(self: Self) -> ExprT: @@ -349,6 +357,8 @@ def hour(self: Self) -> ExprT: return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).dt.hour(), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def minute(self: Self) -> ExprT: @@ -418,6 +428,8 @@ def minute(self: Self) -> ExprT: return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).dt.minute(), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def second(self: Self) -> ExprT: @@ -485,6 +497,8 @@ def second(self: Self) -> ExprT: return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).dt.second(), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def millisecond(self: Self) -> ExprT: @@ -552,6 +566,8 @@ def millisecond(self: Self) -> ExprT: return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).dt.millisecond(), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def microsecond(self: Self) -> ExprT: @@ -619,6 +635,8 @@ def microsecond(self: Self) -> ExprT: return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).dt.microsecond(), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def nanosecond(self: Self) -> ExprT: @@ -686,6 +704,8 @@ def nanosecond(self: Self) -> ExprT: return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).dt.nanosecond(), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def ordinal_day(self: Self) -> ExprT: @@ -745,6 +765,8 @@ def ordinal_day(self: Self) -> ExprT: return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).dt.ordinal_day(), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def weekday(self: Self) -> ExprT: @@ -802,6 +824,8 @@ def weekday(self: Self) -> ExprT: return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).dt.weekday(), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def total_minutes(self: Self) -> ExprT: @@ -866,6 +890,8 @@ def total_minutes(self: Self) -> ExprT: return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).dt.total_minutes(), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def total_seconds(self: Self) -> ExprT: @@ -930,6 +956,8 @@ def total_seconds(self: Self) -> ExprT: return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).dt.total_seconds(), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def total_milliseconds(self: Self) -> ExprT: @@ -999,6 +1027,8 @@ def total_milliseconds(self: Self) -> ExprT: return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).dt.total_milliseconds(), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def total_microseconds(self: Self) -> ExprT: @@ -1068,6 +1098,8 @@ def total_microseconds(self: Self) -> ExprT: return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).dt.total_microseconds(), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def total_nanoseconds(self: Self) -> ExprT: @@ -1124,6 +1156,8 @@ def total_nanoseconds(self: Self) -> ExprT: return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).dt.total_nanoseconds(), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def to_string(self: Self, format: str) -> ExprT: # noqa: A002 @@ -1223,6 +1257,8 @@ def to_string(self: Self, format: str) -> ExprT: # noqa: A002 return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).dt.to_string(format), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def replace_time_zone(self: Self, time_zone: str | None) -> ExprT: @@ -1290,6 +1326,8 @@ def replace_time_zone(self: Self, time_zone: str | None) -> ExprT: time_zone ), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def convert_time_zone(self: Self, time_zone: str) -> ExprT: @@ -1363,6 +1401,8 @@ def convert_time_zone(self: Self, time_zone: str) -> ExprT: time_zone ), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def timestamp(self: Self, time_unit: Literal["ns", "us", "ms"] = "us") -> ExprT: @@ -1437,4 +1477,6 @@ def timestamp(self: Self, time_unit: Literal["ns", "us", "ms"] = "us") -> ExprT: return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).dt.timestamp(time_unit), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) diff --git a/narwhals/expr_list.py b/narwhals/expr_list.py index c64defda87..0532db5fe8 100644 --- a/narwhals/expr_list.py +++ b/narwhals/expr_list.py @@ -75,4 +75,6 @@ def len(self: Self) -> ExprT: return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).list.len(), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) diff --git a/narwhals/expr_name.py b/narwhals/expr_name.py index 0d428cea1e..706f9427da 100644 --- a/narwhals/expr_name.py +++ b/narwhals/expr_name.py @@ -61,6 +61,8 @@ def keep(self: Self) -> ExprT: return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).name.keep(), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def map(self: Self, function: Callable[[str], str]) -> ExprT: @@ -111,6 +113,8 @@ def map(self: Self, function: Callable[[str], str]) -> ExprT: return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).name.map(function), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def prefix(self: Self, prefix: str) -> ExprT: @@ -160,6 +164,8 @@ def prefix(self: Self, prefix: str) -> ExprT: return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).name.prefix(prefix), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def suffix(self: Self, suffix: str) -> ExprT: @@ -209,6 +215,8 @@ def suffix(self: Self, suffix: str) -> ExprT: return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).name.suffix(suffix), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def to_lowercase(self: Self) -> ExprT: @@ -255,6 +263,8 @@ def to_lowercase(self: Self) -> ExprT: return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).name.to_lowercase(), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def to_uppercase(self: Self) -> ExprT: @@ -301,4 +311,6 @@ def to_uppercase(self: Self) -> ExprT: return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).name.to_uppercase(), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) diff --git a/narwhals/expr_str.py b/narwhals/expr_str.py index 0ea89ceb4c..90283930a7 100644 --- a/narwhals/expr_str.py +++ b/narwhals/expr_str.py @@ -78,6 +78,8 @@ def len_chars(self: Self) -> ExprT: return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).str.len_chars(), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def replace( @@ -145,6 +147,8 @@ def replace( pattern, value, literal=literal, n=n ), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def replace_all( @@ -211,6 +215,8 @@ def replace_all( pattern, value, literal=literal ), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def strip_chars(self: Self, characters: str | None = None) -> ExprT: @@ -260,6 +266,8 @@ def strip_chars(self: Self, characters: str | None = None) -> ExprT: return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).str.strip_chars(characters), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def starts_with(self: Self, prefix: str) -> ExprT: @@ -323,6 +331,8 @@ def starts_with(self: Self, prefix: str) -> ExprT: return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).str.starts_with(prefix), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def ends_with(self: Self, suffix: str) -> ExprT: @@ -386,6 +396,8 @@ def ends_with(self: Self, suffix: str) -> ExprT: return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).str.ends_with(suffix), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def contains(self: Self, pattern: str, *, literal: bool = False) -> ExprT: @@ -465,6 +477,8 @@ def contains(self: Self, pattern: str, *, literal: bool = False) -> ExprT: pattern, literal=literal ), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def slice(self: Self, offset: int, length: int | None = None) -> ExprT: @@ -568,6 +582,8 @@ def slice(self: Self, offset: int, length: int | None = None) -> ExprT: offset=offset, length=length ), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def head(self: Self, n: int = 5) -> ExprT: @@ -636,6 +652,8 @@ def head(self: Self, n: int = 5) -> ExprT: return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).str.slice(0, n), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def tail(self: Self, n: int = 5) -> ExprT: @@ -706,6 +724,8 @@ def tail(self: Self, n: int = 5) -> ExprT: offset=-n, length=None ), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def to_datetime(self: Self, format: str | None = None) -> ExprT: # noqa: A002 @@ -776,6 +796,8 @@ def to_datetime(self: Self, format: str | None = None) -> ExprT: # noqa: A002 return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).str.to_datetime(format=format), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def to_uppercase(self: Self) -> ExprT: @@ -841,6 +863,8 @@ def to_uppercase(self: Self) -> ExprT: return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).str.to_uppercase(), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def to_lowercase(self: Self) -> ExprT: @@ -901,4 +925,6 @@ def to_lowercase(self: Self) -> ExprT: return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).str.to_lowercase(), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) diff --git a/narwhals/functions.py b/narwhals/functions.py index 9b3a0609ac..1845edf384 100644 --- a/narwhals/functions.py +++ b/narwhals/functions.py @@ -13,6 +13,8 @@ from typing import overload from narwhals._expression_parsing import extract_compliant +from narwhals._expression_parsing import operation_aggregates +from narwhals._expression_parsing import operation_changes_length from narwhals._expression_parsing import operation_is_order_dependent from narwhals._pandas_like.utils import broadcast_align_and_extract_native from narwhals.dataframe import DataFrame @@ -1393,7 +1395,7 @@ def col(*names: str | Iterable[str]) -> Expr: def func(plx: Any) -> Any: return plx.col(*flatten(names)) - return Expr(func, is_order_dependent=False) + return Expr(func, is_order_dependent=False, changes_length=False, aggregates=False) def nth(*indices: int | Sequence[int]) -> Expr: @@ -1455,7 +1457,7 @@ def nth(*indices: int | Sequence[int]) -> Expr: def func(plx: Any) -> Any: return plx.nth(*flatten(indices)) - return Expr(func, is_order_dependent=False) + return Expr(func, is_order_dependent=False, changes_length=False, aggregates=False) # Add underscore so it doesn't conflict with builtin `all` @@ -1512,7 +1514,12 @@ def all_() -> Expr: a: [[2,4,6]] b: [[8,10,12]] """ - return Expr(lambda plx: plx.all(), is_order_dependent=False) + return Expr( + lambda plx: plx.all(), + is_order_dependent=False, + changes_length=False, + aggregates=False, + ) # Add underscore so it doesn't conflict with builtin `len` @@ -1565,7 +1572,7 @@ def len_() -> Expr: def func(plx: Any) -> Any: return plx.len() - return Expr(func, is_order_dependent=False) + return Expr(func, is_order_dependent=False, changes_length=False, aggregates=True) def sum(*columns: str) -> Expr: @@ -1621,7 +1628,12 @@ def sum(*columns: str) -> Expr: ---- a: [[3]] """ - return Expr(lambda plx: plx.col(*columns).sum(), is_order_dependent=False) + return Expr( + lambda plx: plx.col(*columns).sum(), + is_order_dependent=False, + changes_length=False, + aggregates=True, + ) def mean(*columns: str) -> Expr: @@ -1677,7 +1689,12 @@ def mean(*columns: str) -> Expr: ---- a: [[4]] """ - return Expr(lambda plx: plx.col(*columns).mean(), is_order_dependent=False) + return Expr( + lambda plx: plx.col(*columns).mean(), + is_order_dependent=False, + changes_length=False, + aggregates=True, + ) def median(*columns: str) -> Expr: @@ -1735,7 +1752,12 @@ def median(*columns: str) -> Expr: ---- a: [[4]] """ - return Expr(lambda plx: plx.col(*columns).median(), is_order_dependent=False) + return Expr( + lambda plx: plx.col(*columns).median(), + is_order_dependent=False, + changes_length=False, + aggregates=True, + ) def min(*columns: str) -> Expr: @@ -1791,7 +1813,12 @@ def min(*columns: str) -> Expr: ---- b: [[5]] """ - return Expr(lambda plx: plx.col(*columns).min(), is_order_dependent=False) + return Expr( + lambda plx: plx.col(*columns).min(), + is_order_dependent=False, + changes_length=False, + aggregates=True, + ) def max(*columns: str) -> Expr: @@ -1847,7 +1874,12 @@ def max(*columns: str) -> Expr: ---- a: [[2]] """ - return Expr(lambda plx: plx.col(*columns).max(), is_order_dependent=False) + return Expr( + lambda plx: plx.col(*columns).max(), + is_order_dependent=False, + changes_length=False, + aggregates=True, + ) def sum_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: @@ -1914,6 +1946,8 @@ def sum_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: return Expr( lambda plx: plx.sum_horizontal(*[extract_compliant(plx, v) for v in flat_exprs]), is_order_dependent=operation_is_order_dependent(*flat_exprs), + changes_length=operation_changes_length(*flat_exprs), + aggregates=operation_aggregates(*flat_exprs), ) @@ -1984,6 +2018,8 @@ def min_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: return Expr( lambda plx: plx.min_horizontal(*[extract_compliant(plx, v) for v in flat_exprs]), is_order_dependent=operation_is_order_dependent(*flat_exprs), + changes_length=operation_changes_length(*flat_exprs), + aggregates=operation_aggregates(*flat_exprs), ) @@ -2054,6 +2090,8 @@ def max_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: return Expr( lambda plx: plx.max_horizontal(*[extract_compliant(plx, v) for v in flat_exprs]), is_order_dependent=operation_is_order_dependent(*flat_exprs), + changes_length=operation_changes_length(*flat_exprs), + aggregates=operation_aggregates(*flat_exprs), ) @@ -2073,6 +2111,8 @@ def then(self, value: IntoExpr | Any) -> Then: extract_compliant(plx, value) ), is_order_dependent=operation_is_order_dependent(*self._predicates, value), + changes_length=operation_changes_length(*self._predicates, value), + aggregates=operation_aggregates(*self._predicates, value), ) @@ -2083,6 +2123,8 @@ def otherwise(self, value: IntoExpr | Any) -> Expr: extract_compliant(plx, value) ), is_order_dependent=operation_is_order_dependent(self, value), + changes_length=operation_changes_length(self, value), + aggregates=operation_aggregates(self, value), ) @@ -2236,6 +2278,8 @@ def all_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: return Expr( lambda plx: plx.all_horizontal(*[extract_compliant(plx, v) for v in flat_exprs]), is_order_dependent=operation_is_order_dependent(*flat_exprs), + changes_length=operation_changes_length(*flat_exprs), + aggregates=operation_aggregates(*flat_exprs), ) @@ -2306,7 +2350,12 @@ def lit(value: Any, dtype: DType | type[DType] | None = None) -> Expr: msg = f"Nested datatypes are not supported yet. Got {value}" raise NotImplementedError(msg) - return Expr(lambda plx: plx.lit(value, dtype), is_order_dependent=False) + return Expr( + lambda plx: plx.lit(value, dtype), + is_order_dependent=False, + changes_length=False, + aggregates=True, + ) def any_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: @@ -2384,6 +2433,8 @@ def any_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: return Expr( lambda plx: plx.any_horizontal(*[extract_compliant(plx, v) for v in flat_exprs]), is_order_dependent=operation_is_order_dependent(*flat_exprs), + changes_length=operation_changes_length(*flat_exprs), + aggregates=operation_aggregates(*flat_exprs), ) @@ -2454,6 +2505,8 @@ def mean_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: return Expr( lambda plx: plx.mean_horizontal(*[extract_compliant(plx, v) for v in flat_exprs]), is_order_dependent=operation_is_order_dependent(*flat_exprs), + changes_length=operation_changes_length(*flat_exprs), + aggregates=operation_aggregates(*flat_exprs), ) @@ -2544,4 +2597,6 @@ def concat_str( ignore_nulls=ignore_nulls, ), is_order_dependent=operation_is_order_dependent(*flat_exprs, *more_exprs), + changes_length=operation_changes_length(*flat_exprs, *more_exprs), + aggregates=operation_aggregates(*flat_exprs, *more_exprs), ) diff --git a/narwhals/selectors.py b/narwhals/selectors.py index 664fdc9cac..e67424281e 100644 --- a/narwhals/selectors.py +++ b/narwhals/selectors.py @@ -53,7 +53,10 @@ def by_dtype(*dtypes: Any) -> Expr: └─────┴─────┘ """ return Selector( - lambda plx: plx.selectors.by_dtype(flatten(dtypes)), is_order_dependent=False + lambda plx: plx.selectors.by_dtype(flatten(dtypes)), + is_order_dependent=False, + changes_length=False, + aggregates=False, ) @@ -97,7 +100,12 @@ def numeric() -> Expr: │ 4 ┆ 4.6 │ └─────┴─────┘ """ - return Selector(lambda plx: plx.selectors.numeric(), is_order_dependent=False) + return Selector( + lambda plx: plx.selectors.numeric(), + is_order_dependent=False, + changes_length=False, + aggregates=False, + ) def boolean() -> Expr: @@ -140,7 +148,12 @@ def boolean() -> Expr: │ true │ └───────┘ """ - return Selector(lambda plx: plx.selectors.boolean(), is_order_dependent=False) + return Selector( + lambda plx: plx.selectors.boolean(), + is_order_dependent=False, + changes_length=False, + aggregates=False, + ) def string() -> Expr: @@ -183,7 +196,12 @@ def string() -> Expr: │ y │ └─────┘ """ - return Selector(lambda plx: plx.selectors.string(), is_order_dependent=False) + return Selector( + lambda plx: plx.selectors.string(), + is_order_dependent=False, + changes_length=False, + aggregates=False, + ) def categorical() -> Expr: @@ -226,7 +244,12 @@ def categorical() -> Expr: │ y │ └─────┘ """ - return Selector(lambda plx: plx.selectors.categorical(), is_order_dependent=False) + return Selector( + lambda plx: plx.selectors.categorical(), + is_order_dependent=False, + changes_length=False, + aggregates=False, + ) def all() -> Expr: @@ -269,7 +292,12 @@ def all() -> Expr: │ 2 ┆ y ┆ true │ └─────┴─────┴───────┘ """ - return Selector(lambda plx: plx.selectors.all(), is_order_dependent=False) + return Selector( + lambda plx: plx.selectors.all(), + is_order_dependent=False, + changes_length=False, + aggregates=False, + ) __all__ = [ diff --git a/narwhals/stable/v1/__init__.py b/narwhals/stable/v1/__init__.py index 76dee2e491..6e02b08ff0 100644 --- a/narwhals/stable/v1/__init__.py +++ b/narwhals/stable/v1/__init__.py @@ -240,7 +240,7 @@ def _dataframe(self) -> type[DataFrame[Any]]: return DataFrame def _extract_compliant(self, arg: Any) -> Any: - # After v1, we raise when passing order-dependent + # After v1, we raise when passing order-dependent or length-changing # expressions to LazyFrame from narwhals.dataframe import BaseFrame from narwhals.expr import Expr @@ -252,7 +252,7 @@ def _extract_compliant(self, arg: Any) -> Any: msg = "Mixing Series with LazyFrame is not supported." raise TypeError(msg) if isinstance(arg, Expr): - # After stable.v1, we raise if arg._is_order_dependent + # After stable.v1, we raise if arg._is_order_dependent or arg._changes_length return arg._to_compliant_expr(self.__narwhals_namespace__()) if get_polars() is not None and "polars" in str(type(arg)): # pragma: no cover msg = ( @@ -872,7 +872,10 @@ def head(self, n: int = 10) -> Self: A new expression. """ return self.__class__( - lambda plx: self._to_compliant_expr(plx).head(n), is_order_dependent=True + lambda plx: self._to_compliant_expr(plx).head(n), + is_order_dependent=True, + changes_length=True, + aggregates=self._aggregates, ) def tail(self, n: int = 10) -> Self: @@ -885,7 +888,10 @@ def tail(self, n: int = 10) -> Self: A new expression. """ return self.__class__( - lambda plx: self._to_compliant_expr(plx).tail(n), is_order_dependent=True + lambda plx: self._to_compliant_expr(plx).tail(n), + is_order_dependent=True, + changes_length=True, + aggregates=self._aggregates, ) def gather_every(self: Self, n: int, offset: int = 0) -> Self: @@ -901,6 +907,8 @@ def gather_every(self: Self, n: int, offset: int = 0) -> Self: return self.__class__( lambda plx: self._to_compliant_expr(plx).gather_every(n=n, offset=offset), is_order_dependent=True, + changes_length=True, + aggregates=self._aggregates, ) def sort(self, *, descending: bool = False, nulls_last: bool = False) -> Self: @@ -918,6 +926,8 @@ def sort(self, *, descending: bool = False, nulls_last: bool = False) -> Self: descending=descending, nulls_last=nulls_last ), is_order_dependent=True, + changes_length=self._changes_length, + aggregates=self._aggregates, ) def arg_true(self) -> Self: @@ -927,7 +937,10 @@ def arg_true(self) -> Self: A new expression. """ return self.__class__( - lambda plx: self._to_compliant_expr(plx).arg_true(), is_order_dependent=True + lambda plx: self._to_compliant_expr(plx).arg_true(), + is_order_dependent=True, + changes_length=True, + aggregates=self._aggregates, ) def sample( @@ -962,6 +975,8 @@ def sample( n, fraction=fraction, with_replacement=with_replacement, seed=seed ), is_order_dependent=True, + changes_length=True, + aggregates=self._aggregates, ) @@ -1006,7 +1021,12 @@ def _stableify( level=obj._level, ) if isinstance(obj, NwExpr): - return Expr(obj._to_compliant_expr, is_order_dependent=obj._is_order_dependent) + return Expr( + obj._to_compliant_expr, + is_order_dependent=obj._is_order_dependent, + changes_length=obj._changes_length, + aggregates=obj._aggregates, + ) return obj @@ -1964,7 +1984,12 @@ def then(self, value: Any) -> Then: class Then(NwThen, Expr): @classmethod def from_then(cls, then: NwThen) -> Self: - return cls(then._to_compliant_expr, is_order_dependent=then._is_order_dependent) + return cls( + then._to_compliant_expr, + is_order_dependent=then._is_order_dependent, + changes_length=then._changes_length, + aggregates=then._aggregates, + ) def otherwise(self, value: Any) -> Expr: return _stableify(super().otherwise(value)) diff --git a/tests/expr_and_series/drop_nulls_test.py b/tests/expr_and_series/drop_nulls_test.py index 0584674e6c..b4e1f7c6b0 100644 --- a/tests/expr_and_series/drop_nulls_test.py +++ b/tests/expr_and_series/drop_nulls_test.py @@ -1,6 +1,9 @@ from __future__ import annotations +import pytest + import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import ConstructorEager from tests.utils import assert_equal_data @@ -30,6 +33,22 @@ def test_drop_nulls(constructor_eager: ConstructorEager) -> None: assert_equal_data(result_d, expected_d) +def test_drop_nulls_agg(constructor: Constructor, request: pytest.FixtureRequest) -> None: + if any(x in str(constructor) for x in ("duckdb", "pyspark")): + request.applymarker(pytest.mark.xfail) + data = { + "A": [1, 2, None, 4], + "B": [5, 6, 7, 8], + "C": [None, None, None, None], + "D": [9, 10, 11, 12], + } + + df = nw.from_native(constructor(data)) + result = df.select(nw.all().drop_nulls().len()) + expected = {"A": [3], "B": [4], "C": [0], "D": [4]} + assert_equal_data(result, expected) + + def test_drop_nulls_series(constructor_eager: ConstructorEager) -> None: data = { "A": [1, 2, None, 4], diff --git a/tests/expr_and_series/unique_test.py b/tests/expr_and_series/unique_test.py index 989a577c24..b2877c646d 100644 --- a/tests/expr_and_series/unique_test.py +++ b/tests/expr_and_series/unique_test.py @@ -1,6 +1,12 @@ from __future__ import annotations -import narwhals.stable.v1 as nw +from contextlib import nullcontext as does_not_raise + +import pytest + +import narwhals as nw +from narwhals.exceptions import LengthChangingExprError +from tests.utils import Constructor from tests.utils import ConstructorEager from tests.utils import assert_equal_data @@ -8,13 +14,36 @@ data_str = {"a": ["x", "x", "y"]} -def test_unique_expr(constructor_eager: ConstructorEager) -> None: - df = nw.from_native(constructor_eager(data)) - result = df.select(nw.col("a").unique()) - expected = {"a": [1, 2]} +def test_unique_expr(constructor: Constructor) -> None: + df = nw.from_native(constructor(data)) + context = ( + pytest.raises(LengthChangingExprError) + if isinstance(df, nw.LazyFrame) + else does_not_raise() + ) + with context: + result = df.select(nw.col("a").unique()) + expected = {"a": [1, 2]} + assert_equal_data(result, expected) + + +def test_unique_expr_agg( + constructor: Constructor, request: pytest.FixtureRequest +) -> None: + if any(x in str(constructor) for x in ("duckdb", "pyspark")): + request.applymarker(pytest.mark.xfail) + df = nw.from_native(constructor(data)) + result = df.select(nw.col("a").unique().sum()) + expected = {"a": [3]} assert_equal_data(result, expected) +def test_unique_illegal_combination(constructor: Constructor) -> None: + df = nw.from_native(constructor(data)) + with pytest.raises(LengthChangingExprError): + df.select((nw.col("a").unique() + nw.col("b").unique()).sum()) + + def test_unique_series(constructor_eager: ConstructorEager) -> None: series = nw.from_native(constructor_eager(data_str), eager_only=True)["a"] result = series.unique(maintain_order=True) diff --git a/tests/frame/filter_test.py b/tests/frame/filter_test.py index 1b691c5c1e..96aa2b44e6 100644 --- a/tests/frame/filter_test.py +++ b/tests/frame/filter_test.py @@ -6,6 +6,7 @@ from polars.exceptions import ShapeError as PlShapeError import narwhals as nw +from narwhals.exceptions import LengthChangingExprError from narwhals.exceptions import ShapeError from tests.utils import Constructor from tests.utils import assert_equal_data @@ -54,17 +55,5 @@ def test_filter_raise_on_agg_predicate(constructor: Constructor) -> None: def test_filter_raise_on_shape_mismatch(constructor: Constructor) -> None: data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8.0, 9.0]} df = nw.from_native(constructor(data)) - - context = ( - pytest.raises( - ShapeError, match="filter's length: 2 differs from that of the series: 3" - ) - if any(x in str(constructor) for x in ("pandas", "pyarrow", "modin")) - else pytest.raises( - PlShapeError, match="filter's length: 2 differs from that of the series: 3" - ) - if "polars" in str(constructor) - else pytest.raises(Exception) # type: ignore[arg-type] # noqa: PT011 - ) - with context: + with pytest.raises((LengthChangingExprError, ShapeError, PlShapeError)): df.filter(nw.col("b").unique() > 2).lazy().collect() diff --git a/utils/check_api_reference.py b/utils/check_api_reference.py index 8df16e88b0..e0ebf97a9a 100644 --- a/utils/check_api_reference.py +++ b/utils/check_api_reference.py @@ -161,7 +161,9 @@ # Expr methods expr_methods = [ i - for i in nw.Expr(lambda: 0, is_order_dependent=False).__dir__() + for i in nw.Expr( + lambda: 0, is_order_dependent=False, changes_length=False, aggregates=False + ).__dir__() if not i[0].isupper() and i[0] != "_" ] with open("docs/api-reference/expr.md") as fd: @@ -185,7 +187,13 @@ expr_methods = [ i for i in getattr( - nw.Expr(lambda: 0, is_order_dependent=False), namespace + nw.Expr( + lambda: 0, + is_order_dependent=False, + changes_length=False, + aggregates=False, + ), + namespace, ).__dir__() if not i[0].isupper() and i[0] != "_" ] @@ -228,7 +236,9 @@ # Check Expr vs Series expr = [ i - for i in nw.Expr(lambda: 0, is_order_dependent=False).__dir__() + for i in nw.Expr( + lambda: 0, is_order_dependent=False, changes_length=False, aggregates=False + ).__dir__() if not i[0].isupper() and i[0] != "_" ] series = [ @@ -250,7 +260,13 @@ expr_internal = [ i for i in getattr( - nw.Expr(lambda: 0, is_order_dependent=False), namespace + nw.Expr( + lambda: 0, + is_order_dependent=False, + changes_length=False, + aggregates=False, + ), + namespace, ).__dir__() if not i[0].isupper() and i[0] != "_" ]