From a55ff8972ba6a0cee19c7821ba5ba99ad5c4dcce Mon Sep 17 00:00:00 2001 From: Francesco Bruzzesi <42817048+FBruzzesi@users.noreply.github.com> Date: Mon, 20 Jan 2025 09:29:16 +0100 Subject: [PATCH] feat: add `is_nan` and `is_finite` for duckdb, `is_nan` for pyspark (#1825) * feat: add is_nan for duckdb and pyspark * preserve nulls in pyspark is_finite --- narwhals/_duckdb/expr.py | 18 ++++++++++++++++++ narwhals/_spark_like/expr.py | 19 +++++++++++++------ tests/expr_and_series/is_finite_test.py | 14 ++++++-------- tests/expr_and_series/is_nan_test.py | 8 +++----- 4 files changed, 40 insertions(+), 19 deletions(-) diff --git a/narwhals/_duckdb/expr.py b/narwhals/_duckdb/expr.py index c79acdfaec..c6ff5e2de7 100644 --- a/narwhals/_duckdb/expr.py +++ b/narwhals/_duckdb/expr.py @@ -515,6 +515,24 @@ def is_null(self) -> Self: lambda _input: _input.isnull(), "is_null", returns_scalar=self._returns_scalar ) + def is_nan(self) -> Self: + from duckdb import FunctionExpression + + return self._from_call( + lambda _input: FunctionExpression("isnan", _input), + "is_nan", + returns_scalar=self._returns_scalar, + ) + + def is_finite(self) -> Self: + from duckdb import FunctionExpression + + return self._from_call( + lambda _input: FunctionExpression("isfinite", _input), + "is_finite", + returns_scalar=self._returns_scalar, + ) + def is_in(self, other: Sequence[Any]) -> Self: from duckdb import ConstantExpression diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index ba063c6424..723acbc2b3 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -433,13 +433,12 @@ def is_finite(self) -> Self: def _is_finite(_input: Column) -> Column: from pyspark.sql import functions as F # noqa: N812 - # A value is finite if it's not NaN, not NULL, and not infinite - return ( - ~F.isnan(_input) - & ~F.isnull(_input) - & (_input != float("inf")) - & (_input != float("-inf")) + # A value is finite if it's not NaN, and not infinite, while NULLs should be + # preserved + is_finite_condition = ( + ~F.isnan(_input) & (_input != float("inf")) & (_input != float("-inf")) ) + return F.when(~F.isnull(_input), is_finite_condition).otherwise(None) return self._from_call( _is_finite, "is_finite", returns_scalar=self._returns_scalar @@ -535,6 +534,14 @@ def is_null(self: Self) -> Self: return self._from_call(F.isnull, "is_null", returns_scalar=self._returns_scalar) + def is_nan(self: Self) -> Self: + from pyspark.sql import functions as F # noqa: N812 + + def _is_nan(_input: Column) -> Column: + return F.when(F.isnull(_input), None).otherwise(F.isnan(_input)) + + return self._from_call(_is_nan, "is_nan", returns_scalar=self._returns_scalar) + @property def str(self: Self) -> SparkLikeExprStringNamespace: return SparkLikeExprStringNamespace(self) diff --git a/tests/expr_and_series/is_finite_test.py b/tests/expr_and_series/is_finite_test.py index 4fb0246e91..8bd3bc0d20 100644 --- a/tests/expr_and_series/is_finite_test.py +++ b/tests/expr_and_series/is_finite_test.py @@ -11,15 +11,13 @@ @pytest.mark.filterwarnings("ignore:invalid value encountered in cast") -def test_is_finite_expr(constructor: Constructor, request: pytest.FixtureRequest) -> None: - if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): - request.applymarker(pytest.mark.xfail) - if "polars" in str(constructor) or "pyarrow_table" in str(constructor): +def test_is_finite_expr(constructor: Constructor) -> None: + if any( + x in str(constructor) for x in ("polars", "pyarrow_table", "duckdb", "pyspark") + ): expected = {"a": [False, False, True, None]} - elif ( - "pandas_constructor" in str(constructor) - or "dask" in str(constructor) - or "modin_constructor" in str(constructor) + elif any( + x in str(constructor) for x in ("pandas_constructor", "dask", "modin_constructor") ): expected = {"a": [False, False, True, False]} else: # pandas_nullable_constructor, pandas_pyarrow_constructor, modin_pyarrrow_constructor diff --git a/tests/expr_and_series/is_nan_test.py b/tests/expr_and_series/is_nan_test.py index 0280d6555f..44ca575c62 100644 --- a/tests/expr_and_series/is_nan_test.py +++ b/tests/expr_and_series/is_nan_test.py @@ -24,12 +24,10 @@ ] -def test_nan(constructor: Constructor, request: pytest.FixtureRequest) -> None: - if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): - request.applymarker(pytest.mark.xfail) - data_na = {"int": [0, 1, None]} +def test_nan(constructor: Constructor) -> None: + data_na = {"int": [-1, 1, None]} df = nw.from_native(constructor(data_na)).with_columns( - float=nw.col("int").cast(nw.Float64), float_na=nw.col("int") / nw.col("int") + float=nw.col("int").cast(nw.Float64), float_na=nw.col("int") ** 0.5 ) result = df.select( int=nw.col("int").is_nan(),