diff --git a/narwhals/_arrow/group_by.py b/narwhals/_arrow/group_by.py index 27c7ff3688..78b241c9b0 100644 --- a/narwhals/_arrow/group_by.py +++ b/narwhals/_arrow/group_by.py @@ -15,6 +15,12 @@ from narwhals._arrow.expr import ArrowExpr from narwhals._arrow.typing import IntoArrowExpr +POLARS_TO_ARROW_AGGREGATIONS = { + "n_unique": "count_distinct", + "std": "stddev", + "var": "variance", # currently unused, we don't have `var` yet +} + class ArrowGroupBy: def __init__(self, df: ArrowDataFrame, keys: list[str]) -> None: @@ -112,16 +118,27 @@ def agg_arrow( raise AssertionError(msg) function_name = remove_prefix(expr._function_name, "col->") + function_name = POLARS_TO_ARROW_AGGREGATIONS.get(function_name, function_name) for root_name, output_name in zip(expr._root_names, expr._output_names): - if function_name != "len": + if function_name == "len": simple_aggregations[output_name] = ( - (root_name, function_name), - f"{root_name}_{function_name}", + (root_name, "count", pc.CountOptions(mode="all")), + f"{root_name}_count", + ) + elif function_name == "count_distinct": + simple_aggregations[output_name] = ( + (root_name, "count_distinct", pc.CountOptions(mode="all")), + f"{root_name}_count_distinct", + ) + elif function_name == "stddev": + simple_aggregations[output_name] = ( + (root_name, "stddev", pc.VarianceOptions(ddof=1)), + f"{root_name}_stddev", ) else: simple_aggregations[output_name] = ( - (root_name, "count", pc.CountOptions(mode="all")), - f"{root_name}_count", + (root_name, function_name), + f"{root_name}_{function_name}", ) aggs: list[Any] = [] diff --git a/narwhals/_dask/group_by.py b/narwhals/_dask/group_by.py index 8538c62d22..463d6fc582 100644 --- a/narwhals/_dask/group_by.py +++ b/narwhals/_dask/group_by.py @@ -10,12 +10,33 @@ from narwhals.utils import remove_prefix if TYPE_CHECKING: + import dask.dataframe as dd + import pandas as pd + from narwhals._dask.dataframe import DaskLazyFrame from narwhals._dask.expr import DaskExpr from narwhals._dask.typing import IntoDaskExpr -POLARS_TO_PANDAS_AGGREGATIONS = { + +def n_unique() -> dd.Aggregation: + import dask.dataframe as dd # ignore-banned-import + + def chunk(s: pd.core.groupby.generic.SeriesGroupBy) -> int: + return s.nunique(dropna=False) # type: ignore[no-any-return] + + def agg(s0: pd.core.groupby.generic.SeriesGroupBy) -> int: + return s0.sum() # type: ignore[no-any-return] + + return dd.Aggregation( + name="nunique", + chunk=chunk, + agg=agg, + ) + + +POLARS_TO_DASK_AGGREGATIONS = { "len": "size", + "n_unique": n_unique, } @@ -85,7 +106,7 @@ def agg_dask( break if all_simple_aggs: - simple_aggregations: dict[str, tuple[str, str]] = {} + simple_aggregations: dict[str, tuple[str, str | dd.Aggregation]] = {} for expr in exprs: if expr._depth == 0: # e.g. agg(nw.len()) # noqa: ERA001 @@ -93,7 +114,7 @@ def agg_dask( msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues" raise AssertionError(msg) - function_name = POLARS_TO_PANDAS_AGGREGATIONS.get( + function_name = POLARS_TO_DASK_AGGREGATIONS.get( expr._function_name, expr._function_name ) for output_name in expr._output_names: @@ -108,9 +129,11 @@ def agg_dask( raise AssertionError(msg) function_name = remove_prefix(expr._function_name, "col->") - function_name = POLARS_TO_PANDAS_AGGREGATIONS.get( - function_name, function_name - ) + function_name = POLARS_TO_DASK_AGGREGATIONS.get(function_name, function_name) + + # deal with n_unique case in a "lazy" mode to not depend on dask globally + function_name = function_name() if callable(function_name) else function_name + for root_name, output_name in zip(expr._root_names, expr._output_names): simple_aggregations[output_name] = (root_name, function_name) try: diff --git a/narwhals/_pandas_like/group_by.py b/narwhals/_pandas_like/group_by.py index 11abc85c8b..97a477dc47 100644 --- a/narwhals/_pandas_like/group_by.py +++ b/narwhals/_pandas_like/group_by.py @@ -21,6 +21,7 @@ POLARS_TO_PANDAS_AGGREGATIONS = { "len": "size", + "n_unique": "nunique", } @@ -103,7 +104,7 @@ def __iter__(self) -> Iterator[tuple[Any, PandasLikeDataFrame]]: yield from ((key, self._from_native_frame(sub_df)) for (key, sub_df) in iterator) -def agg_pandas( +def agg_pandas( # noqa: PLR0915 grouped: Any, exprs: list[PandasLikeExpr], keys: list[str], @@ -120,13 +121,18 @@ def agg_pandas( - https://github.com/rapidsai/cudf/issues/15118 - https://github.com/rapidsai/cudf/issues/15084 """ - all_simple_aggs = True + all_aggs_are_simple = True for expr in exprs: if not is_simple_aggregation(expr): - all_simple_aggs = False + all_aggs_are_simple = False break - if all_simple_aggs: + # dict of {output_name: root_name} that we count n_unique on + # We need to do this separately from the rest so that we + # can pass the `dropna` kwargs. + nunique_aggs: dict[str, str] = {} + + if all_aggs_are_simple: simple_aggregations: dict[str, tuple[str, str]] = {} for expr in exprs: if expr._depth == 0: @@ -154,21 +160,54 @@ def agg_pandas( function_name, function_name ) for root_name, output_name in zip(expr._root_names, expr._output_names): - simple_aggregations[output_name] = (root_name, function_name) + if function_name == "nunique": + nunique_aggs[output_name] = root_name + else: + simple_aggregations[output_name] = (root_name, function_name) - aggs = collections.defaultdict(list) + simple_aggs = collections.defaultdict(list) name_mapping = {} for output_name, named_agg in simple_aggregations.items(): - aggs[named_agg[0]].append(named_agg[1]) + simple_aggs[named_agg[0]].append(named_agg[1]) name_mapping[f"{named_agg[0]}_{named_agg[1]}"] = output_name - try: - result_simple = grouped.agg(aggs) - except AttributeError as exc: - msg = "Failed to aggregated - does your aggregation function return a scalar?" - raise RuntimeError(msg) from exc - result_simple.columns = [f"{a}_{b}" for a, b in result_simple.columns] - result_simple = result_simple.rename(columns=name_mapping).reset_index() - return from_dataframe(result_simple.loc[:, output_names]) + if simple_aggs: + try: + result_simple_aggs = grouped.agg(simple_aggs) + except AttributeError as exc: + msg = "Failed to aggregated - does your aggregation function return a scalar?" + raise RuntimeError(msg) from exc + result_simple_aggs.columns = [ + f"{a}_{b}" for a, b in result_simple_aggs.columns + ] + result_simple_aggs = result_simple_aggs.rename( + columns=name_mapping + ).reset_index() + if nunique_aggs: + result_nunique_aggs = grouped[list(nunique_aggs.values())].nunique( + dropna=False + ) + result_nunique_aggs.columns = list(nunique_aggs.keys()) + result_nunique_aggs = result_nunique_aggs.reset_index() + if simple_aggs and nunique_aggs: + if ( + set(result_simple_aggs.columns) + .difference(keys) + .intersection(result_nunique_aggs.columns) + ): + msg = ( + "Got two aggregations with the same output name. Please make sure " + "that aggregations have unique output names." + ) + raise ValueError(msg) + result_aggs = result_simple_aggs.merge(result_nunique_aggs, on=keys) + elif nunique_aggs and not simple_aggs: + result_aggs = result_nunique_aggs + elif simple_aggs and not nunique_aggs: + result_aggs = result_simple_aggs + else: # pragma: no cover + msg = "Congrats, you entered unreachable code. Please report a bug to https://github.com/narwhals-dev/narwhals/issues." + raise RuntimeError(msg) + return from_dataframe(result_aggs.loc[:, output_names]) if dataframe_is_empty: # Don't even attempt this, it's way too inconsistent across pandas versions. diff --git a/tests/test_group_by.py b/tests/test_group_by.py index 2bb8d435b4..4bd3427a53 100644 --- a/tests/test_group_by.py +++ b/tests/test_group_by.py @@ -102,6 +102,57 @@ def test_group_by_len(constructor: Any) -> None: compare_dicts(result, expected) +def test_group_by_n_unique(constructor: Any) -> None: + result = ( + nw.from_native(constructor(data)) + .group_by("a") + .agg(nw.col("b").n_unique()) + .sort("a") + ) + expected = {"a": [1, 3], "b": [1, 1]} + compare_dicts(result, expected) + + +def test_group_by_std(constructor: Any) -> None: + data = {"a": [1, 1, 2, 2], "b": [5, 4, 3, 2]} + result = ( + nw.from_native(constructor(data)).group_by("a").agg(nw.col("b").std()).sort("a") + ) + expected = {"a": [1, 2], "b": [0.707107] * 2} + compare_dicts(result, expected) + + +def test_group_by_n_unique_w_missing(constructor: Any) -> None: + data = {"a": [1, 1, 2], "b": [4, None, 5], "c": [None, None, 7], "d": [1, 1, 3]} + result = ( + nw.from_native(constructor(data)) + .group_by("a") + .agg( + nw.col("b").n_unique(), + c_n_unique=nw.col("c").n_unique(), + c_n_min=nw.col("b").min(), + d_n_unique=nw.col("d").n_unique(), + ) + .sort("a") + ) + expected = { + "a": [1, 2], + "b": [2, 1], + "c_n_unique": [1, 1], + "c_n_min": [4, 5], + "d_n_unique": [1, 1], + } + compare_dicts(result, expected) + + +def test_group_by_same_name_twice() -> None: + import pandas as pd + + df = pd.DataFrame({"a": [1, 1, 2], "b": [4, 5, 6]}) + with pytest.raises(ValueError, match="two aggregations with the same"): + nw.from_native(df).group_by("a").agg(nw.col("b").sum(), nw.col("b").n_unique()) + + def test_group_by_empty_result_pandas() -> None: df_any = pd.DataFrame({"a": [1, 2, 3], "b": [4, 3, 2]}) df = nw.from_native(df_any, eager_only=True)