From 53de4f384a2ecaa3981b769e90edb09c2c89aee9 Mon Sep 17 00:00:00 2001 From: devdanzin <74280297+devdanzin@users.noreply.github.com> Date: Sat, 7 Dec 2024 09:21:09 -0300 Subject: [PATCH 1/2] enh: allow passing Series to Series.__getitem__. --- narwhals/_polars/series.py | 9 +++++++-- narwhals/series.py | 2 +- tests/series_only/__getitem___test.py | 17 +++++++++++++++++ 3 files changed, 25 insertions(+), 3 deletions(-) create mode 100644 tests/series_only/__getitem___test.py diff --git a/narwhals/_polars/series.py b/narwhals/_polars/series.py index 0bfea20ad..3497096c9 100644 --- a/narwhals/_polars/series.py +++ b/narwhals/_polars/series.py @@ -21,6 +21,7 @@ from narwhals._polars.dataframe import PolarsDataFrame from narwhals.dtypes import DType + from narwhals.series import Series as NWSeries from narwhals.utils import Version T = TypeVar("T") @@ -122,8 +123,12 @@ def __getitem__(self: Self, item: int) -> Any: ... @overload def __getitem__(self: Self, item: slice | Sequence[int]) -> Self: ... - def __getitem__(self: Self, item: int | slice | Sequence[int]) -> Any | Self: - return self._from_native_object(self._native_series.__getitem__(item)) + def __getitem__( + self: Self, item: int | slice | Sequence[int] | NWSeries + ) -> Any | Self: + if isinstance(item, (int, slice, Sequence)): + return self._from_native_object(self._native_series.__getitem__(item)) + return self._from_native_object(self._native_series.__getitem__(item.to_numpy())) def cast(self: Self, dtype: DType) -> Self: ser = self._native_series diff --git a/narwhals/series.py b/narwhals/series.py index 2846aebea..f5d05bc67 100644 --- a/narwhals/series.py +++ b/narwhals/series.py @@ -66,7 +66,7 @@ def __getitem__(self: Self, idx: int) -> Any: ... @overload def __getitem__(self: Self, idx: slice | Sequence[int]) -> Self: ... - def __getitem__(self: Self, idx: int | slice | Sequence[int]) -> Any | Self: + def __getitem__(self: Self, idx: int | slice | Sequence[int] | Series) -> Any | Self: if isinstance(idx, int): return self._compliant_series[idx] return self._from_compliant_series(self._compliant_series[idx]) diff --git a/tests/series_only/__getitem___test.py b/tests/series_only/__getitem___test.py new file mode 100644 index 000000000..5da0f51b8 --- /dev/null +++ b/tests/series_only/__getitem___test.py @@ -0,0 +1,17 @@ +from __future__ import annotations + +import pytest +import polars as pl + +import narwhals.stable.v1 as nw + +def test_getitem() -> None: + spl = pl.Series([1, 2, 3]) + assert spl[spl[0, 1]].equals(pl.Series([2, 3])) + + snw = nw.from_native(spl, series_only=True) + assert snw[snw[0, 1]].to_native().equals(pl.Series([2, 3])) + + spl = pl.Series([1, 2, 3]) + snw = nw.from_native(spl, series_only=True) + assert pytest.raises(TypeError, lambda: snw[snw[True, False]]) \ No newline at end of file From 9f0139cadd3152425c0e84b50833ef48653d9abf Mon Sep 17 00:00:00 2001 From: devdanzin <74280297+devdanzin@users.noreply.github.com> Date: Sat, 7 Dec 2024 10:00:17 -0300 Subject: [PATCH 2/2] Address pre-commit issues. --- narwhals/_polars/series.py | 4 ++-- narwhals/series.py | 6 ++++-- tests/series_only/__getitem___test.py | 5 +++-- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/narwhals/_polars/series.py b/narwhals/_polars/series.py index 3497096c9..13da2c136 100644 --- a/narwhals/_polars/series.py +++ b/narwhals/_polars/series.py @@ -121,10 +121,10 @@ def dtype(self: Self) -> DType: def __getitem__(self: Self, item: int) -> Any: ... @overload - def __getitem__(self: Self, item: slice | Sequence[int]) -> Self: ... + def __getitem__(self: Self, item: slice | Sequence[int] | NWSeries[Any]) -> Self: ... def __getitem__( - self: Self, item: int | slice | Sequence[int] | NWSeries + self: Self, item: int | slice | Sequence[int] | NWSeries[Any] ) -> Any | Self: if isinstance(item, (int, slice, Sequence)): return self._from_native_object(self._native_series.__getitem__(item)) diff --git a/narwhals/series.py b/narwhals/series.py index f5d05bc67..96cc2c009 100644 --- a/narwhals/series.py +++ b/narwhals/series.py @@ -64,9 +64,11 @@ def __array__(self: Self, dtype: Any = None, copy: bool | None = None) -> np.nda def __getitem__(self: Self, idx: int) -> Any: ... @overload - def __getitem__(self: Self, idx: slice | Sequence[int]) -> Self: ... + def __getitem__(self: Self, idx: slice | Sequence[int] | Series[Any]) -> Self: ... - def __getitem__(self: Self, idx: int | slice | Sequence[int] | Series) -> Any | Self: + def __getitem__( + self: Self, idx: int | slice | Sequence[int] | Series[Any] + ) -> Any | Self: if isinstance(idx, int): return self._compliant_series[idx] return self._from_compliant_series(self._compliant_series[idx]) diff --git a/tests/series_only/__getitem___test.py b/tests/series_only/__getitem___test.py index 5da0f51b8..c6dc125e2 100644 --- a/tests/series_only/__getitem___test.py +++ b/tests/series_only/__getitem___test.py @@ -1,10 +1,11 @@ from __future__ import annotations -import pytest import polars as pl +import pytest import narwhals.stable.v1 as nw + def test_getitem() -> None: spl = pl.Series([1, 2, 3]) assert spl[spl[0, 1]].equals(pl.Series([2, 3])) @@ -14,4 +15,4 @@ def test_getitem() -> None: spl = pl.Series([1, 2, 3]) snw = nw.from_native(spl, series_only=True) - assert pytest.raises(TypeError, lambda: snw[snw[True, False]]) \ No newline at end of file + assert pytest.raises(TypeError, lambda: snw[snw[True, False]])