Skip to content

Commit c4c31cb

Browse files
committed
chore: dask nightly
1 parent 5dca2a9 commit c4c31cb

File tree

6 files changed

+81
-52
lines changed

6 files changed

+81
-52
lines changed

narwhals/_dask/expr.py

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,11 @@
2323
from narwhals.utils import import_dtypes_module
2424

2525
if TYPE_CHECKING:
26-
import dask_expr
26+
try:
27+
import dask.dataframe.dask_expr as dx
28+
except ModuleNotFoundError:
29+
import dask_expr as dx
30+
2731
from typing_extensions import Self
2832

2933
from narwhals._dask.dataframe import DaskLazyFrame
@@ -32,12 +36,12 @@
3236
from narwhals.utils import Version
3337

3438

35-
class DaskExpr(CompliantExpr["dask_expr.Series"]):
39+
class DaskExpr(CompliantExpr["dx.Series"]):
3640
_implementation: Implementation = Implementation.DASK
3741

3842
def __init__(
3943
self,
40-
call: Callable[[DaskLazyFrame], Sequence[dask_expr.Series]],
44+
call: Callable[[DaskLazyFrame], Sequence[dx.Series]],
4145
*,
4246
depth: int,
4347
function_name: str,
@@ -60,7 +64,7 @@ def __init__(
6064
self._version = version
6165
self._kwargs = kwargs
6266

63-
def __call__(self, df: DaskLazyFrame) -> Sequence[dask_expr.Series]:
67+
def __call__(self, df: DaskLazyFrame) -> Sequence[dx.Series]:
6468
return self._call(df)
6569

6670
def __narwhals_expr__(self) -> None: ...
@@ -78,7 +82,7 @@ def from_column_names(
7882
backend_version: tuple[int, ...],
7983
version: Version,
8084
) -> Self:
81-
def func(df: DaskLazyFrame) -> list[dask_expr.Series]:
85+
def func(df: DaskLazyFrame) -> list[dx.Series]:
8286
try:
8387
return [df._native_frame[column_name] for column_name in column_names]
8488
except KeyError as e:
@@ -107,7 +111,7 @@ def from_column_indices(
107111
backend_version: tuple[int, ...],
108112
version: Version,
109113
) -> Self:
110-
def func(df: DaskLazyFrame) -> list[dask_expr.Series]:
114+
def func(df: DaskLazyFrame) -> list[dx.Series]:
111115
return [
112116
df._native_frame.iloc[:, column_index] for column_index in column_indices
113117
]
@@ -126,14 +130,14 @@ def func(df: DaskLazyFrame) -> list[dask_expr.Series]:
126130

127131
def _from_call(
128132
self,
129-
# First argument to `call` should be `dask_expr.Series`
130-
call: Callable[..., dask_expr.Series],
133+
# First argument to `call` should be `dx.Series`
134+
call: Callable[..., dx.Series],
131135
expr_name: str,
132136
*,
133137
returns_scalar: bool,
134138
**kwargs: Any,
135139
) -> Self:
136-
def func(df: DaskLazyFrame) -> list[dask_expr.Series]:
140+
def func(df: DaskLazyFrame) -> list[dx.Series]:
137141
results = []
138142
inputs = self._call(df)
139143
_kwargs = {key: maybe_evaluate(df, value) for key, value in kwargs.items()}
@@ -163,7 +167,7 @@ def func(df: DaskLazyFrame) -> list[dask_expr.Series]:
163167
)
164168

165169
def alias(self, name: str) -> Self:
166-
def func(df: DaskLazyFrame) -> list[dask_expr.Series]:
170+
def func(df: DaskLazyFrame) -> list[dx.Series]:
167171
inputs = self._call(df)
168172
return [_input.rename(name) for _input in inputs]
169173

@@ -312,7 +316,7 @@ def mean(self) -> Self:
312316
def median(self) -> Self:
313317
from narwhals.exceptions import InvalidOperationError
314318

315-
def func(s: dask_expr.Series) -> dask_expr.Series:
319+
def func(s: dx.Series) -> dx.Series:
316320
dtype = native_to_narwhals_dtype(s, self._version, Implementation.DASK)
317321
if not dtype.is_numeric():
318322
msg = "`median` operation not supported for non-numeric input type."
@@ -511,11 +515,11 @@ def fill_null(
511515
limit: int | None = None,
512516
) -> DaskExpr:
513517
def func(
514-
_input: dask_expr.Series,
518+
_input: dx.Series,
515519
value: Any | None,
516520
strategy: str | None,
517521
limit: int | None,
518-
) -> dask_expr.Series:
522+
) -> dx.Series:
519523
if value is not None:
520524
res_ser = _input.fillna(value)
521525
else:
@@ -566,7 +570,7 @@ def is_null(self: Self) -> Self:
566570
)
567571

