Skip to content

Commit

Permalink
feat: support when-then-otherwise for Dask (#868)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli authored Aug 26, 2024
1 parent cad2284 commit 9fdef26
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 36 deletions.
107 changes: 107 additions & 0 deletions narwhals/_dask/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,17 @@
from typing import Any
from typing import Callable
from typing import NoReturn
from typing import cast

from narwhals import dtypes
from narwhals._dask.expr import DaskExpr
from narwhals._dask.selectors import DaskSelectorNamespace
from narwhals._dask.utils import validate_comparand
from narwhals._expression_parsing import parse_into_exprs

if TYPE_CHECKING:
import dask_expr

from narwhals._dask.dataframe import DaskLazyFrame
from narwhals._dask.typing import IntoDaskExpr

Expand Down Expand Up @@ -171,3 +175,106 @@ def _create_expr_from_callable( # pragma: no cover
"`_create_expr_from_callable` for DaskNamespace exists only for compatibility"
)
raise NotImplementedError(msg)

def when(
self,
*predicates: IntoDaskExpr,
) -> DaskWhen:
plx = self.__class__(backend_version=self._backend_version)
if predicates:
condition = plx.all_horizontal(*predicates)
else:
msg = "at least one predicate needs to be provided"
raise TypeError(msg)

return DaskWhen(condition, self._backend_version, returns_scalar=False)


class DaskWhen:
def __init__(
self,
condition: DaskExpr,
backend_version: tuple[int, ...],
then_value: Any = None,
otherwise_value: Any = None,
*,
returns_scalar: bool,
) -> None:
self._backend_version = backend_version
self._condition = condition
self._then_value = then_value
self._otherwise_value = otherwise_value
self._returns_scalar = returns_scalar

def __call__(self, df: DaskLazyFrame) -> list[Any]:
from narwhals._dask.namespace import DaskNamespace
from narwhals._expression_parsing import parse_into_expr

plx = DaskNamespace(backend_version=self._backend_version)

condition = parse_into_expr(self._condition, namespace=plx)._call(df)[0] # type: ignore[arg-type]
condition = cast("dask_expr.Series", condition)
try:
value_series = parse_into_expr(self._then_value, namespace=plx)._call(df)[0] # type: ignore[arg-type]
except TypeError:
# `self._otherwise_value` is a scalar and can't be converted to an expression
_df = condition.to_frame("a")
_df["tmp"] = self._then_value
value_series = _df["tmp"]
value_series = cast("dask_expr.Series", value_series)
validate_comparand(condition, value_series)

if self._otherwise_value is None:
return [value_series.where(condition)]
try:
otherwise_series = parse_into_expr(
self._otherwise_value, namespace=plx
)._call(df)[0] # type: ignore[arg-type]
except TypeError:
# `self._otherwise_value` is a scalar and can't be converted to an expression
return [value_series.where(condition, self._otherwise_value)]
validate_comparand(condition, otherwise_series)
return [value_series.zip_with(condition, otherwise_series)]

def then(self, value: DaskExpr | Any) -> DaskThen:
self._then_value = value

return DaskThen(
self,
depth=0,
function_name="whenthen",
root_names=None,
output_names=None,
returns_scalar=self._returns_scalar,
backend_version=self._backend_version,
)


class DaskThen(DaskExpr):
def __init__(
self,
call: DaskWhen,
*,
depth: int,
function_name: str,
root_names: list[str] | None,
output_names: list[str] | None,
returns_scalar: bool,
backend_version: tuple[int, ...],
) -> None:
self._backend_version = backend_version

self._call = call
self._depth = depth
self._function_name = function_name
self._root_names = root_names
self._output_names = output_names
self._returns_scalar = returns_scalar

def otherwise(self, value: DaskExpr | Any) -> DaskExpr:
# type ignore because we are setting the `_call` attribute to a
# callable object of type `DaskWhen`, base class has the attribute as
# only a `Callable`
self._call._otherwise_value = value # type: ignore[attr-defined]
self._function_name = "whenotherwise"
return self
37 changes: 21 additions & 16 deletions narwhals/_dask/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
from typing import TYPE_CHECKING
from typing import Any

from narwhals.dependencies import get_dask_expr
from narwhals.dependencies import get_pandas
from narwhals.dependencies import get_pyarrow
from narwhals.utils import isinstance_or_issubclass
from narwhals.utils import parse_version

if TYPE_CHECKING:
import dask_expr

from narwhals._dask.dataframe import DaskLazyFrame
from narwhals.dtypes import DType

Expand All @@ -23,21 +24,7 @@ def maybe_evaluate(df: DaskLazyFrame, obj: Any) -> Any:
msg = "Multi-output expressions not supported in this context"
raise NotImplementedError(msg)
result = results[0]
if not get_dask_expr()._expr.are_co_aligned(
df._native_frame._expr, result._expr
): # pragma: no cover
# are_co_aligned is a method which cheaply checks if two Dask expressions
# have the same index, and therefore don't require index alignment.
# If someone only operates on a Dask DataFrame via expressions, then this
# should always be the case: expression outputs (by definition) all come from the
# same input dataframe, and Dask Series does not have any operations which
# change the index. Nonetheless, we perform this safety check anyway.

# However, we still need to carefully vet which methods we support for Dask, to
# avoid issues where `are_co_aligned` doesn't do what we want it to do:
# https://github.com/dask/dask-expr/issues/1112.
msg = "Implicit index alignment is not support for Dask DataFrame in Narwhals"
raise NotImplementedError(msg)
validate_comparand(df._native_frame, result)
if obj._returns_scalar:
# Return scalar, let Dask do its broadcasting
return result[0]
Expand Down Expand Up @@ -80,6 +67,24 @@ def add_row_index(frame: Any, name: str) -> Any:
return frame.assign(**{name: frame[name].cumsum(method="blelloch") - 1})


def validate_comparand(lhs: dask_expr.Series, rhs: dask_expr.Series) -> None:
import dask_expr # ignore-banned-import

if not dask_expr._expr.are_co_aligned(lhs._expr, rhs._expr): # pragma: no cover
# are_co_aligned is a method which cheaply checks if two Dask expressions
# have the same index, and therefore don't require index alignment.
# If someone only operates on a Dask DataFrame via expressions, then this
# should always be the case: expression outputs (by definition) all come from the
# same input dataframe, and Dask Series does not have any operations which
# change the index. Nonetheless, we perform this safety check anyway.

# However, we still need to carefully vet which methods we support for Dask, to
# avoid issues where `are_co_aligned` doesn't do what we want it to do:
# https://github.com/dask/dask-expr/issues/1112.
msg = "Objects are not co-aligned, so this operation is not supported for Dask backend"
raise RuntimeError(msg)


def reverse_translate_dtype(dtype: DType | type[DType]) -> Any:
from narwhals import dtypes

Expand Down
25 changes: 5 additions & 20 deletions tests/expr_and_series/when_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,7 @@
}


