Skip to content

Commit 2700588

Browse files
committed
old polars and some more
1 parent 2c19b25 commit 2700588

File tree

5 files changed

+69
-15
lines changed

5 files changed

+69
-15
lines changed

narwhals/_expression_parsing.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,17 +30,27 @@
3030
from narwhals._pandas_like.namespace import PandasLikeNamespace
3131
from narwhals._pandas_like.series import PandasLikeSeries
3232
from narwhals._pandas_like.typing import IntoPandasLikeExpr
33+
from narwhals._polars.expr import PolarsExpr
34+
from narwhals._polars.namespace import PolarsNamespace
35+
from narwhals._polars.series import PolarsSeries
36+
from narwhals._polars.typing import IntoPolarsExpr
3337

34-
CompliantNamespace = Union[PandasLikeNamespace, ArrowNamespace, DaskNamespace]
35-
CompliantExpr = Union[PandasLikeExpr, ArrowExpr, DaskExpr]
36-
IntoCompliantExpr = Union[IntoPandasLikeExpr, IntoArrowExpr, IntoDaskExpr]
38+
CompliantNamespace = Union[
39+
PandasLikeNamespace, ArrowNamespace, DaskNamespace, PolarsNamespace
40+
]
41+
CompliantExpr = Union[PandasLikeExpr, ArrowExpr, DaskExpr, PolarsExpr]
42+
IntoCompliantExpr = Union[
43+
IntoPandasLikeExpr, IntoArrowExpr, IntoDaskExpr, IntoPolarsExpr
44+
]
3745
IntoCompliantExprT = TypeVar("IntoCompliantExprT", bound=IntoCompliantExpr)
3846
CompliantExprT = TypeVar("CompliantExprT", bound=CompliantExpr)
39-
CompliantSeries = Union[PandasLikeSeries, ArrowSeries]
47+
CompliantSeries = Union[PandasLikeSeries, ArrowSeries, PolarsSeries]
4048
ListOfCompliantSeries = Union[
41-
list[PandasLikeSeries], list[ArrowSeries], list[DaskExpr]
49+
list[PandasLikeSeries], list[ArrowSeries], list[DaskExpr], list[PolarsSeries]
50+
]
51+
ListOfCompliantExpr = Union[
52+
list[PandasLikeExpr], list[ArrowExpr], list[DaskExpr], list[PolarsExpr]
4253
]
43-
ListOfCompliantExpr = Union[list[PandasLikeExpr], list[ArrowExpr], list[DaskExpr]]
4454
CompliantDataFrame = Union[PandasLikeDataFrame, ArrowDataFrame, DaskLazyFrame]
4555

4656
T = TypeVar("T")
@@ -133,14 +143,22 @@ def parse_into_exprs(
133143
) -> list[DaskExpr]: ...
134144

135145