568572
def is_nan(self: Self) -> Self:
569-
def func(_input: dask_expr.Series) -> dask_expr.Series:
573+
def func(_input: dx.Series) -> dx.Series:
570574
dtype = native_to_narwhals_dtype(_input, self._version, self._implementation)
571575
if dtype.is_numeric():
572576
return _input != _input # noqa: PLR0124
@@ -585,7 +589,7 @@ def quantile(
585589
) -> Self:
586590
if interpolation == "linear":
587591

588-
def func(_input: dask_expr.Series, quantile: float) -> dask_expr.Series:
592+
def func(_input: dx.Series, quantile: float) -> dx.Series:
589593
if _input.npartitions > 1:
590594
msg = "`Expr.quantile` is not supported for Dask backend with multiple partitions."
591595
raise NotImplementedError(msg)
@@ -599,7 +603,7 @@ def func(_input: dask_expr.Series, quantile: float) -> dask_expr.Series:
599603
raise NotImplementedError(msg)
600604

601605
def is_first_distinct(self: Self) -> Self:
602-
def func(_input: dask_expr.Series) -> dask_expr.Series:
606+
def func(_input: dx.Series) -> dx.Series:
603607
_name = _input.name
604608
col_token = generate_temporary_column_name(n_bytes=8, columns=[_name])
605609
_input = add_row_index(
@@ -618,7 +622,7 @@ def func(_input: dask_expr.Series) -> dask_expr.Series:
618622
)
619623

620624
def is_last_distinct(self: Self) -> Self:
621-
def func(_input: dask_expr.Series) -> dask_expr.Series:
625+
def func(_input: dx.Series) -> dx.Series:
622626
_name = _input.name
623627
col_token = generate_temporary_column_name(n_bytes=8, columns=[_name])
624628
_input = add_row_index(
@@ -635,7 +639,7 @@ def func(_input: dask_expr.Series) -> dask_expr.Series:
635639
)
636640

637641
def is_duplicated(self: Self) -> Self:
638-
def func(_input: dask_expr.Series) -> dask_expr.Series:
642+
def func(_input: dx.Series) -> dx.Series:
639643
_name = _input.name
640644
return (
641645
_input.to_frame()
@@ -647,7 +651,7 @@ def func(_input: dask_expr.Series) -> dask_expr.Series:
647651
return self._from_call(func, "is_duplicated", returns_scalar=self._returns_scalar)
648652

649653
def is_unique(self: Self) -> Self:
650-
def func(_input: dask_expr.Series) -> dask_expr.Series:
654+
def func(_input: dx.Series) -> dx.Series:
651655
_name = _input.name
652656
return (
653657
_input.to_frame()
@@ -967,7 +971,7 @@ def replace_time_zone(self, time_zone: str | None) -> DaskExpr:
967971
)
968972

969973
def convert_time_zone(self, time_zone: str) -> DaskExpr:
970-
def func(s: dask_expr.Series, time_zone: str) -> dask_expr.Series:
974+
def func(s: dx.Series, time_zone: str) -> dx.Series:
971975
dtype = native_to_narwhals_dtype(
972976
s, self._compliant_expr._version, Implementation.DASK
973977
)
@@ -984,9 +988,7 @@ def func(s: dask_expr.Series, time_zone: str) -> dask_expr.Series:
984988
)
985989

986990
def timestamp(self, time_unit: Literal["ns", "us", "ms"] = "us") -> DaskExpr:
987-
def func(
988-
s: dask_expr.Series, time_unit: Literal["ns", "us", "ms"] = "us"
989-
) -> dask_expr.Series:
991+
def func(s: dx.Series, time_unit: Literal["ns", "us", "ms"] = "us") -> dx.Series:
990992
dtype = native_to_narwhals_dtype(
991993
s, self._compliant_expr._version, Implementation.DASK
992994
)

narwhals/_dask/group_by.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,12 @@
1212

1313
if TYPE_CHECKING:
1414
import dask.dataframe as dd
15-
import dask_expr
15+
16+
try:
17+
import dask.dataframe.dask_expr as dx
18+
except ModuleNotFoundError:
19+
import dask_expr as dx
20+
1621
import pandas as pd
1722

1823
from narwhals._dask.dataframe import DaskLazyFrame
@@ -43,7 +48,10 @@ def var(
4348
]:
4449
from functools import partial
4550

46-
import dask_expr as dx
51+
try:
52+
import dask.dataframe.dask_expr as dx
53+
except ModuleNotFoundError:
54+
import dask_expr as dx
4755

