10
10
from narwhals .utils import remove_prefix
11
11
12
12
if TYPE_CHECKING :
13
+ import dask .dataframe as dd
14
+
13
15
from narwhals ._dask .dataframe import DaskLazyFrame
14
16
from narwhals ._dask .expr import DaskExpr
15
17
from narwhals ._dask .typing import IntoDaskExpr
16
18
17
- POLARS_TO_PANDAS_AGGREGATIONS = {
19
+
20
+ def n_unique () -> dd .Aggregation :
21
+ import dask .dataframe as dd # ignore-banned-import
22
+
23
+ return dd .Aggregation (
24
+ name = "nunique" ,
25
+ chunk = lambda s : s .apply (lambda x : list (set (x ))),
26
+ agg = lambda s0 : s0 .obj .groupby (level = list (range (s0 .obj .index .nlevels ))).sum (),
27
+ finalize = lambda s1 : s1 .apply (lambda final : len (set (final ))),
28
+ )
29
+
30
+
31
+ POLARS_TO_DASK_AGGREGATIONS = {
18
32
"len" : "size" ,
19
- "n_unique" : "nunique" ,
33
+ "n_unique" : n_unique () ,
20
34
}
21
35
22
36
@@ -94,7 +108,7 @@ def agg_dask(
94
108
msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues"
95
109
raise AssertionError (msg )
96
110
97
- function_name = POLARS_TO_PANDAS_AGGREGATIONS .get (
111
+ function_name = POLARS_TO_DASK_AGGREGATIONS .get (
98
112
expr ._function_name , expr ._function_name
99
113
)
100
114
for output_name in expr ._output_names :
@@ -109,9 +123,7 @@ def agg_dask(
109
123
raise AssertionError (msg )
110
124
111
125
function_name = remove_prefix (expr ._function_name , "col->" )
112
- function_name = POLARS_TO_PANDAS_AGGREGATIONS .get (
113
- function_name , function_name
114
- )
126
+ function_name = POLARS_TO_DASK_AGGREGATIONS .get (function_name , function_name )
115
127
for root_name , output_name in zip (expr ._root_names , expr ._output_names ):
116
128
simple_aggregations [output_name ] = (root_name , function_name )
117
129
try :
0 commit comments