-
Notifications
You must be signed in to change notification settings - Fork 121
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
patch: group by n_unique
#917
Changes from 2 commits
1eeaa1c
5b07223
055334a
3906c12
b874516
8a320d3
d37d2a3
35a6226
3774fe0
250f949
0c7d071
963b5df
76bffde
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,12 +10,27 @@ | |
from narwhals.utils import remove_prefix | ||
|
||
if TYPE_CHECKING: | ||
import dask.dataframe as dd | ||
|
||
from narwhals._dask.dataframe import DaskLazyFrame | ||
from narwhals._dask.expr import DaskExpr | ||
from narwhals._dask.typing import IntoDaskExpr | ||
|
||
POLARS_TO_PANDAS_AGGREGATIONS = { | ||
|
||
def n_unique() -> dd.Aggregation: | ||
import dask.dataframe as dd # ignore-banned-import | ||
|
||
return dd.Aggregation( | ||
name="nunique", | ||
chunk=lambda s: s.apply(lambda x: list(set(x))), | ||
agg=lambda s0: s0.obj.groupby(level=list(range(s0.obj.index.nlevels))).sum(), | ||
finalize=lambda s1: s1.apply(lambda final: len(set(final))), | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This comes from the dask documentation itself. df_dd.groupby("a").b.nunique() but not in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thanks for digging this up in the dask docs! I think this is way more complex than it needs to be - i've pushed something much simpler which seems to work |
||
|
||
|
||
POLARS_TO_DASK_AGGREGATIONS = { | ||
"len": "size", | ||
"n_unique": n_unique(), | ||
} | ||
|
||
|
||
|
@@ -93,7 +108,7 @@ def agg_dask( | |
msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues" | ||
raise AssertionError(msg) | ||
|
||
function_name = POLARS_TO_PANDAS_AGGREGATIONS.get( | ||
function_name = POLARS_TO_DASK_AGGREGATIONS.get( | ||
expr._function_name, expr._function_name | ||
) | ||
for output_name in expr._output_names: | ||
|
@@ -108,9 +123,7 @@ def agg_dask( | |
raise AssertionError(msg) | ||
|
||
function_name = remove_prefix(expr._function_name, "col->") | ||
function_name = POLARS_TO_PANDAS_AGGREGATIONS.get( | ||
function_name, function_name | ||
) | ||
function_name = POLARS_TO_DASK_AGGREGATIONS.get(function_name, function_name) | ||
for root_name, output_name in zip(expr._root_names, expr._output_names): | ||
simple_aggregations[output_name] = (root_name, function_name) | ||
try: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,6 +21,7 @@ | |
|
||
POLARS_TO_PANDAS_AGGREGATIONS = { | ||
"len": "size", | ||
"n_unique": "nunique", | ||
} | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Found these other two cases while reading pyarrow grouped table aggregations (https://arrow.apache.org/docs/python/compute.html#py-grouped-aggrs)