4856
return partial(dx._groupby.GroupBy.var, ddof=ddof)
4957

@@ -55,7 +63,10 @@ def std(
5563
]:
5664
from functools import partial
5765

58-
import dask_expr as dx
66+
try:
67+
import dask.dataframe.dask_expr as dx
68+
except ModuleNotFoundError:
69+
import dask_expr as dx
5970

6071
return partial(dx._groupby.GroupBy.std, ddof=ddof)
6172

@@ -127,7 +138,7 @@ def _from_native_frame(self, df: DaskLazyFrame) -> DaskLazyFrame:
127138
def agg_dask(
128139
df: DaskLazyFrame,
129140
grouped: Any,
130-
exprs: Sequence[CompliantExpr[dask_expr.Series]],
141+
exprs: Sequence[CompliantExpr[dx.Series]],
131142
keys: list[str],
132143
from_dataframe: Callable[[Any], DaskLazyFrame],
133144
) -> DaskLazyFrame:

narwhals/_dask/namespace.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,17 @@
2121
from narwhals.typing import CompliantNamespace
2222

2323
if TYPE_CHECKING:
24-
import dask_expr
24+
try:
25+
import dask.dataframe.dask_expr as dx
26+
except ModuleNotFoundError:
27+
import dask_expr as dx
2528

2629
from narwhals._dask.typing import IntoDaskExpr
2730
from narwhals.dtypes import DType
2831
from narwhals.utils import Version
2932

3033

