Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: catch Polars exceptions, unify exception raising more #1918

Merged
merged 9 commits into from
Feb 10, 2025
Merged
46 changes: 32 additions & 14 deletions narwhals/_polars/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import polars as pl

from narwhals._polars.namespace import PolarsNamespace
from narwhals._polars.utils import catch_polars_exception
from narwhals._polars.utils import convert_str_slice_to_int_slice
from narwhals._polars.utils import extract_args_kwargs
from narwhals._polars.utils import native_to_narwhals_dtype
Expand Down Expand Up @@ -102,6 +103,15 @@ def _from_native_object(
# scalar
return obj

def __len__(self) -> int:
return len(self._native_frame)

def head(self, n: int) -> Self:
return self._from_native_frame(self._native_frame.head(n))

def tail(self, n: int) -> Self:
return self._from_native_frame(self._native_frame.tail(n))

def __getattr__(self: Self, attr: str) -> Any:
def func(*args: Any, **kwargs: Any) -> Any:
args, kwargs = extract_args_kwargs(args, kwargs) # type: ignore[assignment]
Expand All @@ -112,6 +122,8 @@ def func(*args: Any, **kwargs: Any) -> Any:
except pl.exceptions.ColumnNotFoundError as e: # pragma: no cover
msg = f"{e!s}\n\nHint: Did you mean one of these columns: {self.columns}?"
raise ColumnNotFoundError(msg) from e
except Exception as e: # noqa: BLE001
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could these be narrowed any further than Exception?
https://docs.astral.sh/ruff/rules/blind-except/

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this shouldn't be an issue because we re-raise the exception anyway

raise catch_polars_exception(e, self._backend_version) from None

return func

Expand All @@ -134,11 +146,12 @@ def collect_schema(self: Self) -> dict[str, DType]:
for name, dtype in self._native_frame.schema.items()
}
else:
collected_schema = self._native_frame.collect_schema()
return {
name: native_to_narwhals_dtype(
dtype, self._version, self._backend_version
)
for name, dtype in self._native_frame.collect_schema().items()
for name, dtype in collected_schema.items()
}

@property
Expand Down Expand Up @@ -342,14 +355,17 @@ def pivot(
if self._backend_version < (1, 0, 0): # pragma: no cover
msg = "`pivot` is only supported for Polars>=1.0.0"
raise NotImplementedError(msg)
result = self._native_frame.pivot(
on,
index=index,
values=values,
aggregate_function=aggregate_function,
sort_columns=sort_columns,
separator=separator,
)
try:
result = self._native_frame.pivot(
on,
index=index,
values=values,
aggregate_function=aggregate_function,
sort_columns=sort_columns,
separator=separator,
)
except Exception as e: # noqa: BLE001
raise catch_polars_exception(e, self._backend_version) from None
return self._from_native_object(result)

def to_polars(self: Self) -> pl.DataFrame:
Expand Down Expand Up @@ -431,24 +447,26 @@ def collect_schema(self: Self) -> dict[str, DType]:
for name, dtype in self._native_frame.schema.items()
}
else:
try:
collected_schema = self._native_frame.collect_schema()
except Exception as e: # noqa: BLE001
raise catch_polars_exception(e, self._backend_version) from None
return {
name: native_to_narwhals_dtype(
dtype, self._version, self._backend_version
)
for name, dtype in self._native_frame.collect_schema().items()
for name, dtype in collected_schema.items()
}

def collect(
self: Self,
backend: Implementation | None,
**kwargs: Any,
) -> CompliantDataFrame:
import polars as pl

try:
result = self._native_frame.collect(**kwargs)
except pl.exceptions.ColumnNotFoundError as e:
raise ColumnNotFoundError(str(e)) from e
except Exception as e: # noqa: BLE001
raise catch_polars_exception(e, self._backend_version) from None

if backend is None or backend is Implementation.POLARS:
from narwhals._polars.dataframe import PolarsDataFrame
Expand Down
21 changes: 9 additions & 12 deletions narwhals/_polars/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import polars as pl

from narwhals._polars.utils import catch_polars_exception
from narwhals._polars.utils import extract_args_kwargs
from narwhals._polars.utils import extract_native
from narwhals._polars.utils import narwhals_to_native_dtype
Expand Down Expand Up @@ -223,14 +224,15 @@ def __invert__(self: Self) -> Self:

def is_nan(self: Self) -> Self:
native = self._native_series

try:
native_is_nan = native.is_nan()
except Exception as e: # noqa: BLE001
raise catch_polars_exception(e, self._backend_version) from None
if self._backend_version < (1, 18): # pragma: no cover
return self._from_native_series(
pl.select(pl.when(native.is_not_null()).then(native.is_nan()))[
native.name
]
pl.select(pl.when(native.is_not_null()).then(native_is_nan))[native.name]
)
return self._from_native_series(native.is_nan())
return self._from_native_series(native_is_nan)

