Skip to content

Commit c095144

Browse files
committed
keep simplifying
1 parent 24861c7 commit c095144

File tree

9 files changed

+39
-65
lines changed

9 files changed

+39
-65
lines changed

narwhals/_arrow/group_by.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import pyarrow.compute as pc
1212

1313
from narwhals._expression_parsing import is_simple_aggregation
14-
from narwhals._expression_parsing import parse_into_exprs
1514
from narwhals.exceptions import AnonymousExprError
1615
from narwhals.utils import generate_temporary_column_name
1716
from narwhals.utils import remove_prefix
@@ -20,8 +19,8 @@
2019
from typing_extensions import Self
2120

2221
from narwhals._arrow.dataframe import ArrowDataFrame
22+
from narwhals._arrow.expr import ArrowExpr
2323
from narwhals._arrow.series import ArrowSeries
24-
from narwhals._arrow.typing import IntoArrowExpr
2524
from narwhals.typing import CompliantExpr
2625

2726
POLARS_TO_ARROW_AGGREGATIONS = {
@@ -51,14 +50,8 @@ def __init__(
5150

5251
def agg(
5352
self: Self,
54-
*aggs: IntoArrowExpr,
55-
**named_aggs: IntoArrowExpr,
53+
*exprs: ArrowExpr,
5654
) -> ArrowDataFrame:
57-
exprs = parse_into_exprs(
58-
*aggs,
59-
namespace=self._df.__narwhals_namespace__(),
60-
**named_aggs,
61-
)
6255
for expr in exprs:
6356
if expr._output_names is None:
6457
msg = "group_by.agg"

narwhals/_arrow/namespace.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -365,9 +365,7 @@ def concat_str(
365365
separator: str,
366366
ignore_nulls: bool,
367367
) -> ArrowExpr:
368-
parsed_exprs = [
369-
*parse_into_exprs(*exprs, namespace=self),
370-
]
368+
parsed_exprs = parse_into_exprs(*exprs, namespace=self)
371369
dtypes = import_dtypes_module(self._version)
372370

373371
def func(df: ArrowDataFrame) -> list[ArrowSeries]:

narwhals/_dask/group_by.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -88,13 +88,8 @@ def __init__(
8888

8989
def agg(
9090
self: Self,
91-
*aggs: DaskExpr,
92-
**named_aggs: DaskExpr,
91+
*exprs: DaskExpr,
9392
) -> DaskLazyFrame:
94-
exprs: list[DaskExpr] = [
95-
*aggs,
96-
*(val.alias(key) for key, val in named_aggs.items()),
97-
]
9893
output_names: list[str] = copy(self._keys)
9994
for expr in exprs:
10095
if expr._output_names is None:

narwhals/_duckdb/group_by.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,8 @@ def __init__(
2727

2828
def agg(
2929
self: Self,
30-
*aggs: IntoDuckDBExpr,
31-
**named_aggs: IntoDuckDBExpr,
30+
*exprs: IntoDuckDBExpr,
3231
) -> DuckDBLazyFrame:
33-
exprs = tuple(
34-
*(x for x in aggs), *(val.alias(key) for key, val in named_aggs.items())
35-
)
3632
output_names: list[str] = copy(self._keys)
3733
for expr in exprs:
3834
if expr._output_names is None: # pragma: no cover

narwhals/_pandas_like/group_by.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from typing import Sequence
1111

1212
from narwhals._expression_parsing import is_simple_aggregation
13-
from narwhals._expression_parsing import parse_into_exprs
1413
from narwhals._pandas_like.utils import horizontal_concat
1514
from narwhals._pandas_like.utils import native_series_from_iterable
1615
from narwhals._pandas_like.utils import select_columns_by_name
@@ -24,8 +23,8 @@
2423
from typing_extensions import Self
2524

2625
from narwhals._pandas_like.dataframe import PandasLikeDataFrame
26+
from narwhals._pandas_like.expr import PandasLikeExpr
2727
from narwhals._pandas_like.series import PandasLikeSeries
28-
from narwhals._pandas_like.typing import IntoPandasLikeExpr
2928
from narwhals.typing import CompliantExpr
3029

3130
POLARS_TO_PANDAS_AGGREGATIONS = {
@@ -83,14 +82,8 @@ def __init__(
8382

8483
def agg(
8584
self: Self,
86-
*aggs: IntoPandasLikeExpr,
87-
**named_aggs: IntoPandasLikeExpr,
85+
*exprs: PandasLikeExpr,
8886
) -> PandasLikeDataFrame:
89-
exprs = parse_into_exprs(
90-
*aggs,
91-
namespace=self._df.__narwhals_namespace__(),
92-
**named_aggs,
93-
)
9487
implementation: Implementation = self._df._implementation
9588
output_names: list[str] = copy(self._keys)
9689
for expr in exprs:

narwhals/_pandas_like/namespace.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -381,9 +381,7 @@ def concat_str(
381381
separator: str,
382382
ignore_nulls: bool,
383383
) -> PandasLikeExpr:
384-
parsed_exprs = [
385-
*parse_into_exprs(*exprs, namespace=self),
386-
]
384+
parsed_exprs = parse_into_exprs(*exprs, namespace=self)
387385
dtypes = import_dtypes_module(self._version)
388386

389387
def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]:

narwhals/_polars/group_by.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,9 @@ def __init__(
2424
else:
2525
self._grouped = df._native_frame.group_by(keys)
2626

27-
def agg(self: Self, *aggs: PolarsExpr, **named_aggs: PolarsExpr) -> PolarsDataFrame:
28-
aggs, named_aggs = extract_args_kwargs(aggs, named_aggs) # type: ignore[assignment]
29-
return self._compliant_frame._from_native_frame(
30-
self._grouped.agg(*aggs, **named_aggs),
31-
)
27+
def agg(self: Self, *aggs: PolarsExpr) -> PolarsDataFrame:
28+
aggs, _ = extract_args_kwargs(aggs, {}) # type: ignore[assignment]
29+
return self._compliant_frame._from_native_frame(self._grouped.agg(*aggs))
3230

3331
def __iter__(self: Self) -> Iterator[tuple[tuple[str, ...], PolarsDataFrame]]:
3432
for key, df in self._grouped:
@@ -46,8 +44,6 @@ def __init__(
4644
else:
4745
self._grouped = df._native_frame.group_by(keys)
4846

49-
def agg(self: Self, *aggs: PolarsExpr, **named_aggs: PolarsExpr) -> PolarsLazyFrame:
50-
aggs, named_aggs = extract_args_kwargs(aggs, named_aggs) # type: ignore[assignment]
51-
return self._compliant_frame._from_native_frame(
52-
self._grouped.agg(*aggs, **named_aggs),
53-
)
47+
def agg(self: Self, *aggs: PolarsExpr) -> PolarsLazyFrame:
48+
aggs, _ = extract_args_kwargs(aggs, {}) # type: ignore[assignment]
49+
return self._compliant_frame._from_native_frame(self._grouped.agg(*aggs))

narwhals/_spark_like/group_by.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,8 @@ def __init__(
4444

4545
def agg(
4646
self: Self,
47-
*aggs: SparkLikeExpr,
48-
**named_aggs: SparkLikeExpr,
47+
*exprs: SparkLikeExpr,
4948
) -> SparkLikeLazyFrame:
50-
exprs = tuple(
51-
*(x for x in aggs), *(val.alias(key) for key, val in named_aggs.items())
52-
)
5349
output_names: list[str] = copy(self._keys)
5450
for expr in exprs:
5551
if expr._output_names is None: # pragma: no cover

narwhals/group_by.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,13 @@
1111
from narwhals.dataframe import DataFrame
1212
from narwhals.dataframe import LazyFrame
1313
from narwhals.exceptions import InvalidOperationError
14+
from narwhals.utils import flatten
1415
from narwhals.utils import tupleify
1516

1617
if TYPE_CHECKING:
1718
from typing_extensions import Self
1819

19-
from narwhals.typing import IntoExpr
20+
from narwhals.expr import Expr
2021

2122
DataFrameT = TypeVar("DataFrameT")
2223
LazyFrameT = TypeVar("LazyFrameT")
@@ -30,9 +31,7 @@ def __init__(self: Self, df: DataFrameT, *keys: str, drop_null_keys: bool) -> No
3031
*self._keys, drop_null_keys=drop_null_keys
3132
)
3233

33-
def agg(
34-
self: Self, *aggs: IntoExpr | Iterable[IntoExpr], **named_aggs: IntoExpr
35-
) -> DataFrameT:
34+
def agg(self: Self, *aggs: Expr | Iterable[Expr], **named_aggs: Expr) -> DataFrameT:
3635
"""Compute aggregations for each group of a group by operation.
3736
3837
Arguments:
@@ -112,7 +111,8 @@ def agg(
112111
│ c ┆ 3 ┆ 1 │
113112
└─────┴─────┴─────┘
114113
"""
115-
if not all(getattr(x, "_aggregates", True) for x in aggs) and all(
114+
flat_aggs = tuple(flatten(aggs))
115+
if not all(getattr(x, "_aggregates", True) for x in flat_aggs) and all(
116116
getattr(x, "_aggregates", True) for x in named_aggs.values()
117117
):
118118
msg = (
@@ -122,11 +122,16 @@ def agg(
122122
"but `df.group_by('a').agg(nw.col('b'))` is not."
123123
)
124124
raise InvalidOperationError(msg)
125-
compliant_aggs, compliant_named_aggs = self._df._flatten_and_extract(
126-
*aggs, **named_aggs
125+
plx = self._df.__narwhals_namespace__()
126+
compliant_aggs = (
127+
*(x._to_compliant_expr(plx) for x in flat_aggs),
128+
*(
129+
value._to_compliant_expr(plx).alias(key)
130+
for key, value in named_aggs.items()
131+
),
127132
)
128133
return self._df._from_compliant_dataframe( # type: ignore[return-value]
129-
self._grouped.agg(*compliant_aggs, **compliant_named_aggs),
134+
self._grouped.agg(*compliant_aggs),
130135
)
131136

132137
def __iter__(self: Self) -> Iterator[tuple[Any, DataFrameT]]:
@@ -144,9 +149,7 @@ def __init__(self: Self, df: LazyFrameT, *keys: str, drop_null_keys: bool) -> No
144149
*self._keys, drop_null_keys=drop_null_keys
145150
)
146151

147-
def agg(
148-
self: Self, *aggs: IntoExpr | Iterable[IntoExpr], **named_aggs: IntoExpr
149-
) -> LazyFrameT:
152+
def agg(self: Self, *aggs: Expr | Iterable[Expr], **named_aggs: Expr) -> LazyFrameT:
150153
"""Compute aggregations for each group of a group by operation.
151154
152155
If a library does not support lazy execution, then this is a no-op.
@@ -210,7 +213,8 @@ def agg(
210213
│ c ┆ 3 ┆ 1 │
211214
└─────┴─────┴─────┘
212215
"""
213-
if not all(getattr(x, "_aggregates", True) for x in aggs) and all(
216+
flat_aggs = tuple(flatten(aggs))
217+
if not all(getattr(x, "_aggregates", True) for x in flat_aggs) and all(
214218
getattr(x, "_aggregates", True) for x in named_aggs.values()
215219
):
216220
msg = (
@@ -220,9 +224,14 @@ def agg(
220224
"but `df.group_by('a').agg(nw.col('b'))` is not."
221225
)
222226
raise InvalidOperationError(msg)
223-
compliant_aggs, compliant_named_aggs = self._df._flatten_and_extract(
224-
*aggs, **named_aggs
227+
plx = self._df.__narwhals_namespace__()
228+
compliant_aggs = (
229+
*(x._to_compliant_expr(plx) for x in flat_aggs),
230+
*(
231+
value._to_compliant_expr(plx).alias(key)
232+
for key, value in named_aggs.items()
233+
),
225234
)
226235
return self._df._from_compliant_dataframe( # type: ignore[return-value]
227-
self._grouped.agg(*compliant_aggs, **compliant_named_aggs),
236+
self._grouped.agg(*compliant_aggs),
228237
)

0 commit comments

Comments
 (0)