Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

patch: group by n_unique #917

Merged
merged 13 commits into from
Sep 6, 2024
7 changes: 7 additions & 0 deletions narwhals/_arrow/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@
from narwhals._arrow.expr import ArrowExpr
from narwhals._arrow.typing import IntoArrowExpr

POLARS_TO_ARROW_AGGREGATIONS = {
"n_unique": "count_distinct",
"std": "stddev",
"var": "variance",
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Found these other two cases while reading pyarrow grouped table aggregations (https://arrow.apache.org/docs/python/compute.html#py-grouped-aggrs)

}


class ArrowGroupBy:
def __init__(self, df: ArrowDataFrame, keys: list[str]) -> None:
Expand Down Expand Up @@ -112,6 +118,7 @@ def agg_arrow(
raise AssertionError(msg)

function_name = remove_prefix(expr._function_name, "col->")
function_name = POLARS_TO_ARROW_AGGREGATIONS.get(function_name, function_name)
for root_name, output_name in zip(expr._root_names, expr._output_names):
if function_name != "len":
simple_aggregations[output_name] = (
Expand Down
23 changes: 18 additions & 5 deletions narwhals/_dask/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,27 @@
from narwhals.utils import remove_prefix

if TYPE_CHECKING:
import dask.dataframe as dd

from narwhals._dask.dataframe import DaskLazyFrame
from narwhals._dask.expr import DaskExpr
from narwhals._dask.typing import IntoDaskExpr

POLARS_TO_PANDAS_AGGREGATIONS = {

def n_unique() -> dd.Aggregation:
import dask.dataframe as dd # ignore-banned-import

return dd.Aggregation(
name="nunique",
chunk=lambda s: s.apply(lambda x: list(set(x))),
agg=lambda s0: s0.obj.groupby(level=list(range(s0.obj.index.nlevels))).sum(),
finalize=lambda s1: s1.apply(lambda final: len(set(final))),
)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comes from the dask documentation itself.
For some reason, nunique kw only works in

df_dd.groupby("a").b.nunique()

but not in agg context

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for digging this up in the dask docs!

I think this is way more complex than it needs to be - i've pushed something much simpler which seems to work



POLARS_TO_DASK_AGGREGATIONS = {
"len": "size",
"n_unique": n_unique(),
}


Expand Down Expand Up @@ -93,7 +108,7 @@ def agg_dask(
msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues"
raise AssertionError(msg)

function_name = POLARS_TO_PANDAS_AGGREGATIONS.get(
function_name = POLARS_TO_DASK_AGGREGATIONS.get(
expr._function_name, expr._function_name
)
for output_name in expr._output_names:
Expand All @@ -108,9 +123,7 @@ def agg_dask(
raise AssertionError(msg)

function_name = remove_prefix(expr._function_name, "col->")
function_name = POLARS_TO_PANDAS_AGGREGATIONS.get(
function_name, function_name
)
function_name = POLARS_TO_DASK_AGGREGATIONS.get(function_name, function_name)
for root_name, output_name in zip(expr._root_names, expr._output_names):
simple_aggregations[output_name] = (root_name, function_name)
try:
Expand Down
1 change: 1 addition & 0 deletions narwhals/_pandas_like/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

POLARS_TO_PANDAS_AGGREGATIONS = {
"len": "size",
"n_unique": "nunique",
}


Expand Down
11 changes: 11 additions & 0 deletions tests/test_group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,17 @@ def test_group_by_len(constructor: Any) -> None:
compare_dicts(result, expected)


def test_group_by_n_unique(constructor: Any) -> None:
result = (
nw.from_native(constructor(data))
.group_by("a")
.agg(nw.col("b").n_unique())
.sort("a")
)
expected = {"a": [1, 3], "b": [1, 1]}
compare_dicts(result, expected)


def test_group_by_empty_result_pandas() -> None:
df_any = pd.DataFrame({"a": [1, 2, 3], "b": [4, 3, 2]})
df = nw.from_native(df_any, eager_only=True)
Expand Down
Loading