Skip to content

Commit

Permalink
feat: add is_nan and is_finite for duckdb, is_nan for pyspark (#…
Browse files Browse the repository at this point in the history
…1825)

* feat: add is_nan for duckdb and pyspark

* preserve nulls in pyspark is_finite
  • Loading branch information
FBruzzesi authored Jan 20, 2025
1 parent 79098f1 commit a55ff89
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 19 deletions.
18 changes: 18 additions & 0 deletions narwhals/_duckdb/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,24 @@ def is_null(self) -> Self:
lambda _input: _input.isnull(), "is_null", returns_scalar=self._returns_scalar
)

def is_nan(self) -> Self:
from duckdb import FunctionExpression

return self._from_call(
lambda _input: FunctionExpression("isnan", _input),
"is_nan",
returns_scalar=self._returns_scalar,
)

def is_finite(self) -> Self:
from duckdb import FunctionExpression

return self._from_call(
lambda _input: FunctionExpression("isfinite", _input),
"is_finite",
returns_scalar=self._returns_scalar,
)

def is_in(self, other: Sequence[Any]) -> Self:
from duckdb import ConstantExpression

Expand Down
19 changes: 13 additions & 6 deletions narwhals/_spark_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,13 +433,12 @@ def is_finite(self) -> Self:
def _is_finite(_input: Column) -> Column:
from pyspark.sql import functions as F # noqa: N812

# A value is finite if it's not NaN, not NULL, and not infinite
return (
~F.isnan(_input)
& ~F.isnull(_input)
& (_input != float("inf"))
& (_input != float("-inf"))
# A value is finite if it's not NaN, and not infinite, while NULLs should be
# preserved
is_finite_condition = (
~F.isnan(_input) & (_input != float("inf")) & (_input != float("-inf"))
)
return F.when(~F.isnull(_input), is_finite_condition).otherwise(None)

return self._from_call(
_is_finite, "is_finite", returns_scalar=self._returns_scalar
Expand Down Expand Up @@ -535,6 +534,14 @@ def is_null(self: Self) -> Self:

return self._from_call(F.isnull, "is_null", returns_scalar=self._returns_scalar)

def is_nan(self: Self) -> Self:
from pyspark.sql import functions as F # noqa: N812

def _is_nan(_input: Column) -> Column:
return F.when(F.isnull(_input), None).otherwise(F.isnan(_input))

return self._from_call(_is_nan, "is_nan", returns_scalar=self._returns_scalar)

@property
def str(self: Self) -> SparkLikeExprStringNamespace:
return SparkLikeExprStringNamespace(self)
Expand Down
14 changes: 6 additions & 8 deletions tests/expr_and_series/is_finite_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,13 @@


@pytest.mark.filterwarnings("ignore:invalid value encountered in cast")
def test_is_finite_expr(constructor: Constructor, request: pytest.FixtureRequest) -> None:
if ("pyspark" in str(constructor)) or "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
if "polars" in str(constructor) or "pyarrow_table" in str(constructor):
def test_is_finite_expr(constructor: Constructor) -> None:
if any(
x in str(constructor) for x in ("polars", "pyarrow_table", "duckdb", "pyspark")
):
expected = {"a": [False, False, True, None]}
elif (
"pandas_constructor" in str(constructor)
or "dask" in str(constructor)
or "modin_constructor" in str(constructor)
elif any(
x in str(constructor) for x in ("pandas_constructor", "dask", "modin_constructor")
):
expected = {"a": [False, False, True, False]}
else: # pandas_nullable_constructor, pandas_pyarrow_constructor, modin_pyarrrow_constructor
Expand Down
8 changes: 3 additions & 5 deletions tests/expr_and_series/is_nan_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,10 @@
]


def test_nan(constructor: Constructor, request: pytest.FixtureRequest) -> None:
if ("pyspark" in str(constructor)) or "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
data_na = {"int": [0, 1, None]}
def test_nan(constructor: Constructor) -> None:
data_na = {"int": [-1, 1, None]}
df = nw.from_native(constructor(data_na)).with_columns(
float=nw.col("int").cast(nw.Float64), float_na=nw.col("int") / nw.col("int")
float=nw.col("int").cast(nw.Float64), float_na=nw.col("int") ** 0.5
)
result = df.select(
int=nw.col("int").is_nan(),
Expand Down

0 comments on commit a55ff89

Please sign in to comment.