def test_when(request: Any, constructor: Any) -> None:
if "dask" in str(constructor):
request.applymarker(pytest.mark.xfail)

def test_when(constructor: Any) -> None:
df = nw.from_native(constructor(data))
result = df.select(nw.when(nw.col("a") == 1).then(value=3).alias("a_when"))
expected = {
Expand All @@ -29,10 +26,7 @@ def test_when(request: Any, constructor: Any) -> None:
compare_dicts(result, expected)


def test_when_otherwise(request: Any, constructor: Any) -> None:
if "dask" in str(constructor):
request.applymarker(pytest.mark.xfail)

def test_when_otherwise(constructor: Any) -> None:
df = nw.from_native(constructor(data))
result = df.select(nw.when(nw.col("a") == 1).then(3).otherwise(6).alias("a_when"))
expected = {
Expand All @@ -41,10 +35,7 @@ def test_when_otherwise(request: Any, constructor: Any) -> None:
compare_dicts(result, expected)


def test_multiple_conditions(request: Any, constructor: Any) -> None:
if "dask" in str(constructor):
request.applymarker(pytest.mark.xfail)

def test_multiple_conditions(constructor: Any) -> None:
df = nw.from_native(constructor(data))
result = df.select(
nw.when(nw.col("a") < 3, nw.col("c") < 5.0).then(3).alias("a_when")
Expand All @@ -55,10 +46,7 @@ def test_multiple_conditions(request: Any, constructor: Any) -> None:
compare_dicts(result, expected)


def test_no_arg_when_fail(request: Any, constructor: Any) -> None:
if "dask" in str(constructor):
request.applymarker(pytest.mark.xfail)

def test_no_arg_when_fail(constructor: Any) -> None:
df = nw.from_native(constructor(data))
with pytest.raises((TypeError, ValueError)):
df.select(nw.when().then(value=3).alias("a_when"))
Expand Down Expand Up @@ -92,10 +80,7 @@ def test_value_series(constructor_eager: Any) -> None:
compare_dicts(result, expected)


def test_value_expression(request: Any, constructor: Any) -> None:
if "dask" in str(constructor):
request.applymarker(pytest.mark.xfail)

def test_value_expression(constructor: Any) -> None:
df = nw.from_native(constructor(data))
result = df.select(nw.when(nw.col("a") == 1).then(nw.col("a") + 9).alias("a_when"))
expected = {
Expand Down

0 comments on commit 9fdef26

Please sign in to comment.