def median(self: Self) -> Any:
from narwhals.exceptions import InvalidOperationError
Expand Down Expand Up @@ -456,15 +458,10 @@ def cum_count(self: Self, *, reverse: bool) -> Self:
return self._from_native_series(result)

def __contains__(self: Self, other: Any) -> bool:
from polars.exceptions import InvalidOperationError as PlInvalidOperationError

try:
return self._native_series.__contains__(other)
except PlInvalidOperationError as exc:
from narwhals.exceptions import InvalidOperationError

msg = f"Unable to compare other of type {type(other)} with series of type {self.dtype}."
raise InvalidOperationError(msg) from exc
except Exception as e: # noqa: BLE001
raise catch_polars_exception(e, self._backend_version) from None

def to_polars(self: Self) -> pl.Series:
return self._native_series
Expand Down
28 changes: 28 additions & 0 deletions narwhals/_polars/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@

import polars as pl

from narwhals.exceptions import ColumnNotFoundError
from narwhals.exceptions import InvalidOperationError
from narwhals.exceptions import NarwhalsError
from narwhals.exceptions import ShapeError
from narwhals.utils import import_dtypes_module

if TYPE_CHECKING:
Expand Down Expand Up @@ -216,3 +220,27 @@ def convert_str_slice_to_int_slice(
stop = columns.index(str_slice.stop) + 1 if str_slice.stop is not None else None
step = str_slice.step
return (start, stop, step)


def catch_polars_exception(
exception: Exception, backend_version: tuple[int, ...]
) -> NarwhalsError | Exception:
if isinstance(exception, pl.exceptions.ColumnNotFoundError):
return ColumnNotFoundError(str(exception))
elif isinstance(exception, pl.exceptions.ShapeError):
return ShapeError(str(exception))
elif isinstance(exception, pl.exceptions.InvalidOperationError):
return InvalidOperationError(str(exception))
elif isinstance(exception, pl.exceptions.ComputeError):
# We don't (yet?) have a Narwhals ComputeError.
return NarwhalsError(str(exception))
if backend_version >= (1,) and isinstance(exception, pl.exceptions.PolarsError):
# Old versions of Polars didn't have PolarsError.
return NarwhalsError(str(exception))
elif backend_version < (1,) and "polars.exceptions" in str(
type(exception)
): # pragma: no cover
# Last attempt, for old Polars versions.
return NarwhalsError(str(exception))
# Just return exception as-is.
return exception
23 changes: 9 additions & 14 deletions tests/expr_and_series/is_nan_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from typing import Any

import pytest
from polars.exceptions import ComputeError

import narwhals.stable.v1 as nw
from narwhals.exceptions import NarwhalsError
from tests.conftest import dask_lazy_p1_constructor
from tests.conftest import dask_lazy_p2_constructor
from tests.conftest import modin_constructor
Expand Down Expand Up @@ -53,7 +53,8 @@ def test_nan(constructor: Constructor) -> None:

context = (
pytest.raises(
ComputeError, match="NAN is not supported in a Non-floating point type column"
NarwhalsError,
match="NAN is not supported in a Non-floating point type column",
)
if "polars_lazy" in str(constructor)
and os.environ.get("NARWHALS_POLARS_GPU", False)
Expand Down Expand Up @@ -96,37 +97,31 @@ def test_nan_series(constructor_eager: ConstructorEager) -> None:
def test_nan_non_float(constructor: Constructor, request: pytest.FixtureRequest) -> None:
if ("pyspark" in str(constructor)) or "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
from polars.exceptions import InvalidOperationError as PlInvalidOperationError
from pyarrow.lib import ArrowNotImplementedError

from narwhals.exceptions import InvalidOperationError as NwInvalidOperationError
from narwhals.exceptions import InvalidOperationError

data = {"a": ["x", "y"]}
df = nw.from_native(constructor(data))

exc = NwInvalidOperationError
if "polars" in str(constructor):
exc = PlInvalidOperationError
elif "pyarrow_table" in str(constructor):
exc = InvalidOperationError
if "pyarrow_table" in str(constructor):
exc = ArrowNotImplementedError

with pytest.raises(exc):
df.select(nw.col("a").is_nan()).lazy().collect()


def test_nan_non_float_series(constructor_eager: ConstructorEager) -> None:
from polars.exceptions import InvalidOperationError as PlInvalidOperationError
from pyarrow.lib import ArrowNotImplementedError

from narwhals.exceptions import InvalidOperationError as NwInvalidOperationError
from narwhals.exceptions import InvalidOperationError

data = {"a": ["x", "y"]}
df = nw.from_native(constructor_eager(data), eager_only=True)

exc = NwInvalidOperationError
if "polars" in str(constructor_eager):
exc = PlInvalidOperationError
elif "pyarrow_table" in str(constructor_eager):
exc = InvalidOperationError
if "pyarrow_table" in str(constructor_eager):
exc = ArrowNotImplementedError

with pytest.raises(exc):
Expand Down
5 changes: 2 additions & 3 deletions tests/expr_and_series/median_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,18 +45,17 @@ def test_median_expr_raises_on_str(
) -> None:
if ("pyspark" in str(constructor)) or "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
from polars.exceptions import InvalidOperationError as PlInvalidOperationError

df = nw.from_native(constructor(data))
if isinstance(df, nw.LazyFrame):
with pytest.raises(
(InvalidOperationError, PlInvalidOperationError),
InvalidOperationError,
match="`median` operation not supported",
):
df.select(expr).lazy().collect()
else:
with pytest.raises(
(InvalidOperationError, PlInvalidOperationError),
InvalidOperationError,
match="`median` operation not supported",
):
df.select(expr)
Expand Down
7 changes: 3 additions & 4 deletions tests/expr_and_series/replace_strict_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest

import narwhals.stable.v1 as nw
from narwhals.exceptions import NarwhalsError
from tests.utils import POLARS_VERSION
from tests.utils import Constructor
from tests.utils import ConstructorEager
Expand Down Expand Up @@ -56,20 +57,18 @@ def test_replace_strict_series(
def test_replace_non_full(
constructor: Constructor, request: pytest.FixtureRequest
) -> None:
from polars.exceptions import PolarsError

if "dask" in str(constructor):
request.applymarker(pytest.mark.xfail)
if ("pyspark" in str(constructor)) or "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
df = nw.from_native(constructor({"a": [1, 2, 3]}))
if isinstance(df, nw.LazyFrame):
with pytest.raises((ValueError, PolarsError)):
with pytest.raises((ValueError, NarwhalsError)):
df.select(
nw.col("a").replace_strict([1, 3], [3, 4], return_dtype=nw.Int64)
).collect()
else:
with pytest.raises((ValueError, PolarsError)):
with pytest.raises((ValueError, NarwhalsError)):
df.select(nw.col("a").replace_strict([1, 3], [3, 4], return_dtype=nw.Int64))


Expand Down
3 changes: 1 addition & 2 deletions tests/frame/drop_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import Any

import pytest
from polars.exceptions import ColumnNotFoundError as PlColumnNotFoundError

import narwhals.stable.v1 as nw
from narwhals.exceptions import ColumnNotFoundError
Expand Down Expand Up @@ -36,7 +35,7 @@ def test_drop(constructor: Constructor, to_drop: list[str], expected: list[str])
[
(
True,
pytest.raises((ColumnNotFoundError, PlColumnNotFoundError), match="z"),
pytest.raises(ColumnNotFoundError, match="z"),
),
(False, does_not_raise()),
],
Expand Down
6 changes: 2 additions & 4 deletions tests/frame/explode_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
from typing import Sequence

import pytest
from polars.exceptions import InvalidOperationError as PlInvalidOperationError
from polars.exceptions import ShapeError as PlShapeError

import narwhals.stable.v1 as nw
from narwhals.exceptions import InvalidOperationError
Expand Down Expand Up @@ -119,7 +117,7 @@ def test_explode_shape_error(
request.applymarker(pytest.mark.xfail)

with pytest.raises(
(ShapeError, PlShapeError, NotImplementedError),
(ShapeError, NotImplementedError),
match=r".*exploded columns (must )?have matching element counts",
):
_ = (
Expand All @@ -141,7 +139,7 @@ def test_explode_invalid_operation_error(
request.applymarker(pytest.mark.xfail)

with pytest.raises(
(InvalidOperationError, PlInvalidOperationError),
InvalidOperationError,
match="`explode` operation not supported for dtype",
):
_ = nw.from_native(constructor(data)).lazy().explode("a").collect()
3 changes: 1 addition & 2 deletions tests/frame/filter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from contextlib import nullcontext as does_not_raise

import pytest
from polars.exceptions import ShapeError as PlShapeError

import narwhals as nw
from narwhals.exceptions import LengthChangingExprError
Expand Down Expand Up @@ -44,5 +43,5 @@ def test_filter_raise_on_agg_predicate(constructor: Constructor) -> None:
def test_filter_raise_on_shape_mismatch(constructor: Constructor) -> None:
data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8.0, 9.0]}
df = nw.from_native(constructor(data))
with pytest.raises((LengthChangingExprError, ShapeError, PlShapeError)):
with pytest.raises((LengthChangingExprError, ShapeError)):
df.filter(nw.col("b").unique() > 2).lazy().collect()
4 changes: 2 additions & 2 deletions tests/frame/pivot_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from contextlib import nullcontext as does_not_raise
from typing import Any

import polars as pl
import pytest

import narwhals.stable.v1 as nw
from narwhals.exceptions import NarwhalsError
from tests.utils import PANDAS_VERSION
from tests.utils import POLARS_VERSION
from tests.utils import ConstructorEager
Expand Down Expand Up @@ -147,7 +147,7 @@ def test_pivot(
("data_", "context"),
[
(data_no_dups, does_not_raise()),
(data, pytest.raises((ValueError, pl.exceptions.ComputeError))),
(data, pytest.raises((ValueError, NarwhalsError))),
],
)
def test_pivot_no_agg(
Expand Down
Loading
Loading