diff --git a/narwhals/_polars/series.py b/narwhals/_polars/series.py index 0bfea20ad..13da2c136 100644 --- a/narwhals/_polars/series.py +++ b/narwhals/_polars/series.py @@ -21,6 +21,7 @@ from narwhals._polars.dataframe import PolarsDataFrame from narwhals.dtypes import DType + from narwhals.series import Series as NWSeries from narwhals.utils import Version T = TypeVar("T") @@ -120,10 +121,14 @@ def dtype(self: Self) -> DType: def __getitem__(self: Self, item: int) -> Any: ... @overload - def __getitem__(self: Self, item: slice | Sequence[int]) -> Self: ... - - def __getitem__(self: Self, item: int | slice | Sequence[int]) -> Any | Self: - return self._from_native_object(self._native_series.__getitem__(item)) + def __getitem__(self: Self, item: slice | Sequence[int] | NWSeries[Any]) -> Self: ... + + def __getitem__( + self: Self, item: int | slice | Sequence[int] | NWSeries[Any] + ) -> Any | Self: + if isinstance(item, (int, slice, Sequence)): + return self._from_native_object(self._native_series.__getitem__(item)) + return self._from_native_object(self._native_series.__getitem__(item.to_numpy())) def cast(self: Self, dtype: DType) -> Self: ser = self._native_series diff --git a/narwhals/series.py b/narwhals/series.py index 2846aebea..96cc2c009 100644 --- a/narwhals/series.py +++ b/narwhals/series.py @@ -64,9 +64,11 @@ def __array__(self: Self, dtype: Any = None, copy: bool | None = None) -> np.nda def __getitem__(self: Self, idx: int) -> Any: ... @overload - def __getitem__(self: Self, idx: slice | Sequence[int]) -> Self: ... + def __getitem__(self: Self, idx: slice | Sequence[int] | Series[Any]) -> Self: ... - def __getitem__(self: Self, idx: int | slice | Sequence[int]) -> Any | Self: + def __getitem__( + self: Self, idx: int | slice | Sequence[int] | Series[Any] + ) -> Any | Self: if isinstance(idx, int): return self._compliant_series[idx] return self._from_compliant_series(self._compliant_series[idx]) diff --git a/tests/series_only/__getitem___test.py b/tests/series_only/__getitem___test.py new file mode 100644 index 000000000..c6dc125e2 --- /dev/null +++ b/tests/series_only/__getitem___test.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +import polars as pl +import pytest + +import narwhals.stable.v1 as nw + + +def test_getitem() -> None: + spl = pl.Series([1, 2, 3]) + assert spl[spl[0, 1]].equals(pl.Series([2, 3])) + + snw = nw.from_native(spl, series_only=True) + assert snw[snw[0, 1]].to_native().equals(pl.Series([2, 3])) + + spl = pl.Series([1, 2, 3]) + snw = nw.from_native(spl, series_only=True) + assert pytest.raises(TypeError, lambda: snw[snw[True, False]])