Skip to content

Commit

Permalink
solve tests conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
FBruzzesi committed Jan 9, 2025
2 parents e4c8281 + 0f38521 commit 931993e
Show file tree
Hide file tree
Showing 7 changed files with 115 additions and 51 deletions.
2 changes: 1 addition & 1 deletion narwhals/_spark_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __native_namespace__(self) -> Any: # pragma: no cover
def __narwhals_namespace__(self) -> SparkLikeNamespace:
from narwhals._spark_like.namespace import SparkLikeNamespace

return SparkLikeNamespace( # type: ignore[abstract]
return SparkLikeNamespace(
backend_version=self._backend_version, version=self._version
)

Expand Down
87 changes: 80 additions & 7 deletions narwhals/_spark_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __narwhals_namespace__(self) -> SparkLikeNamespace: # pragma: no cover
# Unused, just for compatibility with PandasLikeExpr
from narwhals._spark_like.namespace import SparkLikeNamespace

return SparkLikeNamespace( # type: ignore[abstract]
return SparkLikeNamespace(
backend_version=self._backend_version, version=self._version
)

Expand Down Expand Up @@ -139,32 +139,66 @@ def __ne__(self, other: SparkLikeExpr) -> Self: # type: ignore[override]

def __add__(self, other: SparkLikeExpr) -> Self:
return self._from_call(
lambda _input, other: _input + other,
lambda _input, other: _input.__add__(other),
"__add__",
other=other,
returns_scalar=False,
)

def __sub__(self, other: SparkLikeExpr) -> Self:
return self._from_call(
lambda _input, other: _input - other,
lambda _input, other: _input.__sub__(other),
"__sub__",
other=other,
returns_scalar=False,
)

def __mul__(self, other: SparkLikeExpr) -> Self:
return self._from_call(
lambda _input, other: _input * other,
lambda _input, other: _input.__mul__(other),
"__mul__",
other=other,
returns_scalar=False,
)

def __lt__(self, other: SparkLikeExpr) -> Self:
def __truediv__(self, other: SparkLikeExpr) -> Self:
return self._from_call(
lambda _input, other: _input < other,
"__lt__",
lambda _input, other: _input.__truediv__(other),
"__truediv__",
other=other,
returns_scalar=False,
)

def __floordiv__(self, other: SparkLikeExpr) -> Self:
def _floordiv(_input: Column, other: Column) -> Column:
from pyspark.sql import functions as F # noqa: N812

return F.floor(_input / other)

return self._from_call(
_floordiv, "__floordiv__", other=other, returns_scalar=False
)

def __pow__(self, other: SparkLikeExpr) -> Self:
return self._from_call(
lambda _input, other: _input.__pow__(other),
"__pow__",
other=other,
returns_scalar=False,
)

def __mod__(self, other: SparkLikeExpr) -> Self:
return self._from_call(
lambda _input, other: _input.__mod__(other),
"__mod__",
other=other,
returns_scalar=False,
)

def __ge__(self, other: SparkLikeExpr) -> Self:
return self._from_call(
lambda _input, other: _input.__ge__(other),
"__ge__",
other=other,
returns_scalar=False,
)
Expand All @@ -177,6 +211,45 @@ def __gt__(self, other: SparkLikeExpr) -> Self:
returns_scalar=False,
)

def __le__(self, other: SparkLikeExpr) -> Self:
return self._from_call(
lambda _input, other: _input.__le__(other),
"__le__",
other=other,
returns_scalar=False,
)

def __lt__(self, other: SparkLikeExpr) -> Self:
return self._from_call(
lambda _input, other: _input.__lt__(other),
"__lt__",
other=other,
returns_scalar=False,
)

def __and__(self, other: SparkLikeExpr) -> Self:
return self._from_call(
lambda _input, other: _input.__and__(other),
"__and__",
other=other,
returns_scalar=False,
)

def __or__(self, other: SparkLikeExpr) -> Self:
return self._from_call(
lambda _input, other: _input.__or__(other),
"__or__",
other=other,
returns_scalar=False,
)

def __invert__(self) -> Self:
return self._from_call(
lambda _input: _input.__invert__(),
"__invert__",
returns_scalar=self._returns_scalar,
)

def abs(self) -> Self:
from pyspark.sql import functions as F # noqa: N812

Expand Down
23 changes: 23 additions & 0 deletions narwhals/_spark_like/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from narwhals._spark_like.dataframe import SparkLikeLazyFrame
from narwhals._spark_like.typing import IntoSparkLikeExpr
from narwhals.dtypes import DType
from narwhals.utils import Version


Expand Down Expand Up @@ -67,6 +68,28 @@ def col(self, *column_names: str) -> SparkLikeExpr:
*column_names, backend_version=self._backend_version, version=self._version
)

def lit(self, value: object, dtype: DType | None) -> SparkLikeExpr:
if dtype is not None:
msg = "todo"
raise NotImplementedError(msg)

def _lit(_: SparkLikeLazyFrame) -> list[Column]:
import pyspark.sql.functions as F # noqa: N812

return [F.lit(value).alias("literal")]

return SparkLikeExpr( # type: ignore[abstract]
call=_lit,
depth=0,
function_name="lit",
root_names=None,
output_names=["literal"],
returns_scalar=True,
backend_version=self._backend_version,
version=self._version,
kwargs={},
)

def sum_horizontal(self, *exprs: IntoSparkLikeExpr) -> SparkLikeExpr:
parsed_exprs = parse_into_exprs(*exprs, namespace=self)

Expand Down
16 changes: 2 additions & 14 deletions tests/expr_and_series/arithmetic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,6 @@ def test_arithmetic_expr(
):
request.applymarker(pytest.mark.xfail)