146+
@overload
147+
def parse_into_exprs(
148+
*exprs: IntoPolarsExpr,
149+
namespace: PolarsNamespace,
150+
**named_exprs: IntoPolarsExpr,
151+
) -> list[PolarsExpr]: ...
152+
153+
136154
def parse_into_exprs(
137155
*exprs: IntoCompliantExpr,
138156
namespace: CompliantNamespace,
139157
**named_exprs: IntoCompliantExpr,
140158
) -> ListOfCompliantExpr:
141159
"""Parse each input as an expression (if it's not already one). See `parse_into_expr` for
142160
more details."""
143-
return [ # type: ignore[return-value]
161+
return [
144162
parse_into_expr(into_expr, namespace=namespace) for into_expr in flatten(exprs)
145163
] + [
146164
parse_into_expr(expr, namespace=namespace).alias(name)

narwhals/_polars/namespace.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
from __future__ import annotations
22

3+
from functools import reduce
34
from typing import TYPE_CHECKING
45
from typing import Any
56
from typing import Iterable
67
from typing import Sequence
78

89
from narwhals import dtypes
10+
from narwhals._expression_parsing import parse_into_exprs
911
from narwhals._polars.utils import extract_args_kwargs
1012
from narwhals._polars.utils import narwhals_to_native_dtype
1113
from narwhals.dependencies import get_polars
@@ -15,6 +17,7 @@
1517
from narwhals._polars.dataframe import PolarsDataFrame
1618
from narwhals._polars.dataframe import PolarsLazyFrame
1719
from narwhals._polars.expr import PolarsExpr
20+
from narwhals._polars.typing import IntoPolarsExpr
1821

1922

2023
class PolarsNamespace:
@@ -85,11 +88,28 @@ def lit(self, value: Any, dtype: dtypes.DType | None = None) -> PolarsExpr:
8588
return PolarsExpr(pl.lit(value, dtype=narwhals_to_native_dtype(dtype)))
8689
return PolarsExpr(pl.lit(value))
8790

88-
def mean(self, *column_names: str) -> Any:
91+
def mean(self, *column_names: str) -> PolarsExpr:
92+
from narwhals._polars.expr import PolarsExpr
93+
8994
pl = get_polars()
9095
if self._backend_version < (0, 20, 4): # pragma: no cover
91-
return pl.mean([*column_names])
92-
return pl.mean(*column_names)
96+
return PolarsExpr(pl.mean([*column_names]))
97+
return PolarsExpr(pl.mean(*column_names))
98+
99+
def mean_horizontal(self, *exprs: IntoPolarsExpr) -> PolarsExpr:
100+
from narwhals._polars.expr import PolarsExpr
101+
102+
pl = get_polars()
103+
polars_exprs = parse_into_exprs(*exprs, namespace=self)
104+
105+
if self._backend_version < (0, 20, 8): # pragma: no cover
106+
total = reduce(lambda x, y: x + y, (e.fill_null(0.0) for e in polars_exprs))
107+
n_non_zero = reduce(
108+
lambda x, y: x + y, ((1 - e.is_null()) for e in polars_exprs)
109+
)
110+
return PolarsExpr(total._native_expr / n_non_zero._native_expr)
111+
112+
return PolarsExpr(pl.mean_horizontal([e._native_expr for e in polars_exprs]))
93113

94114
@property
95115
def selectors(self) -> PolarsSelectors:

narwhals/_polars/typing.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from __future__ import annotations # pragma: no cover
2+
3+
from typing import TYPE_CHECKING # pragma: no cover
4+
from typing import Union # pragma: no cover
5+
6+
if TYPE_CHECKING:
7+
import sys
8+
9+
if sys.version_info >= (3, 10):
10+
from typing import TypeAlias
11+
else:
12+
from typing_extensions import TypeAlias
13+
14+
from narwhals._polars.expr import PolarsExpr
15+
from narwhals._polars.series import PolarsSeries
16+
17+
IntoPolarsExpr: TypeAlias = Union[PolarsExpr, str, PolarsSeries]

narwhals/stable/v1.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1607,6 +1607,7 @@ def from_dict(
16071607
"min",
16081608
"max",
16091609
"mean",
1610+
"mean_horizontal",
16101611
"sum",
16111612
"sum_horizontal",
16121613
"DataFrame",

tests/expr_and_series/mean_horizontal_test.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,9 @@
77

88

99
@pytest.mark.parametrize("col_expr", [nw.col("a"), "a"])
10-
def test_sumh(constructor: Any, col_expr: Any) -> None:
11-
data = {"a": [1, 3, None], "b": [4, None, 6]}
10+
def test_meanh(constructor: Any, col_expr: Any) -> None:
11+
data = {"a": [1, 3, None, None], "b": [4, None, 6, None]}
1212
df = nw.from_native(constructor(data))
1313
result = df.select(horizontal_mean=nw.mean_horizontal(col_expr, nw.col("b")))
14-
expected = {
15-
"horizontal_mean": [2.5, 3.0, 6.0],
16-
}
14+
expected = {"horizontal_mean": [2.5, 3.0, 6.0, float("nan")]}
1715
compare_dicts(result, expected)

0 commit comments

Comments
 (0)