From 0551b50a32418cfa9f384c27c421e0a7a479a76c Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sun, 19 Jan 2025 12:49:10 +0400 Subject: [PATCH 01/10] track "changes length" and "aggregates --- narwhals/_expression_parsing.py | 41 ++++++ narwhals/expr.py | 228 +++++++++++++++++++++++++++++--- narwhals/expr_cat.py | 2 + narwhals/expr_dt.py | 42 ++++++ narwhals/expr_list.py | 2 + narwhals/expr_name.py | 12 ++ narwhals/expr_str.py | 26 ++++ narwhals/functions.py | 75 +++++++++-- narwhals/selectors.py | 40 +++++- narwhals/stable/v1/__init__.py | 30 ++++- utils/check_api_reference.py | 24 +++- 11 files changed, 480 insertions(+), 42 deletions(-) diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index dc52da002..fed77fe44 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -342,3 +342,44 @@ def operation_is_order_dependent(*args: IntoExpr | Any) -> bool: # it means that it was a scalar (e.g. nw.col('a') + 1) or a column name, # neither of which is order-dependent, so we default to `False`. return any(getattr(x, "_is_order_dependent", False) for x in args) + + +def operation_changes_length(*args: IntoExpr | Any) -> bool: + """Track whether operation changes length. + + n-ary operations between expressions which change length are not + allowed. This is because the output might be non-relational. For + example: + df = pl.LazyFrame({'a': [1,2,None], 'b': [4,None,6]}) + df.select(pl.col('a', 'b').drop_nulls()) + Polars does allow this, but in the result we end up with the + tuple (2, 6) which wasn't part of the original data. + + Rules are: + - in an n-ary operation, if any one of them changes length, then + it must be the only expression present + - in a comparison between a changes-length expression and a + scalar, the output changes length + """ + from narwhals.expr import Expr + + n_length_changing_expressions = len( + [x for x in args if isinstance(x, Expr) and x._changes_length] + ) + if n_length_changing_expressions > 1: + msg = ( + "Found multiple expressions which change length. You can only use one " + "length-changing expression at a time, unless it is followed by an aggregation." + ) + # TODO(marco): custom error class + raise ValueError(msg) + return n_length_changing_expressions > 0 + + +def operation_aggregates(*args: IntoExpr | Any) -> bool: + # If an arg is an Expr, we look at `_aggregates`. If it isn't, + # it means that it was a scalar (e.g. nw.col('a').sum() + 1), + # which is already length-1, so we default to `True`. If any + # expression does not aggregate, then broadcasting will take + # place and the result will not be an aggregate. + return all(getattr(x, "_aggregates", True) for x in args) diff --git a/narwhals/expr.py b/narwhals/expr.py index 69c2e7dcc..e6d34f73c 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -9,6 +9,8 @@ from typing import Sequence from narwhals._expression_parsing import extract_compliant +from narwhals._expression_parsing import operation_aggregates +from narwhals._expression_parsing import operation_changes_length from narwhals._expression_parsing import operation_is_order_dependent from narwhals.dtypes import _validate_dtype from narwhals.expr_cat import ExprCatNamespace @@ -34,16 +36,23 @@ def __init__( self, to_compliant_expr: Callable[[Any], Any], is_order_dependent: bool, # noqa: FBT001 + changes_length: bool, # noqa: FBT001 + aggregates: bool, # noqa: FBT001 ) -> None: # callable from CompliantNamespace to CompliantExpr self._to_compliant_expr = to_compliant_expr self._is_order_dependent = is_order_dependent + self._changes_length = changes_length + self._aggregates = aggregates def _taxicab_norm(self) -> Self: # This is just used to test out the stable api feature in a realistic-ish way. # It's not intended to be used. return self.__class__( - lambda plx: self._to_compliant_expr(plx).abs().sum(), self._is_order_dependent + lambda plx: self._to_compliant_expr(plx).abs().sum(), + self._is_order_dependent, + self._changes_length, + self._aggregates, ) # --- convert --- @@ -103,6 +112,8 @@ def alias(self, name: str) -> Self: return self.__class__( lambda plx: self._to_compliant_expr(plx).alias(name), is_order_dependent=self._is_order_dependent, + changes_length=self._changes_length, + aggregates=self._aggregates, ) def pipe(self, function: Callable[[Any], Self], *args: Any, **kwargs: Any) -> Self: @@ -225,6 +236,8 @@ def cast(self: Self, dtype: DType | type[DType]) -> Self: return self.__class__( lambda plx: self._to_compliant_expr(plx).cast(dtype), is_order_dependent=self._is_order_dependent, + changes_length=self._changes_length, + aggregates=self._aggregates, ) # --- binary --- @@ -234,6 +247,8 @@ def __eq__(self, other: object) -> Self: # type: ignore[override] extract_compliant(plx, other) ), is_order_dependent=operation_is_order_dependent(self, other), + changes_length=operation_changes_length(self, other), + aggregates=operation_aggregates(self, other), ) def __ne__(self, other: object) -> Self: # type: ignore[override] @@ -242,6 +257,8 @@ def __ne__(self, other: object) -> Self: # type: ignore[override] extract_compliant(plx, other) ), is_order_dependent=operation_is_order_dependent(self, other), + changes_length=operation_changes_length(self, other), + aggregates=operation_aggregates(self, other), ) def __and__(self, other: Any) -> Self: @@ -250,6 +267,8 @@ def __and__(self, other: Any) -> Self: extract_compliant(plx, other) ), is_order_dependent=operation_is_order_dependent(self, other), + changes_length=operation_changes_length(self, other), + aggregates=operation_aggregates(self, other), ) def __rand__(self, other: Any) -> Self: @@ -261,6 +280,8 @@ def func(plx: CompliantNamespace[Any]) -> CompliantExpr[Any]: return self.__class__( func, is_order_dependent=operation_is_order_dependent(self, other), + changes_length=operation_changes_length(self, other), + aggregates=operation_aggregates(self, other), ) def __or__(self, other: Any) -> Self: @@ -269,6 +290,8 @@ def __or__(self, other: Any) -> Self: extract_compliant(plx, other) ), is_order_dependent=operation_is_order_dependent(self, other), + changes_length=operation_changes_length(self, other), + aggregates=operation_aggregates(self, other), ) def __ror__(self, other: Any) -> Self: @@ -280,6 +303,8 @@ def func(plx: CompliantNamespace[Any]) -> CompliantExpr[Any]: return self.__class__( func, is_order_dependent=operation_is_order_dependent(self, other), + changes_length=operation_changes_length(self, other), + aggregates=operation_aggregates(self, other), ) def __add__(self, other: Any) -> Self: @@ -288,6 +313,8 @@ def __add__(self, other: Any) -> Self: extract_compliant(plx, other) ), is_order_dependent=operation_is_order_dependent(self, other), + changes_length=operation_changes_length(self, other), + aggregates=operation_aggregates(self, other), ) def __radd__(self, other: Any) -> Self: @@ -299,6 +326,8 @@ def func(plx: CompliantNamespace[Any]) -> CompliantExpr[Any]: return self.__class__( func, is_order_dependent=operation_is_order_dependent(self, other), + changes_length=operation_changes_length(self, other), + aggregates=operation_aggregates(self, other), ) def __sub__(self, other: Any) -> Self: @@ -307,6 +336,8 @@ def __sub__(self, other: Any) -> Self: extract_compliant(plx, other) ), is_order_dependent=operation_is_order_dependent(self, other), + changes_length=operation_changes_length(self, other), + aggregates=operation_aggregates(self, other), ) def __rsub__(self, other: Any) -> Self: @@ -318,6 +349,8 @@ def func(plx: CompliantNamespace[Any]) -> CompliantExpr[Any]: return self.__class__( func, is_order_dependent=operation_is_order_dependent(self, other), + changes_length=operation_changes_length(self, other), + aggregates=operation_aggregates(self, other), ) def __truediv__(self, other: Any) -> Self: @@ -326,6 +359,8 @@ def __truediv__(self, other: Any) -> Self: extract_compliant(plx, other) ), is_order_dependent=operation_is_order_dependent(self, other), + changes_length=operation_changes_length(self, other), + aggregates=operation_aggregates(self, other), ) def __rtruediv__(self, other: Any) -> Self: @@ -337,6 +372,8 @@ def func(plx: CompliantNamespace[Any]) -> CompliantExpr[Any]: return self.__class__( func, is_order_dependent=operation_is_order_dependent(self, other), + changes_length=operation_changes_length(self, other), + aggregates=operation_aggregates(self, other), ) def __mul__(self, other: Any) -> Self: @@ -345,6 +382,8 @@ def __mul__(self, other: Any) -> Self: extract_compliant(plx, other) ), is_order_dependent=operation_is_order_dependent(self, other), + changes_length=operation_changes_length(self, other), + aggregates=operation_aggregates(self, other), ) def __rmul__(self, other: Any) -> Self: @@ -356,6 +395,8 @@ def func(plx: CompliantNamespace[Any]) -> CompliantExpr[Any]: return self.__class__( func, is_order_dependent=operation_is_order_dependent(self, other), + changes_length=operation_changes_length(self, other), + aggregates=operation_aggregates(self, other), ) def __le__(self, other: Any) -> Self: @@ -364,6 +405,8 @@ def __le__(self, other: Any) -> Self: extract_compliant(plx, other) ), is_order_dependent=operation_is_order_dependent(self, other), + changes_length=operation_changes_length(self, other), + aggregates=operation_aggregates(self, other), ) def __lt__(self, other: Any) -> Self: @@ -372,6 +415,8 @@ def __lt__(self, other: Any) -> Self: extract_compliant(plx, other) ), is_order_dependent=operation_is_order_dependent(self, other), + changes_length=operation_changes_length(self, other), + aggregates=operation_aggregates(self, other), ) def __gt__(self, other: Any) -> Self: @@ -380,6 +425,8 @@ def __gt__(self, other: Any) -> Self: extract_compliant(plx, other) ), is_order_dependent=operation_is_order_dependent(self, other), + changes_length=operation_changes_length(self, other), + aggregates=operation_aggregates(self, other), ) def __ge__(self, other: Any) -> Self: @@ -388,6 +435,8 @@ def __ge__(self, other: Any) -> Self: extract_compliant(plx, other) ), is_order_dependent=operation_is_order_dependent(self, other), + changes_length=operation_changes_length(self, other), + aggregates=operation_aggregates(self, other), ) def __pow__(self, other: Any) -> Self: @@ -396,6 +445,8 @@ def __pow__(self, other: Any) -> Self: extract_compliant(plx, other) ), is_order_dependent=operation_is_order_dependent(self, other), + changes_length=operation_changes_length(self, other), + aggregates=operation_aggregates(self, other), ) def __rpow__(self, other: Any) -> Self: @@ -407,6 +458,8 @@ def func(plx: CompliantNamespace[Any]) -> CompliantExpr[Any]: return self.__class__( func, is_order_dependent=operation_is_order_dependent(self, other), + changes_length=operation_changes_length(self, other), + aggregates=operation_aggregates(self, other), ) def __floordiv__(self, other: Any) -> Self: @@ -415,6 +468,8 @@ def __floordiv__(self, other: Any) -> Self: extract_compliant(plx, other) ), is_order_dependent=operation_is_order_dependent(self, other), + changes_length=operation_changes_length(self, other), + aggregates=operation_aggregates(self, other), ) def __rfloordiv__(self, other: Any) -> Self: @@ -426,6 +481,8 @@ def func(plx: CompliantNamespace[Any]) -> CompliantExpr[Any]: return self.__class__( func, is_order_dependent=operation_is_order_dependent(self, other), + changes_length=operation_changes_length(self, other), + aggregates=operation_aggregates(self, other), ) def __mod__(self, other: Any) -> Self: @@ -434,6 +491,8 @@ def __mod__(self, other: Any) -> Self: extract_compliant(plx, other) ), is_order_dependent=operation_is_order_dependent(self, other), + changes_length=operation_changes_length(self, other), + aggregates=operation_aggregates(self, other), ) def __rmod__(self, other: Any) -> Self: @@ -445,6 +504,8 @@ def func(plx: CompliantNamespace[Any]) -> CompliantExpr[Any]: return self.__class__( func, is_order_dependent=operation_is_order_dependent(self, other), + changes_length=operation_changes_length(self, other), + aggregates=operation_aggregates(self, other), ) # --- unary --- @@ -452,6 +513,8 @@ def __invert__(self) -> Self: return self.__class__( lambda plx: self._to_compliant_expr(plx).__invert__(), is_order_dependent=self._is_order_dependent, + changes_length=self._changes_length, + aggregates=self._aggregates, ) def any(self) -> Self: @@ -506,6 +569,8 @@ def any(self) -> Self: return self.__class__( lambda plx: self._to_compliant_expr(plx).any(), is_order_dependent=self._is_order_dependent, + changes_length=False, + aggregates=True, ) def all(self) -> Self: @@ -560,6 +625,8 @@ def all(self) -> Self: return self.__class__( lambda plx: self._to_compliant_expr(plx).all(), is_order_dependent=self._is_order_dependent, + changes_length=False, + aggregates=True, ) def ewm_mean( @@ -663,6 +730,8 @@ def ewm_mean( ignore_nulls=ignore_nulls, ), is_order_dependent=self._is_order_dependent, + changes_length=self._changes_length, + aggregates=self._aggregates, ) def mean(self) -> Self: @@ -717,6 +786,8 @@ def mean(self) -> Self: return self.__class__( lambda plx: self._to_compliant_expr(plx).mean(), is_order_dependent=self._is_order_dependent, + changes_length=False, + aggregates=True, ) def median(self) -> Self: @@ -774,6 +845,8 @@ def median(self) -> Self: return self.__class__( lambda plx: self._to_compliant_expr(plx).median(), is_order_dependent=self._is_order_dependent, + changes_length=False, + aggregates=True, ) def std(self, *, ddof: int = 1) -> Self: @@ -831,6 +904,8 @@ def std(self, *, ddof: int = 1) -> Self: return self.__class__( lambda plx: self._to_compliant_expr(plx).std(ddof=ddof), is_order_dependent=self._is_order_dependent, + changes_length=False, + aggregates=True, ) def var(self, *, ddof: int = 1) -> Self: @@ -889,6 +964,8 @@ def var(self, *, ddof: int = 1) -> Self: return self.__class__( lambda plx: self._to_compliant_expr(plx).var(ddof=ddof), is_order_dependent=self._is_order_dependent, + changes_length=False, + aggregates=True, ) def map_batches( @@ -964,7 +1041,10 @@ def map_batches( lambda plx: self._to_compliant_expr(plx).map_batches( function=function, return_dtype=return_dtype ), - is_order_dependent=True, # safest assumption + # safest assumptions + is_order_dependent=True, + changes_length=True, + aggregates=False, ) def skew(self: Self) -> Self: @@ -1019,6 +1099,8 @@ def skew(self: Self) -> Self: return self.__class__( lambda plx: self._to_compliant_expr(plx).skew(), is_order_dependent=self._is_order_dependent, + changes_length=False, + aggregates=True, ) def sum(self) -> Expr: @@ -1071,6 +1153,8 @@ def sum(self) -> Expr: return self.__class__( lambda plx: self._to_compliant_expr(plx).sum(), is_order_dependent=self._is_order_dependent, + changes_length=False, + aggregates=True, ) def min(self) -> Self: @@ -1125,6 +1209,8 @@ def min(self) -> Self: return self.__class__( lambda plx: self._to_compliant_expr(plx).min(), is_order_dependent=self._is_order_dependent, + changes_length=False, + aggregates=True, ) def max(self) -> Self: @@ -1179,6 +1265,8 @@ def max(self) -> Self: return self.__class__( lambda plx: self._to_compliant_expr(plx).max(), is_order_dependent=self._is_order_dependent, + changes_length=False, + aggregates=True, ) def arg_min(self) -> Self: @@ -1233,7 +1321,10 @@ def arg_min(self) -> Self: b_arg_min: [[1]] """ return self.__class__( - lambda plx: self._to_compliant_expr(plx).arg_min(), is_order_dependent=True + lambda plx: self._to_compliant_expr(plx).arg_min(), + is_order_dependent=True, + changes_length=False, + aggregates=True, ) def arg_max(self) -> Self: @@ -1288,7 +1379,10 @@ def arg_max(self) -> Self: b_arg_max: [[0]] """ return self.__class__( - lambda plx: self._to_compliant_expr(plx).arg_max(), is_order_dependent=True + lambda plx: self._to_compliant_expr(plx).arg_max(), + is_order_dependent=True, + changes_length=False, + aggregates=True, ) def count(self) -> Self: @@ -1341,7 +1435,10 @@ def count(self) -> Self: b: [[2]] """ return self.__class__( - lambda plx: self._to_compliant_expr(plx).count(), self._is_order_dependent + lambda plx: self._to_compliant_expr(plx).count(), + self._is_order_dependent, + changes_length=False, + aggregates=True, ) def n_unique(self) -> Self: @@ -1392,7 +1489,10 @@ def n_unique(self) -> Self: b: [[3]] """ return self.__class__( - lambda plx: self._to_compliant_expr(plx).n_unique(), self._is_order_dependent + lambda plx: self._to_compliant_expr(plx).n_unique(), + self._is_order_dependent, + changes_length=False, + aggregates=True, ) def unique(self, *, maintain_order: bool = False) -> Self: @@ -1458,6 +1558,8 @@ def unique(self, *, maintain_order: bool = False) -> Self: maintain_order=maintain_order ), self._is_order_dependent, + changes_length=True, + aggregates=self._aggregates, ) def abs(self) -> Self: @@ -1512,7 +1614,10 @@ def abs(self) -> Self: b: [[3,4]] """ return self.__class__( - lambda plx: self._to_compliant_expr(plx).abs(), self._is_order_dependent + lambda plx: self._to_compliant_expr(plx).abs(), + self._is_order_dependent, + changes_length=self._changes_length, + aggregates=self._aggregates, ) def cum_sum(self: Self, *, reverse: bool = False) -> Self: @@ -1576,6 +1681,8 @@ def cum_sum(self: Self, *, reverse: bool = False) -> Self: return self.__class__( lambda plx: self._to_compliant_expr(plx).cum_sum(reverse=reverse), is_order_dependent=True, + changes_length=self._changes_length, + aggregates=self._aggregates, ) def diff(self) -> Self: @@ -1643,7 +1750,10 @@ def diff(self) -> Self: a_diff: [[null,0,2,2,0]] """ return self.__class__( - lambda plx: self._to_compliant_expr(plx).diff(), is_order_dependent=True + lambda plx: self._to_compliant_expr(plx).diff(), + is_order_dependent=True, + changes_length=self._changes_length, + aggregates=self._aggregates, ) def shift(self, n: int) -> Self: @@ -1714,7 +1824,10 @@ def shift(self, n: int) -> Self: a_shift: [[null,1,1,3,5]] """ return self.__class__( - lambda plx: self._to_compliant_expr(plx).shift(n), is_order_dependent=True + lambda plx: self._to_compliant_expr(plx).shift(n), + is_order_dependent=True, + changes_length=self._changes_length, + aggregates=self._aggregates, ) def replace_strict( @@ -1808,6 +1921,8 @@ def replace_strict( old, new, return_dtype=return_dtype ), self._is_order_dependent, + changes_length=self._changes_length, + aggregates=self._aggregates, ) def sort(self, *, descending: bool = False, nulls_last: bool = False) -> Self: @@ -1839,6 +1954,8 @@ def sort(self, *, descending: bool = False, nulls_last: bool = False) -> Self: descending=descending, nulls_last=nulls_last ), is_order_dependent=True, + changes_length=self._changes_length, + aggregates=self._aggregates, ) # --- transform --- @@ -1916,6 +2033,8 @@ def is_between( is_order_dependent=operation_is_order_dependent( self, lower_bound, upper_bound ), + changes_length=self._changes_length, + aggregates=self._aggregates, ) def is_in(self, other: Any) -> Self: @@ -1982,6 +2101,8 @@ def is_in(self, other: Any) -> Self: extract_compliant(plx, other) ), self._is_order_dependent, + changes_length=self._changes_length, + aggregates=self._aggregates, ) else: msg = "Narwhals `is_in` doesn't accept expressions as an argument, as opposed to Polars. You should provide an iterable instead." @@ -2052,6 +2173,8 @@ def filter(self, *predicates: Any) -> Self: *[extract_compliant(plx, pred) for pred in flat_predicates], ), is_order_dependent=operation_is_order_dependent(*flat_predicates), + changes_length=True, + aggregates=self._aggregates, ) def is_null(self) -> Self: @@ -2131,7 +2254,10 @@ def is_null(self) -> Self: b_is_null: [[false,false,true,false,false]] """ return self.__class__( - lambda plx: self._to_compliant_expr(plx).is_null(), self._is_order_dependent + lambda plx: self._to_compliant_expr(plx).is_null(), + self._is_order_dependent, + changes_length=self._changes_length, + aggregates=self._aggregates, ) def is_nan(self) -> Self: @@ -2198,7 +2324,10 @@ def is_nan(self) -> Self: divided_is_nan: [[true,null,false]] """ return self.__class__( - lambda plx: self._to_compliant_expr(plx).is_nan(), self._is_order_dependent + lambda plx: self._to_compliant_expr(plx).is_nan(), + self._is_order_dependent, + changes_length=self._changes_length, + aggregates=self._aggregates, ) def arg_true(self) -> Self: @@ -2251,7 +2380,10 @@ def arg_true(self) -> Self: a: [[1,2]] """ return self.__class__( - lambda plx: self._to_compliant_expr(plx).arg_true(), is_order_dependent=True + lambda plx: self._to_compliant_expr(plx).arg_true(), + is_order_dependent=True, + changes_length=True, + aggregates=self._aggregates, ) def fill_null( @@ -2396,6 +2528,8 @@ def fill_null( value=value, strategy=strategy, limit=limit ), self._is_order_dependent, + changes_length=self._changes_length, + aggregates=self._aggregates, ) # --- partial reduction --- @@ -2459,6 +2593,8 @@ def drop_nulls(self) -> Self: return self.__class__( lambda plx: self._to_compliant_expr(plx).drop_nulls(), self._is_order_dependent, + changes_length=True, + aggregates=self._aggregates, ) def sample( @@ -2500,6 +2636,8 @@ def sample( n, fraction=fraction, with_replacement=with_replacement, seed=seed ), self._is_order_dependent, + changes_length=True, + aggregates=self._aggregates, ) def over(self, *keys: str | Iterable[str]) -> Self: @@ -2592,6 +2730,8 @@ def over(self, *keys: str | Iterable[str]) -> Self: return self.__class__( lambda plx: self._to_compliant_expr(plx).over(flatten(keys)), self._is_order_dependent, + changes_length=self._changes_length, + aggregates=self._aggregates, ) def is_duplicated(self) -> Self: @@ -2652,6 +2792,8 @@ def is_duplicated(self) -> Self: return self.__class__( lambda plx: self._to_compliant_expr(plx).is_duplicated(), self._is_order_dependent, + changes_length=self._changes_length, + aggregates=self._aggregates, ) def is_unique(self) -> Self: @@ -2710,7 +2852,10 @@ def is_unique(self) -> Self: b: [[false,false,true,true]] """ return self.__class__( - lambda plx: self._to_compliant_expr(plx).is_unique(), self._is_order_dependent + lambda plx: self._to_compliant_expr(plx).is_unique(), + self._is_order_dependent, + changes_length=self._changes_length, + aggregates=self._aggregates, ) def null_count(self) -> Self: @@ -2770,6 +2915,8 @@ def null_count(self) -> Self: return self.__class__( lambda plx: self._to_compliant_expr(plx).null_count(), self._is_order_dependent, + changes_length=self._changes_length, + aggregates=self._aggregates, ) def is_first_distinct(self) -> Self: @@ -2830,6 +2977,8 @@ def is_first_distinct(self) -> Self: return self.__class__( lambda plx: self._to_compliant_expr(plx).is_first_distinct(), is_order_dependent=True, + changes_length=self._changes_length, + aggregates=self._aggregates, ) def is_last_distinct(self) -> Self: @@ -2890,6 +3039,8 @@ def is_last_distinct(self) -> Self: return self.__class__( lambda plx: self._to_compliant_expr(plx).is_last_distinct(), is_order_dependent=True, + changes_length=self._changes_length, + aggregates=self._aggregates, ) def quantile( @@ -2961,6 +3112,8 @@ def quantile( return self.__class__( lambda plx: self._to_compliant_expr(plx).quantile(quantile, interpolation), self._is_order_dependent, + changes_length=False, + aggregates=True, ) def head(self, n: int = 10) -> Self: @@ -2987,7 +3140,10 @@ def head(self, n: int = 10) -> Self: ) issue_deprecation_warning(msg, _version="1.22.0") return self.__class__( - lambda plx: self._to_compliant_expr(plx).head(n), is_order_dependent=True + lambda plx: self._to_compliant_expr(plx).head(n), + is_order_dependent=True, + changes_length=True, + aggregates=self._aggregates, ) def tail(self, n: int = 10) -> Self: @@ -3014,7 +3170,10 @@ def tail(self, n: int = 10) -> Self: ) issue_deprecation_warning(msg, _version="1.22.0") return self.__class__( - lambda plx: self._to_compliant_expr(plx).tail(n), is_order_dependent=True + lambda plx: self._to_compliant_expr(plx).tail(n), + is_order_dependent=True, + changes_length=True, + aggregates=self._aggregates, ) def round(self, decimals: int = 0) -> Self: @@ -3083,6 +3242,8 @@ def round(self, decimals: int = 0) -> Self: return self.__class__( lambda plx: self._to_compliant_expr(plx).round(decimals), self._is_order_dependent, + changes_length=self._changes_length, + aggregates=self._aggregates, ) def len(self) -> Self: @@ -3141,7 +3302,10 @@ def len(self) -> Self: a2: [[1]] """ return self.__class__( - lambda plx: self._to_compliant_expr(plx).len(), self._is_order_dependent + lambda plx: self._to_compliant_expr(plx).len(), + self._is_order_dependent, + changes_length=False, + aggregates=True, ) def gather_every(self: Self, n: int, offset: int = 0) -> Self: @@ -3171,6 +3335,8 @@ def gather_every(self: Self, n: int, offset: int = 0) -> Self: return self.__class__( lambda plx: self._to_compliant_expr(plx).gather_every(n=n, offset=offset), is_order_dependent=True, + changes_length=True, + aggregates=self._aggregates, ) # need to allow numeric typing @@ -3321,6 +3487,8 @@ def clip( is_order_dependent=operation_is_order_dependent( self, lower_bound, upper_bound ), + changes_length=self._changes_length, + aggregates=self._aggregates, ) def mode(self: Self) -> Self: @@ -3376,7 +3544,10 @@ def mode(self: Self) -> Self: a: [[1]] """ return self.__class__( - lambda plx: self._to_compliant_expr(plx).mode(), self._is_order_dependent + lambda plx: self._to_compliant_expr(plx).mode(), + self._is_order_dependent, + changes_length=True, + aggregates=self._aggregates, ) def is_finite(self: Self) -> Self: @@ -3438,7 +3609,10 @@ def is_finite(self: Self) -> Self: a: [[false,false,true,null]] """ return self.__class__( - lambda plx: self._to_compliant_expr(plx).is_finite(), self._is_order_dependent + lambda plx: self._to_compliant_expr(plx).is_finite(), + self._is_order_dependent, + changes_length=self._changes_length, + aggregates=self._aggregates, ) def cum_count(self: Self, *, reverse: bool = False) -> Self: @@ -3507,6 +3681,8 @@ def cum_count(self: Self, *, reverse: bool = False) -> Self: return self.__class__( lambda plx: self._to_compliant_expr(plx).cum_count(reverse=reverse), is_order_dependent=True, + changes_length=self._changes_length, + aggregates=self._aggregates, ) def cum_min(self: Self, *, reverse: bool = False) -> Self: @@ -3575,6 +3751,8 @@ def cum_min(self: Self, *, reverse: bool = False) -> Self: return self.__class__( lambda plx: self._to_compliant_expr(plx).cum_min(reverse=reverse), is_order_dependent=True, + changes_length=self._changes_length, + aggregates=self._aggregates, ) def cum_max(self: Self, *, reverse: bool = False) -> Self: @@ -3643,6 +3821,8 @@ def cum_max(self: Self, *, reverse: bool = False) -> Self: return self.__class__( lambda plx: self._to_compliant_expr(plx).cum_max(reverse=reverse), is_order_dependent=True, + changes_length=self._changes_length, + aggregates=self._aggregates, ) def cum_prod(self: Self, *, reverse: bool = False) -> Self: @@ -3711,6 +3891,8 @@ def cum_prod(self: Self, *, reverse: bool = False) -> Self: return self.__class__( lambda plx: self._to_compliant_expr(plx).cum_prod(reverse=reverse), is_order_dependent=True, + changes_length=self._changes_length, + aggregates=self._aggregates, ) def rolling_sum( @@ -3806,6 +3988,8 @@ def rolling_sum( center=center, ), is_order_dependent=True, + changes_length=self._changes_length, + aggregates=self._aggregates, ) def rolling_mean( @@ -3901,6 +4085,8 @@ def rolling_mean( center=center, ), is_order_dependent=True, + changes_length=self._changes_length, + aggregates=self._aggregates, ) def rolling_var( @@ -3996,6 +4182,8 @@ def rolling_var( window_size=window_size, min_periods=min_periods, center=center, ddof=ddof ), is_order_dependent=True, + changes_length=self._changes_length, + aggregates=self._aggregates, ) def rolling_std( @@ -4094,6 +4282,8 @@ def rolling_std( ddof=ddof, ), is_order_dependent=True, + changes_length=self._changes_length, + aggregates=self._aggregates, ) def rank( @@ -4192,6 +4382,8 @@ def rank( method=method, descending=descending ), is_order_dependent=True, + changes_length=self._changes_length, + aggregates=self._aggregates, ) @property diff --git a/narwhals/expr_cat.py b/narwhals/expr_cat.py index baf467df3..80b047a2e 100644 --- a/narwhals/expr_cat.py +++ b/narwhals/expr_cat.py @@ -64,4 +64,6 @@ def get_categories(self: Self) -> ExprT: return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).cat.get_categories(), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) diff --git a/narwhals/expr_dt.py b/narwhals/expr_dt.py index 6b981315d..6ea1fbbdd 100644 --- a/narwhals/expr_dt.py +++ b/narwhals/expr_dt.py @@ -73,6 +73,8 @@ def date(self: Self) -> ExprT: return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).dt.date(), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def year(self: Self) -> ExprT: @@ -142,6 +144,8 @@ def year(self: Self) -> ExprT: return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).dt.year(), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def month(self: Self) -> ExprT: @@ -211,6 +215,8 @@ def month(self: Self) -> ExprT: return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).dt.month(), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def day(self: Self) -> ExprT: @@ -280,6 +286,8 @@ def day(self: Self) -> ExprT: return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).dt.day(), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def hour(self: Self) -> ExprT: @@ -349,6 +357,8 @@ def hour(self: Self) -> ExprT: return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).dt.hour(), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def minute(self: Self) -> ExprT: @@ -418,6 +428,8 @@ def minute(self: Self) -> ExprT: return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).dt.minute(), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def second(self: Self) -> ExprT: @@ -485,6 +497,8 @@ def second(self: Self) -> ExprT: return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).dt.second(), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def millisecond(self: Self) -> ExprT: @@ -552,6 +566,8 @@ def millisecond(self: Self) -> ExprT: return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).dt.millisecond(), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def microsecond(self: Self) -> ExprT: @@ -619,6 +635,8 @@ def microsecond(self: Self) -> ExprT: return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).dt.microsecond(), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def nanosecond(self: Self) -> ExprT: @@ -686,6 +704,8 @@ def nanosecond(self: Self) -> ExprT: return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).dt.nanosecond(), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def ordinal_day(self: Self) -> ExprT: @@ -745,6 +765,8 @@ def ordinal_day(self: Self) -> ExprT: return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).dt.ordinal_day(), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def weekday(self: Self) -> ExprT: @@ -802,6 +824,8 @@ def weekday(self: Self) -> ExprT: return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).dt.weekday(), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def total_minutes(self: Self) -> ExprT: @@ -866,6 +890,8 @@ def total_minutes(self: Self) -> ExprT: return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).dt.total_minutes(), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def total_seconds(self: Self) -> ExprT: @@ -930,6 +956,8 @@ def total_seconds(self: Self) -> ExprT: return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).dt.total_seconds(), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def total_milliseconds(self: Self) -> ExprT: @@ -999,6 +1027,8 @@ def total_milliseconds(self: Self) -> ExprT: return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).dt.total_milliseconds(), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def total_microseconds(self: Self) -> ExprT: @@ -1068,6 +1098,8 @@ def total_microseconds(self: Self) -> ExprT: return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).dt.total_microseconds(), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def total_nanoseconds(self: Self) -> ExprT: @@ -1124,6 +1156,8 @@ def total_nanoseconds(self: Self) -> ExprT: return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).dt.total_nanoseconds(), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def to_string(self: Self, format: str) -> ExprT: # noqa: A002 @@ -1223,6 +1257,8 @@ def to_string(self: Self, format: str) -> ExprT: # noqa: A002 return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).dt.to_string(format), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def replace_time_zone(self: Self, time_zone: str | None) -> ExprT: @@ -1290,6 +1326,8 @@ def replace_time_zone(self: Self, time_zone: str | None) -> ExprT: time_zone ), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def convert_time_zone(self: Self, time_zone: str) -> ExprT: @@ -1363,6 +1401,8 @@ def convert_time_zone(self: Self, time_zone: str) -> ExprT: time_zone ), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def timestamp(self: Self, time_unit: Literal["ns", "us", "ms"] = "us") -> ExprT: @@ -1437,4 +1477,6 @@ def timestamp(self: Self, time_unit: Literal["ns", "us", "ms"] = "us") -> ExprT: return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).dt.timestamp(time_unit), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) diff --git a/narwhals/expr_list.py b/narwhals/expr_list.py index c64defda8..0532db5fe 100644 --- a/narwhals/expr_list.py +++ b/narwhals/expr_list.py @@ -75,4 +75,6 @@ def len(self: Self) -> ExprT: return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).list.len(), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) diff --git a/narwhals/expr_name.py b/narwhals/expr_name.py index 0d428cea1..706f9427d 100644 --- a/narwhals/expr_name.py +++ b/narwhals/expr_name.py @@ -61,6 +61,8 @@ def keep(self: Self) -> ExprT: return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).name.keep(), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def map(self: Self, function: Callable[[str], str]) -> ExprT: @@ -111,6 +113,8 @@ def map(self: Self, function: Callable[[str], str]) -> ExprT: return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).name.map(function), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def prefix(self: Self, prefix: str) -> ExprT: @@ -160,6 +164,8 @@ def prefix(self: Self, prefix: str) -> ExprT: return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).name.prefix(prefix), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def suffix(self: Self, suffix: str) -> ExprT: @@ -209,6 +215,8 @@ def suffix(self: Self, suffix: str) -> ExprT: return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).name.suffix(suffix), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def to_lowercase(self: Self) -> ExprT: @@ -255,6 +263,8 @@ def to_lowercase(self: Self) -> ExprT: return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).name.to_lowercase(), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def to_uppercase(self: Self) -> ExprT: @@ -301,4 +311,6 @@ def to_uppercase(self: Self) -> ExprT: return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).name.to_uppercase(), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) diff --git a/narwhals/expr_str.py b/narwhals/expr_str.py index 0ea89ceb4..90283930a 100644 --- a/narwhals/expr_str.py +++ b/narwhals/expr_str.py @@ -78,6 +78,8 @@ def len_chars(self: Self) -> ExprT: return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).str.len_chars(), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def replace( @@ -145,6 +147,8 @@ def replace( pattern, value, literal=literal, n=n ), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def replace_all( @@ -211,6 +215,8 @@ def replace_all( pattern, value, literal=literal ), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def strip_chars(self: Self, characters: str | None = None) -> ExprT: @@ -260,6 +266,8 @@ def strip_chars(self: Self, characters: str | None = None) -> ExprT: return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).str.strip_chars(characters), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def starts_with(self: Self, prefix: str) -> ExprT: @@ -323,6 +331,8 @@ def starts_with(self: Self, prefix: str) -> ExprT: return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).str.starts_with(prefix), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def ends_with(self: Self, suffix: str) -> ExprT: @@ -386,6 +396,8 @@ def ends_with(self: Self, suffix: str) -> ExprT: return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).str.ends_with(suffix), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def contains(self: Self, pattern: str, *, literal: bool = False) -> ExprT: @@ -465,6 +477,8 @@ def contains(self: Self, pattern: str, *, literal: bool = False) -> ExprT: pattern, literal=literal ), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def slice(self: Self, offset: int, length: int | None = None) -> ExprT: @@ -568,6 +582,8 @@ def slice(self: Self, offset: int, length: int | None = None) -> ExprT: offset=offset, length=length ), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def head(self: Self, n: int = 5) -> ExprT: @@ -636,6 +652,8 @@ def head(self: Self, n: int = 5) -> ExprT: return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).str.slice(0, n), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def tail(self: Self, n: int = 5) -> ExprT: @@ -706,6 +724,8 @@ def tail(self: Self, n: int = 5) -> ExprT: offset=-n, length=None ), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def to_datetime(self: Self, format: str | None = None) -> ExprT: # noqa: A002 @@ -776,6 +796,8 @@ def to_datetime(self: Self, format: str | None = None) -> ExprT: # noqa: A002 return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).str.to_datetime(format=format), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def to_uppercase(self: Self) -> ExprT: @@ -841,6 +863,8 @@ def to_uppercase(self: Self) -> ExprT: return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).str.to_uppercase(), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) def to_lowercase(self: Self) -> ExprT: @@ -901,4 +925,6 @@ def to_lowercase(self: Self) -> ExprT: return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).str.to_lowercase(), self._expr._is_order_dependent, + changes_length=self._expr._changes_length, + aggregates=self._expr._aggregates, ) diff --git a/narwhals/functions.py b/narwhals/functions.py index 9b3a0609a..1845edf38 100644 --- a/narwhals/functions.py +++ b/narwhals/functions.py @@ -13,6 +13,8 @@ from typing import overload from narwhals._expression_parsing import extract_compliant +from narwhals._expression_parsing import operation_aggregates +from narwhals._expression_parsing import operation_changes_length from narwhals._expression_parsing import operation_is_order_dependent from narwhals._pandas_like.utils import broadcast_align_and_extract_native from narwhals.dataframe import DataFrame @@ -1393,7 +1395,7 @@ def col(*names: str | Iterable[str]) -> Expr: def func(plx: Any) -> Any: return plx.col(*flatten(names)) - return Expr(func, is_order_dependent=False) + return Expr(func, is_order_dependent=False, changes_length=False, aggregates=False) def nth(*indices: int | Sequence[int]) -> Expr: @@ -1455,7 +1457,7 @@ def nth(*indices: int | Sequence[int]) -> Expr: def func(plx: Any) -> Any: return plx.nth(*flatten(indices)) - return Expr(func, is_order_dependent=False) + return Expr(func, is_order_dependent=False, changes_length=False, aggregates=False) # Add underscore so it doesn't conflict with builtin `all` @@ -1512,7 +1514,12 @@ def all_() -> Expr: a: [[2,4,6]] b: [[8,10,12]] """ - return Expr(lambda plx: plx.all(), is_order_dependent=False) + return Expr( + lambda plx: plx.all(), + is_order_dependent=False, + changes_length=False, + aggregates=False, + ) # Add underscore so it doesn't conflict with builtin `len` @@ -1565,7 +1572,7 @@ def len_() -> Expr: def func(plx: Any) -> Any: return plx.len() - return Expr(func, is_order_dependent=False) + return Expr(func, is_order_dependent=False, changes_length=False, aggregates=True) def sum(*columns: str) -> Expr: @@ -1621,7 +1628,12 @@ def sum(*columns: str) -> Expr: ---- a: [[3]] """ - return Expr(lambda plx: plx.col(*columns).sum(), is_order_dependent=False) + return Expr( + lambda plx: plx.col(*columns).sum(), + is_order_dependent=False, + changes_length=False, + aggregates=True, + ) def mean(*columns: str) -> Expr: @@ -1677,7 +1689,12 @@ def mean(*columns: str) -> Expr: ---- a: [[4]] """ - return Expr(lambda plx: plx.col(*columns).mean(), is_order_dependent=False) + return Expr( + lambda plx: plx.col(*columns).mean(), + is_order_dependent=False, + changes_length=False, + aggregates=True, + ) def median(*columns: str) -> Expr: @@ -1735,7 +1752,12 @@ def median(*columns: str) -> Expr: ---- a: [[4]] """ - return Expr(lambda plx: plx.col(*columns).median(), is_order_dependent=False) + return Expr( + lambda plx: plx.col(*columns).median(), + is_order_dependent=False, + changes_length=False, + aggregates=True, + ) def min(*columns: str) -> Expr: @@ -1791,7 +1813,12 @@ def min(*columns: str) -> Expr: ---- b: [[5]] """ - return Expr(lambda plx: plx.col(*columns).min(), is_order_dependent=False) + return Expr( + lambda plx: plx.col(*columns).min(), + is_order_dependent=False, + changes_length=False, + aggregates=True, + ) def max(*columns: str) -> Expr: @@ -1847,7 +1874,12 @@ def max(*columns: str) -> Expr: ---- a: [[2]] """ - return Expr(lambda plx: plx.col(*columns).max(), is_order_dependent=False) + return Expr( + lambda plx: plx.col(*columns).max(), + is_order_dependent=False, + changes_length=False, + aggregates=True, + ) def sum_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: @@ -1914,6 +1946,8 @@ def sum_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: return Expr( lambda plx: plx.sum_horizontal(*[extract_compliant(plx, v) for v in flat_exprs]), is_order_dependent=operation_is_order_dependent(*flat_exprs), + changes_length=operation_changes_length(*flat_exprs), + aggregates=operation_aggregates(*flat_exprs), ) @@ -1984,6 +2018,8 @@ def min_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: return Expr( lambda plx: plx.min_horizontal(*[extract_compliant(plx, v) for v in flat_exprs]), is_order_dependent=operation_is_order_dependent(*flat_exprs), + changes_length=operation_changes_length(*flat_exprs), + aggregates=operation_aggregates(*flat_exprs), ) @@ -2054,6 +2090,8 @@ def max_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: return Expr( lambda plx: plx.max_horizontal(*[extract_compliant(plx, v) for v in flat_exprs]), is_order_dependent=operation_is_order_dependent(*flat_exprs), + changes_length=operation_changes_length(*flat_exprs), + aggregates=operation_aggregates(*flat_exprs), ) @@ -2073,6 +2111,8 @@ def then(self, value: IntoExpr | Any) -> Then: extract_compliant(plx, value) ), is_order_dependent=operation_is_order_dependent(*self._predicates, value), + changes_length=operation_changes_length(*self._predicates, value), + aggregates=operation_aggregates(*self._predicates, value), ) @@ -2083,6 +2123,8 @@ def otherwise(self, value: IntoExpr | Any) -> Expr: extract_compliant(plx, value) ), is_order_dependent=operation_is_order_dependent(self, value), + changes_length=operation_changes_length(self, value), + aggregates=operation_aggregates(self, value), ) @@ -2236,6 +2278,8 @@ def all_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: return Expr( lambda plx: plx.all_horizontal(*[extract_compliant(plx, v) for v in flat_exprs]), is_order_dependent=operation_is_order_dependent(*flat_exprs), + changes_length=operation_changes_length(*flat_exprs), + aggregates=operation_aggregates(*flat_exprs), ) @@ -2306,7 +2350,12 @@ def lit(value: Any, dtype: DType | type[DType] | None = None) -> Expr: msg = f"Nested datatypes are not supported yet. Got {value}" raise NotImplementedError(msg) - return Expr(lambda plx: plx.lit(value, dtype), is_order_dependent=False) + return Expr( + lambda plx: plx.lit(value, dtype), + is_order_dependent=False, + changes_length=False, + aggregates=True, + ) def any_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: @@ -2384,6 +2433,8 @@ def any_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: return Expr( lambda plx: plx.any_horizontal(*[extract_compliant(plx, v) for v in flat_exprs]), is_order_dependent=operation_is_order_dependent(*flat_exprs), + changes_length=operation_changes_length(*flat_exprs), + aggregates=operation_aggregates(*flat_exprs), ) @@ -2454,6 +2505,8 @@ def mean_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: return Expr( lambda plx: plx.mean_horizontal(*[extract_compliant(plx, v) for v in flat_exprs]), is_order_dependent=operation_is_order_dependent(*flat_exprs), + changes_length=operation_changes_length(*flat_exprs), + aggregates=operation_aggregates(*flat_exprs), ) @@ -2544,4 +2597,6 @@ def concat_str( ignore_nulls=ignore_nulls, ), is_order_dependent=operation_is_order_dependent(*flat_exprs, *more_exprs), + changes_length=operation_changes_length(*flat_exprs, *more_exprs), + aggregates=operation_aggregates(*flat_exprs, *more_exprs), ) diff --git a/narwhals/selectors.py b/narwhals/selectors.py index 664fdc9ca..e67424281 100644 --- a/narwhals/selectors.py +++ b/narwhals/selectors.py @@ -53,7 +53,10 @@ def by_dtype(*dtypes: Any) -> Expr: └─────┴─────┘ """ return Selector( - lambda plx: plx.selectors.by_dtype(flatten(dtypes)), is_order_dependent=False + lambda plx: plx.selectors.by_dtype(flatten(dtypes)), + is_order_dependent=False, + changes_length=False, + aggregates=False, ) @@ -97,7 +100,12 @@ def numeric() -> Expr: │ 4 ┆ 4.6 │ └─────┴─────┘ """ - return Selector(lambda plx: plx.selectors.numeric(), is_order_dependent=False) + return Selector( + lambda plx: plx.selectors.numeric(), + is_order_dependent=False, + changes_length=False, + aggregates=False, + ) def boolean() -> Expr: @@ -140,7 +148,12 @@ def boolean() -> Expr: │ true │ └───────┘ """ - return Selector(lambda plx: plx.selectors.boolean(), is_order_dependent=False) + return Selector( + lambda plx: plx.selectors.boolean(), + is_order_dependent=False, + changes_length=False, + aggregates=False, + ) def string() -> Expr: @@ -183,7 +196,12 @@ def string() -> Expr: │ y │ └─────┘ """ - return Selector(lambda plx: plx.selectors.string(), is_order_dependent=False) + return Selector( + lambda plx: plx.selectors.string(), + is_order_dependent=False, + changes_length=False, + aggregates=False, + ) def categorical() -> Expr: @@ -226,7 +244,12 @@ def categorical() -> Expr: │ y │ └─────┘ """ - return Selector(lambda plx: plx.selectors.categorical(), is_order_dependent=False) + return Selector( + lambda plx: plx.selectors.categorical(), + is_order_dependent=False, + changes_length=False, + aggregates=False, + ) def all() -> Expr: @@ -269,7 +292,12 @@ def all() -> Expr: │ 2 ┆ y ┆ true │ └─────┴─────┴───────┘ """ - return Selector(lambda plx: plx.selectors.all(), is_order_dependent=False) + return Selector( + lambda plx: plx.selectors.all(), + is_order_dependent=False, + changes_length=False, + aggregates=False, + ) __all__ = [ diff --git a/narwhals/stable/v1/__init__.py b/narwhals/stable/v1/__init__.py index 941f16377..b37171380 100644 --- a/narwhals/stable/v1/__init__.py +++ b/narwhals/stable/v1/__init__.py @@ -872,7 +872,10 @@ def head(self, n: int = 10) -> Self: A new expression. """ return self.__class__( - lambda plx: self._to_compliant_expr(plx).head(n), is_order_dependent=True + lambda plx: self._to_compliant_expr(plx).head(n), + is_order_dependent=True, + changes_length=True, + aggregates=self._aggregates, ) def tail(self, n: int = 10) -> Self: @@ -885,7 +888,10 @@ def tail(self, n: int = 10) -> Self: A new expression. """ return self.__class__( - lambda plx: self._to_compliant_expr(plx).tail(n), is_order_dependent=True + lambda plx: self._to_compliant_expr(plx).tail(n), + is_order_dependent=True, + changes_length=True, + aggregates=self._aggregates, ) def gather_every(self: Self, n: int, offset: int = 0) -> Self: @@ -901,6 +907,8 @@ def gather_every(self: Self, n: int, offset: int = 0) -> Self: return self.__class__( lambda plx: self._to_compliant_expr(plx).gather_every(n=n, offset=offset), is_order_dependent=True, + changes_length=True, + aggregates=self._aggregates, ) def sort(self, *, descending: bool = False, nulls_last: bool = False) -> Self: @@ -918,6 +926,8 @@ def sort(self, *, descending: bool = False, nulls_last: bool = False) -> Self: descending=descending, nulls_last=nulls_last ), is_order_dependent=True, + changes_length=self._changes_length, + aggregates=self._aggregates, ) def sample( @@ -952,6 +962,8 @@ def sample( n, fraction=fraction, with_replacement=with_replacement, seed=seed ), is_order_dependent=True, + changes_length=True, + aggregates=self._aggregates, ) @@ -996,7 +1008,12 @@ def _stableify( level=obj._level, ) if isinstance(obj, NwExpr): - return Expr(obj._to_compliant_expr, is_order_dependent=obj._is_order_dependent) + return Expr( + obj._to_compliant_expr, + is_order_dependent=obj._is_order_dependent, + changes_length=obj._changes_length, + aggregates=obj._aggregates, + ) return obj @@ -1954,7 +1971,12 @@ def then(self, value: Any) -> Then: class Then(NwThen, Expr): @classmethod def from_then(cls, then: NwThen) -> Self: - return cls(then._to_compliant_expr, is_order_dependent=then._is_order_dependent) + return cls( + then._to_compliant_expr, + is_order_dependent=then._is_order_dependent, + changes_length=then._changes_length, + aggregates=then._aggregates, + ) def otherwise(self, value: Any) -> Expr: return _stableify(super().otherwise(value)) diff --git a/utils/check_api_reference.py b/utils/check_api_reference.py index 8df16e88b..e0ebf97a9 100644 --- a/utils/check_api_reference.py +++ b/utils/check_api_reference.py @@ -161,7 +161,9 @@ # Expr methods expr_methods = [ i - for i in nw.Expr(lambda: 0, is_order_dependent=False).__dir__() + for i in nw.Expr( + lambda: 0, is_order_dependent=False, changes_length=False, aggregates=False + ).__dir__() if not i[0].isupper() and i[0] != "_" ] with open("docs/api-reference/expr.md") as fd: @@ -185,7 +187,13 @@ expr_methods = [ i for i in getattr( - nw.Expr(lambda: 0, is_order_dependent=False), namespace + nw.Expr( + lambda: 0, + is_order_dependent=False, + changes_length=False, + aggregates=False, + ), + namespace, ).__dir__() if not i[0].isupper() and i[0] != "_" ] @@ -228,7 +236,9 @@ # Check Expr vs Series expr = [ i - for i in nw.Expr(lambda: 0, is_order_dependent=False).__dir__() + for i in nw.Expr( + lambda: 0, is_order_dependent=False, changes_length=False, aggregates=False + ).__dir__() if not i[0].isupper() and i[0] != "_" ] series = [ @@ -250,7 +260,13 @@ expr_internal = [ i for i in getattr( - nw.Expr(lambda: 0, is_order_dependent=False), namespace + nw.Expr( + lambda: 0, + is_order_dependent=False, + changes_length=False, + aggregates=False, + ), + namespace, ).__dir__() if not i[0].isupper() and i[0] != "_" ] From bba0c775e2e2d4f6e61d864657abe2a1d7e31e1d Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sun, 19 Jan 2025 13:15:07 +0400 Subject: [PATCH 02/10] fixup --- narwhals/_expression_parsing.py | 11 ++++++----- narwhals/dataframe.py | 11 +++++++++++ narwhals/exceptions.py | 8 ++++++++ narwhals/expr.py | 8 ++++++++ narwhals/stable/v1/__init__.py | 4 ++-- 5 files changed, 35 insertions(+), 7 deletions(-) diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index fed77fe44..b536dfe19 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -363,17 +363,18 @@ def operation_changes_length(*args: IntoExpr | Any) -> bool: """ from narwhals.expr import Expr - n_length_changing_expressions = len( - [x for x in args if isinstance(x, Expr) and x._changes_length] + n_exprs = len([x for x in args if isinstance(x, Expr)]) + changes_length = any( + isinstance(x, Expr) and x._changes_length and not x._aggregates for x in args ) - if n_length_changing_expressions > 1: + if n_exprs > 1 and changes_length: msg = ( - "Found multiple expressions which change length. You can only use one " + "Found multiple expressions at least one of which changes length. You can only use one " "length-changing expression at a time, unless it is followed by an aggregation." ) # TODO(marco): custom error class raise ValueError(msg) - return n_length_changing_expressions > 0 + return changes_length def operation_aggregates(*args: IntoExpr | Any) -> bool: diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index b055fed12..1633a3c22 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -16,6 +16,7 @@ from narwhals.dependencies import get_polars from narwhals.dependencies import is_numpy_array +from narwhals.exceptions import LengthChangingExprError from narwhals.exceptions import OrderDependentExprError from narwhals.schema import Schema from narwhals.translate import to_native @@ -3648,6 +3649,16 @@ def _extract_compliant(self, arg: Any) -> Any: " they will be supported." ) raise OrderDependentExprError(msg) + if arg._changes_length: + msg = ( + "Length-changing expressions are not supported for use in LazyFrame, unless\n" + "followed by an aggregation.\n\n" + "Hints:\n" + "- Instead of `lf.select(nw.col('a').head())`, use `lf.select('a').head()\n" + "- Instead of `lf.select(nw.col('a').drop_nulls()).select(nw.sum('a'))`,\n" + " use `lf.select(nw.col('a').drop_nulls().sum())\n" + ) + raise LengthChangingExprError(msg) return arg._to_compliant_expr(self.__narwhals_namespace__()) if get_polars() is not None and "polars" in str(type(arg)): # pragma: no cover msg = ( diff --git a/narwhals/exceptions.py b/narwhals/exceptions.py index 9b05e3ba8..6a553fa44 100644 --- a/narwhals/exceptions.py +++ b/narwhals/exceptions.py @@ -91,6 +91,14 @@ def __init__(self, message: str) -> None: super().__init__(self.message) +class LengthChangingExprError(ValueError): + """Exception raised when trying to use an expression which changes length with LazyFrames.""" + + def __init__(self, message: str) -> None: + self.message = message + super().__init__(self.message) + + class UnsupportedDTypeError(ValueError): """Exception raised when trying to convert to a DType which is not supported by the given backend.""" diff --git a/narwhals/expr.py b/narwhals/expr.py index e6d34f73c..6904d4b6f 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -45,6 +45,14 @@ def __init__( self._changes_length = changes_length self._aggregates = aggregates + def __repr__(self) -> str: + return ( + "Narwhals Expr\n" + f"is_order_dependent: {self._is_order_dependent}\n" + f"changes_length: {self._changes_length}\n" + f"aggregates: {self._aggregates}" + ) + def _taxicab_norm(self) -> Self: # This is just used to test out the stable api feature in a realistic-ish way. # It's not intended to be used. diff --git a/narwhals/stable/v1/__init__.py b/narwhals/stable/v1/__init__.py index b37171380..896f1303b 100644 --- a/narwhals/stable/v1/__init__.py +++ b/narwhals/stable/v1/__init__.py @@ -240,7 +240,7 @@ def _dataframe(self) -> type[DataFrame[Any]]: return DataFrame def _extract_compliant(self, arg: Any) -> Any: - # After v1, we raise when passing order-dependent + # After v1, we raise when passing order-dependent or length-changing # expressions to LazyFrame from narwhals.dataframe import BaseFrame from narwhals.expr import Expr @@ -252,7 +252,7 @@ def _extract_compliant(self, arg: Any) -> Any: msg = "Mixing Series with LazyFrame is not supported." raise TypeError(msg) if isinstance(arg, Expr): - # After stable.v1, we raise if arg._is_order_dependent + # After stable.v1, we raise if arg._is_order_dependent or arg._changes_length return arg._to_compliant_expr(self.__narwhals_namespace__()) if get_polars() is not None and "polars" in str(type(arg)): # pragma: no cover msg = ( From 774f8f0f8e8157e59a3d645d92fa7388cc260001 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sun, 19 Jan 2025 13:26:29 +0400 Subject: [PATCH 03/10] dask Expr.unique --- narwhals/_dask/expr.py | 56 ++++++------------------ tests/expr_and_series/drop_nulls_test.py | 15 +++++++ tests/expr_and_series/unique_test.py | 8 ++++ 3 files changed, 36 insertions(+), 43 deletions(-) diff --git a/narwhals/_dask/expr.py b/narwhals/_dask/expr.py index 09a3adf8f..1319f1001 100644 --- a/narwhals/_dask/expr.py +++ b/narwhals/_dask/expr.py @@ -4,7 +4,6 @@ from typing import Any from typing import Callable from typing import Literal -from typing import NoReturn from typing import Sequence from narwhals._dask.expr_dt import DaskExprDateTimeNamespace @@ -448,34 +447,20 @@ def round(self, decimals: int) -> Self: returns_scalar=self._returns_scalar, ) - def ewm_mean( - self: Self, - *, - com: float | None = None, - span: float | None = None, - half_life: float | None = None, - alpha: float | None = None, - adjust: bool = True, - min_periods: int = 1, - ignore_nulls: bool = False, - ) -> NoReturn: - msg = "`Expr.ewm_mean` is not supported for the Dask backend" - raise NotImplementedError(msg) - - def unique(self) -> NoReturn: - # We can't (yet?) allow methods which modify the index - msg = "`Expr.unique` is not supported for the Dask backend. Please use `LazyFrame.unique` instead." - raise NotImplementedError(msg) - - def drop_nulls(self) -> NoReturn: - # We can't (yet?) allow methods which modify the index - msg = "`Expr.drop_nulls` is not supported for the Dask backend. Please use `LazyFrame.drop_nulls` instead." - raise NotImplementedError(msg) + def unique(self, *, maintain_order: bool) -> Self: + # TODO(marco): maintain_order has no effect and will be deprecated + return self._from_call( + lambda _input: _input.unique(), + "unique", + returns_scalar=self._returns_scalar, + ) - def head(self) -> NoReturn: - # We can't (yet?) allow methods which modify the index - msg = "`Expr.head` is not supported for the Dask backend. Please use `LazyFrame.head` instead." - raise NotImplementedError(msg) + def drop_nulls(self) -> Self: + return self._from_call( + lambda _input: _input.dropna(), + "drop_nulls", + returns_scalar=self._returns_scalar, + ) def replace_strict( self, old: Sequence[Any], new: Sequence[Any], *, return_dtype: DType | None @@ -483,11 +468,6 @@ def replace_strict( msg = "`replace_strict` is not yet supported for Dask expressions" raise NotImplementedError(msg) - def sort(self, *, descending: bool = False, nulls_last: bool = False) -> NoReturn: - # We can't (yet?) allow methods which modify the index - msg = "`Expr.sort` is not supported for the Dask backend. Please use `LazyFrame.sort` instead." - raise NotImplementedError(msg) - def abs(self) -> Self: return self._from_call( lambda _input: _input.abs(), "abs", returns_scalar=self._returns_scalar @@ -678,16 +658,6 @@ def null_count(self: Self) -> Self: returns_scalar=True, ) - def tail(self: Self) -> NoReturn: - # We can't (yet?) allow methods which modify the index - msg = "`Expr.tail` is not supported for the Dask backend. Please use `LazyFrame.tail` instead." - raise NotImplementedError(msg) - - def gather_every(self: Self, n: int, offset: int = 0) -> NoReturn: - # We can't (yet?) allow methods which modify the index - msg = "`Expr.gather_every` is not supported for the Dask backend. Please use `LazyFrame.gather_every` instead." - raise NotImplementedError(msg) - def over(self: Self, keys: list[str]) -> Self: def func(df: DaskLazyFrame) -> list[Any]: if self._output_names is None: diff --git a/tests/expr_and_series/drop_nulls_test.py b/tests/expr_and_series/drop_nulls_test.py index 0584674e6..25fecd297 100644 --- a/tests/expr_and_series/drop_nulls_test.py +++ b/tests/expr_and_series/drop_nulls_test.py @@ -1,6 +1,7 @@ from __future__ import annotations import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import ConstructorEager from tests.utils import assert_equal_data @@ -30,6 +31,20 @@ def test_drop_nulls(constructor_eager: ConstructorEager) -> None: assert_equal_data(result_d, expected_d) +def test_drop_nulls_agg(constructor: Constructor) -> None: + data = { + "A": [1, 2, None, 4], + "B": [5, 6, 7, 8], + "C": [None, None, None, None], + "D": [9, 10, 11, 12], + } + + df = nw.from_native(constructor(data)) + result = df.select(nw.all().drop_nulls().len()) + expected = {"A": [3], "B": [4], "C": [0], "D": [4]} + assert_equal_data(result, expected) + + def test_drop_nulls_series(constructor_eager: ConstructorEager) -> None: data = { "A": [1, 2, None, 4], diff --git a/tests/expr_and_series/unique_test.py b/tests/expr_and_series/unique_test.py index 989a577c2..b4ca2e9e7 100644 --- a/tests/expr_and_series/unique_test.py +++ b/tests/expr_and_series/unique_test.py @@ -1,6 +1,7 @@ from __future__ import annotations import narwhals.stable.v1 as nw +from tests.utils import Constructor from tests.utils import ConstructorEager from tests.utils import assert_equal_data @@ -15,6 +16,13 @@ def test_unique_expr(constructor_eager: ConstructorEager) -> None: assert_equal_data(result, expected) +def test_unique_expr_agg(constructor: Constructor) -> None: + df = nw.from_native(constructor(data)) + result = df.select(nw.col("a").unique().sum()) + expected = {"a": [3]} + assert_equal_data(result, expected) + + def test_unique_series(constructor_eager: ConstructorEager) -> None: series = nw.from_native(constructor_eager(data_str), eager_only=True)["a"] result = series.unique(maintain_order=True) From d66577f24e2810068f3b0459439efeecfe148f70 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sun, 19 Jan 2025 13:30:45 +0400 Subject: [PATCH 04/10] coverage --- tests/expr_and_series/unique_test.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/tests/expr_and_series/unique_test.py b/tests/expr_and_series/unique_test.py index b4ca2e9e7..47fca2b9a 100644 --- a/tests/expr_and_series/unique_test.py +++ b/tests/expr_and_series/unique_test.py @@ -1,6 +1,11 @@ from __future__ import annotations -import narwhals.stable.v1 as nw +from contextlib import nullcontext as does_not_raise + +import pytest + +import narwhals as nw +from narwhals.exceptions import LengthChangingExprError from tests.utils import Constructor from tests.utils import ConstructorEager from tests.utils import assert_equal_data @@ -9,11 +14,17 @@ data_str = {"a": ["x", "x", "y"]} -def test_unique_expr(constructor_eager: ConstructorEager) -> None: - df = nw.from_native(constructor_eager(data)) - result = df.select(nw.col("a").unique()) - expected = {"a": [1, 2]} - assert_equal_data(result, expected) +def test_unique_expr(constructor: Constructor) -> None: + df = nw.from_native(constructor(data)) + context = ( + pytest.raises(LengthChangingExprError) + if isinstance(df, nw.LazyFrame) + else does_not_raise() + ) + with context: + result = df.select(nw.col("a").unique()) + expected = {"a": [1, 2]} + assert_equal_data(result, expected) def test_unique_expr_agg(constructor: Constructor) -> None: From 98b325cdd839b55c103ddeb40f11801bfd07d664 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sun, 19 Jan 2025 19:24:54 +0400 Subject: [PATCH 05/10] xfail --- tests/expr_and_series/drop_nulls_test.py | 6 +++++- tests/expr_and_series/unique_test.py | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/tests/expr_and_series/drop_nulls_test.py b/tests/expr_and_series/drop_nulls_test.py index 25fecd297..b4e1f7c6b 100644 --- a/tests/expr_and_series/drop_nulls_test.py +++ b/tests/expr_and_series/drop_nulls_test.py @@ -1,5 +1,7 @@ from __future__ import annotations +import pytest + import narwhals.stable.v1 as nw from tests.utils import Constructor from tests.utils import ConstructorEager @@ -31,7 +33,9 @@ def test_drop_nulls(constructor_eager: ConstructorEager) -> None: assert_equal_data(result_d, expected_d) -def test_drop_nulls_agg(constructor: Constructor) -> None: +def test_drop_nulls_agg(constructor: Constructor, request: pytest.FixtureRequest) -> None: + if any(x in str(constructor) for x in ("duckdb", "pyspark")): + request.applymarker(pytest.mark.xfail) data = { "A": [1, 2, None, 4], "B": [5, 6, 7, 8], diff --git a/tests/expr_and_series/unique_test.py b/tests/expr_and_series/unique_test.py index 47fca2b9a..41300997b 100644 --- a/tests/expr_and_series/unique_test.py +++ b/tests/expr_and_series/unique_test.py @@ -27,7 +27,11 @@ def test_unique_expr(constructor: Constructor) -> None: assert_equal_data(result, expected) -def test_unique_expr_agg(constructor: Constructor) -> None: +def test_unique_expr_agg( + constructor: Constructor, request: pytest.FixtureRequest +) -> None: + if any(x in str(constructor) for x in ("duckdb", "pyspark")): + request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) result = df.select(nw.col("a").unique().sum()) expected = {"a": [3]} From 27c792ce7772d055fb76eaf2fb6999f362ea0740 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sun, 19 Jan 2025 19:32:51 +0400 Subject: [PATCH 06/10] coverage --- narwhals/_expression_parsing.py | 4 ++-- tests/expr_and_series/unique_test.py | 6 ++++++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index b536dfe19..819a4f8bc 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -14,6 +14,7 @@ from narwhals.dependencies import is_numpy_array from narwhals.exceptions import InvalidIntoExprError +from narwhals.exceptions import LengthChangingExprError from narwhals.utils import Implementation if TYPE_CHECKING: @@ -372,8 +373,7 @@ def operation_changes_length(*args: IntoExpr | Any) -> bool: "Found multiple expressions at least one of which changes length. You can only use one " "length-changing expression at a time, unless it is followed by an aggregation." ) - # TODO(marco): custom error class - raise ValueError(msg) + raise LengthChangingExprError(msg) return changes_length diff --git a/tests/expr_and_series/unique_test.py b/tests/expr_and_series/unique_test.py index 41300997b..b2877c646 100644 --- a/tests/expr_and_series/unique_test.py +++ b/tests/expr_and_series/unique_test.py @@ -38,6 +38,12 @@ def test_unique_expr_agg( assert_equal_data(result, expected) +def test_unique_illegal_combination(constructor: Constructor) -> None: + df = nw.from_native(constructor(data)) + with pytest.raises(LengthChangingExprError): + df.select((nw.col("a").unique() + nw.col("b").unique()).sum()) + + def test_unique_series(constructor_eager: ConstructorEager) -> None: series = nw.from_native(constructor_eager(data_str), eager_only=True)["a"] result = series.unique(maintain_order=True) From 11e219d92c0fa7238ad033709ab05dfee52d8fbf Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sun, 19 Jan 2025 20:48:32 +0400 Subject: [PATCH 07/10] fix get_categories and improve message --- narwhals/_expression_parsing.py | 5 +++-- narwhals/expr_cat.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index 819a4f8bc..28e51781f 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -370,8 +370,9 @@ def operation_changes_length(*args: IntoExpr | Any) -> bool: ) if n_exprs > 1 and changes_length: msg = ( - "Found multiple expressions at least one of which changes length. You can only use one " - "length-changing expression at a time, unless it is followed by an aggregation." + "Found multiple expressions at least one of which changes length.\n" + "Any length-changing expression can only be used in isolation, unless\n" + "it is followed by an aggregation." ) raise LengthChangingExprError(msg) return changes_length diff --git a/narwhals/expr_cat.py b/narwhals/expr_cat.py index 80b047a2e..16dbb3929 100644 --- a/narwhals/expr_cat.py +++ b/narwhals/expr_cat.py @@ -64,6 +64,6 @@ def get_categories(self: Self) -> ExprT: return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).cat.get_categories(), self._expr._is_order_dependent, - changes_length=self._expr._changes_length, + changes_length=True, aggregates=self._expr._aggregates, ) From 2537d9859e45cb75c4a26d204bd5892a5147fc44 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Mon, 20 Jan 2025 13:26:03 +0400 Subject: [PATCH 08/10] simplify --- narwhals/_expression_parsing.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index 28e51781f..a11db68eb 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -365,9 +365,7 @@ def operation_changes_length(*args: IntoExpr | Any) -> bool: from narwhals.expr import Expr n_exprs = len([x for x in args if isinstance(x, Expr)]) - changes_length = any( - isinstance(x, Expr) and x._changes_length and not x._aggregates for x in args - ) + changes_length = any(isinstance(x, Expr) and x._changes_length for x in args) if n_exprs > 1 and changes_length: msg = ( "Found multiple expressions at least one of which changes length.\n" From 28f32568cecf813fb808904dc7417048f65ffe8a Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Mon, 20 Jan 2025 13:34:13 +0400 Subject: [PATCH 09/10] fixup --- narwhals/stable/v1/__init__.py | 5 ++++- tests/frame/filter_test.py | 16 ++-------------- 2 files changed, 6 insertions(+), 15 deletions(-) diff --git a/narwhals/stable/v1/__init__.py b/narwhals/stable/v1/__init__.py index bf579b308..6e02b08ff 100644 --- a/narwhals/stable/v1/__init__.py +++ b/narwhals/stable/v1/__init__.py @@ -937,7 +937,10 @@ def arg_true(self) -> Self: A new expression. """ return self.__class__( - lambda plx: self._to_compliant_expr(plx).arg_true(), is_order_dependent=True + lambda plx: self._to_compliant_expr(plx).arg_true(), + is_order_dependent=True, + changes_length=True, + aggregates=self._aggregates, ) def sample( diff --git a/tests/frame/filter_test.py b/tests/frame/filter_test.py index 1b691c5c1..b785e664b 100644 --- a/tests/frame/filter_test.py +++ b/tests/frame/filter_test.py @@ -3,9 +3,9 @@ from contextlib import nullcontext as does_not_raise import pytest -from polars.exceptions import ShapeError as PlShapeError import narwhals as nw +from narwhals.exceptions import LengthChangingExprError from narwhals.exceptions import ShapeError from tests.utils import Constructor from tests.utils import assert_equal_data @@ -54,17 +54,5 @@ def test_filter_raise_on_agg_predicate(constructor: Constructor) -> None: def test_filter_raise_on_shape_mismatch(constructor: Constructor) -> None: data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8.0, 9.0]} df = nw.from_native(constructor(data)) - - context = ( - pytest.raises( - ShapeError, match="filter's length: 2 differs from that of the series: 3" - ) - if any(x in str(constructor) for x in ("pandas", "pyarrow", "modin")) - else pytest.raises( - PlShapeError, match="filter's length: 2 differs from that of the series: 3" - ) - if "polars" in str(constructor) - else pytest.raises(Exception) # type: ignore[arg-type] # noqa: PT011 - ) - with context: + with pytest.raises(LengthChangingExprError): df.filter(nw.col("b").unique() > 2).lazy().collect() From ba2ca28462b4460abfea756c3fa0840fb302e0e8 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Mon, 20 Jan 2025 14:11:30 +0400 Subject: [PATCH 10/10] actually fix --- tests/frame/filter_test.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/frame/filter_test.py b/tests/frame/filter_test.py index b785e664b..96aa2b44e 100644 --- a/tests/frame/filter_test.py +++ b/tests/frame/filter_test.py @@ -3,6 +3,7 @@ from contextlib import nullcontext as does_not_raise import pytest +from polars.exceptions import ShapeError as PlShapeError import narwhals as nw from narwhals.exceptions import LengthChangingExprError @@ -54,5 +55,5 @@ def test_filter_raise_on_agg_predicate(constructor: Constructor) -> None: def test_filter_raise_on_shape_mismatch(constructor: Constructor) -> None: data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8.0, 9.0]} df = nw.from_native(constructor(data)) - with pytest.raises(LengthChangingExprError): + with pytest.raises((LengthChangingExprError, ShapeError, PlShapeError)): df.filter(nw.col("b").unique() > 2).lazy().collect()