diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index c89ab2cd7..bb46b4f0d 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -50,6 +50,10 @@ jobs: cache-dependency-glob: "pyproject.toml" - name: install-reqs run: uv pip install -e ".[dev, core, extra, dask, modin]" --system + - name: install pyspark + run: uv pip install -e ".[pyspark]" --system + # PySpark is not yet available on Python3.12+ + if: matrix.python-version != '3.12' - name: show-deps run: uv pip freeze - name: Run pytest diff --git a/narwhals/_spark_like/dataframe.py b/narwhals/_spark_like/dataframe.py index e54a05997..101d5ad24 100644 --- a/narwhals/_spark_like/dataframe.py +++ b/narwhals/_spark_like/dataframe.py @@ -1,5 +1,6 @@ from __future__ import annotations +from itertools import chain from typing import TYPE_CHECKING from typing import Any from typing import Iterable @@ -8,6 +9,7 @@ from narwhals._spark_like.utils import native_to_narwhals_dtype from narwhals._spark_like.utils import parse_exprs_and_named_exprs +from narwhals.exceptions import ColumnNotFoundError from narwhals.utils import Implementation from narwhals.utils import flatten from narwhals.utils import parse_columns_to_drop @@ -106,9 +108,11 @@ def select( new_columns_list = [col.alias(col_name) for col_name, col in new_columns.items()] return self._from_native_frame(self._native_frame.select(*new_columns_list)) - def filter(self, *predicates: SparkLikeExpr) -> Self: + def filter(self, *predicates: SparkLikeExpr, **constraints: Any) -> Self: plx = self.__narwhals_namespace__() - expr = plx.all_horizontal(*predicates) + expr = plx.all_horizontal( + *chain(predicates, (plx.col(name) == v for name, v in constraints.items())) + ) # `[0]` is safe as all_horizontal's expression only returns a single column condition = expr._call(self)[0] spark_df = self._native_frame.where(condition) @@ -203,6 +207,11 @@ def unique( if keep != "any": msg = "`LazyFrame.unique` with PySpark backend only supports `keep='any'`." raise ValueError(msg) + + if subset is not None and any(x not in self.columns for x in subset): + msg = f"Column(s) {subset} not found in {self.columns}" + raise ColumnNotFoundError(msg) + subset = [subset] if isinstance(subset, str) else subset return self._from_native_frame(self._native_frame.dropDuplicates(subset=subset)) diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index 10fb76227..a8cafccfd 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -121,6 +121,22 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: kwargs=kwargs, ) + def __eq__(self, other: SparkLikeExpr) -> Self: # type: ignore[override] + return self._from_call( + lambda _input, other: _input.__eq__(other), + "__eq__", + other=other, + returns_scalar=False, + ) + + def __ne__(self, other: SparkLikeExpr) -> Self: # type: ignore[override] + return self._from_call( + lambda _input, other: _input.__ne__(other), + "__ne__", + other=other, + returns_scalar=False, + ) + def __add__(self, other: SparkLikeExpr) -> Self: return self._from_call( lambda _input, other: _input.__add__(other), @@ -179,22 +195,6 @@ def __mod__(self, other: SparkLikeExpr) -> Self: returns_scalar=False, ) - def __eq__(self, other: SparkLikeExpr) -> Self: # type: ignore[override] - return self._from_call( - lambda _input, other: _input.__eq__(other), - "__eq__", - other=other, - returns_scalar=False, - ) - - def __ne__(self, other: SparkLikeExpr) -> Self: # type: ignore[override] - return self._from_call( - lambda _input, other: _input.__ne__(other), - "__ne__", - other=other, - returns_scalar=False, - ) - def __ge__(self, other: SparkLikeExpr) -> Self: return self._from_call( lambda _input, other: _input.__ge__(other), diff --git a/narwhals/_spark_like/group_by.py b/narwhals/_spark_like/group_by.py index 0100500ff..cbcf87692 100644 --- a/narwhals/_spark_like/group_by.py +++ b/narwhals/_spark_like/group_by.py @@ -80,6 +80,8 @@ def _from_native_frame(self, df: SparkLikeLazyFrame) -> SparkLikeLazyFrame: def get_spark_function(function_name: str, **kwargs: Any) -> Column: + from pyspark.sql import functions as F # noqa: N812 + if function_name in {"std", "var"}: import numpy as np # ignore-banned-import @@ -88,9 +90,15 @@ def get_spark_function(function_name: str, **kwargs: Any) -> Column: ddof=kwargs["ddof"], np_version=parse_version(np.__version__), ) - from pyspark.sql import functions as F # noqa: N812 + elif function_name == "len": + # Use count(*) to count all rows including nulls + def _count(*_args: Any, **_kwargs: Any) -> Column: + return F.count("*") - return getattr(F, function_name) + return _count + + else: + return getattr(F, function_name) def agg_pyspark( @@ -138,10 +146,7 @@ def agg_pyspark( raise AssertionError(msg) function_name = remove_prefix(expr._function_name, "col->") - pyspark_function = POLARS_TO_PYSPARK_AGGREGATIONS.get( - function_name, function_name - ) - agg_func = get_spark_function(pyspark_function, **expr._kwargs) + agg_func = get_spark_function(function_name, **expr._kwargs) simple_aggregations.update( { diff --git a/pyproject.toml b/pyproject.toml index bea188a59..91770923e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -164,11 +164,16 @@ filterwarnings = [ 'ignore:.*Passing a BlockManager to DataFrame:DeprecationWarning', # This warning was temporarily raised by Polars but then reverted. 'ignore:.*The default coalesce behavior of left join will change:DeprecationWarning', + 'ignore: unclosed IntoDataFrame: return pa.table(obj) # type: ignore[no-any-return] -@pytest.fixture(scope="session") -def spark_session() -> Generator[SparkSession, None, None]: # pragma: no cover +def pyspark_lazy_constructor() -> Callable[[Any], IntoFrame]: # pragma: no cover try: from pyspark.sql import SparkSession except ImportError: # pragma: no cover pytest.skip("pyspark is not installed") - return + return None import warnings + from atexit import register - os.environ["PYARROW_IGNORE_TIMEZONE"] = "1" with warnings.catch_warnings(): # The spark session seems to trigger a polars warning. # Polars is imported in the tests, but not used in the spark operations warnings.filterwarnings( "ignore", r"Using fork\(\) can cause Polars", category=RuntimeWarning ) + session = ( SparkSession.builder.appName("unit-tests") .master("local[1]") @@ -155,8 +154,26 @@ def spark_session() -> Generator[SparkSession, None, None]: # pragma: no cover .config("spark.sql.shuffle.partitions", "2") .getOrCreate() ) - yield session - session.stop() + + register(session.stop) + + def _constructor(obj: Any) -> IntoFrame: + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + r".*is_datetime64tz_dtype is deprecated and will be removed in a future version.*", + module="pyspark", + category=DeprecationWarning, + ) + pd_df = pd.DataFrame(obj).replace({float("nan"): None}).reset_index() + return ( # type: ignore[no-any-return] + session.createDataFrame(pd_df) + .repartition(2) + .orderBy("index") + .drop("index") + ) + + return _constructor EAGER_CONSTRUCTORS: dict[str, Callable[[Any], IntoDataFrame]] = { @@ -173,6 +190,7 @@ def spark_session() -> Generator[SparkSession, None, None]: # pragma: no cover "dask": dask_lazy_p2_constructor, "polars[lazy]": polars_lazy_constructor, "duckdb": duckdb_lazy_constructor, + "pyspark": pyspark_lazy_constructor, # type: ignore[dict-item] } GPU_CONSTRUCTORS: dict[str, Callable[[Any], IntoFrame]] = {"cudf": cudf_constructor} @@ -201,7 +219,13 @@ def pytest_generate_tests(metafunc: pytest.Metafunc) -> None: constructors.append(EAGER_CONSTRUCTORS[constructor]) constructors_ids.append(constructor) elif constructor in LAZY_CONSTRUCTORS: - constructors.append(LAZY_CONSTRUCTORS[constructor]) + if constructor == "pyspark": + if sys.version_info < (3, 12): # pragma: no cover + constructors.append(pyspark_lazy_constructor()) + else: # pragma: no cover + continue + else: + constructors.append(LAZY_CONSTRUCTORS[constructor]) constructors_ids.append(constructor) else: # pragma: no cover msg = f"Expected one of {EAGER_CONSTRUCTORS.keys()} or {LAZY_CONSTRUCTORS.keys()}, got {constructor}" diff --git a/tests/expr_and_series/all_horizontal_test.py b/tests/expr_and_series/all_horizontal_test.py index 6eb98c3a3..826c0fe19 100644 --- a/tests/expr_and_series/all_horizontal_test.py +++ b/tests/expr_and_series/all_horizontal_test.py @@ -57,7 +57,7 @@ def test_allh_nth( ) -> None: if "polars" in str(constructor) and POLARS_VERSION < (1, 0): request.applymarker(pytest.mark.xfail) - if "duckdb" in str(constructor): + if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) data = { "a": [False, False, True], diff --git a/tests/expr_and_series/any_all_test.py b/tests/expr_and_series/any_all_test.py index c5f22ad9a..7fd81f04d 100644 --- a/tests/expr_and_series/any_all_test.py +++ b/tests/expr_and_series/any_all_test.py @@ -1,12 +1,17 @@ from __future__ import annotations +import pytest + import narwhals.stable.v1 as nw from tests.utils import Constructor from tests.utils import ConstructorEager from tests.utils import assert_equal_data -def test_any_all(constructor: Constructor) -> None: +def test_any_all(request: pytest.FixtureRequest, constructor: Constructor) -> None: + if "pyspark" in str(constructor): + request.applymarker(pytest.mark.xfail) + df = nw.from_native( constructor( { diff --git a/tests/expr_and_series/any_horizontal_test.py b/tests/expr_and_series/any_horizontal_test.py index 4eb082b51..06157f393 100644 --- a/tests/expr_and_series/any_horizontal_test.py +++ b/tests/expr_and_series/any_horizontal_test.py @@ -11,7 +11,11 @@ @pytest.mark.parametrize("expr1", ["a", nw.col("a")]) @pytest.mark.parametrize("expr2", ["b", nw.col("b")]) -def test_anyh(constructor: Constructor, expr1: Any, expr2: Any) -> None: +def test_anyh( + request: pytest.FixtureRequest, constructor: Constructor, expr1: Any, expr2: Any +) -> None: + if "pyspark" in str(constructor): + request.applymarker(pytest.mark.xfail) data = { "a": [False, False, True], "b": [False, True, True], @@ -23,7 +27,9 @@ def test_anyh(constructor: Constructor, expr1: Any, expr2: Any) -> None: assert_equal_data(result, expected) -def test_anyh_all(constructor: Constructor) -> None: +def test_anyh_all(request: pytest.FixtureRequest, constructor: Constructor) -> None: + if "pyspark" in str(constructor): + request.applymarker(pytest.mark.xfail) data = { "a": [False, False, True], "b": [False, True, True], diff --git a/tests/expr_and_series/arithmetic_test.py b/tests/expr_and_series/arithmetic_test.py index aec586c62..1baae44e5 100644 --- a/tests/expr_and_series/arithmetic_test.py +++ b/tests/expr_and_series/arithmetic_test.py @@ -76,7 +76,6 @@ def test_right_arithmetic_expr( x in str(constructor) for x in ["pandas_pyarrow", "modin_pyarrow"] ): request.applymarker(pytest.mark.xfail) - data = {"a": [1, 2, 3]} df = nw.from_native(constructor(data)) result = df.select(getattr(nw.col("a"), attr)(rhs)) diff --git a/tests/expr_and_series/binary_test.py b/tests/expr_and_series/binary_test.py index 0808810bc..308745cb4 100644 --- a/tests/expr_and_series/binary_test.py +++ b/tests/expr_and_series/binary_test.py @@ -9,7 +9,9 @@ def test_expr_binary(constructor: Constructor, request: pytest.FixtureRequest) -> None: - if "dask" in str(constructor) and DASK_VERSION < (2024, 10): + if ("dask" in str(constructor) and DASK_VERSION < (2024, 10)) or "pyspark" in str( + constructor + ): request.applymarker(pytest.mark.xfail) data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} df_raw = constructor(data) diff --git a/tests/expr_and_series/cast_test.py b/tests/expr_and_series/cast_test.py index b6ce43573..ba2b82493 100644 --- a/tests/expr_and_series/cast_test.py +++ b/tests/expr_and_series/cast_test.py @@ -60,7 +60,7 @@ def test_cast( constructor: Constructor, request: pytest.FixtureRequest, ) -> None: - if "duckdb" in str(constructor): + if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) if "pyarrow_table_constructor" in str(constructor) and PYARROW_VERSION <= ( 15, @@ -180,7 +180,7 @@ def test_cast_string() -> None: def test_cast_raises_for_unknown_dtype( constructor: Constructor, request: pytest.FixtureRequest ) -> None: - if "duckdb" in str(constructor): + if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) if "pyarrow_table" in str(constructor) and PYARROW_VERSION < (15,): # Unsupported cast from string to dictionary using function cast_dictionary @@ -204,6 +204,7 @@ def test_cast_datetime_tz_aware( or "duckdb" in str(constructor) or "cudf" in str(constructor) # https://github.com/rapidsai/cudf/issues/16973 or ("pyarrow_table" in str(constructor) and is_windows()) + or ("pyspark" in str(constructor)) ): request.applymarker(pytest.mark.xfail) @@ -229,7 +230,8 @@ def test_cast_datetime_tz_aware( def test_cast_struct(request: pytest.FixtureRequest, constructor: Constructor) -> None: if any( - backend in str(constructor) for backend in ("dask", "modin", "cudf", "duckdb") + backend in str(constructor) + for backend in ("dask", "modin", "cudf", "duckdb", "pyspark") ): request.applymarker(pytest.mark.xfail) diff --git a/tests/expr_and_series/concat_str_test.py b/tests/expr_and_series/concat_str_test.py index 7c9f259ba..37d4a581d 100644 --- a/tests/expr_and_series/concat_str_test.py +++ b/tests/expr_and_series/concat_str_test.py @@ -27,7 +27,7 @@ def test_concat_str( expected: list[str], request: pytest.FixtureRequest, ) -> None: - if "duckdb" in str(constructor): + if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) result = ( diff --git a/tests/expr_and_series/convert_time_zone_test.py b/tests/expr_and_series/convert_time_zone_test.py index 6b3cf5b41..9a18ee07f 100644 --- a/tests/expr_and_series/convert_time_zone_test.py +++ b/tests/expr_and_series/convert_time_zone_test.py @@ -29,6 +29,7 @@ def test_convert_time_zone( or ("modin_pyarrow" in str(constructor) and PANDAS_VERSION < (2, 1)) or ("cudf" in str(constructor)) or ("duckdb" in str(constructor)) + or ("pyspark" in str(constructor)) ): request.applymarker(pytest.mark.xfail) data = { @@ -86,6 +87,7 @@ def test_convert_time_zone_from_none( or ("pyarrow_table" in str(constructor) and PYARROW_VERSION < (12,)) or ("cudf" in str(constructor)) or ("duckdb" in str(constructor)) + or ("pyspark" in str(constructor)) ): request.applymarker(pytest.mark.xfail) if "polars" in str(constructor) and POLARS_VERSION < (0, 20, 7): diff --git a/tests/expr_and_series/cum_count_test.py b/tests/expr_and_series/cum_count_test.py index 1a2377f34..dab77ebbc 100644 --- a/tests/expr_and_series/cum_count_test.py +++ b/tests/expr_and_series/cum_count_test.py @@ -21,7 +21,7 @@ def test_cum_count_expr( ) -> None: if "dask" in str(constructor) and reverse: request.applymarker(pytest.mark.xfail) - if "duckdb" in str(constructor): + if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) name = "reverse_cum_count" if reverse else "cum_count" diff --git a/tests/expr_and_series/cum_max_test.py b/tests/expr_and_series/cum_max_test.py index 22b7c73fa..3df5a6ad4 100644 --- a/tests/expr_and_series/cum_max_test.py +++ b/tests/expr_and_series/cum_max_test.py @@ -23,7 +23,7 @@ def test_cum_max_expr( ) -> None: if "dask" in str(constructor) and reverse: request.applymarker(pytest.mark.xfail) - if "duckdb" in str(constructor): + if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) if PYARROW_VERSION < (13, 0, 0) and "pyarrow_table" in str(constructor): diff --git a/tests/expr_and_series/cum_min_test.py b/tests/expr_and_series/cum_min_test.py index b34672219..a758dc8b4 100644 --- a/tests/expr_and_series/cum_min_test.py +++ b/tests/expr_and_series/cum_min_test.py @@ -23,7 +23,7 @@ def test_cum_min_expr( ) -> None: if "dask" in str(constructor) and reverse: request.applymarker(pytest.mark.xfail) - if "duckdb" in str(constructor): + if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) if PYARROW_VERSION < (13, 0, 0) and "pyarrow_table" in str(constructor): diff --git a/tests/expr_and_series/cum_prod_test.py b/tests/expr_and_series/cum_prod_test.py index 4dd5207dc..2d6861b8d 100644 --- a/tests/expr_and_series/cum_prod_test.py +++ b/tests/expr_and_series/cum_prod_test.py @@ -23,7 +23,7 @@ def test_cum_prod_expr( ) -> None: if "dask" in str(constructor) and reverse: request.applymarker(pytest.mark.xfail) - if "duckdb" in str(constructor): + if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) if PYARROW_VERSION < (13, 0, 0) and "pyarrow_table" in str(constructor): diff --git a/tests/expr_and_series/cum_sum_test.py b/tests/expr_and_series/cum_sum_test.py index 5878222fb..8a419c9a9 100644 --- a/tests/expr_and_series/cum_sum_test.py +++ b/tests/expr_and_series/cum_sum_test.py @@ -18,7 +18,7 @@ def test_cum_sum_expr( request: pytest.FixtureRequest, constructor: Constructor, *, reverse: bool ) -> None: - if "duckdb" in str(constructor): + if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) if "dask" in str(constructor) and reverse: request.applymarker(pytest.mark.xfail) diff --git a/tests/expr_and_series/dt/datetime_attributes_test.py b/tests/expr_and_series/dt/datetime_attributes_test.py index e1af276e4..9f578d3c1 100644 --- a/tests/expr_and_series/dt/datetime_attributes_test.py +++ b/tests/expr_and_series/dt/datetime_attributes_test.py @@ -51,6 +51,8 @@ def test_datetime_attributes( request.applymarker(pytest.mark.xfail) if "duckdb" in str(constructor) and attribute in ("date", "weekday", "ordinal_day"): request.applymarker(pytest.mark.xfail) + if "pyspark" in str(constructor): + request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) result = df.select(getattr(nw.col("a").dt, attribute)()) @@ -121,6 +123,7 @@ def test_to_date(request: pytest.FixtureRequest, constructor: Constructor) -> No "cudf", "modin_constructor", "duckdb", + "pyspark", ) ): request.applymarker(pytest.mark.xfail) diff --git a/tests/expr_and_series/dt/datetime_duration_test.py b/tests/expr_and_series/dt/datetime_duration_test.py index bda3e4703..7ec281daa 100644 --- a/tests/expr_and_series/dt/datetime_duration_test.py +++ b/tests/expr_and_series/dt/datetime_duration_test.py @@ -46,7 +46,7 @@ def test_duration_attributes( ) -> None: if PANDAS_VERSION < (2, 2) and "pandas_pyarrow" in str(constructor): request.applymarker(pytest.mark.xfail) - if "duckdb" in str(constructor): + if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) diff --git a/tests/expr_and_series/dt/to_string_test.py b/tests/expr_and_series/dt/to_string_test.py index 6fa500024..3cc3f0edd 100644 --- a/tests/expr_and_series/dt/to_string_test.py +++ b/tests/expr_and_series/dt/to_string_test.py @@ -62,7 +62,7 @@ def test_dt_to_string_series(constructor_eager: ConstructorEager, fmt: str) -> N def test_dt_to_string_expr( constructor: Constructor, fmt: str, request: pytest.FixtureRequest ) -> None: - if "duckdb" in str(constructor): + if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) input_frame = nw.from_native(constructor(data)) @@ -141,7 +141,7 @@ def test_dt_to_string_iso_local_datetime_expr( expected: str, request: pytest.FixtureRequest, ) -> None: - if "duckdb" in str(constructor): + if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) df = constructor({"a": [data]}) @@ -180,7 +180,7 @@ def test_dt_to_string_iso_local_date_expr( expected: str, request: pytest.FixtureRequest, ) -> None: - if "duckdb" in str(constructor): + if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) df = constructor({"a": [data]}) result = nw.from_native(df).with_columns( diff --git a/tests/expr_and_series/fill_null_test.py b/tests/expr_and_series/fill_null_test.py index 58ef5c890..39b0a3c64 100644 --- a/tests/expr_and_series/fill_null_test.py +++ b/tests/expr_and_series/fill_null_test.py @@ -12,7 +12,9 @@ from tests.utils import assert_equal_data -def test_fill_null(constructor: Constructor) -> None: +def test_fill_null(request: pytest.FixtureRequest, constructor: Constructor) -> None: + if "pyspark" in str(constructor): + request.applymarker(pytest.mark.xfail) data = { "a": [0.0, None, 2, 3, 4], "b": [1.0, None, None, 5, 3], @@ -50,7 +52,7 @@ def test_fill_null_exceptions(constructor: Constructor) -> None: def test_fill_null_strategies_with_limit_as_none( constructor: Constructor, request: pytest.FixtureRequest ) -> None: - if "duckdb" in str(constructor): + if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) data_limits = { "a": [1, None, None, None, 5, 6, None, None, None, 10], @@ -120,7 +122,7 @@ def test_fill_null_strategies_with_limit_as_none( def test_fill_null_limits( constructor: Constructor, request: pytest.FixtureRequest ) -> None: - if "duckdb" in str(constructor): + if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) context: Any = ( pytest.raises(NotImplementedError, match="The limit keyword is not supported") diff --git a/tests/expr_and_series/is_duplicated_test.py b/tests/expr_and_series/is_duplicated_test.py index fe8b45bf1..d97d30cbd 100644 --- a/tests/expr_and_series/is_duplicated_test.py +++ b/tests/expr_and_series/is_duplicated_test.py @@ -11,7 +11,7 @@ def test_is_duplicated_expr( constructor: Constructor, request: pytest.FixtureRequest ) -> None: - if "duckdb" in str(constructor): + if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) data = {"a": [1, 1, 2], "b": [1, 2, 3], "index": [0, 1, 2]} df = nw.from_native(constructor(data)) @@ -23,7 +23,7 @@ def test_is_duplicated_expr( def test_is_duplicated_w_nulls_expr( constructor: Constructor, request: pytest.FixtureRequest ) -> None: - if "duckdb" in str(constructor): + if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) data = {"a": [1, 1, None], "b": [1, None, None], "index": [0, 1, 2]} df = nw.from_native(constructor(data)) diff --git a/tests/expr_and_series/is_finite_test.py b/tests/expr_and_series/is_finite_test.py index 7718ed1a7..4fb0246e9 100644 --- a/tests/expr_and_series/is_finite_test.py +++ b/tests/expr_and_series/is_finite_test.py @@ -12,7 +12,7 @@ @pytest.mark.filterwarnings("ignore:invalid value encountered in cast") def test_is_finite_expr(constructor: Constructor, request: pytest.FixtureRequest) -> None: - if "duckdb" in str(constructor): + if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) if "polars" in str(constructor) or "pyarrow_table" in str(constructor): expected = {"a": [False, False, True, None]} diff --git a/tests/expr_and_series/is_first_distinct_test.py b/tests/expr_and_series/is_first_distinct_test.py index 786f2ade7..6870c3394 100644 --- a/tests/expr_and_series/is_first_distinct_test.py +++ b/tests/expr_and_series/is_first_distinct_test.py @@ -16,7 +16,7 @@ def test_is_first_distinct_expr( constructor: Constructor, request: pytest.FixtureRequest ) -> None: - if "duckdb" in str(constructor): + if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) result = df.select(nw.all().is_first_distinct()) diff --git a/tests/expr_and_series/is_last_distinct_test.py b/tests/expr_and_series/is_last_distinct_test.py index c5d73c8d7..9362cd02a 100644 --- a/tests/expr_and_series/is_last_distinct_test.py +++ b/tests/expr_and_series/is_last_distinct_test.py @@ -16,7 +16,7 @@ def test_is_last_distinct_expr( constructor: Constructor, request: pytest.FixtureRequest ) -> None: - if "duckdb" in str(constructor): + if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) result = df.select(nw.all().is_last_distinct()) diff --git a/tests/expr_and_series/is_nan_test.py b/tests/expr_and_series/is_nan_test.py index 7bae35a52..0280d6555 100644 --- a/tests/expr_and_series/is_nan_test.py +++ b/tests/expr_and_series/is_nan_test.py @@ -25,7 +25,7 @@ def test_nan(constructor: Constructor, request: pytest.FixtureRequest) -> None: - if "duckdb" in str(constructor): + if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) data_na = {"int": [0, 1, None]} df = nw.from_native(constructor(data_na)).with_columns( @@ -96,7 +96,7 @@ def test_nan_series(constructor_eager: ConstructorEager) -> None: def test_nan_non_float(constructor: Constructor, request: pytest.FixtureRequest) -> None: - if "duckdb" in str(constructor): + if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) from polars.exceptions import InvalidOperationError as PlInvalidOperationError from pyarrow.lib import ArrowNotImplementedError diff --git a/tests/expr_and_series/is_null_test.py b/tests/expr_and_series/is_null_test.py index 5d5250da9..cf4d2e73b 100644 --- a/tests/expr_and_series/is_null_test.py +++ b/tests/expr_and_series/is_null_test.py @@ -1,12 +1,17 @@ from __future__ import annotations +import pytest + import narwhals.stable.v1 as nw from tests.utils import Constructor from tests.utils import ConstructorEager from tests.utils import assert_equal_data -def test_null(constructor: Constructor) -> None: +def test_null(request: pytest.FixtureRequest, constructor: Constructor) -> None: + if "pyspark" in str(constructor): + request.applymarker(pytest.mark.xfail) + data_na = {"a": [None, 3, 2], "z": [7.0, None, None]} expected = {"a": [True, False, False], "z": [True, False, False]} df = nw.from_native(constructor(data_na)) diff --git a/tests/expr_and_series/is_unique_test.py b/tests/expr_and_series/is_unique_test.py index 3e9259c03..92e725623 100644 --- a/tests/expr_and_series/is_unique_test.py +++ b/tests/expr_and_series/is_unique_test.py @@ -9,7 +9,7 @@ def test_is_unique_expr(constructor: Constructor, request: pytest.FixtureRequest) -> None: - if "duckdb" in str(constructor): + if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) data = { "a": [1, 1, 2], @@ -29,7 +29,7 @@ def test_is_unique_expr(constructor: Constructor, request: pytest.FixtureRequest def test_is_unique_w_nulls_expr( constructor: Constructor, request: pytest.FixtureRequest ) -> None: - if "duckdb" in str(constructor): + if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) data = { "a": [None, 1, 2], diff --git a/tests/expr_and_series/len_test.py b/tests/expr_and_series/len_test.py index fffcbd4a3..142fe488b 100644 --- a/tests/expr_and_series/len_test.py +++ b/tests/expr_and_series/len_test.py @@ -34,7 +34,10 @@ def test_len_chaining( assert_equal_data(df, expected) -def test_namespace_len(constructor: Constructor) -> None: +def test_namespace_len(request: pytest.FixtureRequest, constructor: Constructor) -> None: + if "pyspark" in str(constructor): + request.applymarker(pytest.mark.xfail) + df = nw.from_native(constructor({"a": [1, 2, 3], "b": [4, 5, 6]})).select( nw.len(), a=nw.len() ) diff --git a/tests/expr_and_series/list/len_test.py b/tests/expr_and_series/list/len_test.py index 7066fc6cf..375cfc7d8 100644 --- a/tests/expr_and_series/list/len_test.py +++ b/tests/expr_and_series/list/len_test.py @@ -17,7 +17,9 @@ def test_len_expr( request: pytest.FixtureRequest, constructor: Constructor, ) -> None: - if any(backend in str(constructor) for backend in ("dask", "modin", "cudf")): + if any( + backend in str(constructor) for backend in ("dask", "modin", "cudf", "pyspark") + ): request.applymarker(pytest.mark.xfail) if "pandas" in str(constructor) and PANDAS_VERSION < (2, 2): diff --git a/tests/expr_and_series/lit_test.py b/tests/expr_and_series/lit_test.py index 505d99bf8..f24e6d4a1 100644 --- a/tests/expr_and_series/lit_test.py +++ b/tests/expr_and_series/lit_test.py @@ -22,8 +22,13 @@ [(None, [2, 2, 2]), (nw.String, ["2", "2", "2"]), (nw.Float32, [2.0, 2.0, 2.0])], ) def test_lit( - constructor: Constructor, dtype: DType | None, expected_lit: list[Any] + request: pytest.FixtureRequest, + constructor: Constructor, + dtype: DType | None, + expected_lit: list[Any], ) -> None: + if "pyspark" in str(constructor) and dtype is not None: + request.applymarker(pytest.mark.xfail) data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} df_raw = constructor(data) df = nw.from_native(df_raw).lazy() @@ -100,6 +105,14 @@ def test_lit_operation( and DASK_VERSION < (2024, 10) ): request.applymarker(pytest.mark.xfail) + if "pyspark" in str(constructor) and col_name in { + "left_lit_with_agg", + "left_scalar_with_agg", + "right_lit_with_agg", + "right_lit", + }: + request.applymarker(pytest.mark.xfail) + data = {"a": [1, 3, 2]} df_raw = constructor(data) df = nw.from_native(df_raw).lazy() @@ -110,7 +123,7 @@ def test_lit_operation( @pytest.mark.skipif(PANDAS_VERSION < (1, 5), reason="too old for pyarrow") def test_date_lit(constructor: Constructor, request: pytest.FixtureRequest) -> None: - if "dask" in str(constructor): + if "dask" in str(constructor) or "pyspark" in str(constructor): # https://github.com/dask/dask/issues/11637 request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor({"a": [1]})) diff --git a/tests/expr_and_series/max_horizontal_test.py b/tests/expr_and_series/max_horizontal_test.py index c86e11318..9df17fed3 100644 --- a/tests/expr_and_series/max_horizontal_test.py +++ b/tests/expr_and_series/max_horizontal_test.py @@ -14,7 +14,12 @@ @pytest.mark.parametrize("col_expr", [nw.col("a"), "a"]) @pytest.mark.filterwarnings(r"ignore:.*All-NaN slice encountered:RuntimeWarning") -def test_maxh(constructor: Constructor, col_expr: Any) -> None: +def test_maxh( + request: pytest.FixtureRequest, constructor: Constructor, col_expr: Any +) -> None: + if "pyspark" in str(constructor): + request.applymarker(pytest.mark.xfail) + df = nw.from_native(constructor(data)) result = df.select(horizontal_max=nw.max_horizontal(col_expr, nw.col("b"), "z")) expected = {"horizontal_max": expected_values} @@ -22,7 +27,10 @@ def test_maxh(constructor: Constructor, col_expr: Any) -> None: @pytest.mark.filterwarnings(r"ignore:.*All-NaN slice encountered:RuntimeWarning") -def test_maxh_all(constructor: Constructor) -> None: +def test_maxh_all(request: pytest.FixtureRequest, constructor: Constructor) -> None: + if "pyspark" in str(constructor): + request.applymarker(pytest.mark.xfail) + df = nw.from_native(constructor(data)) result = df.select(nw.max_horizontal(nw.all()), c=nw.max_horizontal(nw.all())) expected = {"a": expected_values, "c": expected_values} diff --git a/tests/expr_and_series/mean_horizontal_test.py b/tests/expr_and_series/mean_horizontal_test.py index c1652c837..5ed472e31 100644 --- a/tests/expr_and_series/mean_horizontal_test.py +++ b/tests/expr_and_series/mean_horizontal_test.py @@ -13,7 +13,7 @@ def test_meanh( constructor: Constructor, col_expr: Any, request: pytest.FixtureRequest ) -> None: - if "duckdb" in str(constructor): + if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) data = {"a": [1, 3, None, None], "b": [4, None, 6, None]} df = nw.from_native(constructor(data)) @@ -23,7 +23,7 @@ def test_meanh( def test_meanh_all(constructor: Constructor, request: pytest.FixtureRequest) -> None: - if "duckdb" in str(constructor): + if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) data = {"a": [2, 4, 6], "b": [10, 20, 30]} df = nw.from_native(constructor(data)) diff --git a/tests/expr_and_series/median_test.py b/tests/expr_and_series/median_test.py index b0b6edcba..9c509a182 100644 --- a/tests/expr_and_series/median_test.py +++ b/tests/expr_and_series/median_test.py @@ -43,7 +43,7 @@ def test_median_series( def test_median_expr_raises_on_str( constructor: Constructor, expr: nw.Expr, request: pytest.FixtureRequest ) -> None: - if "duckdb" in str(constructor): + if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) from polars.exceptions import InvalidOperationError as PlInvalidOperationError diff --git a/tests/expr_and_series/min_horizontal_test.py b/tests/expr_and_series/min_horizontal_test.py index 787e3e2a4..bbb0b9149 100644 --- a/tests/expr_and_series/min_horizontal_test.py +++ b/tests/expr_and_series/min_horizontal_test.py @@ -14,7 +14,12 @@ @pytest.mark.parametrize("col_expr", [nw.col("a"), "a"]) @pytest.mark.filterwarnings(r"ignore:.*All-NaN slice encountered:RuntimeWarning") -def test_minh(constructor: Constructor, col_expr: Any) -> None: +def test_minh( + request: pytest.FixtureRequest, constructor: Constructor, col_expr: Any +) -> None: + if "pyspark" in str(constructor): + request.applymarker(pytest.mark.xfail) + df = nw.from_native(constructor(data)) result = df.select(horizontal_min=nw.min_horizontal(col_expr, nw.col("b"), "z")) expected = {"horizontal_min": expected_values} @@ -22,7 +27,10 @@ def test_minh(constructor: Constructor, col_expr: Any) -> None: @pytest.mark.filterwarnings(r"ignore:.*All-NaN slice encountered:RuntimeWarning") -def test_minh_all(constructor: Constructor) -> None: +def test_minh_all(request: pytest.FixtureRequest, constructor: Constructor) -> None: + if "pyspark" in str(constructor): + request.applymarker(pytest.mark.xfail) + df = nw.from_native(constructor(data)) result = df.select(nw.min_horizontal(nw.all()), c=nw.min_horizontal(nw.all())) expected = {"a": expected_values, "c": expected_values} diff --git a/tests/expr_and_series/n_unique_test.py b/tests/expr_and_series/n_unique_test.py index 90bffb04b..cfa14e0d7 100644 --- a/tests/expr_and_series/n_unique_test.py +++ b/tests/expr_and_series/n_unique_test.py @@ -1,5 +1,7 @@ from __future__ import annotations +import pytest + import narwhals.stable.v1 as nw from tests.utils import Constructor from tests.utils import ConstructorEager @@ -11,7 +13,9 @@ } -def test_n_unique(constructor: Constructor) -> None: +def test_n_unique(constructor: Constructor, request: pytest.FixtureRequest) -> None: + if "pyspark" in str(constructor): + request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) result = df.select(nw.all().n_unique()) expected = {"a": [3], "b": [4]} diff --git a/tests/expr_and_series/name/keep_test.py b/tests/expr_and_series/name/keep_test.py index 6c89d09fc..e382db733 100644 --- a/tests/expr_and_series/name/keep_test.py +++ b/tests/expr_and_series/name/keep_test.py @@ -12,21 +12,34 @@ data = {"foo": [1, 2, 3], "BAR": [4, 5, 6]} -def test_keep(constructor: Constructor) -> None: +def test_keep(request: pytest.FixtureRequest, constructor: Constructor) -> None: + if "pyspark" in str(constructor): + request.applymarker(pytest.mark.xfail) + df = nw.from_native(constructor(data)) result = df.select((nw.col("foo", "BAR") * 2).name.keep()) expected = {k: [e * 2 for e in v] for k, v in data.items()} assert_equal_data(result, expected) -def test_keep_after_alias(constructor: Constructor) -> None: +def test_keep_after_alias( + request: pytest.FixtureRequest, constructor: Constructor +) -> None: + if "pyspark" in str(constructor): + request.applymarker(pytest.mark.xfail) + df = nw.from_native(constructor(data)) result = df.select((nw.col("foo")).alias("alias_for_foo").name.keep()) expected = {"foo": data["foo"]} assert_equal_data(result, expected) -def test_keep_raise_anonymous(constructor: Constructor) -> None: +def test_keep_raise_anonymous( + request: pytest.FixtureRequest, constructor: Constructor +) -> None: + if "pyspark" in str(constructor): + request.applymarker(pytest.mark.xfail) + df_raw = constructor(data) df = nw.from_native(df_raw) diff --git a/tests/expr_and_series/name/map_test.py b/tests/expr_and_series/name/map_test.py index 5afda2ee8..276138ef9 100644 --- a/tests/expr_and_series/name/map_test.py +++ b/tests/expr_and_series/name/map_test.py @@ -16,21 +16,34 @@ def map_func(s: str | None) -> str: return str(s)[::-1].lower() -def test_map(constructor: Constructor) -> None: +def test_map(request: pytest.FixtureRequest, constructor: Constructor) -> None: + if "pyspark" in str(constructor): + request.applymarker(pytest.mark.xfail) + df = nw.from_native(constructor(data)) result = df.select((nw.col("foo", "BAR") * 2).name.map(function=map_func)) expected = {map_func(k): [e * 2 for e in v] for k, v in data.items()} assert_equal_data(result, expected) -def test_map_after_alias(constructor: Constructor) -> None: +def test_map_after_alias( + request: pytest.FixtureRequest, constructor: Constructor +) -> None: + if "pyspark" in str(constructor): + request.applymarker(pytest.mark.xfail) + df = nw.from_native(constructor(data)) result = df.select((nw.col("foo")).alias("alias_for_foo").name.map(function=map_func)) expected = {map_func("foo"): data["foo"]} assert_equal_data(result, expected) -def test_map_raise_anonymous(constructor: Constructor) -> None: +def test_map_raise_anonymous( + request: pytest.FixtureRequest, constructor: Constructor +) -> None: + if "pyspark" in str(constructor): + request.applymarker(pytest.mark.xfail) + df_raw = constructor(data) df = nw.from_native(df_raw) diff --git a/tests/expr_and_series/name/prefix_test.py b/tests/expr_and_series/name/prefix_test.py index 6f3fb3c9b..934d1d664 100644 --- a/tests/expr_and_series/name/prefix_test.py +++ b/tests/expr_and_series/name/prefix_test.py @@ -13,21 +13,34 @@ prefix = "with_prefix_" -def test_prefix(constructor: Constructor) -> None: +def test_prefix(request: pytest.FixtureRequest, constructor: Constructor) -> None: + if "pyspark" in str(constructor): + request.applymarker(pytest.mark.xfail) + df = nw.from_native(constructor(data)) result = df.select((nw.col("foo", "BAR") * 2).name.prefix(prefix)) expected = {prefix + str(k): [e * 2 for e in v] for k, v in data.items()} assert_equal_data(result, expected) -def test_suffix_after_alias(constructor: Constructor) -> None: +def test_suffix_after_alias( + request: pytest.FixtureRequest, constructor: Constructor +) -> None: + if "pyspark" in str(constructor): + request.applymarker(pytest.mark.xfail) + df = nw.from_native(constructor(data)) result = df.select((nw.col("foo")).alias("alias_for_foo").name.prefix(prefix)) expected = {prefix + "foo": data["foo"]} assert_equal_data(result, expected) -def test_prefix_raise_anonymous(constructor: Constructor) -> None: +def test_prefix_raise_anonymous( + request: pytest.FixtureRequest, constructor: Constructor +) -> None: + if "pyspark" in str(constructor): + request.applymarker(pytest.mark.xfail) + df_raw = constructor(data) df = nw.from_native(df_raw) diff --git a/tests/expr_and_series/name/suffix_test.py b/tests/expr_and_series/name/suffix_test.py index 1c5816154..479546630 100644 --- a/tests/expr_and_series/name/suffix_test.py +++ b/tests/expr_and_series/name/suffix_test.py @@ -13,21 +13,34 @@ suffix = "_with_suffix" -def test_suffix(constructor: Constructor) -> None: +def test_suffix(request: pytest.FixtureRequest, constructor: Constructor) -> None: + if "pyspark" in str(constructor): + request.applymarker(pytest.mark.xfail) + df = nw.from_native(constructor(data)) result = df.select((nw.col("foo", "BAR") * 2).name.suffix(suffix)) expected = {str(k) + suffix: [e * 2 for e in v] for k, v in data.items()} assert_equal_data(result, expected) -def test_suffix_after_alias(constructor: Constructor) -> None: +def test_suffix_after_alias( + request: pytest.FixtureRequest, constructor: Constructor +) -> None: + if "pyspark" in str(constructor): + request.applymarker(pytest.mark.xfail) + df = nw.from_native(constructor(data)) result = df.select((nw.col("foo")).alias("alias_for_foo").name.suffix(suffix)) expected = {"foo" + suffix: data["foo"]} assert_equal_data(result, expected) -def test_suffix_raise_anonymous(constructor: Constructor) -> None: +def test_suffix_raise_anonymous( + request: pytest.FixtureRequest, constructor: Constructor +) -> None: + if "pyspark" in str(constructor): + request.applymarker(pytest.mark.xfail) + df_raw = constructor(data) df = nw.from_native(df_raw) diff --git a/tests/expr_and_series/name/to_lowercase_test.py b/tests/expr_and_series/name/to_lowercase_test.py index 882663f60..1b39fc726 100644 --- a/tests/expr_and_series/name/to_lowercase_test.py +++ b/tests/expr_and_series/name/to_lowercase_test.py @@ -12,21 +12,34 @@ data = {"foo": [1, 2, 3], "BAR": [4, 5, 6]} -def test_to_lowercase(constructor: Constructor) -> None: +def test_to_lowercase(request: pytest.FixtureRequest, constructor: Constructor) -> None: + if "pyspark" in str(constructor): + request.applymarker(pytest.mark.xfail) + df = nw.from_native(constructor(data)) result = df.select((nw.col("foo", "BAR") * 2).name.to_lowercase()) expected = {k.lower(): [e * 2 for e in v] for k, v in data.items()} assert_equal_data(result, expected) -def test_to_lowercase_after_alias(constructor: Constructor) -> None: +def test_to_lowercase_after_alias( + request: pytest.FixtureRequest, constructor: Constructor +) -> None: + if "pyspark" in str(constructor): + request.applymarker(pytest.mark.xfail) + df = nw.from_native(constructor(data)) result = df.select((nw.col("BAR")).alias("ALIAS_FOR_BAR").name.to_lowercase()) expected = {"bar": data["BAR"]} assert_equal_data(result, expected) -def test_to_lowercase_raise_anonymous(constructor: Constructor) -> None: +def test_to_lowercase_raise_anonymous( + request: pytest.FixtureRequest, constructor: Constructor +) -> None: + if "pyspark" in str(constructor): + request.applymarker(pytest.mark.xfail) + df_raw = constructor(data) df = nw.from_native(df_raw) diff --git a/tests/expr_and_series/nth_test.py b/tests/expr_and_series/nth_test.py index 4dd453528..a7dc7f648 100644 --- a/tests/expr_and_series/nth_test.py +++ b/tests/expr_and_series/nth_test.py @@ -25,7 +25,7 @@ def test_nth( expected: dict[str, list[int]], request: pytest.FixtureRequest, ) -> None: - if "duckdb" in str(constructor): + if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) if "polars" in str(constructor) and POLARS_VERSION < (1, 0, 0): request.applymarker(pytest.mark.xfail) diff --git a/tests/expr_and_series/null_count_test.py b/tests/expr_and_series/null_count_test.py index d10258901..3bd15c66c 100644 --- a/tests/expr_and_series/null_count_test.py +++ b/tests/expr_and_series/null_count_test.py @@ -16,7 +16,7 @@ def test_null_count_expr( constructor: Constructor, request: pytest.FixtureRequest ) -> None: - if "duckdb" in str(constructor): + if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) result = df.select(nw.all().null_count()) diff --git a/tests/expr_and_series/operators_test.py b/tests/expr_and_series/operators_test.py index 356d81d5b..f36d853d4 100644 --- a/tests/expr_and_series/operators_test.py +++ b/tests/expr_and_series/operators_test.py @@ -21,7 +21,9 @@ ], ) def test_comparand_operators_scalar_expr( - constructor: Constructor, operator: str, expected: list[bool] + constructor: Constructor, + operator: str, + expected: list[bool], ) -> None: data = {"a": [0, 1, 2]} df = nw.from_native(constructor(data)) @@ -41,7 +43,9 @@ def test_comparand_operators_scalar_expr( ], ) def test_comparand_operators_expr( - constructor: Constructor, operator: str, expected: list[bool] + constructor: Constructor, + operator: str, + expected: list[bool], ) -> None: data = {"a": [0, 1, 1], "b": [0, 0, 2]} df = nw.from_native(constructor(data)) @@ -57,7 +61,9 @@ def test_comparand_operators_expr( ], ) def test_logic_operators_expr( - constructor: Constructor, operator: str, expected: list[bool] + constructor: Constructor, + operator: str, + expected: list[bool], ) -> None: data = {"a": [True, True, False, False], "b": [True, False, True, False]} df = nw.from_native(constructor(data)) @@ -85,7 +91,7 @@ def test_logic_operators_expr_scalar( "dask" in str(constructor) and DASK_VERSION < (2024, 10) and operator in ("__rand__", "__ror__") - ): + ) or ("pyspark" in str(constructor) and operator in ("__and__", "__or__")): request.applymarker(pytest.mark.xfail) data = {"a": [True, True, False, False]} df = nw.from_native(constructor(data)) diff --git a/tests/expr_and_series/over_test.py b/tests/expr_and_series/over_test.py index f42bdca54..45b64eba0 100644 --- a/tests/expr_and_series/over_test.py +++ b/tests/expr_and_series/over_test.py @@ -24,7 +24,7 @@ def test_over_single(request: pytest.FixtureRequest, constructor: Constructor) -> None: if "dask_lazy_p2" in str(constructor): request.applymarker(pytest.mark.xfail) - if "duckdb" in str(constructor): + if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) @@ -42,7 +42,7 @@ def test_over_single(request: pytest.FixtureRequest, constructor: Constructor) - def test_over_multiple(request: pytest.FixtureRequest, constructor: Constructor) -> None: if "dask_lazy_p2" in str(constructor): request.applymarker(pytest.mark.xfail) - if "duckdb" in str(constructor): + if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) @@ -60,7 +60,7 @@ def test_over_multiple(request: pytest.FixtureRequest, constructor: Constructor) def test_over_invalid(request: pytest.FixtureRequest, constructor: Constructor) -> None: if "polars" in str(constructor): request.applymarker(pytest.mark.xfail) - if "duckdb" in str(constructor): + if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) @@ -73,7 +73,7 @@ def test_over_cumsum(request: pytest.FixtureRequest, constructor: Constructor) - request.applymarker(pytest.mark.xfail) if "pandas_pyarrow" in str(constructor) and PANDAS_VERSION < (2, 1): request.applymarker(pytest.mark.xfail) - if "duckdb" in str(constructor): + if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data_cum)) @@ -92,7 +92,7 @@ def test_over_cumsum(request: pytest.FixtureRequest, constructor: Constructor) - def test_over_cumcount(request: pytest.FixtureRequest, constructor: Constructor) -> None: if "pyarrow_table" in str(constructor) or "dask_lazy_p2" in str(constructor): request.applymarker(pytest.mark.xfail) - if "duckdb" in str(constructor): + if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data_cum)) @@ -115,7 +115,7 @@ def test_over_cummax(request: pytest.FixtureRequest, constructor: Constructor) - request.applymarker(pytest.mark.xfail) if "pandas_pyarrow" in str(constructor) and PANDAS_VERSION < (2, 1): request.applymarker(pytest.mark.xfail) - if "duckdb" in str(constructor): + if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data_cum)) expected = { @@ -134,7 +134,7 @@ def test_over_cummin(request: pytest.FixtureRequest, constructor: Constructor) - request.applymarker(pytest.mark.xfail) if "pandas_pyarrow" in str(constructor) and PANDAS_VERSION < (2, 1): request.applymarker(pytest.mark.xfail) - if "duckdb" in str(constructor): + if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data_cum)) @@ -155,7 +155,7 @@ def test_over_cumprod(request: pytest.FixtureRequest, constructor: Constructor) request.applymarker(pytest.mark.xfail) if "pandas_pyarrow" in str(constructor) and PANDAS_VERSION < (2, 1): request.applymarker(pytest.mark.xfail) - if "duckdb" in str(constructor): + if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data_cum)) @@ -184,7 +184,7 @@ def test_over_shift(request: pytest.FixtureRequest, constructor: Constructor) -> constructor ) or "dask_lazy_p2_constructor" in str(constructor): request.applymarker(pytest.mark.xfail) - if "duckdb" in str(constructor): + if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) diff --git a/tests/expr_and_series/quantile_test.py b/tests/expr_and_series/quantile_test.py index d52fae16c..a9207cebd 100644 --- a/tests/expr_and_series/quantile_test.py +++ b/tests/expr_and_series/quantile_test.py @@ -31,7 +31,7 @@ def test_quantile_expr( if ( any(x in str(constructor) for x in ("dask", "duckdb")) and interpolation != "linear" - ): + ) or "pyspark" in str(constructor): request.applymarker(pytest.mark.xfail) q = 0.3 diff --git a/tests/expr_and_series/reduction_test.py b/tests/expr_and_series/reduction_test.py index 4f2faa0ce..49a3fddba 100644 --- a/tests/expr_and_series/reduction_test.py +++ b/tests/expr_and_series/reduction_test.py @@ -28,11 +28,21 @@ ids=range(5), ) def test_scalar_reduction_select( - constructor: Constructor, expr: list[Any], expected: dict[str, list[Any]] + constructor: Constructor, + expr: list[Any], + expected: dict[str, list[Any]], + request: pytest.FixtureRequest, ) -> None: - if "duckdb" in str(constructor): - # First one passes, the others fail. - return + if "pyspark" in str(constructor) and request.node.callspec.id in { + "pyspark-2", + "pyspark-3", + "pyspark-4", + }: + request.applymarker(pytest.mark.xfail) + + if "duckdb" in str(constructor) and request.node.callspec.id not in {"duckdb-0"}: + request.applymarker(pytest.mark.xfail) + data = {"a": [1, 2, 3], "b": [4, 5, 6]} df = nw.from_native(constructor(data)) result = df.select(*expr) @@ -62,7 +72,9 @@ def test_scalar_reduction_with_columns( expected: dict[str, list[Any]], request: pytest.FixtureRequest, ) -> None: - if "duckdb" in str(constructor): + if "duckdb" in str(constructor) or ( + "pyspark" in str(constructor) and request.node.callspec.id != "pyspark-1" + ): request.applymarker(pytest.mark.xfail) data = {"a": [1, 2, 3], "b": [4, 5, 6]} df = nw.from_native(constructor(data)) @@ -73,7 +85,7 @@ def test_scalar_reduction_with_columns( def test_empty_scalar_reduction_select( constructor: Constructor, request: pytest.FixtureRequest ) -> None: - if "duckdb" in str(constructor): + if "pyspark" in str(constructor) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) data = { "str": [*"abcde"], @@ -106,7 +118,7 @@ def test_empty_scalar_reduction_select( def test_empty_scalar_reduction_with_columns( constructor: Constructor, request: pytest.FixtureRequest ) -> None: - if "duckdb" in str(constructor): + if "pyspark" in str(constructor) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) from itertools import chain diff --git a/tests/expr_and_series/replace_strict_test.py b/tests/expr_and_series/replace_strict_test.py index 07e349bc6..33c56bae6 100644 --- a/tests/expr_and_series/replace_strict_test.py +++ b/tests/expr_and_series/replace_strict_test.py @@ -23,7 +23,7 @@ def test_replace_strict( ) -> None: if "dask" in str(constructor): request.applymarker(pytest.mark.xfail) - if "duckdb" in str(constructor): + if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor({"a": [1, 2, 3]})) result = df.select( @@ -60,7 +60,7 @@ def test_replace_non_full( if "dask" in str(constructor): request.applymarker(pytest.mark.xfail) - if "duckdb" in str(constructor): + if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor({"a": [1, 2, 3]})) if isinstance(df, nw.LazyFrame): @@ -81,7 +81,7 @@ def test_replace_strict_mapping( ) -> None: if "dask" in str(constructor): request.applymarker(pytest.mark.xfail) - if "duckdb" in str(constructor): + if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor({"a": [1, 2, 3]})) diff --git a/tests/expr_and_series/replace_time_zone_test.py b/tests/expr_and_series/replace_time_zone_test.py index 132c4efc5..6876c318a 100644 --- a/tests/expr_and_series/replace_time_zone_test.py +++ b/tests/expr_and_series/replace_time_zone_test.py @@ -27,6 +27,7 @@ def test_replace_time_zone( or ("pyarrow_table" in str(constructor) and PYARROW_VERSION < (12,)) or ("cudf" in str(constructor)) or ("duckdb" in str(constructor)) + or ("pyspark" in str(constructor)) ): request.applymarker(pytest.mark.xfail) data = { @@ -54,6 +55,7 @@ def test_replace_time_zone_none( or ("modin_pyarrow" in str(constructor) and PANDAS_VERSION < (2,)) or ("pyarrow_table" in str(constructor) and PYARROW_VERSION < (12,)) or ("duckdb" in str(constructor)) + or ("pyspark" in str(constructor)) ): request.applymarker(pytest.mark.xfail) data = { diff --git a/tests/expr_and_series/shift_test.py b/tests/expr_and_series/shift_test.py index 07f5d2b58..4f7894939 100644 --- a/tests/expr_and_series/shift_test.py +++ b/tests/expr_and_series/shift_test.py @@ -17,7 +17,7 @@ def test_shift(constructor: Constructor, request: pytest.FixtureRequest) -> None: - if "duckdb" in str(constructor): + if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) result = df.with_columns(nw.col("a", "b", "c").shift(2)).filter(nw.col("i") > 1) diff --git a/tests/expr_and_series/str/contains_test.py b/tests/expr_and_series/str/contains_test.py index 06c6913aa..c1024d53a 100644 --- a/tests/expr_and_series/str/contains_test.py +++ b/tests/expr_and_series/str/contains_test.py @@ -13,7 +13,7 @@ def test_contains_case_insensitive( constructor: Constructor, request: pytest.FixtureRequest ) -> None: - if "cudf" in str(constructor): + if "cudf" in str(constructor) or "pyspark" in str(constructor): request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) @@ -40,7 +40,12 @@ def test_contains_series_case_insensitive( assert_equal_data(result, expected) -def test_contains_case_sensitive(constructor: Constructor) -> None: +def test_contains_case_sensitive( + request: pytest.FixtureRequest, constructor: Constructor +) -> None: + if "pyspark" in str(constructor): + request.applymarker(pytest.mark.xfail) + df = nw.from_native(constructor(data)) result = df.select(nw.col("pets").str.contains("parrot|Dove").alias("default_match")) expected = { @@ -58,7 +63,12 @@ def test_contains_series_case_sensitive(constructor_eager: ConstructorEager) -> assert_equal_data(result, expected) -def test_contains_literal(constructor: Constructor) -> None: +def test_contains_literal( + request: pytest.FixtureRequest, constructor: Constructor +) -> None: + if "pyspark" in str(constructor): + request.applymarker(pytest.mark.xfail) + df = nw.from_native(constructor(data)) result = df.select( nw.col("pets").str.contains("Parrot|dove").alias("default_match"), diff --git a/tests/expr_and_series/str/head_test.py b/tests/expr_and_series/str/head_test.py index cf6cbd758..97fbbc6f3 100644 --- a/tests/expr_and_series/str/head_test.py +++ b/tests/expr_and_series/str/head_test.py @@ -1,5 +1,7 @@ from __future__ import annotations +import pytest + import narwhals.stable.v1 as nw from tests.utils import Constructor from tests.utils import ConstructorEager @@ -8,7 +10,10 @@ data = {"a": ["foo", "bars"]} -def test_str_head(constructor: Constructor) -> None: +def test_str_head(request: pytest.FixtureRequest, constructor: Constructor) -> None: + if "pyspark" in str(constructor): + request.applymarker(pytest.mark.xfail) + df = nw.from_native(constructor(data)) result = df.select(nw.col("a").str.head(3)) expected = { diff --git a/tests/expr_and_series/str/len_chars_test.py b/tests/expr_and_series/str/len_chars_test.py index 1a318801a..812f193b2 100644 --- a/tests/expr_and_series/str/len_chars_test.py +++ b/tests/expr_and_series/str/len_chars_test.py @@ -11,7 +11,7 @@ def test_str_len_chars(constructor: Constructor, request: pytest.FixtureRequest) -> None: - if "duckdb" in str(constructor): + if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) result = df.select(nw.col("a").str.len_chars()) diff --git a/tests/expr_and_series/str/replace_test.py b/tests/expr_and_series/str/replace_test.py index 7d57eeb7d..53904be73 100644 --- a/tests/expr_and_series/str/replace_test.py +++ b/tests/expr_and_series/str/replace_test.py @@ -101,7 +101,7 @@ def test_str_replace_expr( literal: bool, # noqa: FBT001 expected: dict[str, list[str]], ) -> None: - if "duckdb" in str(constructor): + if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) result_df = df.select( @@ -123,7 +123,9 @@ def test_str_replace_all_expr( literal: bool, # noqa: FBT001 expected: dict[str, list[str]], ) -> None: - if "duckdb" in str(constructor) and literal is False: + if ("pyspark" in str(constructor)) or ( + "duckdb" in str(constructor) and literal is False + ): request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) result = df.select( diff --git a/tests/expr_and_series/str/slice_test.py b/tests/expr_and_series/str/slice_test.py index 1e7115a8a..6f9b4dc4f 100644 --- a/tests/expr_and_series/str/slice_test.py +++ b/tests/expr_and_series/str/slice_test.py @@ -17,8 +17,15 @@ [(1, 2, {"a": ["da", "df"]}), (-2, None, {"a": ["as", "as"]})], ) def test_str_slice( - constructor: Constructor, offset: int, length: int | None, expected: Any + request: pytest.FixtureRequest, + constructor: Constructor, + offset: int, + length: int | None, + expected: Any, ) -> None: + if "pyspark" in str(constructor): + request.applymarker(pytest.mark.xfail) + df = nw.from_native(constructor(data)) result_frame = df.select(nw.col("a").str.slice(offset, length)) assert_equal_data(result_frame, expected) diff --git a/tests/expr_and_series/str/starts_with_ends_with_test.py b/tests/expr_and_series/str/starts_with_ends_with_test.py index 0b11a7537..dac70c288 100644 --- a/tests/expr_and_series/str/starts_with_ends_with_test.py +++ b/tests/expr_and_series/str/starts_with_ends_with_test.py @@ -1,5 +1,7 @@ from __future__ import annotations +import pytest + import narwhals.stable.v1 as nw from tests.utils import Constructor from tests.utils import ConstructorEager @@ -11,7 +13,10 @@ data = {"a": ["fdas", "edfas"]} -def test_ends_with(constructor: Constructor) -> None: +def test_ends_with(request: pytest.FixtureRequest, constructor: Constructor) -> None: + if "pyspark" in str(constructor): + request.applymarker(pytest.mark.xfail) + df = nw.from_native(constructor(data)) result = df.select(nw.col("a").str.ends_with("das")) expected = { @@ -29,7 +34,10 @@ def test_ends_with_series(constructor_eager: ConstructorEager) -> None: assert_equal_data(result, expected) -def test_starts_with(constructor: Constructor) -> None: +def test_starts_with(request: pytest.FixtureRequest, constructor: Constructor) -> None: + if "pyspark" in str(constructor): + request.applymarker(pytest.mark.xfail) + df = nw.from_native(constructor(data)).lazy() result = df.select(nw.col("a").str.starts_with("fda")) expected = { diff --git a/tests/expr_and_series/str/strip_chars_test.py b/tests/expr_and_series/str/strip_chars_test.py index d765e99e3..f369bbbf9 100644 --- a/tests/expr_and_series/str/strip_chars_test.py +++ b/tests/expr_and_series/str/strip_chars_test.py @@ -20,8 +20,13 @@ ], ) def test_str_strip_chars( - constructor: Constructor, characters: str | None, expected: Any + request: pytest.FixtureRequest, + constructor: Constructor, + characters: str | None, + expected: Any, ) -> None: + if "pyspark" in str(constructor): + request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) result_frame = df.select(nw.col("a").str.strip_chars(characters)) assert_equal_data(result_frame, expected) diff --git a/tests/expr_and_series/str/tail_test.py b/tests/expr_and_series/str/tail_test.py index e2543de0a..cdb2c024e 100644 --- a/tests/expr_and_series/str/tail_test.py +++ b/tests/expr_and_series/str/tail_test.py @@ -1,5 +1,7 @@ from __future__ import annotations +import pytest + import narwhals.stable.v1 as nw from tests.utils import Constructor from tests.utils import ConstructorEager @@ -8,7 +10,9 @@ data = {"a": ["foo", "bars"]} -def test_str_tail(constructor: Constructor) -> None: +def test_str_tail(request: pytest.FixtureRequest, constructor: Constructor) -> None: + if "pyspark" in str(constructor): + request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) expected = {"a": ["foo", "ars"]} diff --git a/tests/expr_and_series/str/to_datetime_test.py b/tests/expr_and_series/str/to_datetime_test.py index 3f8df65a7..bfb2a4dfb 100644 --- a/tests/expr_and_series/str/to_datetime_test.py +++ b/tests/expr_and_series/str/to_datetime_test.py @@ -18,7 +18,7 @@ def test_to_datetime(constructor: Constructor, request: pytest.FixtureRequest) -> None: - if "duckdb" in str(constructor): + if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) if "cudf" in str(constructor): expected = "2020-01-01T12:34:56.000000000" @@ -80,7 +80,7 @@ def test_to_datetime_infer_fmt( request.applymarker(pytest.mark.xfail) if "cudf" in str(constructor): expected = expected_cudf - if "duckdb" in str(constructor): + if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) result = ( nw.from_native(constructor(data)) @@ -133,7 +133,7 @@ def test_to_datetime_series_infer_fmt( def test_to_datetime_infer_fmt_from_date( constructor: Constructor, request: pytest.FixtureRequest ) -> None: - if "duckdb" in str(constructor): + if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) data = {"z": ["2020-01-01", "2020-01-02", None]} expected = [datetime(2020, 1, 1), datetime(2020, 1, 2), None] diff --git a/tests/expr_and_series/str/to_uppercase_to_lowercase_test.py b/tests/expr_and_series/str/to_uppercase_to_lowercase_test.py index 1057b33de..087e26a0e 100644 --- a/tests/expr_and_series/str/to_uppercase_to_lowercase_test.py +++ b/tests/expr_and_series/str/to_uppercase_to_lowercase_test.py @@ -30,8 +30,8 @@ def test_str_to_uppercase( expected: dict[str, list[str]], request: pytest.FixtureRequest, ) -> None: - df = nw.from_native(constructor(data)) - result_frame = df.select(nw.col("a").str.to_uppercase()) + if "pyspark" in str(constructor): + request.applymarker(pytest.mark.xfail) if any("ß" in s for value in data.values() for s in value) & ( constructor.__name__ @@ -48,6 +48,9 @@ def test_str_to_uppercase( # smaller cap 'ß' to upper cap 'ẞ' instead of 'SS' request.applymarker(pytest.mark.xfail) + df = nw.from_native(constructor(data)) + result_frame = df.select(nw.col("a").str.to_uppercase()) + assert_equal_data(result_frame, expected) @@ -110,10 +113,13 @@ def test_str_to_uppercase_series( ], ) def test_str_to_lowercase( + request: pytest.FixtureRequest, constructor: Constructor, data: dict[str, list[str]], expected: dict[str, list[str]], ) -> None: + if "pyspark" in str(constructor): + request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) result_frame = df.select(nw.col("a").str.to_lowercase()) assert_equal_data(result_frame, expected) diff --git a/tests/expr_and_series/unary_test.py b/tests/expr_and_series/unary_test.py index f3e01d80f..82f616a64 100644 --- a/tests/expr_and_series/unary_test.py +++ b/tests/expr_and_series/unary_test.py @@ -11,7 +11,7 @@ def test_unary(constructor: Constructor, request: pytest.FixtureRequest) -> None: - if "duckdb" in str(constructor): + if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) data = { "a": [1, 3, 2], @@ -82,7 +82,7 @@ def test_unary_series(constructor_eager: ConstructorEager) -> None: def test_unary_two_elements( constructor: Constructor, request: pytest.FixtureRequest ) -> None: - if "duckdb" in str(constructor): + if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) data = {"a": [1, 2], "b": [2, 10], "c": [2.0, None]} result = nw.from_native(constructor(data)).select( @@ -126,7 +126,11 @@ def test_unary_two_elements_series(constructor_eager: ConstructorEager) -> None: assert_equal_data(result, expected) -def test_unary_one_element(constructor: Constructor) -> None: +def test_unary_one_element( + constructor: Constructor, request: pytest.FixtureRequest +) -> None: + if "pyspark" in str(constructor): + request.applymarker(pytest.mark.xfail) data = {"a": [1], "b": [2], "c": [None]} # Dask runs into a divide by zero RuntimeWarning for 1 element skew. context = ( diff --git a/tests/expr_and_series/when_test.py b/tests/expr_and_series/when_test.py index 94e37aaa3..0faf59172 100644 --- a/tests/expr_and_series/when_test.py +++ b/tests/expr_and_series/when_test.py @@ -17,7 +17,9 @@ } -def test_when(constructor: Constructor) -> None: +def test_when(constructor: Constructor, request: pytest.FixtureRequest) -> None: + if "pyspark" in str(constructor): + request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) result = df.select(nw.when(nw.col("a") == 1).then(value=3).alias("a_when")) expected = { @@ -26,7 +28,9 @@ def test_when(constructor: Constructor) -> None: assert_equal_data(result, expected) -def test_when_otherwise(constructor: Constructor) -> None: +def test_when_otherwise(constructor: Constructor, request: pytest.FixtureRequest) -> None: + if "pyspark" in str(constructor): + request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) result = df.select(nw.when(nw.col("a") == 1).then(3).otherwise(6).alias("a_when")) expected = { @@ -35,7 +39,11 @@ def test_when_otherwise(constructor: Constructor) -> None: assert_equal_data(result, expected) -def test_multiple_conditions(constructor: Constructor) -> None: +def test_multiple_conditions( + constructor: Constructor, request: pytest.FixtureRequest +) -> None: + if "pyspark" in str(constructor): + request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) result = df.select( nw.when(nw.col("a") < 3, nw.col("c") < 5.0).then(3).alias("a_when") @@ -77,7 +85,11 @@ def test_value_series(constructor_eager: ConstructorEager) -> None: assert_equal_data(result, expected) -def test_value_expression(constructor: Constructor) -> None: +def test_value_expression( + constructor: Constructor, request: pytest.FixtureRequest +) -> None: + if "pyspark" in str(constructor): + request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) result = df.select(nw.when(nw.col("a") == 1).then(nw.col("a") + 9).alias("a_when")) expected = { @@ -110,7 +122,11 @@ def test_otherwise_series(constructor_eager: ConstructorEager) -> None: assert_equal_data(result, expected) -def test_otherwise_expression(constructor: Constructor) -> None: +def test_otherwise_expression( + constructor: Constructor, request: pytest.FixtureRequest +) -> None: + if "pyspark" in str(constructor): + request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) result = df.select( nw.when(nw.col("a") == 1).then(-1).otherwise(nw.col("a") + 7).alias("a_when") @@ -124,7 +140,7 @@ def test_otherwise_expression(constructor: Constructor) -> None: def test_when_then_otherwise_into_expr( constructor: Constructor, request: pytest.FixtureRequest ) -> None: - if "duckdb" in str(constructor): + if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) result = df.select(nw.when(nw.col("a") > 1).then("c").otherwise("e")) @@ -135,7 +151,7 @@ def test_when_then_otherwise_into_expr( def test_when_then_otherwise_lit_str( constructor: Constructor, request: pytest.FixtureRequest ) -> None: - if "duckdb" in str(constructor): + if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) result = df.select(nw.when(nw.col("a") > 1).then(nw.col("b")).otherwise(nw.lit("z"))) diff --git a/tests/frame/clone_test.py b/tests/frame/clone_test.py index e142ed0a7..316638c06 100644 --- a/tests/frame/clone_test.py +++ b/tests/frame/clone_test.py @@ -10,7 +10,7 @@ def test_clone(request: pytest.FixtureRequest, constructor: Constructor) -> None: if "dask" in str(constructor): request.applymarker(pytest.mark.xfail) - if "duckdb" in str(constructor): + if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) if "pyarrow_table" in str(constructor): request.applymarker(pytest.mark.xfail) diff --git a/tests/frame/concat_test.py b/tests/frame/concat_test.py index 4d5f3ebc9..6d8fdbda0 100644 --- a/tests/frame/concat_test.py +++ b/tests/frame/concat_test.py @@ -10,7 +10,7 @@ def test_concat_horizontal( constructor: Constructor, request: pytest.FixtureRequest ) -> None: - if "duckdb" in str(constructor): + if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} df_left = nw.from_native(constructor(data)).lazy() @@ -32,7 +32,12 @@ def test_concat_horizontal( nw.concat([]) -def test_concat_vertical(constructor: Constructor) -> None: +def test_concat_vertical( + request: pytest.FixtureRequest, constructor: Constructor +) -> None: + if "pyspark" in str(constructor): + request.applymarker(pytest.mark.xfail) + data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} df_left = ( nw.from_native(constructor(data)).lazy().rename({"a": "c", "b": "d"}).drop("z") @@ -63,7 +68,7 @@ def test_concat_vertical(constructor: Constructor) -> None: def test_concat_diagonal( constructor: Constructor, request: pytest.FixtureRequest ) -> None: - if "duckdb" in str(constructor): + if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) data_1 = {"a": [1, 3], "b": [4, 6]} data_2 = {"a": [100, 200], "z": ["x", "y"]} diff --git a/tests/frame/explode_test.py b/tests/frame/explode_test.py index b79215a18..f3b096194 100644 --- a/tests/frame/explode_test.py +++ b/tests/frame/explode_test.py @@ -40,7 +40,7 @@ def test_explode_single_col( ) -> None: if any( backend in str(constructor) - for backend in ("dask", "modin", "cudf", "pyarrow_table", "duckdb") + for backend in ("dask", "modin", "cudf", "pyarrow_table", "duckdb", "pyspark") ): request.applymarker(pytest.mark.xfail) @@ -89,7 +89,7 @@ def test_explode_multiple_cols( ) -> None: if any( backend in str(constructor) - for backend in ("dask", "modin", "cudf", "pyarrow_table", "duckdb") + for backend in ("dask", "modin", "cudf", "pyarrow_table", "duckdb", "pyspark") ): request.applymarker(pytest.mark.xfail) @@ -110,7 +110,7 @@ def test_explode_shape_error( ) -> None: if any( backend in str(constructor) - for backend in ("dask", "modin", "cudf", "pyarrow_table", "duckdb") + for backend in ("dask", "modin", "cudf", "pyarrow_table", "duckdb", "pyspark") ): request.applymarker(pytest.mark.xfail) @@ -133,7 +133,7 @@ def test_explode_shape_error( def test_explode_invalid_operation_error( request: pytest.FixtureRequest, constructor: Constructor ) -> None: - if any(x in str(constructor) for x in ("pyarrow_table", "dask", "duckdb")): + if any(x in str(constructor) for x in ("pyarrow_table", "dask", "duckdb", "pyspark")): request.applymarker(pytest.mark.xfail) if "polars" in str(constructor) and POLARS_VERSION < (0, 20, 6): diff --git a/tests/frame/gather_every_test.py b/tests/frame/gather_every_test.py index 40e9291de..c151f4503 100644 --- a/tests/frame/gather_every_test.py +++ b/tests/frame/gather_every_test.py @@ -14,7 +14,7 @@ def test_gather_every( constructor: Constructor, n: int, offset: int, request: pytest.FixtureRequest ) -> None: - if "duckdb" in str(constructor): + if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) result = df.gather_every(n=n, offset=offset) diff --git a/tests/frame/join_test.py b/tests/frame/join_test.py index f176aca67..f15a1b79e 100644 --- a/tests/frame/join_test.py +++ b/tests/frame/join_test.py @@ -21,7 +21,7 @@ def test_inner_join_two_keys(constructor: Constructor) -> None: "antananarivo": [1, 3, 2], "bob": [4, 4, 6], "zor ro": [7.0, 8, 9], - "index": [0, 1, 2], + "idx": [0, 1, 2], } df = nw.from_native(constructor(data)) df_right = df @@ -32,13 +32,13 @@ def test_inner_join_two_keys(constructor: Constructor) -> None: how="inner", ) result_on = df.join(df_right, on=["antananarivo", "bob"], how="inner") # type: ignore[arg-type] - result = result.sort("index").drop("index_right") - result_on = result_on.sort("index").drop("index_right") + result = result.sort("idx").drop("idx_right") + result_on = result_on.sort("idx").drop("idx_right") expected = { "antananarivo": [1, 3, 2], "bob": [4, 4, 6], "zor ro": [7.0, 8, 9], - "index": [0, 1, 2], + "idx": [0, 1, 2], "zor ro_right": [7.0, 8, 9], } assert_equal_data(result, expected) @@ -50,7 +50,7 @@ def test_inner_join_single_key(constructor: Constructor) -> None: "antananarivo": [1, 3, 2], "bob": [4, 4, 6], "zor ro": [7.0, 8, 9], - "index": [0, 1, 2], + "idx": [0, 1, 2], } df = nw.from_native(constructor(data)) df_right = df @@ -59,15 +59,15 @@ def test_inner_join_single_key(constructor: Constructor) -> None: left_on="antananarivo", right_on="antananarivo", how="inner", - ).sort("index") - result_on = df.join(df_right, on="antananarivo", how="inner").sort("index") # type: ignore[arg-type] - result = result.drop("index_right") - result_on = result_on.drop("index_right") + ).sort("idx") + result_on = df.join(df_right, on="antananarivo", how="inner").sort("idx") # type: ignore[arg-type] + result = result.drop("idx_right") + result_on = result_on.drop("idx_right") expected = { "antananarivo": [1, 3, 2], "bob": [4, 4, 6], "zor ro": [7.0, 8, 9], - "index": [0, 1, 2], + "idx": [0, 1, 2], "bob_right": [4, 4, 6], "zor ro_right": [7.0, 8, 9], } @@ -235,34 +235,34 @@ def test_left_join(constructor: Constructor) -> None: data_left = { "antananarivo": [1.0, 2, 3], "bob": [4.0, 5, 6], - "index": [0.0, 1.0, 2.0], + "idx": [0.0, 1.0, 2.0], } data_right = { "antananarivo": [1.0, 2, 3], "co": [4.0, 5, 7], - "index": [0.0, 1.0, 2.0], + "idx": [0.0, 1.0, 2.0], } df_left = nw.from_native(constructor(data_left)) df_right = nw.from_native(constructor(data_right)) result = df_left.join(df_right, left_on="bob", right_on="co", how="left") # type: ignore[arg-type] - result = result.sort("index") - result = result.drop("index_right") + result = result.sort("idx") + result = result.drop("idx_right") expected = { "antananarivo": [1, 2, 3], "bob": [4, 5, 6], - "index": [0, 1, 2], + "idx": [0, 1, 2], "antananarivo_right": [1, 2, None], } result_on_list = df_left.join( df_right, # type: ignore[arg-type] - on=["antananarivo", "index"], + on=["antananarivo", "idx"], how="left", ) - result_on_list = result_on_list.sort("index") + result_on_list = result_on_list.sort("idx") expected_on_list = { "antananarivo": [1, 2, 3], "bob": [4, 5, 6], - "index": [0, 1, 2], + "idx": [0, 1, 2], "co": [4, 5, 7], } assert_equal_data(result, expected) @@ -270,8 +270,8 @@ def test_left_join(constructor: Constructor) -> None: def test_left_join_multiple_column(constructor: Constructor) -> None: - data_left = {"antananarivo": [1, 2, 3], "bob": [4, 5, 6], "index": [0, 1, 2]} - data_right = {"antananarivo": [1, 2, 3], "c": [4, 5, 6], "index": [0, 1, 2]} + data_left = {"antananarivo": [1, 2, 3], "bob": [4, 5, 6], "idx": [0, 1, 2]} + data_right = {"antananarivo": [1, 2, 3], "c": [4, 5, 6], "idx": [0, 1, 2]} df_left = nw.from_native(constructor(data_left)) df_right = nw.from_native(constructor(data_right)) result = df_left.join( @@ -280,9 +280,9 @@ def test_left_join_multiple_column(constructor: Constructor) -> None: right_on=["antananarivo", "c"], how="left", ) - result = result.sort("index") - result = result.drop("index_right") - expected = {"antananarivo": [1, 2, 3], "bob": [4, 5, 6], "index": [0, 1, 2]} + result = result.sort("idx") + result = result.drop("idx_right") + expected = {"antananarivo": [1, 2, 3], "bob": [4, 5, 6], "idx": [0, 1, 2]} assert_equal_data(result, expected) @@ -291,23 +291,23 @@ def test_left_join_overlapping_column(constructor: Constructor) -> None: "antananarivo": [1.0, 2, 3], "bob": [4.0, 5, 6], "d": [1.0, 4, 2], - "index": [0.0, 1.0, 2.0], + "idx": [0.0, 1.0, 2.0], } data_right = { "antananarivo": [1.0, 2, 3], "c": [4.0, 5, 6], "d": [1.0, 4, 2], - "index": [0.0, 1.0, 2.0], + "idx": [0.0, 1.0, 2.0], } df_left = nw.from_native(constructor(data_left)) df_right = nw.from_native(constructor(data_right)) - result = df_left.join(df_right, left_on="bob", right_on="c", how="left").sort("index") # type: ignore[arg-type] - result = result.drop("index_right") + result = df_left.join(df_right, left_on="bob", right_on="c", how="left").sort("idx") # type: ignore[arg-type] + result = result.drop("idx_right") expected: dict[str, list[Any]] = { "antananarivo": [1, 2, 3], "bob": [4, 5, 6], "d": [1, 4, 2], - "index": [0, 1, 2], + "idx": [0, 1, 2], "antananarivo_right": [1, 2, 3], "d_right": [1, 4, 2], } @@ -318,13 +318,13 @@ def test_left_join_overlapping_column(constructor: Constructor) -> None: right_on="d", how="left", ) - result = result.sort("index") - result = result.drop("index_right") + result = result.sort("idx") + result = result.drop("idx_right") expected = { "antananarivo": [1, 2, 3], "bob": [4, 5, 6], "d": [1, 4, 2], - "index": [0, 1, 2], + "idx": [0, 1, 2], "antananarivo_right": [1.0, 3.0, None], "c": [4.0, 6.0, None], } @@ -362,7 +362,7 @@ def test_joinasof_numeric( constructor: Constructor, request: pytest.FixtureRequest, ) -> None: - if any(x in str(constructor) for x in ("pyarrow_table", "cudf", "duckdb")): + if any(x in str(constructor) for x in ("pyarrow_table", "cudf", "duckdb", "pyspark")): request.applymarker(pytest.mark.xfail) if PANDAS_VERSION < (2, 1) and ( ("pandas_pyarrow" in str(constructor)) or ("pandas_nullable" in str(constructor)) @@ -421,7 +421,7 @@ def test_joinasof_time( constructor: Constructor, request: pytest.FixtureRequest, ) -> None: - if any(x in str(constructor) for x in ("pyarrow_table", "cudf", "duckdb")): + if any(x in str(constructor) for x in ("pyarrow_table", "cudf", "duckdb", "pyspark")): request.applymarker(pytest.mark.xfail) if PANDAS_VERSION < (2, 1) and ("pandas_pyarrow" in str(constructor)): request.applymarker(pytest.mark.xfail) @@ -502,7 +502,7 @@ def test_joinasof_by( constructor: Constructor, request: pytest.FixtureRequest, ) -> None: - if any(x in str(constructor) for x in ("pyarrow_table", "cudf", "duckdb")): + if any(x in str(constructor) for x in ("pyarrow_table", "cudf", "duckdb", "pyspark")): request.applymarker(pytest.mark.xfail) if PANDAS_VERSION < (2, 1) and ( ("pandas_pyarrow" in str(constructor)) or ("pandas_nullable" in str(constructor)) diff --git a/tests/frame/select_test.py b/tests/frame/select_test.py index 9d601e468..946e58203 100644 --- a/tests/frame/select_test.py +++ b/tests/frame/select_test.py @@ -80,7 +80,7 @@ def test_comparison_with_list_error_message() -> None: def test_missing_columns( constructor: Constructor, request: pytest.FixtureRequest ) -> None: - if "duckdb" in str(constructor): + if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} df = nw.from_native(constructor(data)) @@ -126,7 +126,7 @@ def test_left_to_right_broadcasting( ) -> None: if "dask" in str(constructor) and DASK_VERSION < (2024, 10): request.applymarker(pytest.mark.xfail) - if "duckdb" in str(constructor): + if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor({"a": [1, 1, 2], "b": [4, 5, 6]})) result = df.select(nw.col("a") + nw.col("b").sum()) diff --git a/tests/frame/tail_test.py b/tests/frame/tail_test.py index a4d265797..75f46a4a1 100644 --- a/tests/frame/tail_test.py +++ b/tests/frame/tail_test.py @@ -9,7 +9,10 @@ from tests.utils import assert_equal_data -def test_tail(constructor: Constructor) -> None: +def test_tail(request: pytest.FixtureRequest, constructor: Constructor) -> None: + if "pyspark" in str(constructor): + request.applymarker(pytest.mark.xfail) + data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} expected = {"a": [3, 2], "b": [4, 6], "z": [8.0, 9]} diff --git a/tests/frame/unique_test.py b/tests/frame/unique_test.py index ca34d29b4..a193ab98b 100644 --- a/tests/frame/unique_test.py +++ b/tests/frame/unique_test.py @@ -39,6 +39,13 @@ def test_unique( "last", }: context: Any = pytest.raises(ValueError, match="row order") + elif ( + keep == "none" and df.implementation is nw.Implementation.PYSPARK + ): # pragma: no cover + context = pytest.raises( + ValueError, + match="`LazyFrame.unique` with PySpark backend only supports `keep='any'`.", + ) elif keep == "foo": context = pytest.raises(ValueError, match=": foo") else: diff --git a/tests/frame/unpivot_test.py b/tests/frame/unpivot_test.py index 2867720a7..72aa81f2d 100644 --- a/tests/frame/unpivot_test.py +++ b/tests/frame/unpivot_test.py @@ -37,8 +37,14 @@ [("b", expected_b_only), (["b", "c"], expected_b_c), (None, expected_b_c)], ) def test_unpivot_on( - constructor: Constructor, on: str | list[str] | None, expected: dict[str, list[float]] + request: pytest.FixtureRequest, + constructor: Constructor, + on: str | list[str] | None, + expected: dict[str, list[float]], ) -> None: + if "pyspark" in str(constructor): + request.applymarker(pytest.mark.xfail) + df = nw.from_native(constructor(data)) result = df.unpivot(on=on, index=["a"]).sort("variable", "a") assert_equal_data(result, expected) @@ -53,10 +59,14 @@ def test_unpivot_on( ], ) def test_unpivot_var_value_names( + request: pytest.FixtureRequest, constructor: Constructor, variable_name: str | None, value_name: str | None, ) -> None: + if "pyspark" in str(constructor): + request.applymarker(pytest.mark.xfail) + df = nw.from_native(constructor(data)) result = df.unpivot( on=["b", "c"], index=["a"], variable_name=variable_name, value_name=value_name @@ -65,7 +75,12 @@ def test_unpivot_var_value_names( assert result.collect_schema().names()[-2:] == [variable_name, value_name] -def test_unpivot_default_var_value_names(constructor: Constructor) -> None: +def test_unpivot_default_var_value_names( + request: pytest.FixtureRequest, constructor: Constructor +) -> None: + if "pyspark" in str(constructor): + request.applymarker(pytest.mark.xfail) + df = nw.from_native(constructor(data)) result = df.unpivot(on=["b", "c"], index=["a"]) @@ -87,10 +102,13 @@ def test_unpivot_mixed_types( data: dict[str, Any], expected_dtypes: list[DType], ) -> None: - if "cudf" in str(constructor) or ( - "pyarrow_table" in str(constructor) and PYARROW_VERSION < (14, 0, 0) + if ( + "cudf" in str(constructor) + or "pyspark" in str(constructor) + or ("pyarrow_table" in str(constructor) and PYARROW_VERSION < (14, 0, 0)) ): request.applymarker(pytest.mark.xfail) + df = nw.from_native(constructor(data)) result = df.unpivot(on=["a", "b"], index="idx") diff --git a/tests/frame/with_columns_test.py b/tests/frame/with_columns_test.py index 335c53896..6fa3ab825 100644 --- a/tests/frame/with_columns_test.py +++ b/tests/frame/with_columns_test.py @@ -52,7 +52,7 @@ def test_with_columns_dtypes_single_row( ) -> None: if "pyarrow_table" in str(constructor) and PYARROW_VERSION < (15,): request.applymarker(pytest.mark.xfail) - if "duckdb" in str(constructor): + if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) data = {"a": ["foo"]} df = nw.from_native(constructor(data)).with_columns(nw.col("a").cast(nw.Categorical)) diff --git a/tests/frame/with_row_index_test.py b/tests/frame/with_row_index_test.py index bc514fa70..96f2b1547 100644 --- a/tests/frame/with_row_index_test.py +++ b/tests/frame/with_row_index_test.py @@ -13,7 +13,7 @@ def test_with_row_index(constructor: Constructor, request: pytest.FixtureRequest) -> None: - if "duckdb" in str(constructor): + if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) result = nw.from_native(constructor(data)).with_row_index() expected = {"index": [0, 1], "a": ["foo", "bars"], "ab": ["foo", "bars"]} diff --git a/tests/from_dict_test.py b/tests/from_dict_test.py index 86fe07eda..0630cac43 100644 --- a/tests/from_dict_test.py +++ b/tests/from_dict_test.py @@ -12,7 +12,7 @@ def test_from_dict(constructor: Constructor, request: pytest.FixtureRequest) -> None: - if "dask" in str(constructor): + if "dask" in str(constructor) or "pyspark" in str(constructor): request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor({"a": [1, 2, 3], "b": [4, 5, 6]})) native_namespace = nw.get_native_namespace(df) @@ -25,7 +25,7 @@ def test_from_dict(constructor: Constructor, request: pytest.FixtureRequest) -> def test_from_dict_schema( constructor: Constructor, request: pytest.FixtureRequest ) -> None: - if "dask" in str(constructor): + if "dask" in str(constructor) or "pyspark" in str(constructor): request.applymarker(pytest.mark.xfail) schema = {"c": nw_v1.Int16(), "d": nw_v1.Float32()} df = nw_v1.from_native(constructor({"a": [1, 2, 3], "b": [4, 5, 6]})) @@ -62,7 +62,7 @@ def test_from_dict_one_native_one_narwhals( def test_from_dict_v1(constructor: Constructor, request: pytest.FixtureRequest) -> None: - if "dask" in str(constructor): + if "dask" in str(constructor) or "pyspark" in str(constructor): request.applymarker(pytest.mark.xfail) df = nw_v1.from_native(constructor({"a": [1, 2, 3], "b": [4, 5, 6]})) native_namespace = nw_v1.get_native_namespace(df) diff --git a/tests/from_numpy_test.py b/tests/from_numpy_test.py index b736d5cbd..7a40136e7 100644 --- a/tests/from_numpy_test.py +++ b/tests/from_numpy_test.py @@ -19,7 +19,7 @@ def test_from_numpy(constructor: Constructor, request: pytest.FixtureRequest) -> None: - if "dask" in str(constructor): + if "dask" in str(constructor) or "pyspark" in str(constructor): request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) native_namespace = nw.get_native_namespace(df) @@ -31,7 +31,7 @@ def test_from_numpy(constructor: Constructor, request: pytest.FixtureRequest) -> def test_from_numpy_schema_dict( constructor: Constructor, request: pytest.FixtureRequest ) -> None: - if "dask" in str(constructor): + if "dask" in str(constructor) or "pyspark" in str(constructor): request.applymarker(pytest.mark.xfail) schema = { "c": nw_v1.Int16(), @@ -52,7 +52,7 @@ def test_from_numpy_schema_dict( def test_from_numpy_schema_list( constructor: Constructor, request: pytest.FixtureRequest ) -> None: - if "dask" in str(constructor): + if "dask" in str(constructor) or "pyspark" in str(constructor): request.applymarker(pytest.mark.xfail) schema = ["c", "d", "e", "f"] df = nw_v1.from_native(constructor(data)) @@ -68,7 +68,7 @@ def test_from_numpy_schema_list( def test_from_numpy_schema_notvalid( constructor: Constructor, request: pytest.FixtureRequest ) -> None: - if "dask" in str(constructor): + if "dask" in str(constructor) or "pyspark" in str(constructor): request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) native_namespace = nw_v1.get_native_namespace(df) @@ -79,7 +79,7 @@ def test_from_numpy_schema_notvalid( def test_from_numpy_v1(constructor: Constructor, request: pytest.FixtureRequest) -> None: - if "dask" in str(constructor): + if "dask" in str(constructor) or "pyspark" in str(constructor): request.applymarker(pytest.mark.xfail) df = nw_v1.from_native(constructor(data)) native_namespace = nw_v1.get_native_namespace(df) diff --git a/tests/group_by_test.py b/tests/group_by_test.py index c854da453..64b3844d0 100644 --- a/tests/group_by_test.py +++ b/tests/group_by_test.py @@ -115,6 +115,8 @@ def test_group_by_depth_1_agg( expected: dict[str, list[int | float]], request: pytest.FixtureRequest, ) -> None: + if "pyspark" in str(constructor) and attr == "n_unique": + request.applymarker(pytest.mark.xfail) if "pandas_pyarrow" in str(constructor) and attr == "var" and PANDAS_VERSION < (2, 1): # Known issue with variance calculation in pandas 2.0.x with pyarrow backend in groupby operations" request.applymarker(pytest.mark.xfail) @@ -164,7 +166,11 @@ def test_group_by_median(constructor: Constructor) -> None: assert_equal_data(result, expected) -def test_group_by_n_unique_w_missing(constructor: Constructor) -> None: +def test_group_by_n_unique_w_missing( + constructor: Constructor, request: pytest.FixtureRequest +) -> None: + if "pyspark" in str(constructor): + request.applymarker(pytest.mark.xfail) data = {"a": [1, 1, 2], "b": [4, None, 5], "c": [None, None, 7], "d": [1, 1, 3]} result = ( nw.from_native(constructor(data)) @@ -269,6 +275,10 @@ def test_key_with_nulls( if "modin" in str(constructor): # TODO(unassigned): Modin flaky here? request.applymarker(pytest.mark.skip) + + if "pyspark" in str(constructor): + request.applymarker(pytest.mark.xfail) + context = ( pytest.raises(NotImplementedError, match="null values") if ("pandas_constructor" in str(constructor) and PANDAS_VERSION < (1, 1, 0)) @@ -290,7 +300,7 @@ def test_key_with_nulls( def test_key_with_nulls_ignored( constructor: Constructor, request: pytest.FixtureRequest ) -> None: - if "duckdb" in str(constructor): + if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) data = {"b": [4, 5, None], "a": [1, 2, 3]} result = ( @@ -332,7 +342,9 @@ def test_key_with_nulls_iter( assert len(result) == 4 -def test_no_agg(constructor: Constructor) -> None: +def test_no_agg(request: pytest.FixtureRequest, constructor: Constructor) -> None: + if "pyspark" in str(constructor): + request.applymarker(pytest.mark.xfail) result = nw.from_native(constructor(data)).group_by(["a", "b"]).agg().sort("a", "b") expected = {"a": [1, 3], "b": [4, 6]} @@ -343,7 +355,7 @@ def test_group_by_categorical( constructor: Constructor, request: pytest.FixtureRequest, ) -> None: - if "duckdb" in str(constructor): + if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) if "pyarrow_table" in str(constructor) and PYARROW_VERSION < ( 15, @@ -370,7 +382,7 @@ def test_group_by_categorical( def test_group_by_shift_raises( constructor: Constructor, request: pytest.FixtureRequest ) -> None: - if "duckdb" in str(constructor): + if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) if "polars" in str(constructor): # Polars supports all kinds of crazy group-by aggregations, so @@ -412,7 +424,7 @@ def test_all_kind_of_aggs( # and modin lol https://github.com/modin-project/modin/issues/7414 # and cudf https://github.com/rapidsai/cudf/issues/17649 request.applymarker(pytest.mark.xfail) - if "duckdb" in str(constructor): + if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) if "pandas" in str(constructor) and PANDAS_VERSION < (1, 4): # Bug in old pandas, can't do DataFrameGroupBy[['b', 'b']] diff --git a/tests/read_scan_test.py b/tests/read_scan_test.py index dbb2cf624..55869b46b 100644 --- a/tests/read_scan_test.py +++ b/tests/read_scan_test.py @@ -52,8 +52,11 @@ def test_read_csv_kwargs(tmpdir: pytest.TempdirFactory) -> None: def test_scan_csv( tmpdir: pytest.TempdirFactory, + request: pytest.FixtureRequest, constructor: Constructor, ) -> None: + if "pyspark" in str(constructor): + request.applymarker(pytest.mark.xfail) df_pl = pl.DataFrame(data) filepath = str(tmpdir / "file.csv") # type: ignore[operator] df_pl.write_csv(filepath) @@ -66,8 +69,11 @@ def test_scan_csv( def test_scan_csv_v1( tmpdir: pytest.TempdirFactory, + request: pytest.FixtureRequest, constructor: Constructor, ) -> None: + if "pyspark" in str(constructor): + request.applymarker(pytest.mark.xfail) df_pl = pl.DataFrame(data) filepath = str(tmpdir / "file.csv") # type: ignore[operator] df_pl.write_csv(filepath) @@ -128,8 +134,11 @@ def test_read_parquet_kwargs(tmpdir: pytest.TempdirFactory) -> None: @pytest.mark.skipif(PANDAS_VERSION < (1, 5), reason="too old for pyarrow") def test_scan_parquet( tmpdir: pytest.TempdirFactory, + request: pytest.FixtureRequest, constructor: Constructor, ) -> None: + if "pyspark" in str(constructor): + request.applymarker(pytest.mark.xfail) df_pl = pl.DataFrame(data) filepath = str(tmpdir / "file.parquet") # type: ignore[operator] df_pl.write_parquet(filepath) @@ -143,8 +152,11 @@ def test_scan_parquet( @pytest.mark.skipif(PANDAS_VERSION < (1, 5), reason="too old for pyarrow") def test_scan_parquet_v1( tmpdir: pytest.TempdirFactory, + request: pytest.FixtureRequest, constructor: Constructor, ) -> None: + if "pyspark" in str(constructor): + request.applymarker(pytest.mark.xfail) df_pl = pl.DataFrame(data) filepath = str(tmpdir / "file.parquet") # type: ignore[operator] df_pl.write_parquet(filepath) diff --git a/tests/selectors_test.py b/tests/selectors_test.py index 103ea666d..80aa64803 100644 --- a/tests/selectors_test.py +++ b/tests/selectors_test.py @@ -24,7 +24,7 @@ def test_selectors(constructor: Constructor, request: pytest.FixtureRequest) -> None: - if "duckdb" in str(constructor): + if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) result = df.select(by_dtype([nw.Int64, nw.Float64]) + 1) @@ -33,7 +33,7 @@ def test_selectors(constructor: Constructor, request: pytest.FixtureRequest) -> def test_numeric(constructor: Constructor, request: pytest.FixtureRequest) -> None: - if "duckdb" in str(constructor): + if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) result = df.select(numeric() + 1) @@ -42,7 +42,7 @@ def test_numeric(constructor: Constructor, request: pytest.FixtureRequest) -> No def test_boolean(constructor: Constructor, request: pytest.FixtureRequest) -> None: - if "duckdb" in str(constructor): + if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) result = df.select(boolean()) @@ -51,7 +51,7 @@ def test_boolean(constructor: Constructor, request: pytest.FixtureRequest) -> No def test_string(constructor: Constructor, request: pytest.FixtureRequest) -> None: - if "duckdb" in str(constructor): + if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) result = df.select(string()) @@ -67,7 +67,7 @@ def test_categorical( 15, ): # pragma: no cover request.applymarker(pytest.mark.xfail) - if "duckdb" in str(constructor): + if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) expected = {"b": ["a", "b", "c"]} @@ -96,7 +96,7 @@ def test_set_ops( expected: list[str], request: pytest.FixtureRequest, ) -> None: - if "duckdb" in str(constructor): + if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) result = df.select(selector).collect_schema().names() diff --git a/tests/spark_like_test.py b/tests/spark_like_test.py deleted file mode 100644 index f7cd9e6a9..000000000 --- a/tests/spark_like_test.py +++ /dev/null @@ -1,1211 +0,0 @@ -"""PySpark support in Narwhals is still _very_ limited. - -Start with a simple test file whilst we develop the basics. -Once we're a bit further along, we can integrate PySpark tests into the main test suite. -""" - -from __future__ import annotations - -from contextlib import nullcontext as does_not_raise -from typing import TYPE_CHECKING -from typing import Any -from typing import Literal - -import pandas as pd -import pytest - -import narwhals.stable.v1 as nw -from narwhals.exceptions import ColumnNotFoundError -from tests.utils import assert_equal_data - -if TYPE_CHECKING: - from pyspark.sql import SparkSession - - from narwhals.dtypes import DType - from narwhals.typing import IntoFrame - from tests.utils import Constructor - - -# Apply filterwarnings to all tests in this module -pytestmark = [ - pytest.mark.filterwarnings( - "ignore:.*is_datetime64tz_dtype is deprecated and will be removed in a future version.*:DeprecationWarning" - ), - pytest.mark.filterwarnings( - "ignore:.*distutils Version classes are deprecated. Use packaging.version instead.*:DeprecationWarning" - ), - pytest.mark.filterwarnings("ignore: unclosed IntoFrame: - # NaN and NULL are not the same in PySpark - pd_df = pd.DataFrame(obj).replace({float("nan"): None}).reset_index() - return ( # type: ignore[no-any-return] - spark_session.createDataFrame(pd_df).repartition(2).orderBy("index").drop("index") - ) - - -@pytest.fixture(params=[_pyspark_constructor_with_session]) -def pyspark_constructor( - request: pytest.FixtureRequest, spark_session: SparkSession -) -> Constructor: - def _constructor(obj: Any) -> IntoFrame: - return request.param(obj, spark_session) # type: ignore[no-any-return] - - return _constructor - - -# copied from tests/translate/from_native_test.py -def test_series_only(pyspark_constructor: Constructor) -> None: - obj = pyspark_constructor({"a": [1, 2, 3]}) - with pytest.raises(TypeError, match="Cannot only use `series_only`"): - _ = nw.from_native(obj, series_only=True) - - -def test_eager_only_lazy(pyspark_constructor: Constructor) -> None: - dframe = pyspark_constructor({"a": [1, 2, 3]}) - with pytest.raises(TypeError, match="Cannot only use `eager_only`"): - _ = nw.from_native(dframe, eager_only=True) - - -# copied from tests/frame/with_columns_test.py -def test_columns(pyspark_constructor: Constructor) -> None: - data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} - df = nw.from_native(pyspark_constructor(data)) - result = df.columns - expected = ["a", "b", "z"] - assert result == expected - - -# copied from tests/frame/with_columns_test.py -def test_with_columns_order(pyspark_constructor: Constructor) -> None: - data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} - df = nw.from_native(pyspark_constructor(data)) - result = df.with_columns(nw.col("a") + 1, d=nw.col("a") - 1) - assert result.collect_schema().names() == ["a", "b", "z", "d"] - expected = {"a": [2, 4, 3], "b": [4, 4, 6], "z": [7.0, 8, 9], "d": [0, 2, 1]} - assert_equal_data(result, expected) - - -@pytest.mark.filterwarnings("ignore:If `index_col` is not specified for `to_spark`") -def test_with_columns_empty(pyspark_constructor: Constructor) -> None: - data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} - df = nw.from_native(pyspark_constructor(data)) - result = df.select().with_columns() - assert_equal_data(result, {}) - - -def test_with_columns_order_single_row(pyspark_constructor: Constructor) -> None: - data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9], "i": [0, 1, 2]} - df = nw.from_native(pyspark_constructor(data)).filter(nw.col("i") < 1).drop("i") - result = df.with_columns(nw.col("a") + 1, d=nw.col("a") - 1) - assert result.collect_schema().names() == ["a", "b", "z", "d"] - expected = {"a": [2], "b": [4], "z": [7.0], "d": [0]} - assert_equal_data(result, expected) - - -# copied from tests/frame/select_test.py -def test_select(pyspark_constructor: Constructor) -> None: - data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} - df = nw.from_native(pyspark_constructor(data)) - result = df.select("a") - expected = {"a": [1, 3, 2]} - assert_equal_data(result, expected) - - -@pytest.mark.filterwarnings("ignore:If `index_col` is not specified for `to_spark`") -def test_empty_select(pyspark_constructor: Constructor) -> None: - result = nw.from_native(pyspark_constructor({"a": [1, 2, 3]})).lazy().select() - assert result.collect().shape == (0, 0) - - -# copied from tests/frame/filter_test.py -def test_filter(pyspark_constructor: Constructor) -> None: - data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} - df = nw.from_native(pyspark_constructor(data)) - result = df.filter(nw.col("a") > 1) - expected = {"a": [3, 2], "b": [4, 6], "z": [8.0, 9.0]} - assert_equal_data(result, expected) - - -# copied from tests/frame/schema_test.py -@pytest.mark.filterwarnings("ignore:Determining|Resolving.*") -def test_schema(pyspark_constructor: Constructor) -> None: - df = nw.from_native( - pyspark_constructor({"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.1, 8, 9]}) - ) - result = df.schema - expected = {"a": nw.Int64, "b": nw.Int64, "z": nw.Float64} - - result = df.schema - assert result == expected - result = df.lazy().collect().schema - assert result == expected - - -def test_collect_schema(pyspark_constructor: Constructor) -> None: - df = nw.from_native( - pyspark_constructor({"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.1, 8, 9]}) - ) - expected = {"a": nw.Int64, "b": nw.Int64, "z": nw.Float64} - - result = df.collect_schema() - assert result == expected - result = df.lazy().collect().collect_schema() - assert result == expected - - -# copied from tests/frame/drop_test.py -@pytest.mark.parametrize( - ("to_drop", "expected"), - [ - ("abc", ["b", "z"]), - (["abc"], ["b", "z"]), - (["abc", "b"], ["z"]), - ], -) -def test_drop( - pyspark_constructor: Constructor, to_drop: list[str], expected: list[str] -) -> None: - data = {"abc": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} - df = nw.from_native(pyspark_constructor(data)) - assert df.drop(to_drop).collect_schema().names() == expected - if not isinstance(to_drop, str): - assert df.drop(*to_drop).collect_schema().names() == expected - - -@pytest.mark.parametrize( - ("strict", "context"), - [ - (True, pytest.raises(ColumnNotFoundError, match="z")), - (False, does_not_raise()), - ], -) -def test_drop_strict( - pyspark_constructor: Constructor, context: Any, *, strict: bool -) -> None: - data = {"a": [1, 3, 2], "b": [4, 4, 6]} - to_drop = ["a", "z"] - - df = nw.from_native(pyspark_constructor(data)) - - with context: - names_out = df.drop(to_drop, strict=strict).collect_schema().names() - assert names_out == ["b"] - - -# copied from tests/frame/head_test.py -def test_head(pyspark_constructor: Constructor) -> None: - data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} - expected = {"a": [1, 3], "b": [4, 4], "z": [7.0, 8.0]} - - df_raw = pyspark_constructor(data) - df = nw.from_native(df_raw) - - result = df.head(2) - assert_equal_data(result, expected) - - result = df.head(2) - assert_equal_data(result, expected) - - # negative indices not allowed for lazyframes - result = df.lazy().collect().head(-1) - assert_equal_data(result, expected) - - -# copied from tests/frame/sort_test.py -def test_sort(pyspark_constructor: Constructor) -> None: - data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} - df = nw.from_native(pyspark_constructor(data)) - result = df.sort("a", "b") - expected = { - "a": [1, 2, 3], - "b": [4, 6, 4], - "z": [7.0, 9.0, 8.0], - } - assert_equal_data(result, expected) - result = df.sort("a", "b", descending=[True, False]).lazy().collect() - expected = { - "a": [3, 2, 1], - "b": [4, 6, 4], - "z": [8.0, 9.0, 7.0], - } - assert_equal_data(result, expected) - - -@pytest.mark.parametrize( - ("nulls_last", "expected"), - [ - (True, {"a": [0, 2, 0, -1], "b": [3, 2, 1, None]}), - (False, {"a": [-1, 0, 2, 0], "b": [None, 3, 2, 1]}), - ], -) -def test_sort_nulls( - pyspark_constructor: Constructor, *, nulls_last: bool, expected: dict[str, float] -) -> None: - data = {"a": [0, 0, 2, -1], "b": [1, 3, 2, None]} - df = nw.from_native(pyspark_constructor(data)) - result = df.sort("b", descending=True, nulls_last=nulls_last).lazy().collect() - assert_equal_data(result, expected) - - -# copied from tests/frame/add_test.py -def test_add(pyspark_constructor: Constructor) -> None: - data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} - df = nw.from_native(pyspark_constructor(data)) - result = df.with_columns( - c=nw.col("a") + nw.col("b"), - d=nw.col("a") - nw.col("a").mean(), - e=nw.col("a") - nw.col("a").std(), - ) - expected = { - "a": [1, 3, 2], - "b": [4, 4, 6], - "z": [7.0, 8.0, 9.0], - "c": [5, 7, 8], - "d": [-1.0, 1.0, 0.0], - "e": [0.0, 2.0, 1.0], - } - assert_equal_data(result, expected) - - -def test_abs(pyspark_constructor: Constructor) -> None: - data = {"a": [1, 2, 3, -4, 5]} - df = nw.from_native(pyspark_constructor(data)) - result = df.select(nw.col("a").abs()) - expected = {"a": [1, 2, 3, 4, 5]} - assert_equal_data(result, expected) - - -# copied from tests/expr_and_series/all_horizontal_test.py -@pytest.mark.parametrize("expr1", ["a", nw.col("a")]) -@pytest.mark.parametrize("expr2", ["b", nw.col("b")]) -def test_allh(pyspark_constructor: Constructor, expr1: Any, expr2: Any) -> None: - data = { - "a": [False, False, True], - "b": [False, True, True], - } - df = nw.from_native(pyspark_constructor(data)) - result = df.select(all=nw.all_horizontal(expr1, expr2)) - - expected = {"all": [False, False, True]} - assert_equal_data(result, expected) - - -def test_allh_all(pyspark_constructor: Constructor) -> None: - data = { - "a": [False, False, True], - "b": [False, True, True], - } - df = nw.from_native(pyspark_constructor(data)) - result = df.select(all=nw.all_horizontal(nw.all())) - expected = {"all": [False, False, True]} - assert_equal_data(result, expected) - result = df.select(nw.all_horizontal(nw.all())) - expected = {"a": [False, False, True]} - assert_equal_data(result, expected) - - -# copied from tests/expr_and_series/sum_horizontal_test.py -@pytest.mark.parametrize("col_expr", [nw.col("a"), "a"]) -def test_sumh(pyspark_constructor: Constructor, col_expr: Any) -> None: - data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} - df = nw.from_native(pyspark_constructor(data)) - result = df.with_columns(horizontal_sum=nw.sum_horizontal(col_expr, nw.col("b"))) - expected = { - "a": [1, 3, 2], - "b": [4, 4, 6], - "z": [7.0, 8.0, 9.0], - "horizontal_sum": [5, 7, 8], - } - assert_equal_data(result, expected) - - -def test_sumh_nullable(pyspark_constructor: Constructor) -> None: - data = {"a": [1, 8, 3], "b": [4, 5, None], "idx": [0, 1, 2]} - expected = {"hsum": [5, 13, 3]} - - df = nw.from_native(pyspark_constructor(data)) - result = df.select("idx", hsum=nw.sum_horizontal("a", "b")).sort("idx").drop("idx") - assert_equal_data(result, expected) - - -def test_sumh_all(pyspark_constructor: Constructor) -> None: - data = {"a": [1, 2, 3], "b": [10, 20, 30]} - df = nw.from_native(pyspark_constructor(data)) - result = df.select(nw.sum_horizontal(nw.all())) - expected = { - "a": [11, 22, 33], - } - assert_equal_data(result, expected) - result = df.select(c=nw.sum_horizontal(nw.all())) - expected = { - "c": [11, 22, 33], - } - assert_equal_data(result, expected) - - -# copied from tests/expr_and_series/count_test.py -def test_count(pyspark_constructor: Constructor) -> None: - data = {"a": [1, 2, 3], "b": [4, None, 6], "z": [7.0, None, None]} - df = nw.from_native(pyspark_constructor(data)) - result = df.select(nw.col("a", "b", "z").count()) - expected = {"a": [3], "b": [2], "z": [1]} - assert_equal_data(result, expected) - - -# copied from tests/expr_and_series/double_test.py -def test_double(pyspark_constructor: Constructor) -> None: - data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} - df = nw.from_native(pyspark_constructor(data)) - result = df.with_columns(nw.all() * 2) - expected = {"a": [2, 6, 4], "b": [8, 8, 12], "z": [14.0, 16.0, 18.0]} - assert_equal_data(result, expected) - - -def test_double_alias(pyspark_constructor: Constructor) -> None: - data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} - df = nw.from_native(pyspark_constructor(data)) - result = df.with_columns(nw.col("a").alias("o"), nw.all() * 2) - expected = { - "a": [2, 6, 4], - "b": [8, 8, 12], - "z": [14.0, 16.0, 18.0], - "o": [1, 3, 2], - } - assert_equal_data(result, expected) - - -# copied from tests/expr_and_series/max_test.py -def test_expr_max_expr(pyspark_constructor: Constructor) -> None: - data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} - - df = nw.from_native(pyspark_constructor(data)) - result = df.select(nw.col("a", "b", "z").max()) - expected = {"a": [3], "b": [6], "z": [9.0]} - assert_equal_data(result, expected) - - -# copied from tests/expr_and_series/min_test.py -def test_expr_min_expr(pyspark_constructor: Constructor) -> None: - data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} - df = nw.from_native(pyspark_constructor(data)) - result = df.select(nw.col("a", "b", "z").min()) - expected = {"a": [1], "b": [4], "z": [7.0]} - assert_equal_data(result, expected) - - -# copied from tests/expr_and_series/min_test.py -@pytest.mark.parametrize("expr", [nw.col("a", "b", "z").sum(), nw.sum("a", "b", "z")]) -def test_expr_sum_expr(pyspark_constructor: Constructor, expr: nw.Expr) -> None: - data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} - df = nw.from_native(pyspark_constructor(data)) - result = df.select(expr) - expected = {"a": [6], "b": [14], "z": [24.0]} - assert_equal_data(result, expected) - - -# copied from tests/expr_and_series/std_test.py -def test_std(pyspark_constructor: Constructor) -> None: - data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} - - df = nw.from_native(pyspark_constructor(data)) - result = df.select( - nw.col("a").std().alias("a_ddof_default"), - nw.col("a").std(ddof=1).alias("a_ddof_1"), - nw.col("a").std(ddof=0).alias("a_ddof_0"), - nw.col("b").std(ddof=2).alias("b_ddof_2"), - nw.col("z").std(ddof=0).alias("z_ddof_0"), - ) - expected = { - "a_ddof_default": [1.0], - "a_ddof_1": [1.0], - "a_ddof_0": [0.816497], - "b_ddof_2": [1.632993], - "z_ddof_0": [0.816497], - } - assert_equal_data(result, expected) - - -# copied from tests/expr_and_series/var_test.py -def test_var(pyspark_constructor: Constructor) -> None: - data = {"a": [1, 3, 2, None], "b": [4, 4, 6, None], "z": [7.0, 8, 9, None]} - - expected_results = { - "a_ddof_1": [1.0], - "a_ddof_0": [0.6666666666666666], - "b_ddof_2": [2.666666666666667], - "z_ddof_0": [0.6666666666666666], - } - - df = nw.from_native(pyspark_constructor(data)) - result = df.select( - nw.col("a").var(ddof=1).alias("a_ddof_1"), - nw.col("a").var(ddof=0).alias("a_ddof_0"), - nw.col("b").var(ddof=2).alias("b_ddof_2"), - nw.col("z").var(ddof=0).alias("z_ddof_0"), - ) - assert_equal_data(result, expected_results) - - -# copied from tests/group_by_test.py -def test_group_by_std(pyspark_constructor: Constructor) -> None: - data = {"a": [1, 1, 2, 2], "b": [5, 4, 3, 2]} - result = ( - nw.from_native(pyspark_constructor(data)) - .group_by("a") - .agg(nw.col("b").std()) - .sort("a") - ) - expected = {"a": [1, 2], "b": [0.707107] * 2} - assert_equal_data(result, expected) - - -def test_group_by_simple_named(pyspark_constructor: Constructor) -> None: - data = {"a": [1, 1, 2], "b": [4, 5, 6], "c": [7, 2, 1]} - df = nw.from_native(pyspark_constructor(data)).lazy() - result = ( - df.group_by("a") - .agg( - b_min=nw.col("b").min(), - b_max=nw.col("b").max(), - ) - .collect() - .sort("a") - ) - expected = { - "a": [1, 2], - "b_min": [4, 6], - "b_max": [5, 6], - } - assert_equal_data(result, expected) - - -def test_group_by_simple_unnamed(pyspark_constructor: Constructor) -> None: - data = {"a": [1, 1, 2], "b": [4, 5, 6], "c": [7, 2, 1]} - df = nw.from_native(pyspark_constructor(data)).lazy() - result = ( - df.group_by("a") - .agg( - nw.col("b").min(), - nw.col("c").max(), - ) - .collect() - .sort("a") - ) - expected = { - "a": [1, 2], - "b": [4, 6], - "c": [7, 1], - } - assert_equal_data(result, expected) - - -def test_group_by_multiple_keys(pyspark_constructor: Constructor) -> None: - data = {"a": [1, 1, 2], "b": [4, 4, 6], "c": [7, 2, 1]} - df = nw.from_native(pyspark_constructor(data)).lazy() - result = ( - df.group_by("a", "b") - .agg( - c_min=nw.col("c").min(), - c_max=nw.col("c").max(), - ) - .collect() - .sort("a") - ) - expected = { - "a": [1, 2], - "b": [4, 6], - "c_min": [2, 1], - "c_max": [7, 1], - } - assert_equal_data(result, expected) - - -# copied from tests/group_by_test.py -@pytest.mark.parametrize( - ("attr", "ddof"), - [ - ("std", 0), - ("var", 0), - ("std", 2), - ("var", 2), - ], -) -def test_group_by_depth_1_std_var( - pyspark_constructor: Constructor, - attr: str, - ddof: int, -) -> None: - data = {"a": [1, 1, 1, 2, 2, 2], "b": [4, 5, 6, 0, 5, 5]} - _pow = 0.5 if attr == "std" else 1 - expected = { - "a": [1, 2], - "b": [ - (sum((v - 5) ** 2 for v in [4, 5, 6]) / (3 - ddof)) ** _pow, - (sum((v - 10 / 3) ** 2 for v in [0, 5, 5]) / (3 - ddof)) ** _pow, - ], - } - expr = getattr(nw.col("b"), attr)(ddof=ddof) - result = nw.from_native(pyspark_constructor(data)).group_by("a").agg(expr).sort("a") - assert_equal_data(result, expected) - - -# copied from tests/frame/drop_nulls_test.py -def test_drop_nulls(pyspark_constructor: Constructor) -> None: - data = { - "a": [1.0, 2.0, None, 4.0], - "b": [None, 3.0, None, 5.0], - } - - result = nw.from_native(pyspark_constructor(data)).drop_nulls() - expected = { - "a": [2.0, 4.0], - "b": [3.0, 5.0], - } - assert_equal_data(result, expected) - - -@pytest.mark.parametrize( - ("subset", "expected"), - [ - ("a", {"a": [1, 2.0, 4.0], "b": [None, 3.0, 5.0]}), - (["a"], {"a": [1, 2.0, 4.0], "b": [None, 3.0, 5.0]}), - (["a", "b"], {"a": [2.0, 4.0], "b": [3.0, 5.0]}), - ], -) -def test_drop_nulls_subset( - pyspark_constructor: Constructor, - subset: str | list[str], - expected: dict[str, float], -) -> None: - data = { - "a": [1.0, 2.0, None, 4.0], - "b": [None, 3.0, None, 5.0], - } - - result = nw.from_native(pyspark_constructor(data)).drop_nulls(subset=subset) - assert_equal_data(result, expected) - - -# copied from tests/frame/rename_test.py -def test_rename(pyspark_constructor: Constructor) -> None: - data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} - df = nw.from_native(pyspark_constructor(data)) - result = df.rename({"a": "x", "b": "y"}) - expected = {"x": [1, 3, 2], "y": [4, 4, 6], "z": [7.0, 8, 9]} - assert_equal_data(result, expected) - - -# adapted from tests/frame/unique_test.py -@pytest.mark.parametrize("subset", ["b", ["b"]]) -@pytest.mark.parametrize( - ("keep", "expected"), - [ - ("first", {"a": [1, 2], "b": [4, 6], "z": [7.0, 9.0]}), - ("last", {"a": [3, 2], "b": [4, 6], "z": [8.0, 9.0]}), - ("any", {"a": [1, 2], "b": [4, 6], "z": [7.0, 9.0]}), - ("none", {"a": [2], "b": [6], "z": [9]}), - ], -) -def test_unique( - pyspark_constructor: Constructor, - subset: str | list[str] | None, - keep: str, - expected: dict[str, list[float]], -) -> None: - if keep == "any": - context: Any = does_not_raise() - elif keep == "none": - context = pytest.raises( - ValueError, - match=r"`LazyFrame.unique` with PySpark backend only supports `keep='any'`.", - ) - else: - context = pytest.raises(ValueError, match=f": {keep}") - - with context: - data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} - df = nw.from_native(pyspark_constructor(data)) - - result = df.unique(subset, keep=keep).sort("z") # type: ignore[arg-type] - assert_equal_data(result, expected) - - -def test_unique_none(pyspark_constructor: Constructor) -> None: - data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} - df = nw.from_native(pyspark_constructor(data)) - result = df.unique().sort("z") - assert_equal_data(result, data) - - -def test_inner_join_two_keys(pyspark_constructor: Constructor) -> None: - data = { - "antananarivo": [1, 3, 2], - "bob": [4, 4, 6], - "zorro": [7.0, 8, 9], - "idx": [0, 1, 2], - } - df = nw.from_native(pyspark_constructor(data)) - df_right = nw.from_native(pyspark_constructor(data)) - result = df.join( - df_right, # type: ignore[arg-type] - left_on=["antananarivo", "bob"], - right_on=["antananarivo", "bob"], - how="inner", - ) - result = result.sort("idx").drop("idx_right") - - df = nw.from_native(pyspark_constructor(data)) - df_right = nw.from_native(pyspark_constructor(data)) - - result_on = df.join(df_right, on=["antananarivo", "bob"], how="inner") # type: ignore[arg-type] - result_on = result_on.sort("idx").drop("idx_right") - expected = { - "antananarivo": [1, 3, 2], - "bob": [4, 4, 6], - "zorro": [7.0, 8, 9], - "idx": [0, 1, 2], - "zorro_right": [7.0, 8, 9], - } - assert_equal_data(result, expected) - assert_equal_data(result_on, expected) - - -def test_inner_join_single_key(pyspark_constructor: Constructor) -> None: - data = { - "antananarivo": [1, 3, 2], - "bob": [4, 4, 6], - "zorro": [7.0, 8, 9], - "idx": [0, 1, 2], - } - df = nw.from_native(pyspark_constructor(data)) - df_right = nw.from_native(pyspark_constructor(data)) - result = ( - df.join( - df_right, # type: ignore[arg-type] - left_on="antananarivo", - right_on="antananarivo", - how="inner", - ) - .sort("idx") - .drop("idx_right") - ) - - df = nw.from_native(pyspark_constructor(data)) - df_right = nw.from_native(pyspark_constructor(data)) - result_on = ( - df.join( - df_right, # type: ignore[arg-type] - on="antananarivo", - how="inner", - ) - .sort("idx") - .drop("idx_right") - ) - - expected = { - "antananarivo": [1, 3, 2], - "bob": [4, 4, 6], - "zorro": [7.0, 8, 9], - "idx": [0, 1, 2], - "bob_right": [4, 4, 6], - "zorro_right": [7.0, 8, 9], - } - assert_equal_data(result, expected) - assert_equal_data(result_on, expected) - - -def test_cross_join(pyspark_constructor: Constructor) -> None: - data = {"antananarivo": [1, 3, 2]} - df = nw.from_native(pyspark_constructor(data)) - other = nw.from_native(pyspark_constructor(data)) - result = df.join(other, how="cross").sort("antananarivo", "antananarivo_right") # type: ignore[arg-type] - expected = { - "antananarivo": [1, 1, 1, 2, 2, 2, 3, 3, 3], - "antananarivo_right": [1, 2, 3, 1, 2, 3, 1, 2, 3], - } - assert_equal_data(result, expected) - - with pytest.raises( - ValueError, - match="Can not pass `left_on`, `right_on` or `on` keys for cross join", - ): - df.join(other, how="cross", left_on="antananarivo") # type: ignore[arg-type] - - -@pytest.mark.parametrize("how", ["inner", "left"]) -@pytest.mark.parametrize("suffix", ["_right", "_custom_suffix"]) -def test_suffix(pyspark_constructor: Constructor, how: str, suffix: str) -> None: - data = { - "antananarivo": [1, 3, 2], - "bob": [4, 4, 6], - "zorro": [7.0, 8, 9], - } - df = nw.from_native(pyspark_constructor(data)) - df_right = nw.from_native(pyspark_constructor(data)) - result = df.join( - df_right, # type: ignore[arg-type] - left_on=["antananarivo", "bob"], - right_on=["antananarivo", "bob"], - how=how, # type: ignore[arg-type] - suffix=suffix, - ) - result_cols = result.collect_schema().names() - assert result_cols == ["antananarivo", "bob", "zorro", f"zorro{suffix}"] - - -@pytest.mark.parametrize("suffix", ["_right", "_custom_suffix"]) -def test_cross_join_suffix(pyspark_constructor: Constructor, suffix: str) -> None: - data = {"antananarivo": [1, 3, 2]} - df = nw.from_native(pyspark_constructor(data)) - other = nw.from_native(pyspark_constructor(data)) - result = df.join(other, how="cross", suffix=suffix).sort( # type: ignore[arg-type] - "antananarivo", f"antananarivo{suffix}" - ) - expected = { - "antananarivo": [1, 1, 1, 2, 2, 2, 3, 3, 3], - f"antananarivo{suffix}": [1, 2, 3, 1, 2, 3, 1, 2, 3], - } - assert_equal_data(result, expected) - - -@pytest.mark.parametrize( - ("join_key", "filter_expr", "expected"), - [ - ( - ["antananarivo", "bob"], - (nw.col("bob") < 5), - {"antananarivo": [2], "bob": [6], "zorro": [9]}, - ), - (["bob"], (nw.col("bob") < 5), {"antananarivo": [2], "bob": [6], "zorro": [9]}), - ( - ["bob"], - (nw.col("bob") > 5), - {"antananarivo": [1, 3], "bob": [4, 4], "zorro": [7.0, 8.0]}, - ), - ], -) -def test_anti_join( - pyspark_constructor: Constructor, - join_key: list[str], - filter_expr: nw.Expr, - expected: dict[str, list[Any]], -) -> None: - data = {"antananarivo": [1, 3, 2], "bob": [4, 4, 6], "zorro": [7.0, 8, 9]} - df = nw.from_native(pyspark_constructor(data)) - other = df.filter(filter_expr) - result = df.join(other, how="anti", left_on=join_key, right_on=join_key) # type: ignore[arg-type] - assert_equal_data(result, expected) - - -@pytest.mark.parametrize( - ("join_key", "filter_expr", "expected"), - [ - ( - "antananarivo", - (nw.col("bob") > 5), - {"antananarivo": [2], "bob": [6], "zorro": [9]}, - ), - ( - ["antananarivo"], - (nw.col("bob") > 5), - {"antananarivo": [2], "bob": [6], "zorro": [9]}, - ), - ( - ["bob"], - (nw.col("bob") < 5), - {"antananarivo": [1, 3], "bob": [4, 4], "zorro": [7, 8]}, - ), - ( - ["antananarivo", "bob"], - (nw.col("bob") < 5), - {"antananarivo": [1, 3], "bob": [4, 4], "zorro": [7, 8]}, - ), - ], -) -def test_semi_join( - pyspark_constructor: Constructor, - join_key: list[str], - filter_expr: nw.Expr, - expected: dict[str, list[Any]], -) -> None: - data = {"antananarivo": [1, 3, 2], "bob": [4, 4, 6], "zorro": [7.0, 8, 9]} - df = nw.from_native(pyspark_constructor(data)) - other = df.filter(filter_expr) - result = df.join(other, how="semi", left_on=join_key, right_on=join_key).sort( # type: ignore[arg-type] - "antananarivo" - ) - assert_equal_data(result, expected) - - -def test_left_join(pyspark_constructor: Constructor) -> None: - data_left = { - "antananarivo": [1.0, 2, 3], - "bob": [4.0, 5, 6], - "idx": [0.0, 1.0, 2.0], - } - data_right = { - "antananarivo": [1.0, 2, 3], - "co": [4.0, 5, 7], - "idx": [0.0, 1.0, 2.0], - } - df_left = nw.from_native(pyspark_constructor(data_left)) - df_right = nw.from_native(pyspark_constructor(data_right)) - result = ( - df_left.join(df_right, left_on="bob", right_on="co", how="left") # type: ignore[arg-type] - .sort("idx") - .drop("idx_right") - ) - expected = { - "antananarivo": [1, 2, 3], - "bob": [4, 5, 6], - "idx": [0, 1, 2], - "antananarivo_right": [1, 2, None], - } - assert_equal_data(result, expected) - - df_left = nw.from_native(pyspark_constructor(data_left)) - df_right = nw.from_native(pyspark_constructor(data_right)) - result_on_list = df_left.join( - df_right, # type: ignore[arg-type] - on=["antananarivo", "idx"], - how="left", - ) - result_on_list = result_on_list.sort("idx") - expected_on_list = { - "antananarivo": [1, 2, 3], - "bob": [4, 5, 6], - "idx": [0, 1, 2], - "co": [4, 5, 7], - } - assert_equal_data(result_on_list, expected_on_list) - - -def test_left_join_multiple_column(pyspark_constructor: Constructor) -> None: - data_left = {"antananarivo": [1, 2, 3], "bob": [4, 5, 6], "idx": [0, 1, 2]} - data_right = {"antananarivo": [1, 2, 3], "c": [4, 5, 6], "idx": [0, 1, 2]} - df_left = nw.from_native(pyspark_constructor(data_left)) - df_right = nw.from_native(pyspark_constructor(data_right)) - result = ( - df_left.join( - df_right, # type: ignore[arg-type] - left_on=["antananarivo", "bob"], - right_on=["antananarivo", "c"], - how="left", - ) - .sort("idx") - .drop("idx_right") - ) - expected = {"antananarivo": [1, 2, 3], "bob": [4, 5, 6], "idx": [0, 1, 2]} - assert_equal_data(result, expected) - - -def test_left_join_overlapping_column(pyspark_constructor: Constructor) -> None: - data_left = { - "antananarivo": [1.0, 2, 3], - "bob": [4.0, 5, 6], - "d": [1.0, 4, 2], - "idx": [0.0, 1.0, 2.0], - } - data_right = { - "antananarivo": [1.0, 2, 3], - "c": [4.0, 5, 6], - "d": [1.0, 4, 2], - "idx": [0.0, 1.0, 2.0], - } - df_left = nw.from_native(pyspark_constructor(data_left)) - df_right = nw.from_native(pyspark_constructor(data_right)) - result = df_left.join(df_right, left_on="bob", right_on="c", how="left").sort("idx") # type: ignore[arg-type] - result = result.drop("idx_right") - expected: dict[str, list[Any]] = { - "antananarivo": [1, 2, 3], - "bob": [4, 5, 6], - "d": [1, 4, 2], - "idx": [0, 1, 2], - "antananarivo_right": [1, 2, 3], - "d_right": [1, 4, 2], - } - assert_equal_data(result, expected) - - df_left = nw.from_native(pyspark_constructor(data_left)) - df_right = nw.from_native(pyspark_constructor(data_right)) - result = ( - df_left.join( - df_right, # type: ignore[arg-type] - left_on="antananarivo", - right_on="d", - how="left", - ) - .sort("idx") - .drop("idx_right") - ) - expected = { - "antananarivo": [1, 2, 3], - "bob": [4, 5, 6], - "d": [1, 4, 2], - "idx": [0, 1, 2], - "antananarivo_right": [1.0, 3.0, None], - "c": [4.0, 6.0, None], - } - assert_equal_data(result, expected) - - -# copied from tests/expr_and_series/arithmetic_test.py -@pytest.mark.parametrize( - ("attr", "rhs", "expected"), - [ - ("__add__", 1, [2, 3, 4]), - ("__sub__", 1, [0, 1, 2]), - ("__mul__", 2, [2, 4, 6]), - ("__truediv__", 2.0, [0.5, 1.0, 1.5]), - ("__truediv__", 1, [1, 2, 3]), - ("__floordiv__", 2, [0, 1, 1]), - ("__mod__", 2, [1, 0, 1]), - ("__pow__", 2, [1, 4, 9]), - ], -) -def test_arithmetic_expr( - attr: str, rhs: Any, expected: list[Any], pyspark_constructor: Constructor -) -> None: - data = {"a": [1.0, 2, 3]} - df = nw.from_native(pyspark_constructor(data)) - result = df.select(getattr(nw.col("a"), attr)(rhs)) - assert_equal_data(result, {"a": expected}) - - -@pytest.mark.parametrize( - ("attr", "rhs", "expected"), - [ - ("__radd__", 1, [2, 3, 4]), - ("__rsub__", 1, [0, -1, -2]), - ("__rmul__", 2, [2, 4, 6]), - ("__rtruediv__", 2.0, [2, 1, 2 / 3]), - ("__rfloordiv__", 2, [2, 1, 0]), - ("__rmod__", 2, [0, 0, 2]), - ("__rpow__", 2, [2, 4, 8]), - ], -) -def test_right_arithmetic_expr( - attr: str, - rhs: Any, - expected: list[Any], - pyspark_constructor: Constructor, -) -> None: - data = {"a": [1, 2, 3]} - df = nw.from_native(pyspark_constructor(data)) - result = df.select(getattr(nw.col("a"), attr)(rhs)) - assert_equal_data(result, {"literal": expected}) - - -# Copied from tests/expr_and_series/median_test.py -def test_median(pyspark_constructor: Constructor) -> None: - data = {"a": [3, 8, 2, None], "b": [5, 5, None, 7], "z": [7.0, 8, 9, None]} - df = nw.from_native(pyspark_constructor(data)) - result = df.select( - a=nw.col("a").median(), b=nw.col("b").median(), z=nw.col("z").median() - ) - expected = {"a": [3.0], "b": [5.0], "z": [8.0]} - assert_equal_data(result, expected) - - -# copied from tests/expr_and_series/clip_test.py -def test_clip(pyspark_constructor: Constructor) -> None: - df = nw.from_native(pyspark_constructor({"a": [1, 2, 3, -4, 5]})) - result = df.select( - lower_only=nw.col("a").clip(lower_bound=3), - upper_only=nw.col("a").clip(upper_bound=4), - both=nw.col("a").clip(3, 4), - ) - expected = { - "lower_only": [3, 3, 3, 3, 5], - "upper_only": [1, 2, 3, -4, 4], - "both": [3, 3, 3, 3, 4], - } - assert_equal_data(result, expected) - - -# copied from tests/expr_and_series/is_between_test.py -@pytest.mark.parametrize( - ("closed", "expected"), - [ - ("left", [True, True, True, False]), - ("right", [False, True, True, True]), - ("both", [True, True, True, True]), - ("none", [False, True, True, False]), - ], -) -def test_is_between( - pyspark_constructor: Constructor, - closed: Literal["left", "right", "none", "both"], - expected: list[bool], -) -> None: - data = {"a": [1, 4, 2, 5]} - df = nw.from_native(pyspark_constructor(data)) - result = df.select(nw.col("a").is_between(1, 5, closed=closed)) - expected_dict = {"a": expected} - assert_equal_data(result, expected_dict) - - -# copied from tests/expr_and_series/is_duplicated_test.py -def test_is_duplicated(pyspark_constructor: Constructor) -> None: - data = {"a": [1, 1, 2, None], "b": [1, 2, None, None], "level_0": [0, 1, 2, 3]} - df = nw.from_native(pyspark_constructor(data)) - result = df.select( - a=nw.col("a").is_duplicated(), - b=nw.col("b").is_duplicated(), - level_0=nw.col("level_0"), - ).sort("level_0") - expected = { - "a": [True, True, False, False], - "b": [False, False, True, True], - "level_0": [0, 1, 2, 3], - } - assert_equal_data(result, expected) - - -# copied from tests/expr_and_series/is_finite_test.py -def test_is_finite(pyspark_constructor: Constructor) -> None: - data = {"a": [float("nan"), float("inf"), 2.0, None]} - df = nw.from_native(pyspark_constructor(data)) - result = df.select(finite=nw.col("a").is_finite()) - expected = {"finite": [False, False, True, False]} - assert_equal_data(result, expected) - - -def test_is_in(pyspark_constructor: Constructor) -> None: - data = {"a": [1, 2, 3, 4, 5]} - df = nw.from_native(pyspark_constructor(data)) - result = df.select(in_list=nw.col("a").is_in([2, 4])) - expected = {"in_list": [False, True, False, True, False]} - assert_equal_data(result, expected) - - -# copied from tests/expr_and_series/is_unique_test.py -def test_is_unique(pyspark_constructor: Constructor) -> None: - data = {"a": [1, 1, 2, None], "b": [1, 2, None, None], "level_0": [0, 1, 2, 3]} - df = nw.from_native(pyspark_constructor(data)) - result = df.select( - a=nw.col("a").is_unique(), - b=nw.col("b").is_unique(), - level_0=nw.col("level_0"), - ).sort("level_0") - expected = { - "a": [False, False, True, True], - "b": [True, True, False, False], - "level_0": [0, 1, 2, 3], - } - assert_equal_data(result, expected) - - -def test_len(pyspark_constructor: Constructor) -> None: - data = {"a": [1, 2, float("nan"), 4, None], "b": [None, 3, None, 5, None]} - df = nw.from_native(pyspark_constructor(data)) - result = df.select( - a=nw.col("a").len(), - b=nw.col("b").len(), - ) - expected = {"a": [5], "b": [5]} - assert_equal_data(result, expected) - - -# Copied from tests/expr_and_series/round_test.py -@pytest.mark.parametrize("decimals", [0, 1, 2]) -def test_round(pyspark_constructor: Constructor, decimals: int) -> None: - data = {"a": [2.12345, 2.56789, 3.901234]} - df = nw.from_native(pyspark_constructor(data)) - - expected_data = {k: [round(e, decimals) for e in v] for k, v in data.items()} - result_frame = df.select(nw.col("a").round(decimals)) - assert_equal_data(result_frame, expected_data) - - -# copied from tests/expr_and_series/skew_test.py -@pytest.mark.parametrize( - ("data", "expected"), - [ - pytest.param( - [], - None, - marks=pytest.mark.skip( - reason="PySpark cannot infer schema from empty datasets" - ), - ), - ([1], None), - ([1, 2], 0.0), - ([0.0, 0.0, 0.0], None), - ([1, 2, 3, 2, 1], 0.343622), - ], -) -def test_skew( - pyspark_constructor: Constructor, data: list[float], expected: float | None -) -> None: - df = nw.from_native(pyspark_constructor({"a": data})) - result = df.select(skew=nw.col("a").skew()) - assert_equal_data(result, {"skew": [expected]}) - - -# copied from tests/expr_and_series/list_test.py -@pytest.mark.parametrize( - ("dtype", "expected_lit"), - [(None, [2, 2, 2]), (nw.String, ["2", "2", "2"]), (nw.Float32, [2.0, 2.0, 2.0])], -) -def test_lit( - pyspark_constructor: Constructor, - dtype: DType | None, - expected_lit: list[Any], - request: pytest.FixtureRequest, -) -> None: - if dtype is not None: - request.applymarker(pytest.mark.xfail) - data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} - df_raw = pyspark_constructor(data) - df = nw.from_native(df_raw).lazy() - result = df.with_columns(nw.lit(2, dtype).alias("lit")) - expected = { - "a": [1, 3, 2], - "b": [4, 4, 6], - "z": [7.0, 8.0, 9.0], - "lit": expected_lit, - } - assert_equal_data(result, expected) - - -@pytest.mark.parametrize( - ("col_name", "expr", "expected_result"), - [ - ("left_lit", nw.lit(1) + nw.col("a"), [2, 4, 3]), - ("right_lit", nw.col("a") + nw.lit(1), [2, 4, 3]), - ("left_lit_with_agg", nw.lit(1) + nw.col("a").mean(), [3]), - ("right_lit_with_agg", nw.col("a").mean() - nw.lit(1), [1]), - ("left_scalar", 1 + nw.col("a"), [2, 4, 3]), - ("right_scalar", nw.col("a") + 1, [2, 4, 3]), - ("left_scalar_with_agg", 1 + nw.col("a").mean(), [3]), - ("right_scalar_with_agg", nw.col("a").mean() - 1, [1]), - ], -) -def test_lit_operation( - pyspark_constructor: Constructor, - col_name: str, - expr: nw.Expr, - expected_result: list[int], - request: pytest.FixtureRequest, -) -> None: - if col_name in ( - "left_scalar_with_agg", - "left_lit_with_agg", - "right_lit", - "right_lit_with_agg", - ): - request.applymarker(pytest.mark.xfail) - - data = {"a": [1, 3, 2]} - df_raw = pyspark_constructor(data) - df = nw.from_native(df_raw).lazy() - result = df.select(expr.alias(col_name)) - expected = {col_name: expected_result} - assert_equal_data(result, expected) diff --git a/tests/stable_api_test.py b/tests/stable_api_test.py index c3d028563..862c5966f 100644 --- a/tests/stable_api_test.py +++ b/tests/stable_api_test.py @@ -16,7 +16,7 @@ def test_renamed_taxicab_norm( constructor: Constructor, request: pytest.FixtureRequest ) -> None: - if "duckdb" in str(constructor): + if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) # Suppose we need to rename `_l1_norm` to `_taxicab_norm`. # We need `narwhals.stable.v1` to stay stable. So, we @@ -46,10 +46,15 @@ def test_renamed_taxicab_norm( assert_equal_data(result, expected) -def test_renamed_taxicab_norm_dataframe(constructor: Constructor) -> None: +def test_renamed_taxicab_norm_dataframe( + request: pytest.FixtureRequest, constructor: Constructor +) -> None: # Suppose we have `DataFrame._l1_norm` in `stable.v1`, but remove it # in the main namespace. Here, we check that it's still usable from # the stable api. + if "pyspark" in str(constructor): + request.applymarker(pytest.mark.xfail) + def func(df_any: Any) -> Any: df = nw_v1.from_native(df_any) df = df._l1_norm() @@ -60,10 +65,16 @@ def func(df_any: Any) -> Any: assert_equal_data(result, expected) -def test_renamed_taxicab_norm_dataframe_narwhalify(constructor: Constructor) -> None: +def test_renamed_taxicab_norm_dataframe_narwhalify( + request: pytest.FixtureRequest, constructor: Constructor +) -> None: # Suppose we have `DataFrame._l1_norm` in `stable.v1`, but remove it # in the main namespace. Here, we check that it's still usable from # the stable api when using `narwhalify`. + + if "pyspark" in str(constructor): + request.applymarker(pytest.mark.xfail) + @nw_v1.narwhalify def func(df: Any) -> Any: return df._l1_norm() @@ -136,7 +147,10 @@ def test_series_docstrings() -> None: ), item -def test_dtypes(constructor: Constructor) -> None: +def test_dtypes(request: pytest.FixtureRequest, constructor: Constructor) -> None: + if "pyspark" in str(constructor): + request.applymarker(pytest.mark.xfail) + df = nw_v1.from_native( constructor({"a": [1], "b": [datetime(2020, 1, 1)], "c": [timedelta(1)]}) )