Skip to content

feat: mean_horizontal #843

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api-reference/narwhals.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ Here are the top-level functions available in Narwhals.
- maybe_convert_dtypes
- maybe_set_index
- mean
- mean_horizontal
- min
- narwhalify
- new_series
Expand Down
2 changes: 2 additions & 0 deletions narwhals/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from narwhals.expr import lit
from narwhals.expr import max
from narwhals.expr import mean
from narwhals.expr import mean_horizontal
from narwhals.expr import min
from narwhals.expr import sum
from narwhals.expr import sum_horizontal
Expand Down Expand Up @@ -73,6 +74,7 @@
"min",
"max",
"mean",
"mean_horizontal",
"sum",
"sum_horizontal",
"DataFrame",
Expand Down
9 changes: 9 additions & 0 deletions narwhals/_arrow/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,15 @@ def any_horizontal(self, *exprs: IntoArrowExpr) -> ArrowExpr:
def sum_horizontal(self, *exprs: IntoArrowExpr) -> ArrowExpr:
return reduce(lambda x, y: x + y, parse_into_exprs(*exprs, namespace=self))

def mean_horizontal(self, *exprs: IntoArrowExpr) -> IntoArrowExpr:
arrow_exprs = parse_into_exprs(*exprs, namespace=self)
total = reduce(lambda x, y: x + y, (e.fill_null(0.0) for e in arrow_exprs))
n_non_zero = reduce(
lambda x, y: x + y,
((1 - e.is_null().cast(self.Int64())) for e in arrow_exprs),
)
return total / n_non_zero

