Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: move pyspark tests into main test suite #1761

Merged
merged 20 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ jobs:
cache-dependency-glob: "pyproject.toml"
- name: install-reqs
run: uv pip install -e ".[dev, core, extra, dask, modin]" --system
- name: install pyspark
run: uv pip install -e ".[pyspark]" --system
# PySpark is not yet available on Python3.12+
if: matrix.python-version != '3.12'
- name: show-deps
run: uv pip freeze
- name: Run pytest
Expand Down
13 changes: 11 additions & 2 deletions narwhals/_spark_like/dataframe.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from itertools import chain
from typing import TYPE_CHECKING
from typing import Any
from typing import Iterable
Expand All @@ -8,6 +9,7 @@

from narwhals._spark_like.utils import native_to_narwhals_dtype
from narwhals._spark_like.utils import parse_exprs_and_named_exprs
from narwhals.exceptions import ColumnNotFoundError
from narwhals.utils import Implementation
from narwhals.utils import flatten
from narwhals.utils import parse_columns_to_drop
Expand Down Expand Up @@ -106,9 +108,11 @@ def select(
new_columns_list = [col.alias(col_name) for col_name, col in new_columns.items()]
return self._from_native_frame(self._native_frame.select(*new_columns_list))

def filter(self, *predicates: SparkLikeExpr) -> Self:
def filter(self, *predicates: SparkLikeExpr, **constraints: Any) -> Self:
plx = self.__narwhals_namespace__()
expr = plx.all_horizontal(*predicates)
expr = plx.all_horizontal(
*chain(predicates, (plx.col(name) == v for name, v in constraints.items()))
)
Comment on lines +113 to +115
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needed to implement Expr.__eq__ to get this to work. It overlaps with @EdAbati PR

# `[0]` is safe as all_horizontal's expression only returns a single column
condition = expr._call(self)[0]
spark_df = self._native_frame.where(condition)
Expand Down Expand Up @@ -203,6 +207,11 @@ def unique(
if keep != "any":
msg = "`LazyFrame.unique` with PySpark backend only supports `keep='any'`."
raise ValueError(msg)

if subset is not None and any(x not in self.columns for x in subset):
msg = f"Column(s) {subset} not found in {self.columns}"
raise ColumnNotFoundError(msg)

subset = [subset] if isinstance(subset, str) else subset
return self._from_native_frame(self._native_frame.dropDuplicates(subset=subset))

Expand Down
32 changes: 16 additions & 16 deletions narwhals/_spark_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,22 @@ def func(df: SparkLikeLazyFrame) -> list[Column]:
kwargs=kwargs,
)

def __eq__(self, other: SparkLikeExpr) -> Self: # type: ignore[override]
return self._from_call(
lambda _input, other: _input.__eq__(other),
"__eq__",
other=other,
returns_scalar=False,
)

def __ne__(self, other: SparkLikeExpr) -> Self: # type: ignore[override]
return self._from_call(
lambda _input, other: _input.__ne__(other),
"__ne__",
other=other,
returns_scalar=False,
)

def __add__(self, other: SparkLikeExpr) -> Self:
return self._from_call(
lambda _input, other: _input.__add__(other),
Expand Down Expand Up @@ -179,22 +195,6 @@ def __mod__(self, other: SparkLikeExpr) -> Self:
returns_scalar=False,
)

def __eq__(self, other: SparkLikeExpr) -> Self: # type: ignore[override]
return self._from_call(
lambda _input, other: _input.__eq__(other),
"__eq__",
other=other,
returns_scalar=False,
)

def __ne__(self, other: SparkLikeExpr) -> Self: # type: ignore[override]
return self._from_call(
lambda _input, other: _input.__ne__(other),
"__ne__",
other=other,
returns_scalar=False,
)

def __ge__(self, other: SparkLikeExpr) -> Self:
return self._from_call(
lambda _input, other: _input.__ge__(other),
Expand Down
17 changes: 11 additions & 6 deletions narwhals/_spark_like/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ def _from_native_frame(self, df: SparkLikeLazyFrame) -> SparkLikeLazyFrame:


def get_spark_function(function_name: str, **kwargs: Any) -> Column:
from pyspark.sql import functions as F # noqa: N812

if function_name in {"std", "var"}:
import numpy as np # ignore-banned-import

Expand All @@ -88,9 +90,15 @@ def get_spark_function(function_name: str, **kwargs: Any) -> Column:
ddof=kwargs["ddof"],
np_version=parse_version(np.__version__),
)
from pyspark.sql import functions as F # noqa: N812
elif function_name == "len":
# Use count(*) to count all rows including nulls
def _count(*_args: Any, **_kwargs: Any) -> Column:
return F.count("*")

return getattr(F, function_name)
return _count

else:
return getattr(F, function_name)


def agg_pyspark(
Expand Down Expand Up @@ -138,10 +146,7 @@ def agg_pyspark(
raise AssertionError(msg)

function_name = remove_prefix(expr._function_name, "col->")
pyspark_function = POLARS_TO_PYSPARK_AGGREGATIONS.get(
function_name, function_name
)
agg_func = get_spark_function(pyspark_function, **expr._kwargs)
agg_func = get_spark_function(function_name, **expr._kwargs)

simple_aggregations.update(
{
Expand Down
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -164,11 +164,16 @@ filterwarnings = [
'ignore:.*Passing a BlockManager to DataFrame:DeprecationWarning',
# This warning was temporarily raised by Polars but then reverted.
'ignore:.*The default coalesce behavior of left join will change:DeprecationWarning',
'ignore: unclosed <socket.socket',
'ignore:.*The distutils package is deprecated and slated for removal in Python 3.12:DeprecationWarning:pyspark',
'ignore:.*distutils Version classes are deprecated. Use packaging.version instead.*:DeprecationWarning:pyspark',
Comment on lines +168 to +169
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MarcoGorelli I moved these back to pyproject.toml, yet targeting pyspark module. Would that work for you?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TIL

nice!


]
xfail_strict = true
markers = ["slow: marks tests as slow (deselect with '-m \"not slow\"')"]
env = [
"MODIN_ENGINE=python",
"PYARROW_IGNORE_TIMEZONE=1"
]

[tool.coverage.run]
Expand Down
42 changes: 33 additions & 9 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from __future__ import annotations

import os
import sys
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable
from typing import Generator
from typing import Sequence

import pandas as pd
Expand All @@ -14,7 +14,6 @@

if TYPE_CHECKING:
import duckdb
from pyspark.sql import SparkSession

from narwhals.typing import IntoDataFrame
from narwhals.typing import IntoFrame
Expand Down Expand Up @@ -129,23 +128,23 @@ def pyarrow_table_constructor(obj: Any) -> IntoDataFrame:
return pa.table(obj) # type: ignore[no-any-return]


@pytest.fixture(scope="session")
def spark_session() -> Generator[SparkSession, None, None]: # pragma: no cover
def pyspark_lazy_constructor() -> Callable[[Any], IntoFrame]: # pragma: no cover
try:
from pyspark.sql import SparkSession
except ImportError: # pragma: no cover
pytest.skip("pyspark is not installed")
return
return None

import warnings
from atexit import register

os.environ["PYARROW_IGNORE_TIMEZONE"] = "1"
with warnings.catch_warnings():
# The spark session seems to trigger a polars warning.
# Polars is imported in the tests, but not used in the spark operations
warnings.filterwarnings(
"ignore", r"Using fork\(\) can cause Polars", category=RuntimeWarning
)

session = (
SparkSession.builder.appName("unit-tests")
.master("local[1]")
Expand All @@ -155,8 +154,26 @@ def spark_session() -> Generator[SparkSession, None, None]: # pragma: no cover
.config("spark.sql.shuffle.partitions", "2")
.getOrCreate()
)
yield session
session.stop()

register(session.stop)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TIL atexit.register, nice!


def _constructor(obj: Any) -> IntoFrame:
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
r".*is_datetime64tz_dtype is deprecated and will be removed in a future version.*",
module="pyspark",
category=DeprecationWarning,
)
pd_df = pd.DataFrame(obj).replace({float("nan"): None}).reset_index()
Copy link
Contributor

@camriddell camriddell Jan 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the objects that come into these constructors are (always?) dictionaries I think we can skip the trip through pandas and construct from a built-in Python object that spark knows how to ingest directly (list of dictionaries). Could be overly cautions, but Spark may infer data types differently if it is handed a pandas DataFrame rather than lists of Python objects.

Since pyspark supports a list of records we could convert dict β†’ list of dicts like so

if isinstance(obj, dict):
    obj = [{k: v for k, v in zip(obj, row)} for row in zip(*obj.values())]

Or could pass in the rows & schema separately

if isinstance(obj, dict):
    df = ...createDataFrame([*zip(*obj.values())], schema=[*obj.keys()])

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember having issues with some tests, where we may need to specify the schema with column type. (but I don't remember exactly what was the problem)

But if we can skip pandas here, it would be πŸ‘ŒπŸ‘ŒπŸ‘Œ

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had the same thought when migrating the codebase, yet I can confirm the data type being an issue for a subset of the tests. I would say to keep it like this for now and eventually address it

return ( # type: ignore[no-any-return]
session.createDataFrame(pd_df)
.repartition(2)
.orderBy("index")
.drop("index")
)

return _constructor


EAGER_CONSTRUCTORS: dict[str, Callable[[Any], IntoDataFrame]] = {
Expand All @@ -173,6 +190,7 @@ def spark_session() -> Generator[SparkSession, None, None]: # pragma: no cover
"dask": dask_lazy_p2_constructor,
"polars[lazy]": polars_lazy_constructor,
"duckdb": duckdb_lazy_constructor,
"pyspark": pyspark_lazy_constructor, # type: ignore[dict-item]
}
GPU_CONSTRUCTORS: dict[str, Callable[[Any], IntoFrame]] = {"cudf": cudf_constructor}

Expand Down Expand Up @@ -201,7 +219,13 @@ def pytest_generate_tests(metafunc: pytest.Metafunc) -> None:
constructors.append(EAGER_CONSTRUCTORS[constructor])
constructors_ids.append(constructor)
elif constructor in LAZY_CONSTRUCTORS:
constructors.append(LAZY_CONSTRUCTORS[constructor])
if constructor == "pyspark":
if sys.version_info < (3, 12): # pragma: no cover
constructors.append(pyspark_lazy_constructor())
else: # pragma: no cover
continue
else:
constructors.append(LAZY_CONSTRUCTORS[constructor])
constructors_ids.append(constructor)
else: # pragma: no cover
msg = f"Expected one of {EAGER_CONSTRUCTORS.keys()} or {LAZY_CONSTRUCTORS.keys()}, got {constructor}"
Expand Down
2 changes: 1 addition & 1 deletion tests/expr_and_series/all_horizontal_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def test_allh_nth(
) -> None:
if "polars" in str(constructor) and POLARS_VERSION < (1, 0):
request.applymarker(pytest.mark.xfail)
if "duckdb" in str(constructor):
if ("pyspark" in str(constructor)) or "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
data = {
"a": [False, False, True],
Expand Down
7 changes: 6 additions & 1 deletion tests/expr_and_series/any_all_test.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
from __future__ import annotations

import pytest

import narwhals.stable.v1 as nw
from tests.utils import Constructor
from tests.utils import ConstructorEager
from tests.utils import assert_equal_data


def test_any_all(constructor: Constructor) -> None:
def test_any_all(request: pytest.FixtureRequest, constructor: Constructor) -> None:
if "pyspark" in str(constructor):
request.applymarker(pytest.mark.xfail)

df = nw.from_native(
constructor(
{
Expand Down
10 changes: 8 additions & 2 deletions tests/expr_and_series/any_horizontal_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@

@pytest.mark.parametrize("expr1", ["a", nw.col("a")])
@pytest.mark.parametrize("expr2", ["b", nw.col("b")])
def test_anyh(constructor: Constructor, expr1: Any, expr2: Any) -> None:
def test_anyh(
request: pytest.FixtureRequest, constructor: Constructor, expr1: Any, expr2: Any
) -> None:
if "pyspark" in str(constructor):
request.applymarker(pytest.mark.xfail)
data = {
"a": [False, False, True],
"b": [False, True, True],
Expand All @@ -23,7 +27,9 @@ def test_anyh(constructor: Constructor, expr1: Any, expr2: Any) -> None:
assert_equal_data(result, expected)


def test_anyh_all(constructor: Constructor) -> None:
def test_anyh_all(request: pytest.FixtureRequest, constructor: Constructor) -> None:
if "pyspark" in str(constructor):
request.applymarker(pytest.mark.xfail)
data = {
"a": [False, False, True],
"b": [False, True, True],
Expand Down
1 change: 0 additions & 1 deletion tests/expr_and_series/arithmetic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ def test_right_arithmetic_expr(
x in str(constructor) for x in ["pandas_pyarrow", "modin_pyarrow"]
):
request.applymarker(pytest.mark.xfail)

data = {"a": [1, 2, 3]}
df = nw.from_native(constructor(data))
result = df.select(getattr(nw.col("a"), attr)(rhs))
Expand Down
4 changes: 3 additions & 1 deletion tests/expr_and_series/binary_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@


def test_expr_binary(constructor: Constructor, request: pytest.FixtureRequest) -> None:
if "dask" in str(constructor) and DASK_VERSION < (2024, 10):
if ("dask" in str(constructor) and DASK_VERSION < (2024, 10)) or "pyspark" in str(
constructor
):
request.applymarker(pytest.mark.xfail)
data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]}
df_raw = constructor(data)
Expand Down
8 changes: 5 additions & 3 deletions tests/expr_and_series/cast_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def test_cast(
constructor: Constructor,
request: pytest.FixtureRequest,
) -> None:
if "duckdb" in str(constructor):
if ("pyspark" in str(constructor)) or "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
if "pyarrow_table_constructor" in str(constructor) and PYARROW_VERSION <= (
15,
Expand Down Expand Up @@ -180,7 +180,7 @@ def test_cast_string() -> None:
def test_cast_raises_for_unknown_dtype(
constructor: Constructor, request: pytest.FixtureRequest
) -> None:
if "duckdb" in str(constructor):
if ("pyspark" in str(constructor)) or "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
if "pyarrow_table" in str(constructor) and PYARROW_VERSION < (15,):
# Unsupported cast from string to dictionary using function cast_dictionary
Expand All @@ -204,6 +204,7 @@ def test_cast_datetime_tz_aware(
or "duckdb" in str(constructor)
or "cudf" in str(constructor) # https://github.com/rapidsai/cudf/issues/16973
or ("pyarrow_table" in str(constructor) and is_windows())
or ("pyspark" in str(constructor))
):
request.applymarker(pytest.mark.xfail)

Expand All @@ -229,7 +230,8 @@ def test_cast_datetime_tz_aware(

def test_cast_struct(request: pytest.FixtureRequest, constructor: Constructor) -> None:
if any(
backend in str(constructor) for backend in ("dask", "modin", "cudf", "duckdb")
backend in str(constructor)
for backend in ("dask", "modin", "cudf", "duckdb", "pyspark")
):
request.applymarker(pytest.mark.xfail)

Expand Down
2 changes: 1 addition & 1 deletion tests/expr_and_series/concat_str_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_concat_str(
expected: list[str],
request: pytest.FixtureRequest,
) -> None:
if "duckdb" in str(constructor):
if ("pyspark" in str(constructor)) or "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
df = nw.from_native(constructor(data))
result = (
Expand Down
2 changes: 2 additions & 0 deletions tests/expr_and_series/convert_time_zone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def test_convert_time_zone(
or ("modin_pyarrow" in str(constructor) and PANDAS_VERSION < (2, 1))
or ("cudf" in str(constructor))
or ("duckdb" in str(constructor))
or ("pyspark" in str(constructor))
):
request.applymarker(pytest.mark.xfail)
data = {
Expand Down Expand Up @@ -86,6 +87,7 @@ def test_convert_time_zone_from_none(
or ("pyarrow_table" in str(constructor) and PYARROW_VERSION < (12,))
or ("cudf" in str(constructor))
or ("duckdb" in str(constructor))
or ("pyspark" in str(constructor))
):
request.applymarker(pytest.mark.xfail)
if "polars" in str(constructor) and POLARS_VERSION < (0, 20, 7):
Expand Down
2 changes: 1 addition & 1 deletion tests/expr_and_series/cum_count_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def test_cum_count_expr(
) -> None:
if "dask" in str(constructor) and reverse:
request.applymarker(pytest.mark.xfail)
if "duckdb" in str(constructor):
if ("pyspark" in str(constructor)) or "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)

name = "reverse_cum_count" if reverse else "cum_count"
Expand Down
Loading
Loading