Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

patch: group by n_unique #917

Merged
merged 13 commits into from
Sep 6, 2024
27 changes: 22 additions & 5 deletions narwhals/_arrow/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@
from narwhals._arrow.expr import ArrowExpr
from narwhals._arrow.typing import IntoArrowExpr

POLARS_TO_ARROW_AGGREGATIONS = {
"n_unique": "count_distinct",
"std": "stddev",
"var": "variance", # currently unused, we don't have `var` yet
}


class ArrowGroupBy:
def __init__(self, df: ArrowDataFrame, keys: list[str]) -> None:
Expand Down Expand Up @@ -112,16 +118,27 @@ def agg_arrow(
raise AssertionError(msg)

function_name = remove_prefix(expr._function_name, "col->")
function_name = POLARS_TO_ARROW_AGGREGATIONS.get(function_name, function_name)
for root_name, output_name in zip(expr._root_names, expr._output_names):
if function_name != "len":
if function_name == "len":
simple_aggregations[output_name] = (
(root_name, function_name),
f"{root_name}_{function_name}",
(root_name, "count", pc.CountOptions(mode="all")),
f"{root_name}_count",
)
elif function_name == "count_distinct":
simple_aggregations[output_name] = (
(root_name, "count_distinct", pc.CountOptions(mode="all")),
f"{root_name}_count_distinct",
)
elif function_name == "stddev":
simple_aggregations[output_name] = (
(root_name, "stddev", pc.VarianceOptions(ddof=1)),
f"{root_name}_stddev",
)
else:
simple_aggregations[output_name] = (
(root_name, "count", pc.CountOptions(mode="all")),
f"{root_name}_count",
(root_name, function_name),
f"{root_name}_{function_name}",
)

aggs: list[Any] = []
Expand Down
35 changes: 29 additions & 6 deletions narwhals/_dask/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,33 @@
from narwhals.utils import remove_prefix

if TYPE_CHECKING:
import dask.dataframe as dd
import pandas as pd

from narwhals._dask.dataframe import DaskLazyFrame
from narwhals._dask.expr import DaskExpr
from narwhals._dask.typing import IntoDaskExpr

POLARS_TO_PANDAS_AGGREGATIONS = {

def n_unique() -> dd.Aggregation:
import dask.dataframe as dd # ignore-banned-import

def chunk(s: pd.core.groupby.generic.SeriesGroupBy) -> int:
return s.nunique(dropna=False) # type: ignore[no-any-return]

def agg(s0: pd.core.groupby.generic.SeriesGroupBy) -> int:
return s0.sum() # type: ignore[no-any-return]

return dd.Aggregation(
name="nunique",
chunk=chunk,
agg=agg,
)


POLARS_TO_DASK_AGGREGATIONS = {
"len": "size",
"n_unique": n_unique,
}


Expand Down Expand Up @@ -85,15 +106,15 @@ def agg_dask(
break

if all_simple_aggs:
simple_aggregations: dict[str, tuple[str, str]] = {}
simple_aggregations: dict[str, tuple[str, str | dd.Aggregation]] = {}
for expr in exprs:
if expr._depth == 0:
# e.g. agg(nw.len()) # noqa: ERA001
if expr._output_names is None: # pragma: no cover
msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues"
raise AssertionError(msg)

function_name = POLARS_TO_PANDAS_AGGREGATIONS.get(
function_name = POLARS_TO_DASK_AGGREGATIONS.get(
expr._function_name, expr._function_name
)
for output_name in expr._output_names:
Expand All @@ -108,9 +129,11 @@ def agg_dask(
raise AssertionError(msg)

function_name = remove_prefix(expr._function_name, "col->")
function_name = POLARS_TO_PANDAS_AGGREGATIONS.get(
function_name, function_name
)
function_name = POLARS_TO_DASK_AGGREGATIONS.get(function_name, function_name)

# deal with n_unique case in a "lazy" mode to not depend on dask globally
function_name = function_name() if callable(function_name) else function_name

for root_name, output_name in zip(expr._root_names, expr._output_names):
simple_aggregations[output_name] = (root_name, function_name)
try:
Expand Down
69 changes: 54 additions & 15 deletions narwhals/_pandas_like/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

POLARS_TO_PANDAS_AGGREGATIONS = {
"len": "size",
"n_unique": "nunique",
}


Expand Down Expand Up @@ -103,7 +104,7 @@ def __iter__(self) -> Iterator[tuple[Any, PandasLikeDataFrame]]:
yield from ((key, self._from_native_frame(sub_df)) for (key, sub_df) in iterator)


def agg_pandas(
def agg_pandas( # noqa: PLR0915
grouped: Any,
exprs: list[PandasLikeExpr],
keys: list[str],
Expand All @@ -120,13 +121,18 @@ def agg_pandas(
- https://github.com/rapidsai/cudf/issues/15118
- https://github.com/rapidsai/cudf/issues/15084
"""
all_simple_aggs = True
all_aggs_are_simple = True
for expr in exprs:
if not is_simple_aggregation(expr):
all_simple_aggs = False
all_aggs_are_simple = False
break

if all_simple_aggs:
# dict of {output_name: root_name} that we count n_unique on
# We need to do this separately from the rest so that we
# can pass the `dropna` kwargs.
nunique_aggs: dict[str, str] = {}

if all_aggs_are_simple:
simple_aggregations: dict[str, tuple[str, str]] = {}
for expr in exprs:
if expr._depth == 0:
Expand Down Expand Up @@ -154,21 +160,54 @@ def agg_pandas(
function_name, function_name
)
for root_name, output_name in zip(expr._root_names, expr._output_names):
simple_aggregations[output_name] = (root_name, function_name)
if function_name == "nunique":
nunique_aggs[output_name] = root_name
else:
simple_aggregations[output_name] = (root_name, function_name)

aggs = collections.defaultdict(list)
simple_aggs = collections.defaultdict(list)
name_mapping = {}
for output_name, named_agg in simple_aggregations.items():
aggs[named_agg[0]].append(named_agg[1])
simple_aggs[named_agg[0]].append(named_agg[1])
name_mapping[f"{named_agg[0]}_{named_agg[1]}"] = output_name
try:
result_simple = grouped.agg(aggs)
except AttributeError as exc:
msg = "Failed to aggregated - does your aggregation function return a scalar?"
raise RuntimeError(msg) from exc
result_simple.columns = [f"{a}_{b}" for a, b in result_simple.columns]
result_simple = result_simple.rename(columns=name_mapping).reset_index()
return from_dataframe(result_simple.loc[:, output_names])
if simple_aggs:
try:
result_simple_aggs = grouped.agg(simple_aggs)
except AttributeError as exc:
msg = "Failed to aggregated - does your aggregation function return a scalar?"
raise RuntimeError(msg) from exc
result_simple_aggs.columns = [
f"{a}_{b}" for a, b in result_simple_aggs.columns
]
result_simple_aggs = result_simple_aggs.rename(
columns=name_mapping
).reset_index()
if nunique_aggs:
result_nunique_aggs = grouped[list(nunique_aggs.values())].nunique(
dropna=False
)
result_nunique_aggs.columns = list(nunique_aggs.keys())
result_nunique_aggs = result_nunique_aggs.reset_index()
if simple_aggs and nunique_aggs:
if (
set(result_simple_aggs.columns)
.difference(keys)
.intersection(result_nunique_aggs.columns)
):
msg = (
"Got two aggregations with the same output name. Please make sure "
"that aggregations have unique output names."
)
raise ValueError(msg)
result_aggs = result_simple_aggs.merge(result_nunique_aggs, on=keys)
elif nunique_aggs and not simple_aggs:
result_aggs = result_nunique_aggs
elif simple_aggs and not nunique_aggs:
result_aggs = result_simple_aggs
else: # pragma: no cover
msg = "Congrats, you entered unreachable code. Please report a bug to https://github.com/narwhals-dev/narwhals/issues."
raise RuntimeError(msg)
return from_dataframe(result_aggs.loc[:, output_names])

if dataframe_is_empty:
# Don't even attempt this, it's way too inconsistent across pandas versions.
Expand Down
51 changes: 51 additions & 0 deletions tests/test_group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,57 @@ def test_group_by_len(constructor: Any) -> None:
compare_dicts(result, expected)


def test_group_by_n_unique(constructor: Any) -> None:
result = (
nw.from_native(constructor(data))
.group_by("a")
.agg(nw.col("b").n_unique())
.sort("a")
)
expected = {"a": [1, 3], "b": [1, 1]}
compare_dicts(result, expected)


def test_group_by_std(constructor: Any) -> None:
data = {"a": [1, 1, 2, 2], "b": [5, 4, 3, 2]}
result = (
nw.from_native(constructor(data)).group_by("a").agg(nw.col("b").std()).sort("a")
)
expected = {"a": [1, 2], "b": [0.707107] * 2}
compare_dicts(result, expected)


def test_group_by_n_unique_w_missing(constructor: Any) -> None:
data = {"a": [1, 1, 2], "b": [4, None, 5], "c": [None, None, 7], "d": [1, 1, 3]}
result = (
nw.from_native(constructor(data))
.group_by("a")
.agg(
nw.col("b").n_unique(),
c_n_unique=nw.col("c").n_unique(),
c_n_min=nw.col("b").min(),
d_n_unique=nw.col("d").n_unique(),
)
.sort("a")
)
expected = {
"a": [1, 2],
"b": [2, 1],
"c_n_unique": [1, 1],
"c_n_min": [4, 5],
"d_n_unique": [1, 1],
}
compare_dicts(result, expected)


def test_group_by_same_name_twice() -> None:
import pandas as pd

df = pd.DataFrame({"a": [1, 1, 2], "b": [4, 5, 6]})
with pytest.raises(ValueError, match="two aggregations with the same"):
nw.from_native(df).group_by("a").agg(nw.col("b").sum(), nw.col("b").n_unique())


def test_group_by_empty_result_pandas() -> None:
df_any = pd.DataFrame({"a": [1, 2, 3], "b": [4, 3, 2]})
df = nw.from_native(df_any, eager_only=True)
Expand Down
Loading