Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: track whether expressions change length but don't aggregate, and only allow length-changing expressions if they're followed by aggregations in the lazy API #1828

Merged
merged 11 commits into from
Jan 20, 2025
56 changes: 13 additions & 43 deletions narwhals/_dask/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import Any
from typing import Callable
from typing import Literal
from typing import NoReturn
from typing import Sequence

from narwhals._dask.expr_dt import DaskExprDateTimeNamespace
Expand Down Expand Up @@ -448,46 +447,27 @@ def round(self, decimals: int) -> Self:
returns_scalar=self._returns_scalar,
)

def ewm_mean(
self: Self,
*,
com: float | None = None,
span: float | None = None,
half_life: float | None = None,
alpha: float | None = None,
adjust: bool = True,
min_periods: int = 1,
ignore_nulls: bool = False,
) -> NoReturn:
msg = "`Expr.ewm_mean` is not supported for the Dask backend"
raise NotImplementedError(msg)

def unique(self) -> NoReturn:
# We can't (yet?) allow methods which modify the index
msg = "`Expr.unique` is not supported for the Dask backend. Please use `LazyFrame.unique` instead."
raise NotImplementedError(msg)

def drop_nulls(self) -> NoReturn:
# We can't (yet?) allow methods which modify the index
msg = "`Expr.drop_nulls` is not supported for the Dask backend. Please use `LazyFrame.drop_nulls` instead."
raise NotImplementedError(msg)
def unique(self, *, maintain_order: bool) -> Self:
# TODO(marco): maintain_order has no effect and will be deprecated
return self._from_call(
lambda _input: _input.unique(),
"unique",
returns_scalar=self._returns_scalar,
)

def head(self) -> NoReturn:
# We can't (yet?) allow methods which modify the index
msg = "`Expr.head` is not supported for the Dask backend. Please use `LazyFrame.head` instead."
raise NotImplementedError(msg)
def drop_nulls(self) -> Self:
return self._from_call(
lambda _input: _input.dropna(),
"drop_nulls",
returns_scalar=self._returns_scalar,
)

def replace_strict(
self, old: Sequence[Any], new: Sequence[Any], *, return_dtype: DType | None
) -> Self:
msg = "`replace_strict` is not yet supported for Dask expressions"
raise NotImplementedError(msg)

def sort(self, *, descending: bool = False, nulls_last: bool = False) -> NoReturn:
# We can't (yet?) allow methods which modify the index
msg = "`Expr.sort` is not supported for the Dask backend. Please use `LazyFrame.sort` instead."
raise NotImplementedError(msg)

def abs(self) -> Self:
return self._from_call(
lambda _input: _input.abs(), "abs", returns_scalar=self._returns_scalar
Expand Down Expand Up @@ -678,16 +658,6 @@ def null_count(self: Self) -> Self:
returns_scalar=True,
)

def tail(self: Self) -> NoReturn:
# We can't (yet?) allow methods which modify the index
msg = "`Expr.tail` is not supported for the Dask backend. Please use `LazyFrame.tail` instead."
raise NotImplementedError(msg)

def gather_every(self: Self, n: int, offset: int = 0) -> NoReturn:
# We can't (yet?) allow methods which modify the index
msg = "`Expr.gather_every` is not supported for the Dask backend. Please use `LazyFrame.gather_every` instead."
raise NotImplementedError(msg)

def over(self: Self, keys: list[str]) -> Self:
def func(df: DaskLazyFrame) -> list[Any]:
if self._output_names is None:
Expand Down
41 changes: 41 additions & 0 deletions narwhals/_expression_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from narwhals.dependencies import is_numpy_array
from narwhals.exceptions import InvalidIntoExprError
from narwhals.exceptions import LengthChangingExprError
from narwhals.utils import Implementation

if TYPE_CHECKING:
Expand Down Expand Up @@ -342,3 +343,43 @@ def operation_is_order_dependent(*args: IntoExpr | Any) -> bool:
# it means that it was a scalar (e.g. nw.col('a') + 1) or a column name,
# neither of which is order-dependent, so we default to `False`.
return any(getattr(x, "_is_order_dependent", False) for x in args)


def operation_changes_length(*args: IntoExpr | Any) -> bool:
"""Track whether operation changes length.

n-ary operations between expressions which change length are not
allowed. This is because the output might be non-relational. For
example:
df = pl.LazyFrame({'a': [1,2,None], 'b': [4,None,6]})
df.select(pl.col('a', 'b').drop_nulls())
Polars does allow this, but in the result we end up with the
tuple (2, 6) which wasn't part of the original data.

Rules are:
- in an n-ary operation, if any one of them changes length, then
it must be the only expression present
- in a comparison between a changes-length expression and a
scalar, the output changes length
"""
from narwhals.expr import Expr

n_exprs = len([x for x in args if isinstance(x, Expr)])
changes_length = any(isinstance(x, Expr) and x._changes_length for x in args)
if n_exprs > 1 and changes_length:
msg = (
"Found multiple expressions at least one of which changes length.\n"
"Any length-changing expression can only be used in isolation, unless\n"
"it is followed by an aggregation."
)
raise LengthChangingExprError(msg)
return changes_length


def operation_aggregates(*args: IntoExpr | Any) -> bool:
# If an arg is an Expr, we look at `_aggregates`. If it isn't,
# it means that it was a scalar (e.g. nw.col('a').sum() + 1),
# which is already length-1, so we default to `True`. If any
# expression does not aggregate, then broadcasting will take
# place and the result will not be an aggregate.
return all(getattr(x, "_aggregates", True) for x in args)
11 changes: 11 additions & 0 deletions narwhals/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from narwhals.dependencies import get_polars
from narwhals.dependencies import is_numpy_array
from narwhals.exceptions import LengthChangingExprError
from narwhals.exceptions import OrderDependentExprError
from narwhals.schema import Schema
from narwhals.translate import to_native
Expand Down Expand Up @@ -3648,6 +3649,16 @@ def _extract_compliant(self, arg: Any) -> Any:
" they will be supported."
)
raise OrderDependentExprError(msg)
if arg._changes_length:
msg = (
"Length-changing expressions are not supported for use in LazyFrame, unless\n"
"followed by an aggregation.\n\n"
"Hints:\n"
"- Instead of `lf.select(nw.col('a').head())`, use `lf.select('a').head()\n"
"- Instead of `lf.select(nw.col('a').drop_nulls()).select(nw.sum('a'))`,\n"
" use `lf.select(nw.col('a').drop_nulls().sum())\n"
)
raise LengthChangingExprError(msg)
return arg._to_compliant_expr(self.__narwhals_namespace__())
if get_polars() is not None and "polars" in str(type(arg)): # pragma: no cover
msg = (
Expand Down
8 changes: 8 additions & 0 deletions narwhals/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,14 @@ def __init__(self, message: str) -> None:
super().__init__(self.message)


class LengthChangingExprError(ValueError):
"""Exception raised when trying to use an expression which changes length with LazyFrames."""

def __init__(self, message: str) -> None:
self.message = message
super().__init__(self.message)


class UnsupportedDTypeError(ValueError):
"""Exception raised when trying to convert to a DType which is not supported by the given backend."""

Expand Down
Loading
Loading