31-
class DaskNamespace(CompliantNamespace["dask_expr.Series"]):
34+
class DaskNamespace(CompliantNamespace["dx.Series"]):
3235
@property
3336
def selectors(self) -> DaskSelectorNamespace:
3437
return DaskSelectorNamespace(
@@ -40,7 +43,7 @@ def __init__(self, *, backend_version: tuple[int, ...], version: Version) -> Non
4043
self._version = version
4144

4245
def all(self) -> DaskExpr:
43-
def func(df: DaskLazyFrame) -> list[dask_expr.Series]:
46+
def func(df: DaskLazyFrame) -> list[dx.Series]:
4447
return [df._native_frame[column_name] for column_name in df.columns]
4548

4649
return DaskExpr(
@@ -69,7 +72,7 @@ def lit(self, value: Any, dtype: DType | None) -> DaskExpr:
6972
import dask.dataframe as dd
7073
import pandas as pd
7174

72-
def func(df: DaskLazyFrame) -> list[dask_expr.Series]:
75+
def func(df: DaskLazyFrame) -> list[dx.Series]:
7376
return [
7477
dd.from_pandas(
7578
pd.Series(
@@ -99,7 +102,7 @@ def len(self) -> DaskExpr:
99102
import dask.dataframe as dd
100103
import pandas as pd
101104

102-
def func(df: DaskLazyFrame) -> list[dask_expr.Series]:
105+
def func(df: DaskLazyFrame) -> list[dx.Series]:
103106
if not df.columns:
104107
return [
105108
dd.from_pandas(
@@ -125,7 +128,7 @@ def func(df: DaskLazyFrame) -> list[dask_expr.Series]:
125128
def all_horizontal(self, *exprs: IntoDaskExpr) -> DaskExpr:
126129
parsed_exprs = parse_into_exprs(*exprs, namespace=self)
127130

128-
def func(df: DaskLazyFrame) -> list[dask_expr.Series]:
131+
def func(df: DaskLazyFrame) -> list[dx.Series]:
129132
series = [s for _expr in parsed_exprs for s in _expr(df)]
130133
return [reduce(lambda x, y: x & y, series).rename(series[0].name)]
131134

@@ -144,7 +147,7 @@ def func(df: DaskLazyFrame) -> list[dask_expr.Series]:
144147
def any_horizontal(self, *exprs: IntoDaskExpr) -> DaskExpr:
145148
parsed_exprs = parse_into_exprs(*exprs, namespace=self)
146149

147-
def func(df: DaskLazyFrame) -> list[dask_expr.Series]:
150+
def func(df: DaskLazyFrame) -> list[dx.Series]:
148151
series = [s for _expr in parsed_exprs for s in _expr(df)]
149152
return [reduce(lambda x, y: x | y, series).rename(series[0].name)]
150153

@@ -163,7 +166,7 @@ def func(df: DaskLazyFrame) -> list[dask_expr.Series]:
163166
def sum_horizontal(self, *exprs: IntoDaskExpr) -> DaskExpr:
164167
parsed_exprs = parse_into_exprs(*exprs, namespace=self)
165168

166-
def func(df: DaskLazyFrame) -> list[dask_expr.Series]:
169+
def func(df: DaskLazyFrame) -> list[dx.Series]:
167170
series = [s.fillna(0) for _expr in parsed_exprs for s in _expr(df)]
168171
return [reduce(lambda x, y: x + y, series).rename(series[0].name)]
169172

@@ -239,7 +242,7 @@ def concat(
239242
def mean_horizontal(self, *exprs: IntoDaskExpr) -> DaskExpr:
240243
parsed_exprs = parse_into_exprs(*exprs, namespace=self)
241244

242-
def func(df: DaskLazyFrame) -> list[dask_expr.Series]:
245+
def func(df: DaskLazyFrame) -> list[dx.Series]:
243246
series = (s.fillna(0) for _expr in parsed_exprs for s in _expr(df))
244247
non_na = (1 - s.isna() for _expr in parsed_exprs for s in _expr(df))
245248
return [
@@ -266,7 +269,7 @@ def min_horizontal(self, *exprs: IntoDaskExpr) -> DaskExpr:
266269

267270
parsed_exprs = parse_into_exprs(*exprs, namespace=self)
268271

269-
def func(df: DaskLazyFrame) -> list[dask_expr.Series]:
272+
def func(df: DaskLazyFrame) -> list[dx.Series]:
270273
series = [s for _expr in parsed_exprs for s in _expr(df)]
271274

272275
return [dd.concat(series, axis=1).min(axis=1).rename(series[0].name)]
@@ -288,7 +291,7 @@ def max_horizontal(self, *exprs: IntoDaskExpr) -> DaskExpr:
288291

289292
parsed_exprs = parse_into_exprs(*exprs, namespace=self)
290293

291-
def func(df: DaskLazyFrame) -> list[dask_expr.Series]:
294+
def func(df: DaskLazyFrame) -> list[dx.Series]:
292295
series = [s for _expr in parsed_exprs for s in _expr(df)]
293296

294297
return [dd.concat(series, axis=1).max(axis=1).rename(series[0].name)]
@@ -327,7 +330,7 @@ def concat_str(
327330
*parse_into_exprs(*more_exprs, namespace=self),
328331
]
329332

330-
def func(df: DaskLazyFrame) -> list[dask_expr.Series]:
333+
def func(df: DaskLazyFrame) -> list[dx.Series]:
331334
series = (s.astype(str) for _expr in parsed_exprs for s in _expr(df))
332335
null_mask = [s for _expr in parsed_exprs for s in _expr.is_null()(df)]
333336

@@ -389,20 +392,20 @@ def __init__(
389392
self._returns_scalar = returns_scalar
390393
self._version = version
391394

392-
def __call__(self, df: DaskLazyFrame) -> Sequence[dask_expr.Series]:
395+
def __call__(self, df: DaskLazyFrame) -> Sequence[dx.Series]:
393396
from narwhals._expression_parsing import parse_into_expr
394397

395398
plx = df.__narwhals_namespace__()
396399
condition = parse_into_expr(self._condition, namespace=plx)(df)[0]
397-
condition = cast("dask_expr.Series", condition)
400+
condition = cast("dx.Series", condition)
398401
try:
399402
value_series = parse_into_expr(self._then_value, namespace=plx)(df)[0]
400403
except TypeError:
401404
# `self._otherwise_value` is a scalar and can't be converted to an expression
402405
_df = condition.to_frame("a")
403406
_df["tmp"] = self._then_value
404407
value_series = _df["tmp"]
405-
value_series = cast("dask_expr.Series", value_series)
408+
value_series = cast("dx.Series", value_series)
406409
validate_comparand(condition, value_series)
407410

408411
if self._otherwise_value is None:

narwhals/_dask/selectors.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@
88
from narwhals.utils import import_dtypes_module
99

1010
if TYPE_CHECKING:
11-
import dask_expr
11+
try:
12+
import dask.dataframe.dask_expr as dx
13+
except ModuleNotFoundError:
14+
import dask_expr as dx
1215
from typing_extensions import Self
1316

1417
from narwhals._dask.dataframe import DaskLazyFrame
@@ -135,7 +138,7 @@ def call(df: DaskLazyFrame) -> list[Any]:
135138
def __or__(self: Self, other: DaskSelector | Any) -> DaskSelector | Any:
136139
if isinstance(other, DaskSelector):
137140

138-
def call(df: DaskLazyFrame) -> list[dask_expr.Series]:
141+
def call(df: DaskLazyFrame) -> list[dx.Series]:
139142
lhs = self._call(df)
140143
rhs = other._call(df)
141144
return [*(x for x in lhs if x.name not in {x.name for x in rhs}), *rhs]

0 commit comments

Comments
 (0)