Skip to content

Commit 5b07223

Browse files
committed
dask
1 parent 1eeaa1c commit 5b07223

File tree

1 file changed

+18
-6
lines changed

1 file changed

+18
-6
lines changed

narwhals/_dask/group_by.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,27 @@
1010
from narwhals.utils import remove_prefix
1111

1212
if TYPE_CHECKING:
13+
import dask.dataframe as dd
14+
1315
from narwhals._dask.dataframe import DaskLazyFrame
1416
from narwhals._dask.expr import DaskExpr
1517
from narwhals._dask.typing import IntoDaskExpr
1618

17-
POLARS_TO_PANDAS_AGGREGATIONS = {
19+
20+
def n_unique() -> dd.Aggregation:
21+
import dask.dataframe as dd # ignore-banned-import
22+
23+
return dd.Aggregation(
24+
name="nunique",
25+
chunk=lambda s: s.apply(lambda x: list(set(x))),
26+
agg=lambda s0: s0.obj.groupby(level=list(range(s0.obj.index.nlevels))).sum(),
27+
finalize=lambda s1: s1.apply(lambda final: len(set(final))),
28+
)
29+
30+
31+
POLARS_TO_DASK_AGGREGATIONS = {
1832
"len": "size",
19-
"n_unique": "nunique",
33+
"n_unique": n_unique(),
2034
}
2135

2236

@@ -94,7 +108,7 @@ def agg_dask(
94108
msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues"
95109
raise AssertionError(msg)
96110

97-
function_name = POLARS_TO_PANDAS_AGGREGATIONS.get(
111+
function_name = POLARS_TO_DASK_AGGREGATIONS.get(
98112
expr._function_name, expr._function_name
99113
)
100114
for output_name in expr._output_names:
@@ -109,9 +123,7 @@ def agg_dask(
109123
raise AssertionError(msg)
110124

111125
function_name = remove_prefix(expr._function_name, "col->")
112-
function_name = POLARS_TO_PANDAS_AGGREGATIONS.get(
113-
function_name, function_name
114-
)
126+
function_name = POLARS_TO_DASK_AGGREGATIONS.get(function_name, function_name)
115127
for root_name, output_name in zip(expr._root_names, expr._output_names):
116128
simple_aggregations[output_name] = (root_name, function_name)
117129
try:

0 commit comments

Comments
 (0)