23
23
from narwhals .utils import import_dtypes_module
24
24
25
25
if TYPE_CHECKING :
26
- import dask_expr
26
+ try :
27
+ import dask .dataframe .dask_expr as dx
28
+ except ModuleNotFoundError :
29
+ import dask_expr as dx
30
+
27
31
from typing_extensions import Self
28
32
29
33
from narwhals ._dask .dataframe import DaskLazyFrame
32
36
from narwhals .utils import Version
33
37
34
38
35
- class DaskExpr (CompliantExpr ["dask_expr .Series" ]):
39
+ class DaskExpr (CompliantExpr ["dx .Series" ]):
36
40
_implementation : Implementation = Implementation .DASK
37
41
38
42
def __init__ (
39
43
self ,
40
- call : Callable [[DaskLazyFrame ], Sequence [dask_expr .Series ]],
44
+ call : Callable [[DaskLazyFrame ], Sequence [dx .Series ]],
41
45
* ,
42
46
depth : int ,
43
47
function_name : str ,
@@ -60,7 +64,7 @@ def __init__(
60
64
self ._version = version
61
65
self ._kwargs = kwargs
62
66
63
- def __call__ (self , df : DaskLazyFrame ) -> Sequence [dask_expr .Series ]:
67
+ def __call__ (self , df : DaskLazyFrame ) -> Sequence [dx .Series ]:
64
68
return self ._call (df )
65
69
66
70
def __narwhals_expr__ (self ) -> None : ...
@@ -78,7 +82,7 @@ def from_column_names(
78
82
backend_version : tuple [int , ...],
79
83
version : Version ,
80
84
) -> Self :
81
- def func (df : DaskLazyFrame ) -> list [dask_expr .Series ]:
85
+ def func (df : DaskLazyFrame ) -> list [dx .Series ]:
82
86
try :
83
87
return [df ._native_frame [column_name ] for column_name in column_names ]
84
88
except KeyError as e :
@@ -107,7 +111,7 @@ def from_column_indices(
107
111
backend_version : tuple [int , ...],
108
112
version : Version ,
109
113
) -> Self :
110
- def func (df : DaskLazyFrame ) -> list [dask_expr .Series ]:
114
+ def func (df : DaskLazyFrame ) -> list [dx .Series ]:
111
115
return [
112
116
df ._native_frame .iloc [:, column_index ] for column_index in column_indices
113
117
]
@@ -126,14 +130,14 @@ def func(df: DaskLazyFrame) -> list[dask_expr.Series]:
126
130
127
131
def _from_call (
128
132
self ,
129
- # First argument to `call` should be `dask_expr .Series`
130
- call : Callable [..., dask_expr .Series ],
133
+ # First argument to `call` should be `dx .Series`
134
+ call : Callable [..., dx .Series ],
131
135
expr_name : str ,
132
136
* ,
133
137
returns_scalar : bool ,
134
138
** kwargs : Any ,
135
139
) -> Self :
136
- def func (df : DaskLazyFrame ) -> list [dask_expr .Series ]:
140
+ def func (df : DaskLazyFrame ) -> list [dx .Series ]:
137
141
results = []
138
142
inputs = self ._call (df )
139
143
_kwargs = {key : maybe_evaluate (df , value ) for key , value in kwargs .items ()}
@@ -163,7 +167,7 @@ def func(df: DaskLazyFrame) -> list[dask_expr.Series]:
163
167
)
164
168
165
169
def alias (self , name : str ) -> Self :
166
- def func (df : DaskLazyFrame ) -> list [dask_expr .Series ]:
170
+ def func (df : DaskLazyFrame ) -> list [dx .Series ]:
167
171
inputs = self ._call (df )
168
172
return [_input .rename (name ) for _input in inputs ]
169
173
@@ -312,7 +316,7 @@ def mean(self) -> Self:
312
316
def median (self ) -> Self :
313
317
from narwhals .exceptions import InvalidOperationError
314
318
315
- def func (s : dask_expr .Series ) -> dask_expr .Series :
319
+ def func (s : dx .Series ) -> dx .Series :
316
320
dtype = native_to_narwhals_dtype (s , self ._version , Implementation .DASK )
317
321
if not dtype .is_numeric ():
318
322
msg = "`median` operation not supported for non-numeric input type."
@@ -511,11 +515,11 @@ def fill_null(
511
515
limit : int | None = None ,
512
516
) -> DaskExpr :
513
517
def func (
514
- _input : dask_expr .Series ,
518
+ _input : dx .Series ,
515
519
value : Any | None ,
516
520
strategy : str | None ,
517
521
limit : int | None ,
518
- ) -> dask_expr .Series :
522
+ ) -> dx .Series :
519
523
if value is not None :
520
524
res_ser = _input .fillna (value )
521
525
else :
@@ -566,7 +570,7 @@ def is_null(self: Self) -> Self:
566
570
)
567
571
568
572
def is_nan (self : Self ) -> Self :
569
- def func (_input : dask_expr .Series ) -> dask_expr .Series :
573
+ def func (_input : dx .Series ) -> dx .Series :
570
574
dtype = native_to_narwhals_dtype (_input , self ._version , self ._implementation )
571
575
if dtype .is_numeric ():
572
576
return _input != _input # noqa: PLR0124
@@ -585,7 +589,7 @@ def quantile(
585
589
) -> Self :
586
590
if interpolation == "linear" :
587
591
588
- def func (_input : dask_expr .Series , quantile : float ) -> dask_expr .Series :
592
+ def func (_input : dx .Series , quantile : float ) -> dx .Series :
589
593
if _input .npartitions > 1 :
590
594
msg = "`Expr.quantile` is not supported for Dask backend with multiple partitions."
591
595
raise NotImplementedError (msg )
@@ -599,7 +603,7 @@ def func(_input: dask_expr.Series, quantile: float) -> dask_expr.Series:
599
603
raise NotImplementedError (msg )
600
604
601
605
def is_first_distinct (self : Self ) -> Self :
602
- def func (_input : dask_expr .Series ) -> dask_expr .Series :
606
+ def func (_input : dx .Series ) -> dx .Series :
603
607
_name = _input .name
604
608
col_token = generate_temporary_column_name (n_bytes = 8 , columns = [_name ])
605
609
_input = add_row_index (
@@ -618,7 +622,7 @@ def func(_input: dask_expr.Series) -> dask_expr.Series:
618
622
)
619
623
620
624
def is_last_distinct (self : Self ) -> Self :
621
- def func (_input : dask_expr .Series ) -> dask_expr .Series :
625
+ def func (_input : dx .Series ) -> dx .Series :
622
626
_name = _input .name
623
627
col_token = generate_temporary_column_name (n_bytes = 8 , columns = [_name ])
624
628
_input = add_row_index (
@@ -635,7 +639,7 @@ def func(_input: dask_expr.Series) -> dask_expr.Series:
635
639
)
636
640
637
641
def is_duplicated (self : Self ) -> Self :
638
- def func (_input : dask_expr .Series ) -> dask_expr .Series :
642
+ def func (_input : dx .Series ) -> dx .Series :
639
643
_name = _input .name
640
644
return (
641
645
_input .to_frame ()
@@ -647,7 +651,7 @@ def func(_input: dask_expr.Series) -> dask_expr.Series:
647
651
return self ._from_call (func , "is_duplicated" , returns_scalar = self ._returns_scalar )
648
652
649
653
def is_unique (self : Self ) -> Self :
650
- def func (_input : dask_expr .Series ) -> dask_expr .Series :
654
+ def func (_input : dx .Series ) -> dx .Series :
651
655
_name = _input .name
652
656
return (
653
657
_input .to_frame ()
@@ -967,7 +971,7 @@ def replace_time_zone(self, time_zone: str | None) -> DaskExpr:
967
971
)
968
972
969
973
def convert_time_zone (self , time_zone : str ) -> DaskExpr :
970
- def func (s : dask_expr .Series , time_zone : str ) -> dask_expr .Series :
974
+ def func (s : dx .Series , time_zone : str ) -> dx .Series :
971
975
dtype = native_to_narwhals_dtype (
972
976
s , self ._compliant_expr ._version , Implementation .DASK
973
977
)
@@ -984,9 +988,7 @@ def func(s: dask_expr.Series, time_zone: str) -> dask_expr.Series:
984
988
)
985
989
986
990
def timestamp (self , time_unit : Literal ["ns" , "us" , "ms" ] = "us" ) -> DaskExpr :
987
- def func (
988
- s : dask_expr .Series , time_unit : Literal ["ns" , "us" , "ms" ] = "us"
989
- ) -> dask_expr .Series :
991
+ def func (s : dx .Series , time_unit : Literal ["ns" , "us" , "ms" ] = "us" ) -> dx .Series :
990
992
dtype = native_to_narwhals_dtype (
991
993
s , self ._compliant_expr ._version , Implementation .DASK
992
994
)
0 commit comments