Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: move pyspark tests into main test suite #1761

Merged
merged 20 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions narwhals/_spark_like/dataframe.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from itertools import chain
from typing import TYPE_CHECKING
from typing import Any
from typing import Iterable
Expand All @@ -8,6 +9,7 @@

from narwhals._spark_like.utils import native_to_narwhals_dtype
from narwhals._spark_like.utils import parse_exprs_and_named_exprs
from narwhals.exceptions import ColumnNotFoundError
from narwhals.utils import Implementation
from narwhals.utils import flatten
from narwhals.utils import parse_columns_to_drop
Expand Down Expand Up @@ -106,9 +108,11 @@ def select(
new_columns_list = [col.alias(col_name) for col_name, col in new_columns.items()]
return self._from_native_frame(self._native_frame.select(*new_columns_list))

def filter(self, *predicates: SparkLikeExpr) -> Self:
def filter(self, *predicates: SparkLikeExpr, **constraints: Any) -> Self:
plx = self.__narwhals_namespace__()
expr = plx.all_horizontal(*predicates)
expr = plx.all_horizontal(
*chain(predicates, (plx.col(name) == v for name, v in constraints.items()))
)
Comment on lines +113 to +115
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needed to implement Expr.__eq__ to get this to work. It overlaps with @EdAbati PR

# `[0]` is safe as all_horizontal's expression only returns a single column
condition = expr._call(self)[0]
spark_df = self._native_frame.where(condition)
Expand Down Expand Up @@ -203,6 +207,11 @@ def unique(
if keep != "any":
msg = "`LazyFrame.unique` with PySpark backend only supports `keep='any'`."
raise ValueError(msg)

if subset is not None and any(x not in self.columns for x in subset):
msg = f"Column(s) {subset} not found in {self.columns}"
raise ColumnNotFoundError(msg)

subset = [subset] if isinstance(subset, str) else subset
return self._from_native_frame(self._native_frame.dropDuplicates(subset=subset))

Expand Down
16 changes: 16 additions & 0 deletions narwhals/_spark_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,22 @@ def func(df: SparkLikeLazyFrame) -> list[Column]:
kwargs=kwargs,
)

def __eq__(self, other: SparkLikeExpr) -> Self: # type: ignore[override]
return self._from_call(
lambda _input, other: _input.__eq__(other),
"__eq__",
other=other,
returns_scalar=False,
)

def __ne__(self, other: SparkLikeExpr) -> Self: # type: ignore[override]
return self._from_call(
lambda _input, other: _input.__ne__(other),
"__ne__",
other=other,
returns_scalar=False,
)

def __add__(self, other: SparkLikeExpr) -> Self:
return self._from_call(
lambda _input, other: _input + other,
Expand Down
2 changes: 2 additions & 0 deletions narwhals/_spark_like/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ def agg_pyspark(
simple_aggregations.update(
{
output_name: agg_func(root_name)
if function_name != "len"
else agg_func("*")
FBruzzesi marked this conversation as resolved.
Show resolved Hide resolved
for root_name, output_name in zip(expr._root_names, expr._output_names)
}
)
Expand Down
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,11 @@ filterwarnings = [
'ignore:.*Passing a BlockManager to DataFrame:DeprecationWarning',
# This warning was temporarily raised by Polars but then reverted.
'ignore:.*The default coalesce behavior of left join will change:DeprecationWarning',
'ignore:.*The distutils package is deprecated and slated for removal in Python 3.12:DeprecationWarning',
'ignore:.*distutils Version classes are deprecated. Use packaging.version instead.*:DeprecationWarning',
'ignore:.*is_datetime64tz_dtype is deprecated and will be removed in a future version.*:DeprecationWarning',
'ignore: unclosed <socket.socket',
FBruzzesi marked this conversation as resolved.
Show resolved Hide resolved

]
xfail_strict = true
markers = ["slow: marks tests as slow (deselect with '-m \"not slow\"')"]
Expand Down
29 changes: 21 additions & 8 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable
from typing import Generator
from typing import Sequence

import pandas as pd
Expand All @@ -14,7 +13,6 @@

if TYPE_CHECKING:
import duckdb
from pyspark.sql import SparkSession

from narwhals.typing import IntoDataFrame
from narwhals.typing import IntoFrame
Expand Down Expand Up @@ -129,15 +127,15 @@ def pyarrow_table_constructor(obj: Any) -> IntoDataFrame:
return pa.table(obj) # type: ignore[no-any-return]


@pytest.fixture(scope="session")
def spark_session() -> Generator[SparkSession, None, None]: # pragma: no cover
def pyspark_lazy_constructor() -> Callable[[Any], IntoFrame]: # pragma: no cover
try:
from pyspark.sql import SparkSession
except ImportError: # pragma: no cover
pytest.skip("pyspark is not installed")
return
return None

import warnings
from atexit import register

os.environ["PYARROW_IGNORE_TIMEZONE"] = "1"
with warnings.catch_warnings():
Expand All @@ -155,8 +153,19 @@ def spark_session() -> Generator[SparkSession, None, None]: # pragma: no cover
.config("spark.sql.shuffle.partitions", "2")
.getOrCreate()
)
yield session
session.stop()

register(session.stop)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TIL atexit.register, nice!


def _constructor(obj: Any) -> IntoFrame:
pd_df = pd.DataFrame(obj).replace({float("nan"): None}).reset_index()
return ( # type: ignore[no-any-return]
session.createDataFrame(pd_df)
.repartition(2)
.orderBy("index")
.drop("index")
)

return _constructor


EAGER_CONSTRUCTORS: dict[str, Callable[[Any], IntoDataFrame]] = {
Expand All @@ -173,6 +182,7 @@ def spark_session() -> Generator[SparkSession, None, None]: # pragma: no cover
"dask": dask_lazy_p2_constructor,
"polars[lazy]": polars_lazy_constructor,
"duckdb": duckdb_lazy_constructor,
"pyspark": pyspark_lazy_constructor, # type: ignore[dict-item]
}
GPU_CONSTRUCTORS: dict[str, Callable[[Any], IntoFrame]] = {"cudf": cudf_constructor}

Expand Down Expand Up @@ -201,7 +211,10 @@ def pytest_generate_tests(metafunc: pytest.Metafunc) -> None:
constructors.append(EAGER_CONSTRUCTORS[constructor])
constructors_ids.append(constructor)
elif constructor in LAZY_CONSTRUCTORS:
constructors.append(LAZY_CONSTRUCTORS[constructor])
if constructor == "pyspark":
constructors.append(pyspark_lazy_constructor())
else:
constructors.append(LAZY_CONSTRUCTORS[constructor])
constructors_ids.append(constructor)
else: # pragma: no cover
msg = f"Expected one of {EAGER_CONSTRUCTORS.keys()} or {LAZY_CONSTRUCTORS.keys()}, got {constructor}"
Expand Down
2 changes: 1 addition & 1 deletion tests/expr_and_series/all_horizontal_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def test_allh_nth(
) -> None:
if "polars" in str(constructor) and POLARS_VERSION < (1, 0):
request.applymarker(pytest.mark.xfail)
if "duckdb" in str(constructor):
if ("pyspark" in str(constructor)) or "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
data = {
"a": [False, False, True],
Expand Down
7 changes: 6 additions & 1 deletion tests/expr_and_series/any_all_test.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
from __future__ import annotations

import pytest

import narwhals.stable.v1 as nw
from tests.utils import Constructor
from tests.utils import ConstructorEager
from tests.utils import assert_equal_data


def test_any_all(constructor: Constructor) -> None:
def test_any_all(request: pytest.FixtureRequest, constructor: Constructor) -> None:
if "pyspark" in str(constructor):
request.applymarker(pytest.mark.xfail)

df = nw.from_native(
constructor(
{
Expand Down
10 changes: 8 additions & 2 deletions tests/expr_and_series/any_horizontal_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@

@pytest.mark.parametrize("expr1", ["a", nw.col("a")])
@pytest.mark.parametrize("expr2", ["b", nw.col("b")])
def test_anyh(constructor: Constructor, expr1: Any, expr2: Any) -> None:
def test_anyh(
request: pytest.FixtureRequest, constructor: Constructor, expr1: Any, expr2: Any
) -> None:
if "pyspark" in str(constructor):
request.applymarker(pytest.mark.xfail)
data = {
"a": [False, False, True],
"b": [False, True, True],
Expand All @@ -23,7 +27,9 @@ def test_anyh(constructor: Constructor, expr1: Any, expr2: Any) -> None:
assert_equal_data(result, expected)


def test_anyh_all(constructor: Constructor) -> None:
def test_anyh_all(request: pytest.FixtureRequest, constructor: Constructor) -> None:
if "pyspark" in str(constructor):
request.applymarker(pytest.mark.xfail)
data = {
"a": [False, False, True],
"b": [False, True, True],
Expand Down
17 changes: 14 additions & 3 deletions tests/expr_and_series/arithmetic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,14 @@ def test_arithmetic_expr(
):
request.applymarker(pytest.mark.xfail)

if "pyspark" in str(constructor) and attr in {
"__pow__",
"__mod__",
"__truediv__",
"__floordiv__",
}:
request.applymarker(pytest.mark.xfail)

data = {"a": [1.0, 2, 3]}
df = nw.from_native(constructor(data))
result = df.select(getattr(nw.col("a"), attr)(rhs))
Expand Down Expand Up @@ -76,7 +84,8 @@ def test_right_arithmetic_expr(
x in str(constructor) for x in ["pandas_pyarrow", "modin_pyarrow"]
):
request.applymarker(pytest.mark.xfail)

if "pyspark" in str(constructor):
request.applymarker(pytest.mark.xfail)
data = {"a": [1, 2, 3]}
df = nw.from_native(constructor(data))
result = df.select(getattr(nw.col("a"), attr)(rhs))
Expand Down Expand Up @@ -246,8 +255,10 @@ def test_arithmetic_expr_left_literal(
constructor: Constructor,
request: pytest.FixtureRequest,
) -> None:
if ("duckdb" in str(constructor) and attr == "__floordiv__") or (
"dask" in str(constructor) and DASK_VERSION < (2024, 10)
if (
("duckdb" in str(constructor) and attr == "__floordiv__")
or ("dask" in str(constructor) and DASK_VERSION < (2024, 10))
or ("pyspark" in str(constructor))
):
request.applymarker(pytest.mark.xfail)
if attr == "__mod__" and any(
Expand Down
4 changes: 3 additions & 1 deletion tests/expr_and_series/binary_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@


def test_expr_binary(constructor: Constructor, request: pytest.FixtureRequest) -> None:
if "dask" in str(constructor) and DASK_VERSION < (2024, 10):
if ("dask" in str(constructor) and DASK_VERSION < (2024, 10)) or "pyspark" in str(
constructor
):
request.applymarker(pytest.mark.xfail)
data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]}
df_raw = constructor(data)
Expand Down
8 changes: 5 additions & 3 deletions tests/expr_and_series/cast_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def test_cast(
constructor: Constructor,
request: pytest.FixtureRequest,
) -> None:
if "duckdb" in str(constructor):
if ("pyspark" in str(constructor)) or "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
if "pyarrow_table_constructor" in str(constructor) and PYARROW_VERSION <= (
15,
Expand Down Expand Up @@ -180,7 +180,7 @@ def test_cast_string() -> None:
def test_cast_raises_for_unknown_dtype(
constructor: Constructor, request: pytest.FixtureRequest
) -> None:
if "duckdb" in str(constructor):
if ("pyspark" in str(constructor)) or "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
if "pyarrow_table" in str(constructor) and PYARROW_VERSION < (15,):
# Unsupported cast from string to dictionary using function cast_dictionary
Expand All @@ -204,6 +204,7 @@ def test_cast_datetime_tz_aware(
or "duckdb" in str(constructor)
or "cudf" in str(constructor) # https://github.com/rapidsai/cudf/issues/16973
or ("pyarrow_table" in str(constructor) and is_windows())
or ("pyspark" in str(constructor))
):
request.applymarker(pytest.mark.xfail)

Expand All @@ -229,7 +230,8 @@ def test_cast_datetime_tz_aware(

def test_cast_struct(request: pytest.FixtureRequest, constructor: Constructor) -> None:
if any(
backend in str(constructor) for backend in ("dask", "modin", "cudf", "duckdb")
backend in str(constructor)
for backend in ("dask", "modin", "cudf", "duckdb", "pyspark")
):
request.applymarker(pytest.mark.xfail)

Expand Down
2 changes: 1 addition & 1 deletion tests/expr_and_series/concat_str_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_concat_str(
expected: list[str],
request: pytest.FixtureRequest,
) -> None:
if "duckdb" in str(constructor):
if ("pyspark" in str(constructor)) or "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
df = nw.from_native(constructor(data))
result = (
Expand Down
2 changes: 2 additions & 0 deletions tests/expr_and_series/convert_time_zone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def test_convert_time_zone(
or ("modin_pyarrow" in str(constructor) and PANDAS_VERSION < (2, 1))
or ("cudf" in str(constructor))
or ("duckdb" in str(constructor))
or ("pyspark" in str(constructor))
):
request.applymarker(pytest.mark.xfail)
data = {
Expand Down Expand Up @@ -86,6 +87,7 @@ def test_convert_time_zone_from_none(
or ("pyarrow_table" in str(constructor) and PYARROW_VERSION < (12,))
or ("cudf" in str(constructor))
or ("duckdb" in str(constructor))
or ("pyspark" in str(constructor))
):
request.applymarker(pytest.mark.xfail)
if "polars" in str(constructor) and POLARS_VERSION < (0, 20, 7):
Expand Down
2 changes: 1 addition & 1 deletion tests/expr_and_series/cum_count_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def test_cum_count_expr(
) -> None:
if "dask" in str(constructor) and reverse:
request.applymarker(pytest.mark.xfail)
if "duckdb" in str(constructor):
if ("pyspark" in str(constructor)) or "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)

name = "reverse_cum_count" if reverse else "cum_count"
Expand Down
2 changes: 1 addition & 1 deletion tests/expr_and_series/cum_max_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test_cum_max_expr(
) -> None:
if "dask" in str(constructor) and reverse:
request.applymarker(pytest.mark.xfail)
if "duckdb" in str(constructor):
if ("pyspark" in str(constructor)) or "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)

if PYARROW_VERSION < (13, 0, 0) and "pyarrow_table" in str(constructor):
Expand Down
2 changes: 1 addition & 1 deletion tests/expr_and_series/cum_min_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test_cum_min_expr(
) -> None:
if "dask" in str(constructor) and reverse:
request.applymarker(pytest.mark.xfail)
if "duckdb" in str(constructor):
if ("pyspark" in str(constructor)) or "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)

if PYARROW_VERSION < (13, 0, 0) and "pyarrow_table" in str(constructor):
Expand Down
2 changes: 1 addition & 1 deletion tests/expr_and_series/cum_prod_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test_cum_prod_expr(
) -> None:
if "dask" in str(constructor) and reverse:
request.applymarker(pytest.mark.xfail)
if "duckdb" in str(constructor):
if ("pyspark" in str(constructor)) or "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)

if PYARROW_VERSION < (13, 0, 0) and "pyarrow_table" in str(constructor):
Expand Down
2 changes: 1 addition & 1 deletion tests/expr_and_series/cum_sum_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
def test_cum_sum_expr(
request: pytest.FixtureRequest, constructor: Constructor, *, reverse: bool
) -> None:
if "duckdb" in str(constructor):
if ("pyspark" in str(constructor)) or "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
if "dask" in str(constructor) and reverse:
request.applymarker(pytest.mark.xfail)
Expand Down
3 changes: 3 additions & 0 deletions tests/expr_and_series/dt/datetime_attributes_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ def test_datetime_attributes(
request.applymarker(pytest.mark.xfail)
if "duckdb" in str(constructor) and attribute in ("date", "weekday", "ordinal_day"):
request.applymarker(pytest.mark.xfail)
if "pyspark" in str(constructor):
request.applymarker(pytest.mark.xfail)

df = nw.from_native(constructor(data))
result = df.select(getattr(nw.col("a").dt, attribute)())
Expand Down Expand Up @@ -121,6 +123,7 @@ def test_to_date(request: pytest.FixtureRequest, constructor: Constructor) -> No
"cudf",
"modin_constructor",
"duckdb",
"pyspark",
)
):
request.applymarker(pytest.mark.xfail)
Expand Down
2 changes: 1 addition & 1 deletion tests/expr_and_series/dt/datetime_duration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def test_duration_attributes(
) -> None:
if PANDAS_VERSION < (2, 2) and "pandas_pyarrow" in str(constructor):
request.applymarker(pytest.mark.xfail)
if "duckdb" in str(constructor):
if ("pyspark" in str(constructor)) or "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)

df = nw.from_native(constructor(data))
Expand Down
Loading
Loading