Skip to content

Commit

Permalink
chore: move pyspark tests into main test suite (#1761)
Browse files Browse the repository at this point in the history
* chore: move pyspark tests into main test suite

* delay call to pyspark constructor

* xfail from_dict, from_numpy

* one more

* feedback and tests

* missing condition to xfail

* move warnings to pyproject

* statement order?

* pragma no cover branch

---------

Co-authored-by: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com>
  • Loading branch information
FBruzzesi and MarcoGorelli authored Jan 9, 2025
1 parent ab21e72 commit 20eb53b
Show file tree
Hide file tree
Showing 82 changed files with 556 additions and 1,433 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ jobs:
cache-dependency-glob: "pyproject.toml"
- name: install-reqs
run: uv pip install -e ".[dev, core, extra, dask, modin]" --system
- name: install pyspark
run: uv pip install -e ".[pyspark]" --system
# PySpark is not yet available on Python3.12+
if: matrix.python-version != '3.12'
- name: show-deps
run: uv pip freeze
- name: Run pytest
Expand Down
13 changes: 11 additions & 2 deletions narwhals/_spark_like/dataframe.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from itertools import chain
from typing import TYPE_CHECKING
from typing import Any
from typing import Iterable
Expand All @@ -8,6 +9,7 @@

from narwhals._spark_like.utils import native_to_narwhals_dtype
from narwhals._spark_like.utils import parse_exprs_and_named_exprs
from narwhals.exceptions import ColumnNotFoundError
from narwhals.utils import Implementation
from narwhals.utils import flatten
from narwhals.utils import parse_columns_to_drop
Expand Down Expand Up @@ -106,9 +108,11 @@ def select(
new_columns_list = [col.alias(col_name) for col_name, col in new_columns.items()]
return self._from_native_frame(self._native_frame.select(*new_columns_list))

def filter(self, *predicates: SparkLikeExpr) -> Self:
def filter(self, *predicates: SparkLikeExpr, **constraints: Any) -> Self:
plx = self.__narwhals_namespace__()
expr = plx.all_horizontal(*predicates)
expr = plx.all_horizontal(
*chain(predicates, (plx.col(name) == v for name, v in constraints.items()))
)
# `[0]` is safe as all_horizontal's expression only returns a single column
condition = expr._call(self)[0]
spark_df = self._native_frame.where(condition)
Expand Down Expand Up @@ -203,6 +207,11 @@ def unique(
if keep != "any":
msg = "`LazyFrame.unique` with PySpark backend only supports `keep='any'`."
raise ValueError(msg)

if subset is not None and any(x not in self.columns for x in subset):
msg = f"Column(s) {subset} not found in {self.columns}"
raise ColumnNotFoundError(msg)

subset = [subset] if isinstance(subset, str) else subset
return self._from_native_frame(self._native_frame.dropDuplicates(subset=subset))

Expand Down
32 changes: 16 additions & 16 deletions narwhals/_spark_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,22 @@ def func(df: SparkLikeLazyFrame) -> list[Column]:
kwargs=kwargs,
)

def __eq__(self, other: SparkLikeExpr) -> Self: # type: ignore[override]
return self._from_call(
lambda _input, other: _input.__eq__(other),
"__eq__",
other=other,
returns_scalar=False,
)

def __ne__(self, other: SparkLikeExpr) -> Self: # type: ignore[override]
return self._from_call(
lambda _input, other: _input.__ne__(other),
"__ne__",
other=other,
returns_scalar=False,
)

def __add__(self, other: SparkLikeExpr) -> Self:
return self._from_call(
lambda _input, other: _input.__add__(other),
Expand Down Expand Up @@ -179,22 +195,6 @@ def __mod__(self, other: SparkLikeExpr) -> Self:
returns_scalar=False,
)

def __eq__(self, other: SparkLikeExpr) -> Self: # type: ignore[override]
return self._from_call(
lambda _input, other: _input.__eq__(other),
"__eq__",
other=other,
returns_scalar=False,
)

def __ne__(self, other: SparkLikeExpr) -> Self: # type: ignore[override]
return self._from_call(
lambda _input, other: _input.__ne__(other),
"__ne__",
other=other,
returns_scalar=False,
)

def __ge__(self, other: SparkLikeExpr) -> Self:
return self._from_call(
lambda _input, other: _input.__ge__(other),
Expand Down
17 changes: 11 additions & 6 deletions narwhals/_spark_like/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ def _from_native_frame(self, df: SparkLikeLazyFrame) -> SparkLikeLazyFrame:


def get_spark_function(function_name: str, **kwargs: Any) -> Column:
from pyspark.sql import functions as F # noqa: N812

if function_name in {"std", "var"}:
import numpy as np # ignore-banned-import

Expand All @@ -88,9 +90,15 @@ def get_spark_function(function_name: str, **kwargs: Any) -> Column:
ddof=kwargs["ddof"],
np_version=parse_version(np.__version__),
)
from pyspark.sql import functions as F # noqa: N812
elif function_name == "len":
# Use count(*) to count all rows including nulls
def _count(*_args: Any, **_kwargs: Any) -> Column:
return F.count("*")

return getattr(F, function_name)
return _count

else:
return getattr(F, function_name)


def agg_pyspark(
Expand Down Expand Up @@ -138,10 +146,7 @@ def agg_pyspark(
raise AssertionError(msg)

function_name = remove_prefix(expr._function_name, "col->")
pyspark_function = POLARS_TO_PYSPARK_AGGREGATIONS.get(
function_name, function_name
)
agg_func = get_spark_function(pyspark_function, **expr._kwargs)
agg_func = get_spark_function(function_name, **expr._kwargs)

simple_aggregations.update(
{
Expand Down
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -164,11 +164,16 @@ filterwarnings = [
'ignore:.*Passing a BlockManager to DataFrame:DeprecationWarning',
# This warning was temporarily raised by Polars but then reverted.
'ignore:.*The default coalesce behavior of left join will change:DeprecationWarning',
'ignore: unclosed <socket.socket',
'ignore:.*The distutils package is deprecated and slated for removal in Python 3.12:DeprecationWarning:pyspark',
'ignore:.*distutils Version classes are deprecated. Use packaging.version instead.*:DeprecationWarning:pyspark',

]
xfail_strict = true
markers = ["slow: marks tests as slow (deselect with '-m \"not slow\"')"]
env = [
"MODIN_ENGINE=python",
"PYARROW_IGNORE_TIMEZONE=1"
]

[tool.coverage.run]
Expand Down
42 changes: 33 additions & 9 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from __future__ import annotations

import os
import sys
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable
from typing import Generator
from typing import Sequence

import pandas as pd
Expand All @@ -14,7 +14,6 @@

if TYPE_CHECKING:
import duckdb
from pyspark.sql import SparkSession

from narwhals.typing import IntoDataFrame
from narwhals.typing import IntoFrame
Expand Down Expand Up @@ -129,23 +128,23 @@ def pyarrow_table_constructor(obj: Any) -> IntoDataFrame:
return pa.table(obj) # type: ignore[no-any-return]


@pytest.fixture(scope="session")
def spark_session() -> Generator[SparkSession, None, None]: # pragma: no cover
def pyspark_lazy_constructor() -> Callable[[Any], IntoFrame]: # pragma: no cover
try:
from pyspark.sql import SparkSession
except ImportError: # pragma: no cover
pytest.skip("pyspark is not installed")
return
return None

import warnings
from atexit import register

os.environ["PYARROW_IGNORE_TIMEZONE"] = "1"
with warnings.catch_warnings():
# The spark session seems to trigger a polars warning.
# Polars is imported in the tests, but not used in the spark operations
warnings.filterwarnings(
"ignore", r"Using fork\(\) can cause Polars", category=RuntimeWarning
)

session = (
SparkSession.builder.appName("unit-tests")
.master("local[1]")
Expand All @@ -155,8 +154,26 @@ def spark_session() -> Generator[SparkSession, None, None]: # pragma: no cover
.config("spark.sql.shuffle.partitions", "2")
.getOrCreate()
)
yield session
session.stop()

register(session.stop)

def _constructor(obj: Any) -> IntoFrame:
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
r".*is_datetime64tz_dtype is deprecated and will be removed in a future version.*",
module="pyspark",
category=DeprecationWarning,
)
pd_df = pd.DataFrame(obj).replace({float("nan"): None}).reset_index()
return ( # type: ignore[no-any-return]
session.createDataFrame(pd_df)
.repartition(2)
.orderBy("index")
.drop("index")
)

return _constructor


EAGER_CONSTRUCTORS: dict[str, Callable[[Any], IntoDataFrame]] = {
Expand All @@ -173,6 +190,7 @@ def spark_session() -> Generator[SparkSession, None, None]: # pragma: no cover
"dask": dask_lazy_p2_constructor,
"polars[lazy]": polars_lazy_constructor,
"duckdb": duckdb_lazy_constructor,
"pyspark": pyspark_lazy_constructor, # type: ignore[dict-item]
}
GPU_CONSTRUCTORS: dict[str, Callable[[Any], IntoFrame]] = {"cudf": cudf_constructor}

Expand Down Expand Up @@ -201,7 +219,13 @@ def pytest_generate_tests(metafunc: pytest.Metafunc) -> None:
constructors.append(EAGER_CONSTRUCTORS[constructor])
constructors_ids.append(constructor)
elif constructor in LAZY_CONSTRUCTORS:
constructors.append(LAZY_CONSTRUCTORS[constructor])
if constructor == "pyspark":
if sys.version_info < (3, 12): # pragma: no cover
constructors.append(pyspark_lazy_constructor())
else: # pragma: no cover
continue
else:
constructors.append(LAZY_CONSTRUCTORS[constructor])
constructors_ids.append(constructor)
else: # pragma: no cover
msg = f"Expected one of {EAGER_CONSTRUCTORS.keys()} or {LAZY_CONSTRUCTORS.keys()}, got {constructor}"
Expand Down
2 changes: 1 addition & 1 deletion tests/expr_and_series/all_horizontal_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def test_allh_nth(
) -> None:
if "polars" in str(constructor) and POLARS_VERSION < (1, 0):
request.applymarker(pytest.mark.xfail)
if "duckdb" in str(constructor):
if ("pyspark" in str(constructor)) or "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
data = {
"a": [False, False, True],
Expand Down
7 changes: 6 additions & 1 deletion tests/expr_and_series/any_all_test.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
from __future__ import annotations

import pytest

import narwhals.stable.v1 as nw
from tests.utils import Constructor
from tests.utils import ConstructorEager
from tests.utils import assert_equal_data


def test_any_all(constructor: Constructor) -> None:
def test_any_all(request: pytest.FixtureRequest, constructor: Constructor) -> None:
if "pyspark" in str(constructor):
request.applymarker(pytest.mark.xfail)

df = nw.from_native(
constructor(
{
Expand Down
10 changes: 8 additions & 2 deletions tests/expr_and_series/any_horizontal_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@

@pytest.mark.parametrize("expr1", ["a", nw.col("a")])
@pytest.mark.parametrize("expr2", ["b", nw.col("b")])
def test_anyh(constructor: Constructor, expr1: Any, expr2: Any) -> None:
def test_anyh(
request: pytest.FixtureRequest, constructor: Constructor, expr1: Any, expr2: Any
) -> None:
if "pyspark" in str(constructor):
request.applymarker(pytest.mark.xfail)
data = {
"a": [False, False, True],
"b": [False, True, True],
Expand All @@ -23,7 +27,9 @@ def test_anyh(constructor: Constructor, expr1: Any, expr2: Any) -> None:
assert_equal_data(result, expected)


def test_anyh_all(constructor: Constructor) -> None:
def test_anyh_all(request: pytest.FixtureRequest, constructor: Constructor) -> None:
if "pyspark" in str(constructor):
request.applymarker(pytest.mark.xfail)
data = {
"a": [False, False, True],
"b": [False, True, True],
Expand Down
1 change: 0 additions & 1 deletion tests/expr_and_series/arithmetic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ def test_right_arithmetic_expr(
x in str(constructor) for x in ["pandas_pyarrow", "modin_pyarrow"]
):
request.applymarker(pytest.mark.xfail)

data = {"a": [1, 2, 3]}
df = nw.from_native(constructor(data))
result = df.select(getattr(nw.col("a"), attr)(rhs))
Expand Down
4 changes: 3 additions & 1 deletion tests/expr_and_series/binary_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@


def test_expr_binary(constructor: Constructor, request: pytest.FixtureRequest) -> None:
if "dask" in str(constructor) and DASK_VERSION < (2024, 10):
if ("dask" in str(constructor) and DASK_VERSION < (2024, 10)) or "pyspark" in str(
constructor
):
request.applymarker(pytest.mark.xfail)
data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]}
df_raw = constructor(data)
Expand Down
8 changes: 5 additions & 3 deletions tests/expr_and_series/cast_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def test_cast(
constructor: Constructor,
request: pytest.FixtureRequest,
) -> None:
if "duckdb" in str(constructor):
if ("pyspark" in str(constructor)) or "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
if "pyarrow_table_constructor" in str(constructor) and PYARROW_VERSION <= (
15,
Expand Down Expand Up @@ -180,7 +180,7 @@ def test_cast_string() -> None:
def test_cast_raises_for_unknown_dtype(
constructor: Constructor, request: pytest.FixtureRequest
) -> None:
if "duckdb" in str(constructor):
if ("pyspark" in str(constructor)) or "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
if "pyarrow_table" in str(constructor) and PYARROW_VERSION < (15,):
# Unsupported cast from string to dictionary using function cast_dictionary
Expand All @@ -204,6 +204,7 @@ def test_cast_datetime_tz_aware(
or "duckdb" in str(constructor)
or "cudf" in str(constructor) # https://github.com/rapidsai/cudf/issues/16973
or ("pyarrow_table" in str(constructor) and is_windows())
or ("pyspark" in str(constructor))
):
request.applymarker(pytest.mark.xfail)

Expand All @@ -229,7 +230,8 @@ def test_cast_datetime_tz_aware(

def test_cast_struct(request: pytest.FixtureRequest, constructor: Constructor) -> None:
if any(
backend in str(constructor) for backend in ("dask", "modin", "cudf", "duckdb")
backend in str(constructor)
for backend in ("dask", "modin", "cudf", "duckdb", "pyspark")
):
request.applymarker(pytest.mark.xfail)

Expand Down
2 changes: 1 addition & 1 deletion tests/expr_and_series/concat_str_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_concat_str(
expected: list[str],
request: pytest.FixtureRequest,
) -> None:
if "duckdb" in str(constructor):
if ("pyspark" in str(constructor)) or "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
df = nw.from_native(constructor(data))
result = (
Expand Down
2 changes: 2 additions & 0 deletions tests/expr_and_series/convert_time_zone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def test_convert_time_zone(
or ("modin_pyarrow" in str(constructor) and PANDAS_VERSION < (2, 1))
or ("cudf" in str(constructor))
or ("duckdb" in str(constructor))
or ("pyspark" in str(constructor))
):
request.applymarker(pytest.mark.xfail)
data = {
Expand Down Expand Up @@ -86,6 +87,7 @@ def test_convert_time_zone_from_none(
or ("pyarrow_table" in str(constructor) and PYARROW_VERSION < (12,))
or ("cudf" in str(constructor))
or ("duckdb" in str(constructor))
or ("pyspark" in str(constructor))
):
request.applymarker(pytest.mark.xfail)
if "polars" in str(constructor) and POLARS_VERSION < (0, 20, 7):
Expand Down
2 changes: 1 addition & 1 deletion tests/expr_and_series/cum_count_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def test_cum_count_expr(
) -> None:
if "dask" in str(constructor) and reverse:
request.applymarker(pytest.mark.xfail)
if "duckdb" in str(constructor):
if ("pyspark" in str(constructor)) or "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)

name = "reverse_cum_count" if reverse else "cum_count"
Expand Down
Loading

0 comments on commit 20eb53b

Please sign in to comment.