def concat(
self,
items: Iterable[ArrowDataFrame],
Expand Down
6 changes: 6 additions & 0 deletions narwhals/_dask/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,12 @@ def any_horizontal(self, *exprs: IntoDaskExpr) -> DaskExpr:
def sum_horizontal(self, *exprs: IntoDaskExpr) -> DaskExpr:
return reduce(lambda x, y: x + y, parse_into_exprs(*exprs, namespace=self))

def mean_horizontal(self, *exprs: IntoDaskExpr) -> IntoDaskExpr:
dask_exprs = parse_into_exprs(*exprs, namespace=self)
total = reduce(lambda x, y: x + y, (e.fill_null(0.0) for e in dask_exprs))
n_non_zero = reduce(lambda x, y: x + y, ((1 - e.is_null()) for e in dask_exprs))
return total / n_non_zero

def _create_expr_from_series(self, _: Any) -> NoReturn:
msg = "`_create_expr_from_series` for DaskNamespace exists only for compatibility"
raise NotImplementedError(msg)
Expand Down
32 changes: 25 additions & 7 deletions narwhals/_expression_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,27 @@
from narwhals._pandas_like.namespace import PandasLikeNamespace
from narwhals._pandas_like.series import PandasLikeSeries
from narwhals._pandas_like.typing import IntoPandasLikeExpr
from narwhals._polars.expr import PolarsExpr
from narwhals._polars.namespace import PolarsNamespace
from narwhals._polars.series import PolarsSeries
from narwhals._polars.typing import IntoPolarsExpr

CompliantNamespace = Union[PandasLikeNamespace, ArrowNamespace, DaskNamespace]
CompliantExpr = Union[PandasLikeExpr, ArrowExpr, DaskExpr]
IntoCompliantExpr = Union[IntoPandasLikeExpr, IntoArrowExpr, IntoDaskExpr]
CompliantNamespace = Union[
PandasLikeNamespace, ArrowNamespace, DaskNamespace, PolarsNamespace
]
CompliantExpr = Union[PandasLikeExpr, ArrowExpr, DaskExpr, PolarsExpr]
IntoCompliantExpr = Union[
IntoPandasLikeExpr, IntoArrowExpr, IntoDaskExpr, IntoPolarsExpr
]
IntoCompliantExprT = TypeVar("IntoCompliantExprT", bound=IntoCompliantExpr)
CompliantExprT = TypeVar("CompliantExprT", bound=CompliantExpr)
CompliantSeries = Union[PandasLikeSeries, ArrowSeries]
CompliantSeries = Union[PandasLikeSeries, ArrowSeries, PolarsSeries]
ListOfCompliantSeries = Union[
list[PandasLikeSeries], list[ArrowSeries], list[DaskExpr]
list[PandasLikeSeries], list[ArrowSeries], list[DaskExpr], list[PolarsSeries]
]
ListOfCompliantExpr = Union[
list[PandasLikeExpr], list[ArrowExpr], list[DaskExpr], list[PolarsExpr]
]
ListOfCompliantExpr = Union[list[PandasLikeExpr], list[ArrowExpr], list[DaskExpr]]
CompliantDataFrame = Union[PandasLikeDataFrame, ArrowDataFrame, DaskLazyFrame]

T = TypeVar("T")
Expand Down Expand Up @@ -133,14 +143,22 @@ def parse_into_exprs(
) -> list[DaskExpr]: ...


@overload
def parse_into_exprs(
*exprs: IntoPolarsExpr,
namespace: PolarsNamespace,
**named_exprs: IntoPolarsExpr,
) -> list[PolarsExpr]: ...


def parse_into_exprs(
*exprs: IntoCompliantExpr,
namespace: CompliantNamespace,
**named_exprs: IntoCompliantExpr,
) -> ListOfCompliantExpr:
"""Parse each input as an expression (if it's not already one). See `parse_into_expr` for
more details."""
return [ # type: ignore[return-value]
return [
parse_into_expr(into_expr, namespace=namespace) for into_expr in flatten(exprs)
] + [
parse_into_expr(expr, namespace=namespace).alias(name)
Expand Down
8 changes: 8 additions & 0 deletions narwhals/_pandas_like/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,14 @@ def all_horizontal(self, *exprs: IntoPandasLikeExpr) -> PandasLikeExpr:
def any_horizontal(self, *exprs: IntoPandasLikeExpr) -> PandasLikeExpr:
return reduce(lambda x, y: x | y, parse_into_exprs(*exprs, namespace=self))

def mean_horizontal(self, *exprs: IntoPandasLikeExpr) -> PandasLikeExpr:
pandas_like_exprs = parse_into_exprs(*exprs, namespace=self)
total = reduce(lambda x, y: x + y, (e.fill_null(0.0) for e in pandas_like_exprs))
n_non_zero = reduce(
lambda x, y: x + y, ((1 - e.is_null()) for e in pandas_like_exprs)
)
return total / n_non_zero

def concat(
self,
items: Iterable[PandasLikeDataFrame],
Expand Down
26 changes: 23 additions & 3 deletions narwhals/_polars/namespace.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from __future__ import annotations

from functools import reduce
from typing import TYPE_CHECKING
from typing import Any
from typing import Iterable
from typing import Sequence

from narwhals import dtypes
from narwhals._expression_parsing import parse_into_exprs
from narwhals._polars.utils import extract_args_kwargs
from narwhals._polars.utils import narwhals_to_native_dtype
from narwhals.dependencies import get_polars
Expand All @@ -15,6 +17,7 @@
from narwhals._polars.dataframe import PolarsDataFrame
from narwhals._polars.dataframe import PolarsLazyFrame
from narwhals._polars.expr import PolarsExpr
from narwhals._polars.typing import IntoPolarsExpr


class PolarsNamespace:
Expand Down Expand Up @@ -85,11 +88,28 @@ def lit(self, value: Any, dtype: dtypes.DType | None = None) -> PolarsExpr:
return PolarsExpr(pl.lit(value, dtype=narwhals_to_native_dtype(dtype)))
return PolarsExpr(pl.lit(value))

def mean(self, *column_names: str) -> Any:
def mean(self, *column_names: str) -> PolarsExpr:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should return PolarsExpr to be able to keep chaining and interacting with other narwhals exprs if needed

from narwhals._polars.expr import PolarsExpr

pl = get_polars()
if self._backend_version < (0, 20, 4): # pragma: no cover
return pl.mean([*column_names])
return pl.mean(*column_names)
return PolarsExpr(pl.mean([*column_names]))
return PolarsExpr(pl.mean(*column_names))

def mean_horizontal(self, *exprs: IntoPolarsExpr) -> PolarsExpr:
from narwhals._polars.expr import PolarsExpr

pl = get_polars()
polars_exprs = parse_into_exprs(*exprs, namespace=self)

if self._backend_version < (0, 20, 8): # pragma: no cover
total = reduce(lambda x, y: x + y, (e.fill_null(0.0) for e in polars_exprs))
n_non_zero = reduce(
lambda x, y: x + y, ((1 - e.is_null()) for e in polars_exprs)
)
return PolarsExpr(total._native_expr / n_non_zero._native_expr)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

__div__ is not currently implemented and I did not want to make the PR even larger


return PolarsExpr(pl.mean_horizontal([e._native_expr for e in polars_exprs]))

@property
def selectors(self) -> PolarsSelectors:
Expand Down
17 changes: 17 additions & 0 deletions narwhals/_polars/typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from __future__ import annotations # pragma: no cover

from typing import TYPE_CHECKING # pragma: no cover
from typing import Union # pragma: no cover

if TYPE_CHECKING:
import sys

if sys.version_info >= (3, 10):
from typing import TypeAlias
else:
from typing_extensions import TypeAlias

from narwhals._polars.expr import PolarsExpr
from narwhals._polars.series import PolarsSeries

IntoPolarsExpr: TypeAlias = Union[PolarsExpr, str, PolarsSeries]
53 changes: 53 additions & 0 deletions narwhals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4118,6 +4118,59 @@ def any_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr:
)


def mean_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr:
"""
Compute the mean of all values horizontally across columns.

Arguments:
exprs: Name(s) of the columns to use in the aggregation function. Accepts
expression input.

Examples:
>>> import pandas as pd
>>> import polars as pl
>>> import narwhals as nw
>>> data = {
... "a": [1, 8, 3],
... "b": [4, 5, None],
... "c": ["x", "y", "z"],
... }
>>> df_pl = pl.DataFrame(data)
>>> df_pd = pd.DataFrame(data)

We define a dataframe-agnostic function that computes the horizontal mean of "a"
and "b" columns:

>>> @nw.narwhalify
... def func(df):
... return df.select(nw.mean_horizontal("a", "b"))

We can then pass either pandas or polars to `func`:

>>> func(df_pd)
a
0 2.5
1 6.5
2 3.0
>>> func(df_pl)
shape: (3, 1)
β”Œβ”€β”€β”€β”€β”€β”
β”‚ a β”‚
β”‚ --- β”‚
β”‚ f64 β”‚
β•žβ•β•β•β•β•β•‘
β”‚ 2.5 β”‚
β”‚ 6.5 β”‚
β”‚ 3.0 β”‚
β””β”€β”€β”€β”€β”€β”˜
"""
return Expr(
lambda plx: plx.mean_horizontal(
*[extract_compliant(plx, v) for v in flatten(exprs)]
)
)


__all__ = [
"Expr",
]
50 changes: 50 additions & 0 deletions narwhals/stable/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -1277,6 +1277,55 @@ def any_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr:
return _stableify(nw.any_horizontal(*exprs))


def mean_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr:
"""
Compute the mean of all values horizontally across columns.

Arguments:
exprs: Name(s) of the columns to use in the aggregation function. Accepts
expression input.

Examples:
>>> import pandas as pd
>>> import polars as pl
>>> import narwhals.stable.v1 as nw
>>> data = {
... "a": [1, 8, 3],
... "b": [4, 5, None],
... "c": ["x", "y", "z"],
... }
>>> df_pl = pl.DataFrame(data)
>>> df_pd = pd.DataFrame(data)

We define a dataframe-agnostic function that computes the horizontal mean of "a"
and "b" columns:

>>> @nw.narwhalify
... def func(df):
... return df.select(nw.mean_horizontal("a", "b"))

We can then pass either pandas or polars to `func`:

>>> func(df_pd)
a
0 2.5
1 6.5
2 3.0
>>> func(df_pl)
shape: (3, 1)
β”Œβ”€β”€β”€β”€β”€β”
β”‚ a β”‚
β”‚ --- β”‚
β”‚ f64 β”‚
β•žβ•β•β•β•β•β•‘
β”‚ 2.5 β”‚
β”‚ 6.5 β”‚
β”‚ 3.0 β”‚
β””β”€β”€β”€β”€β”€β”˜
"""
return _stableify(nw.mean_horizontal(*exprs))


def is_ordered_categorical(series: Series) -> bool:
"""
Return whether indices of categories are semantically meaningful.
Expand Down Expand Up @@ -1558,6 +1607,7 @@ def from_dict(
"min",
"max",
"mean",
"mean_horizontal",
"sum",
"sum_horizontal",
"DataFrame",
Expand Down
15 changes: 15 additions & 0 deletions tests/expr_and_series/mean_horizontal_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from typing import Any

import pytest

import narwhals.stable.v1 as nw
from tests.utils import compare_dicts


@pytest.mark.parametrize("col_expr", [nw.col("a"), "a"])
def test_meanh(constructor: Any, col_expr: Any) -> None:
data = {"a": [1, 3, None, None], "b": [4, None, 6, None]}
df = nw.from_native(constructor(data))
result = df.select(horizontal_mean=nw.mean_horizontal(col_expr, nw.col("b")))
expected = {"horizontal_mean": [2.5, 3.0, 6.0, float("nan")]}
compare_dicts(result, expected)
Loading