if "pyspark" in str(constructor) and attr in {
"__pow__",
"__mod__",
"__truediv__",
"__floordiv__",
}:
request.applymarker(pytest.mark.xfail)

data = {"a": [1.0, 2, 3]}
df = nw.from_native(constructor(data))
result = df.select(getattr(nw.col("a"), attr)(rhs))
Expand Down Expand Up @@ -84,8 +76,6 @@ def test_right_arithmetic_expr(
x in str(constructor) for x in ["pandas_pyarrow", "modin_pyarrow"]
):
request.applymarker(pytest.mark.xfail)
if "pyspark" in str(constructor):
request.applymarker(pytest.mark.xfail)
data = {"a": [1, 2, 3]}
df = nw.from_native(constructor(data))
result = df.select(getattr(nw.col("a"), attr)(rhs))
Expand Down Expand Up @@ -255,10 +245,8 @@ def test_arithmetic_expr_left_literal(
constructor: Constructor,
request: pytest.FixtureRequest,
) -> None:
if (
("duckdb" in str(constructor) and attr == "__floordiv__")
or ("dask" in str(constructor) and DASK_VERSION < (2024, 10))
or ("pyspark" in str(constructor))
if ("duckdb" in str(constructor) and attr == "__floordiv__") or (
"dask" in str(constructor) and DASK_VERSION < (2024, 10)
):
request.applymarker(pytest.mark.xfail)
if attr == "__mod__" and any(
Expand Down
14 changes: 7 additions & 7 deletions tests/expr_and_series/lit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_lit(
dtype: DType | None,
expected_lit: list[Any],
) -> None:
if "pyspark" in str(constructor):
if "pyspark" in str(constructor) and dtype is not None:
request.applymarker(pytest.mark.xfail)
data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]}
df_raw = constructor(data)
Expand Down Expand Up @@ -60,9 +60,7 @@ def test_lit_error(constructor: Constructor) -> None:
_ = df.with_columns(nw.lit([1, 2]).alias("lit"))


def test_lit_out_name(request: pytest.FixtureRequest, constructor: Constructor) -> None:
if "pyspark" in str(constructor):
request.applymarker(pytest.mark.xfail)
def test_lit_out_name(constructor: Constructor) -> None:
data = {"a": [1, 3, 2]}
df_raw = constructor(data)
df = nw.from_native(df_raw).lazy()
Expand Down Expand Up @@ -107,9 +105,11 @@ def test_lit_operation(
and DASK_VERSION < (2024, 10)
):
request.applymarker(pytest.mark.xfail)
if "pyspark" in str(constructor) and col_name not in {
"right_scalar_with_agg",
"right_scalar",
if "pyspark" in str(constructor) and col_name in {
"left_lit_with_agg",
"left_scalar_with_agg",
"right_lit_with_agg",
"right_lit",
}:
request.applymarker(pytest.mark.xfail)

Expand Down
17 changes: 1 addition & 16 deletions tests/expr_and_series/operators_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,10 @@
],
)
def test_comparand_operators_scalar_expr(
request: pytest.FixtureRequest,
constructor: Constructor,
operator: str,
expected: list[bool],
) -> None:
if "pyspark" in str(constructor) and operator in {
"__le__",
"__ge__",
}:
request.applymarker(pytest.mark.xfail)
data = {"a": [0, 1, 2]}
df = nw.from_native(constructor(data))
result = df.select(getattr(nw.col("a"), operator)(1))
Expand All @@ -49,16 +43,10 @@ def test_comparand_operators_scalar_expr(
],
)
def test_comparand_operators_expr(
request: pytest.FixtureRequest,
constructor: Constructor,
operator: str,
expected: list[bool],
) -> None:
if "pyspark" in str(constructor) and operator in {
"__le__",
"__ge__",
}:
request.applymarker(pytest.mark.xfail)
data = {"a": [0, 1, 1], "b": [0, 0, 2]}
df = nw.from_native(constructor(data))
result = df.select(getattr(nw.col("a"), operator)(nw.col("b")))
Expand All @@ -73,13 +61,10 @@ def test_comparand_operators_expr(
],
)
def test_logic_operators_expr(
request: pytest.FixtureRequest,
constructor: Constructor,
operator: str,
expected: list[bool],
) -> None:
if "pyspark" in str(constructor):
request.applymarker(pytest.mark.xfail)
data = {"a": [True, True, False, False], "b": [True, False, True, False]}
df = nw.from_native(constructor(data))

Expand All @@ -106,7 +91,7 @@ def test_logic_operators_expr_scalar(
"dask" in str(constructor)
and DASK_VERSION < (2024, 10)
and operator in ("__rand__", "__ror__")
) or "pyspark" in str(constructor):
) or ("pyspark" in str(constructor) and operator in ("__and__", "__or__")):
request.applymarker(pytest.mark.xfail)
data = {"a": [True, True, False, False]}
df = nw.from_native(constructor(data))
Expand Down
7 changes: 1 addition & 6 deletions tests/expr_and_series/pipe_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from __future__ import annotations

import pytest

import narwhals.stable.v1 as nw
from tests.utils import Constructor
from tests.utils import ConstructorEager
Expand All @@ -11,10 +9,7 @@
expected = [4, 16, 36, 64]


def test_pipe_expr(request: pytest.FixtureRequest, constructor: Constructor) -> None:
if "pyspark" in str(constructor):
request.applymarker(pytest.mark.xfail)

def test_pipe_expr(constructor: Constructor) -> None:
df = nw.from_native(constructor(input_list))
e = df.select(nw.col("a").pipe(lambda x: x**2))
assert_equal_data(e, {"a": expected})
Expand Down

0 comments on commit 931993e

Please sign in to comment.