From 3a046ca5c9addf95a306c4e0c1cf2362a7fee56a Mon Sep 17 00:00:00 2001 From: Enrico Minack Date: Fri, 16 Aug 2024 19:53:46 +0200 Subject: [PATCH] Detect and test Spark Connect server (#247) Most features are not supported by PySpark with a Spark Connect server. This adds readable error message to the package and test capability to the CI. --- .github/actions/test-python/action.yml | 58 ++++++++- .github/workflows/test-python.yml | 7 + DIFF.md | 2 + HISTOGRAM.md | 2 + PARQUET.md | 4 + PYSPARK-DEPS.md | 4 + README.md | 28 ++-- ROW_NUMBER.md | 4 + python/gresearch/spark/__init__.py | 114 ++++++++++------ python/gresearch/spark/diff/__init__.py | 19 ++- python/gresearch/spark/parquet/__init__.py | 27 +++- python/test/requirements.txt | 5 +- python/test/spark_common.py | 5 +- python/test/test_diff.py | 2 + python/test/test_histogram.py | 3 + python/test/test_job_description.py | 3 + python/test/test_jvm.py | 143 +++++++++++++++++++++ python/test/test_package.py | 14 ++ python/test/test_parquet.py | 2 + python/test/test_row_number.py | 3 + 20 files changed, 384 insertions(+), 65 deletions(-) create mode 100644 python/test/test_jvm.py diff --git a/.github/actions/test-python/action.yml b/.github/actions/test-python/action.yml index 62436136..1a1a4c36 100644 --- a/.github/actions/test-python/action.yml +++ b/.github/actions/test-python/action.yml @@ -15,6 +15,9 @@ inputs: spark-compat-version: description: Spark compatibility version, e.g. 3.4 required: true + hadoop-version: + description: Hadoop version, e.g. 2.7 or 2 + required: true scala-compat-version: description: Scala compatibility version, e.g. 2.12 required: true @@ -40,6 +43,26 @@ runs: name: Binaries-${{ inputs.spark-compat-version }}-${{ inputs.scala-compat-version }} path: . + - name: Cache Spark Binaries + uses: actions/cache@v4 + if: inputs.scala-compat-version == '2.12' && ! contains(inputs.spark-version, '-SNAPSHOT') + with: + path: ~/spark + key: ${{ runner.os }}-spark-binaries-${{ inputs.spark-version }}-${{ inputs.scala-compat-version }} + + - name: Setup Spark Binaries + if: inputs.scala-compat-version == '2.12' && ! contains(inputs.spark-version, '-SNAPSHOT') + env: + SPARK_PACKAGE: spark-${{ inputs.spark-version }}/spark-${{ inputs.spark-version }}-bin-hadoop${{ inputs.hadoop-version }}${{ inputs.scala-compat-version == '2.13' && '-scala2.13' || '' }}.tgz + run: | + if [[ ! -e ~/spark ]] + then + wget --progress=dot:giga "https://www.apache.org/dyn/closer.lua/spark/${SPARK_PACKAGE}?action=download" -O - | tar -xzC "${{ runner.temp }}" + archive=$(basename "${SPARK_PACKAGE}") bash -c "mv -v "${{ runner.temp }}/\${archive/%.tgz/}" ~/spark" + fi + echo "SPARK_BIN_HOME=$(cd ~/spark; pwd)" >> $GITHUB_ENV + shell: bash + - name: Cache Maven packages if: github.event_name != 'merge_group' uses: actions/cache@v4 @@ -105,6 +128,34 @@ runs: run: mvn --batch-mode --update-snapshots install -Dspotless.check.skip -DskipTests -Dmaven.test.skip=true -Dgpg.skip shell: bash + - name: Start Spark Connect + id: spark-connect + if: (inputs.spark-compat-version == '3.4' || inputs.spark-compat-version == '3.5' || startsWith('4.', inputs.spark-compat-version)) && inputs.scala-compat-version == '2.12' && ! contains(inputs.spark-version, '-SNAPSHOT') + run: | + $SPARK_BIN_HOME/sbin/start-connect-server.sh --packages org.apache.spark:spark-connect_${{ inputs.scala-compat-version }}:${{ inputs.spark-version }} + shell: bash + + - name: Python Unit Tests (Spark Connect) + if: steps.spark-connect.outcome == 'success' + env: + PYTHONPATH: python:python/test + TEST_SPARK_CONNECT_SERVER: sc://localhost:15002 + run: | + pip install pyspark[connect] + python -m pytest python/test --junit-xml test-results-connect/pytest-$(date +%s.%N)-$RANDOM.xml + shell: bash + + - name: Stop Spark Connect + if: always() && steps.spark-connect.outcome == 'success' + run: | + $SPARK_BIN_HOME/sbin/stop-connect-server.sh + echo "::group::Spark Connect server log" + # thoughs started in $SPARK_BIN_HOME/sbin, logs go to $SPARK_HOME/logs + ls -lah $SPARK_HOME/logs || true + cat $SPARK_HOME/logs/spark-*-org.apache.spark.sql.connect.service.SparkConnectServer-*.out || true + echo "::endgroup::" + shell: bash + - name: Python Integration Tests env: PYTHONPATH: python:python/test @@ -112,7 +163,7 @@ runs: find python/test -name 'test*.py' > tests while read test do - if ! $SPARK_HOME/bin/spark-submit --master "local[2]" --packages uk.co.gresearch.spark:spark-extension_${{ inputs.scala-compat-version }}:$SPARK_EXTENSION_VERSION "$test" test-results + if ! $SPARK_HOME/bin/spark-submit --master "local[2]" --packages uk.co.gresearch.spark:spark-extension_${{ inputs.scala-compat-version }}:$SPARK_EXTENSION_VERSION "$test" test-results-submit then state="fail" fi @@ -135,7 +186,10 @@ runs: uses: actions/upload-artifact@v4 with: name: Python Test Results (Spark ${{ inputs.spark-version }} Scala ${{ inputs.scala-version }} Python ${{ inputs.python-version }}) - path: test-results/*.xml + path: | + test-results/*.xml + test-results-submit/*.xml + test-results-connect/*.xml branding: icon: 'check-circle' diff --git a/.github/workflows/test-python.yml b/.github/workflows/test-python.yml index 444e80f5..a83d5599 100644 --- a/.github/workflows/test-python.yml +++ b/.github/workflows/test-python.yml @@ -19,28 +19,34 @@ jobs: include: - spark-compat-version: '3.0' spark-version: '3.0.3' + hadoop-version: '2.7' scala-compat-version: '2.12' scala-version: '2.12.10' python-version: '3.8' - spark-compat-version: '3.1' spark-version: '3.1.3' + hadoop-version: '2.7' scala-compat-version: '2.12' scala-version: '2.12.10' python-version: '3.8' - spark-compat-version: '3.2' spark-version: '3.2.4' + hadoop-version: '2.7' scala-compat-version: '2.12' scala-version: '2.12.15' - spark-compat-version: '3.3' spark-version: '3.3.4' + hadoop-version: '3' scala-compat-version: '2.12' scala-version: '2.12.15' - spark-compat-version: '3.4' spark-version: '3.4.2' + hadoop-version: '3' scala-compat-version: '2.12' scala-version: '2.12.17' - spark-compat-version: '3.5' spark-version: '3.5.1' + hadoop-version: '3' scala-compat-version: '2.12' scala-version: '2.12.18' @@ -55,4 +61,5 @@ jobs: scala-version: ${{ matrix.scala-version }} spark-compat-version: ${{ matrix.spark-compat-version }} scala-compat-version: ${{ matrix.scala-compat-version }} + hadoop-version: ${{ matrix.hadoop-version }} python-version: ${{ matrix.python-version }} diff --git a/DIFF.md b/DIFF.md index c5674c76..5fd04fd3 100644 --- a/DIFF.md +++ b/DIFF.md @@ -404,6 +404,8 @@ The latter variant is prefixed with `_with_options`. * `def diff(self: DataFrame, other: DataFrame, *id_columns: str) -> DataFrame` * `def diffwith(self: DataFrame, other: DataFrame, *id_columns: str) -> DataFrame:` +Note that this feature is not supported in Python when connected with a [Spark Connect server](README.md#spark-connect-server). + ## Diff Spark application There is also a Spark application that can be used to create a diff DataFrame. The application reads two DataFrames diff --git a/HISTOGRAM.md b/HISTOGRAM.md index b8f2948d..8a383498 100644 --- a/HISTOGRAM.md +++ b/HISTOGRAM.md @@ -55,3 +55,5 @@ In Python, call: import gresearch.spark df.histogram([100, 200], 'user').orderBy('user') + +Note that this feature is not supported in Python when connected with a [Spark Connect server](README.md#spark-connect-server). diff --git a/PARQUET.md b/PARQUET.md index 75f5b0df..aad88310 100644 --- a/PARQUET.md +++ b/PARQUET.md @@ -254,3 +254,7 @@ spark.read.parquet_blocks("/path/to/parquet", parallelism=100) spark.read.parquet_block_columns("/path/to/parquet", parallelism=100) spark.read.parquet_partitions("/path/to/parquet", parallelism=100) ``` + +## Known Issues + +Note that this feature is not supported in Python when connected with a [Spark Connect server](README.md#spark-connect-server). diff --git a/PYSPARK-DEPS.md b/PYSPARK-DEPS.md index 28b227fc..bfc3fb6a 100644 --- a/PYSPARK-DEPS.md +++ b/PYSPARK-DEPS.md @@ -129,3 +129,7 @@ Finally, shutdown the example cluster: ```shell docker compose -f docker-compose.yml down ``` + +## Known Issues + +Note that this feature is not supported in Python when connected with a [Spark Connect server](README.md#spark-connect-server). diff --git a/README.md b/README.md index 27f24aef..8e0052f3 100644 --- a/README.md +++ b/README.md @@ -2,27 +2,27 @@ This project provides extensions to the [Apache Spark project](https://spark.apache.org/) in Scala and Python: -**[Diff](DIFF.md):** A `diff` transformation and application for `Dataset`s that computes the differences between +**[Diff](DIFF.md) [[*]](#spark-connect-server):** A `diff` transformation and application for `Dataset`s that computes the differences between two datasets, i.e. which rows to _add_, _delete_ or _change_ to get from one dataset to the other. **[SortedGroups](GROUPS.md):** A `groupByKey` transformation that groups rows by a key while providing a **sorted** iterator for each group. Similar to `Dataset.groupByKey.flatMapGroups`, but with order guarantees for the iterator. -**[Histogram](HISTOGRAM.md):** A `histogram` transformation that computes the histogram DataFrame for a value column. +**[Histogram](HISTOGRAM.md) [[*]](#spark-connect-server):** A `histogram` transformation that computes the histogram DataFrame for a value column. -**[Global Row Number](ROW_NUMBER.md):** A `withRowNumbers` transformation that provides the global row number w.r.t. +**[Global Row Number](ROW_NUMBER.md) [[*]](#spark-connect-server):** A `withRowNumbers` transformation that provides the global row number w.r.t. the current order of the Dataset, or any given order. In contrast to the existing SQL function `row_number`, which requires a window spec, this transformation provides the row number across the entire Dataset without scaling problems. **[Partitioned Writing](PARTITIONING.md):** The `writePartitionedBy` action writes your `Dataset` partitioned and efficiently laid out with a single operation. -**[Inspect Parquet files](PARQUET.md):** The structure of Parquet files (the metadata, not the data stored in Parquet) can be inspected similar to [parquet-tools](https://pypi.org/project/parquet-tools/) +**[Inspect Parquet files](PARQUET.md) [[*]](#spark-connect-server):** The structure of Parquet files (the metadata, not the data stored in Parquet) can be inspected similar to [parquet-tools](https://pypi.org/project/parquet-tools/) or [parquet-cli](https://pypi.org/project/parquet-cli/) by reading from a simple Spark data source. This simplifies identifying why some Parquet files cannot be split by Spark into scalable partitions. -**[Install Python packages into PySpark job](PYSPARK-DEPS.md):** Install Python dependencies via PIP or Poetry programatically into your running PySpark job (PySpark ≥ 3.1.0): +**[Install Python packages into PySpark job](PYSPARK-DEPS.md) [[*]](#spark-connect-server):** Install Python dependencies via PIP or Poetry programatically into your running PySpark job (PySpark ≥ 3.1.0): ```python # noinspection PyUnresolvedReferences @@ -84,7 +84,7 @@ This is a handy way to ensure column names with special characters like dots (`. **Count null values:** `count_null(e: Column)`: an aggregation function like `count` that counts null values in column `e`. This is equivalent to calling `count(when(e.isNull, lit(1)))`. -**.Net DateTime.Ticks:** Convert .Net (C#, F#, Visual Basic) `DateTime.Ticks` into Spark timestamps, seconds and nanoseconds. +**.Net DateTime.Ticks[[*]](#spark-connect-server):** Convert .Net (C#, F#, Visual Basic) `DateTime.Ticks` into Spark timestamps, seconds and nanoseconds.
Available methods: @@ -117,7 +117,7 @@ unix_epoch_nanos_to_dotnet_ticks(column_or_name) ```
-**Spark temporary directory**: Create a temporary directory that will be removed on Spark application shutdown. +**Spark temporary directory[[*]](#spark-connect-server)**: Create a temporary directory that will be removed on Spark application shutdown.
Examples: @@ -138,7 +138,7 @@ dir = spark.create_temporary_dir("prefix") ```
-**Spark job description:** Set Spark job description for all Spark jobs within a context. +**Spark job description[[*]](#spark-connect-server):** Set Spark job description for all Spark jobs within a context.
Examples: @@ -306,6 +306,18 @@ on a filesystem where it is accessible by the notebook, and reference that jar f Check the documentation of your favorite notebook to learn how to add jars to your Spark environment. +## Known issues +### Spark Connect Server + +Most features are not supported **in Python** in conjunction with a [Spark Connect server](https://spark.apache.org/docs/latest/spark-connect-overview.html). +This also holds for Databricks Runtime environment 13.x and above. Details can be found [in this blog](https://semyonsinchenko.github.io/ssinchenko/post/how-databricks-14x-breaks-3dparty-compatibility/). + +Calling any of those features when connected to a Spark Connect server will raise this error: + + This feature is not supported for Spark Connect. + +Use a classic connection to a Spark cluster instead. + ## Build You can build this project against different versions of Spark and Scala. diff --git a/ROW_NUMBER.md b/ROW_NUMBER.md index 02832367..2bd31f47 100644 --- a/ROW_NUMBER.md +++ b/ROW_NUMBER.md @@ -216,3 +216,7 @@ WindowExec: No Partition Defined for Window operation! Moving all data to a sing ``` This warning is unavoidable, because `withRowNumbers` has to pull information about the initial partitions into a single partition. Fortunately, there are only 12 Bytes per input partition required, so this amount of data usually fits into a single partition and the warning can safely be ignored. + +## Known issues + +Note that this feature is not supported in Python when connected with a [Spark Connect server](README.md#spark-connect-server). diff --git a/python/gresearch/spark/__init__.py b/python/gresearch/spark/__init__.py index 15cb4bed..6ff299c5 100644 --- a/python/gresearch/spark/__init__.py +++ b/python/gresearch/spark/__init__.py @@ -27,17 +27,46 @@ from pyspark import __version__ from pyspark.context import SparkContext from pyspark.files import SparkFiles -from pyspark.sql import DataFrame +from pyspark.sql import DataFrame, DataFrameReader, SQLContext from pyspark.sql.column import Column, _to_java_column from pyspark.sql.context import SQLContext from pyspark.sql.functions import col, count, lit, when from pyspark.sql.session import SparkSession from pyspark.storagelevel import StorageLevel +try: + from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame + from pyspark.sql.connect.readwriter import DataFrameReader as ConnectDataFrameReader + from pyspark.sql.connect.session import SparkSession as ConnectSparkSession + has_connect = True +except ImportError: + has_connect = False + if TYPE_CHECKING: from pyspark.sql._typing import ColumnOrName +def _get_jvm(obj: Any) -> JVMView: + if obj is None: + if SparkContext._active_spark_context is None: + raise RuntimeError("This method must be called inside an active Spark session") + else: + raise ValueError("Cannot provide access to JVM from None") + + # helper method to assert the JVM is accessible and provide a useful error message + if has_connect and isinstance(obj, (ConnectDataFrame, ConnectDataFrameReader, ConnectSparkSession)): + raise RuntimeError('This feature is not supported for Spark Connect. Please use a classic Spark client. https://github.com/G-Research/spark-extension#spark-connect-server') + if isinstance(obj, DataFrame): + return _get_jvm(obj._sc) + if isinstance(obj, DataFrameReader): + return _get_jvm(obj._spark) + if isinstance(obj, SparkSession): + return _get_jvm(obj.sparkContext) + if isinstance(obj, (SparkContext, SQLContext)): + return obj._jvm + raise RuntimeError(f'Unsupported class: {type(obj)}') + + def _to_seq(jvm: JVMView, list: List[Any]) -> JavaObject: array = jvm.java.util.ArrayList(list) return jvm.scala.collection.JavaConverters.asScalaIteratorConverter(array.iterator()).asScala().toSeq() @@ -72,11 +101,8 @@ def dotnet_ticks_to_timestamp(tick_column: Union[str, Column]) -> Column: if not isinstance(tick_column, (str, Column)): raise ValueError(f"Given column must be a column name (str) or column instance (Column): {type(tick_column)}") - sc = SparkContext._active_spark_context - if sc is None or sc._jvm is None: - raise RuntimeError("This method must be called inside an active Spark session") - - func = sc._jvm.uk.co.gresearch.spark.__getattr__("package$").__getattr__("MODULE$").dotNetTicksToTimestamp + jvm = _get_jvm(SparkContext._active_spark_context) + func = jvm.uk.co.gresearch.spark.__getattr__("package$").__getattr__("MODULE$").dotNetTicksToTimestamp return Column(func(_to_java_column(tick_column))) @@ -105,11 +131,8 @@ def dotnet_ticks_to_unix_epoch(tick_column: Union[str, Column]) -> Column: if not isinstance(tick_column, (str, Column)): raise ValueError(f"Given column must be a column name (str) or column instance (Column): {type(tick_column)}") - sc = SparkContext._active_spark_context - if sc is None or sc._jvm is None: - raise RuntimeError("This method must be called inside an active Spark session") - - func = sc._jvm.uk.co.gresearch.spark.__getattr__("package$").__getattr__("MODULE$").dotNetTicksToUnixEpoch + jvm = _get_jvm(SparkContext._active_spark_context) + func = jvm.uk.co.gresearch.spark.__getattr__("package$").__getattr__("MODULE$").dotNetTicksToUnixEpoch return Column(func(_to_java_column(tick_column))) @@ -138,11 +161,8 @@ def dotnet_ticks_to_unix_epoch_nanos(tick_column: Union[str, Column]) -> Column: if not isinstance(tick_column, (str, Column)): raise ValueError(f"Given column must be a column name (str) or column instance (Column): {type(tick_column)}") - sc = SparkContext._active_spark_context - if sc is None or sc._jvm is None: - raise RuntimeError("This method must be called inside an active Spark session") - - func = sc._jvm.uk.co.gresearch.spark.__getattr__("package$").__getattr__("MODULE$").dotNetTicksToUnixEpochNanos + jvm = _get_jvm(SparkContext._active_spark_context) + func = jvm.uk.co.gresearch.spark.__getattr__("package$").__getattr__("MODULE$").dotNetTicksToUnixEpochNanos return Column(func(_to_java_column(tick_column))) @@ -170,11 +190,8 @@ def timestamp_to_dotnet_ticks(timestamp_column: Union[str, Column]) -> Column: if not isinstance(timestamp_column, (str, Column)): raise ValueError(f"Given column must be a column name (str) or column instance (Column): {type(timestamp_column)}") - sc = SparkContext._active_spark_context - if sc is None or sc._jvm is None: - raise RuntimeError("This method must be called inside an active Spark session") - - func = sc._jvm.uk.co.gresearch.spark.__getattr__("package$").__getattr__("MODULE$").timestampToDotNetTicks + jvm = _get_jvm(SparkContext._active_spark_context) + func = jvm.uk.co.gresearch.spark.__getattr__("package$").__getattr__("MODULE$").timestampToDotNetTicks return Column(func(_to_java_column(timestamp_column))) @@ -204,11 +221,8 @@ def unix_epoch_to_dotnet_ticks(unix_column: Union[str, Column]) -> Column: if not isinstance(unix_column, (str, Column)): raise ValueError(f"Given column must be a column name (str) or column instance (Column): {type(unix_column)}") - sc = SparkContext._active_spark_context - if sc is None or sc._jvm is None: - raise RuntimeError("This method must be called inside an active Spark session") - - func = sc._jvm.uk.co.gresearch.spark.__getattr__("package$").__getattr__("MODULE$").unixEpochToDotNetTicks + jvm = _get_jvm(SparkContext._active_spark_context) + func = jvm.uk.co.gresearch.spark.__getattr__("package$").__getattr__("MODULE$").unixEpochToDotNetTicks return Column(func(_to_java_column(unix_column))) @@ -239,11 +253,8 @@ def unix_epoch_nanos_to_dotnet_ticks(unix_column: Union[str, Column]) -> Column: if not isinstance(unix_column, (str, Column)): raise ValueError(f"Given column must be a column name (str) or column instance (Column): {type(unix_column)}") - sc = SparkContext._active_spark_context - if sc is None or sc._jvm is None: - raise RuntimeError("This method must be called inside an active Spark session") - - func = sc._jvm.uk.co.gresearch.spark.__getattr__("package$").__getattr__("MODULE$").unixEpochNanosToDotNetTicks + jvm = _get_jvm(SparkContext._active_spark_context) + func = jvm.uk.co.gresearch.spark.__getattr__("package$").__getattr__("MODULE$").unixEpochNanosToDotNetTicks return Column(func(_to_java_column(unix_column))) @@ -281,7 +292,7 @@ def histogram(self: DataFrame, else: raise ValueError('thresholds must be int or floats: {}'.format(t)) - jvm = self._sc._jvm + jvm = _get_jvm(self) col = jvm.org.apache.spark.sql.functions.col value_column = col(value_column) aggregate_columns = [col(column) for column in aggregate_columns] @@ -292,6 +303,8 @@ def histogram(self: DataFrame, DataFrame.histogram = histogram +if has_connect: + ConnectDataFrame.histogram = histogram class UnpersistHandle: @@ -307,7 +320,7 @@ def __call__(self, blocking: Optional[bool] = None): def unpersist_handle(self: SparkSession) -> UnpersistHandle: - jvm = self._sc._jvm + jvm = _get_jvm(self) handle = jvm.uk.co.gresearch.spark.UnpersistHandle() return UnpersistHandle(handle) @@ -321,7 +334,7 @@ def with_row_numbers(self: DataFrame, row_number_column_name: str = "row_number", order: Union[str, Column, List[Union[str, Column]]] = [], ascending: Union[bool, List[bool]] = True) -> DataFrame: - jvm = self._sc._jvm + jvm = _get_jvm(self) jsl = self._sc._getJavaStorageLevel(storage_level) juho = jvm.uk.co.gresearch.spark.UnpersistHandle juh = unpersist_handle._handle if unpersist_handle else juho.Noop() @@ -338,17 +351,23 @@ def with_row_numbers(self: DataFrame, return DataFrame(jdf, self.session_or_ctx()) +DataFrame.with_row_numbers = with_row_numbers +if has_connect: + ConnectDataFrame.with_row_numbers = with_row_numbers + + def session_or_ctx(self: DataFrame) -> Union[SparkSession, SQLContext]: return self.sparkSession if hasattr(self, 'sparkSession') else self.sql_ctx -DataFrame.with_row_numbers = with_row_numbers DataFrame.session_or_ctx = session_or_ctx +if has_connect: + ConnectDataFrame.session_or_ctx = session_or_ctx def set_description(description: str, if_not_set: bool = False): context = SparkContext._active_spark_context - jvm = context._jvm + jvm = _get_jvm(context) spark_package = jvm.uk.co.gresearch.spark.__getattr__("package$").__getattr__("MODULE$") return spark_package.setJobDescription(description, if_not_set, context._jsc.sc()) @@ -384,7 +403,7 @@ def job_description(description: str, if_not_set: bool = False): def append_description(extra_description: str, separator: str = " - "): context = SparkContext._active_spark_context - jvm = context._jvm + jvm = _get_jvm(context) spark_package = jvm.uk.co.gresearch.spark.__getattr__("package$").__getattr__("MODULE$") return spark_package.appendJobDescription(extra_description, separator, context._jsc.sc()) @@ -424,21 +443,25 @@ def create_temporary_dir(spark: Union[SparkSession, SparkContext], prefix: str) :param prefix: prefix string of temporary directory name :return: absolute path of temporary directory """ - if isinstance(spark, SparkSession): - spark = spark.sparkContext - - root_dir = spark._jvm.org.apache.spark.SparkFiles.getRootDirectory() + jvm = _get_jvm(spark) + root_dir = jvm.org.apache.spark.SparkFiles.getRootDirectory() return tempfile.mkdtemp(prefix=prefix, dir=root_dir) SparkSession.create_temporary_dir = create_temporary_dir SparkContext.create_temporary_dir = create_temporary_dir +if has_connect: + ConnectSparkSession.create_temporary_dir = create_temporary_dir + def install_pip_package(spark: Union[SparkSession, SparkContext], *package_or_pip_option: str) -> None: if __version__.startswith('2.') or __version__.startswith('3.0.'): raise NotImplementedError(f'Not supported for PySpark __version__') + # just here to assert JVM is accessible + _get_jvm(spark) + if isinstance(spark, SparkSession): spark = spark.sparkContext @@ -465,6 +488,9 @@ def install_pip_package(spark: Union[SparkSession, SparkContext], *package_or_pi SparkSession.install_pip_package = install_pip_package SparkContext.install_pip_package = install_pip_package +if has_connect: + ConnectSparkSession.install_pip_package = install_pip_package + def install_poetry_project(spark: Union[SparkSession, SparkContext], *project: str, @@ -478,6 +504,9 @@ def install_poetry_project(spark: Union[SparkSession, SparkContext], if __version__.startswith('2.') or __version__.startswith('3.0.'): raise NotImplementedError(f'Not supported for PySpark __version__') + # just here to assert JVM is accessible + _get_jvm(spark) + if isinstance(spark, SparkSession): spark = spark.sparkContext if poetry_python is None: @@ -538,3 +567,6 @@ def build_wheel(project: Path) -> Path: SparkSession.install_poetry_project = install_poetry_project SparkContext.install_poetry_project = install_poetry_project + +if has_connect: + ConnectSparkSession.install_poetry_project = install_poetry_project diff --git a/python/gresearch/spark/diff/__init__.py b/python/gresearch/spark/diff/__init__.py index 7776a743..351dc587 100644 --- a/python/gresearch/spark/diff/__init__.py +++ b/python/gresearch/spark/diff/__init__.py @@ -20,9 +20,15 @@ from pyspark.sql import DataFrame from pyspark.sql.types import DataType -from gresearch.spark import _to_seq, _to_map +from gresearch.spark import _get_jvm, _to_seq, _to_map from gresearch.spark.diff.comparator import DiffComparator, DiffComparators, DefaultDiffComparator +try: + from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame + has_connect = True +except ImportError: + has_connect = False + class DiffMode(Enum): ColumnByColumn = "ColumnByColumn" @@ -333,7 +339,7 @@ def diff(self, left: DataFrame, right: DataFrame, *id_columns: str) -> DataFrame :return: the diff DataFrame :rtype DataFrame """ - jvm = left._sc._jvm + jvm = _get_jvm(left) jdiffer = self._to_java(jvm) jdf = jdiffer.diff(left._jdf, right._jdf, _to_seq(jvm, list(id_columns))) return DataFrame(jdf, left.session_or_ctx()) @@ -354,7 +360,7 @@ def diffwith(self, left: DataFrame, right: DataFrame, *id_columns: str) -> DataF :return: the diff DataFrame :rtype DataFrame """ - jvm = left._sc._jvm + jvm = _get_jvm(left) jdiffer = self._to_java(jvm) jdf = jdiffer.diffWith(left._jdf, right._jdf, _to_seq(jvm, list(id_columns))) df = DataFrame(jdf, left.sql_ctx) @@ -489,6 +495,11 @@ def diffwith_with_options(self: DataFrame, other: DataFrame, options: DiffOption DataFrame.diff = diff DataFrame.diffwith = diffwith - DataFrame.diff_with_options = diff_with_options DataFrame.diffwith_with_options = diffwith_with_options + +if has_connect: + ConnectDataFrame.diff = diff + ConnectDataFrame.diffwith = diffwith + ConnectDataFrame.diff_with_options = diff_with_options + ConnectDataFrame.diffwith_with_options = diffwith_with_options diff --git a/python/gresearch/spark/parquet/__init__.py b/python/gresearch/spark/parquet/__init__.py index f6dd137f..890f22f0 100644 --- a/python/gresearch/spark/parquet/__init__.py +++ b/python/gresearch/spark/parquet/__init__.py @@ -17,11 +17,17 @@ from py4j.java_gateway import JavaObject from pyspark.sql import DataFrameReader, DataFrame -from gresearch.spark import _to_seq +from gresearch.spark import _get_jvm, _to_seq + +try: + from pyspark.sql.connect.readwriter import DataFrameReader as ConnectDataFrameReader + has_connect = True +except ImportError: + has_connect = False def _jreader(reader: DataFrameReader) -> JavaObject: - jvm = reader._spark._jvm + jvm = _get_jvm(reader) return jvm.uk.co.gresearch.spark.parquet.__getattr__("package$").__getattr__("MODULE$").ExtendedDataFrameReader(reader._jreader) @@ -52,7 +58,7 @@ def parquet_metadata(self: DataFrameReader, *paths: str, parallelism: Optional[i :param parallelism: number of partitions of returned DataFrame :return: dataframe with Parquet metadata """ - jvm = self._spark._jvm + jvm = _get_jvm(self) if parallelism is None: jdf = _jreader(self).parquetMetadata(_to_seq(jvm, list(paths))) else: @@ -87,7 +93,7 @@ def parquet_schema(self: DataFrameReader, *paths: str, parallelism: Optional[int :param parallelism: number of partitions of returned DataFrame :return: dataframe with Parquet metadata """ - jvm = self._spark._jvm + jvm = _get_jvm(self) if parallelism is None: jdf = _jreader(self).parquetSchema(_to_seq(jvm, list(paths))) else: @@ -119,7 +125,7 @@ def parquet_blocks(self: DataFrameReader, *paths: str, parallelism: Optional[int :param parallelism: number of partitions of returned DataFrame :return: dataframe with Parquet metadata """ - jvm = self._spark._jvm + jvm = _get_jvm(self) if parallelism is None: jdf = _jreader(self).parquetBlocks(_to_seq(jvm, list(paths))) else: @@ -155,7 +161,7 @@ def parquet_block_columns(self: DataFrameReader, *paths: str, parallelism: Optio :param parallelism: number of partitions of returned DataFrame :return: dataframe with Parquet metadata """ - jvm = self._spark._jvm + jvm = _get_jvm(self) if parallelism is None: jdf = _jreader(self).parquetBlockColumns(_to_seq(jvm, list(paths))) else: @@ -191,7 +197,7 @@ def parquet_partitions(self: DataFrameReader, *paths: str, parallelism: Optional :param parallelism: number of partitions of returned DataFrame :return: dataframe with Parquet metadata """ - jvm = self._spark._jvm + jvm = _get_jvm(self) if parallelism is None: jdf = _jreader(self).parquetPartitions(_to_seq(jvm, list(paths))) else: @@ -204,3 +210,10 @@ def parquet_partitions(self: DataFrameReader, *paths: str, parallelism: Optional DataFrameReader.parquet_blocks = parquet_blocks DataFrameReader.parquet_block_columns = parquet_block_columns DataFrameReader.parquet_partitions = parquet_partitions + +if has_connect: + ConnectDataFrameReader.parquet_metadata = parquet_metadata + ConnectDataFrameReader.parquet_schema = parquet_schema + ConnectDataFrameReader.parquet_blocks = parquet_blocks + ConnectDataFrameReader.parquet_block_columns = parquet_block_columns + ConnectDataFrameReader.parquet_partitions = parquet_partitions diff --git a/python/test/requirements.txt b/python/test/requirements.txt index 7056a990..cefc8f91 100644 --- a/python/test/requirements.txt +++ b/python/test/requirements.txt @@ -1,4 +1,5 @@ -pandas -pyarrow +grpcio>=1.48.1 +pandas>=1.0.5 +pyarrow>=4.0.0 pytest unittest-xml-reporting diff --git a/python/test/spark_common.py b/python/test/spark_common.py index 019cce45..76735beb 100644 --- a/python/test/spark_common.py +++ b/python/test/spark_common.py @@ -80,7 +80,9 @@ def get_spark_config(path) -> SparkConf: def get_spark_session(cls) -> SparkSession: builder = SparkSession.builder - if 'PYSPARK_GATEWAY_PORT' in os.environ: + if 'TEST_SPARK_CONNECT_SERVER' in os.environ: + builder.remote(os.environ['TEST_SPARK_CONNECT_SERVER']) + elif 'PYSPARK_GATEWAY_PORT' in os.environ: logging.info('Running inside existing Spark environment') else: logging.info('Setting up Spark environment') @@ -91,6 +93,7 @@ def get_spark_session(cls) -> SparkSession: return builder.getOrCreate() spark: SparkSession = None + is_spark_connect: bool = 'TEST_SPARK_CONNECT_SERVER' in os.environ @classmethod def setUpClass(cls): diff --git a/python/test/test_diff.py b/python/test/test_diff.py index f34607a9..328612e8 100644 --- a/python/test/test_diff.py +++ b/python/test/test_diff.py @@ -18,11 +18,13 @@ from pyspark.sql import Row from pyspark.sql.functions import col, when from pyspark.sql.types import IntegerType, LongType, StringType, DateType +from unittest import skipIf from gresearch.spark.diff import Differ, DiffOptions, DiffMode, DiffComparators from spark_common import SparkTest +@skipIf(SparkTest.is_spark_connect, "Spark Connect does not provide access to the JVM, required by Diff") class DiffTest(SparkTest): expected_diff = None diff --git a/python/test/test_histogram.py b/python/test/test_histogram.py index 0f58d848..01f92196 100644 --- a/python/test/test_histogram.py +++ b/python/test/test_histogram.py @@ -12,10 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from unittest import skipIf + from spark_common import SparkTest import gresearch.spark +@skipIf(SparkTest.is_spark_connect, "Spark Connect does not provide access to the JVM, required by Historgam") class HistogramTest(SparkTest): @classmethod diff --git a/python/test/test_job_description.py b/python/test/test_job_description.py index 15630a43..df72faee 100644 --- a/python/test/test_job_description.py +++ b/python/test/test_job_description.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from unittest import skipIf + from pyspark import TaskContext, SparkContext from typing import Optional @@ -19,6 +21,7 @@ from gresearch.spark import job_description, append_job_description +@skipIf(SparkTest.is_spark_connect, "Spark Connect does not provide access to the JVM, required by JobDescription") class JobDescriptionTest(SparkTest): def _assert_job_description(self, expected: Optional[str]): diff --git a/python/test/test_jvm.py b/python/test/test_jvm.py new file mode 100644 index 00000000..45ddc968 --- /dev/null +++ b/python/test/test_jvm.py @@ -0,0 +1,143 @@ +# Copyright 2024 G-Research +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest import skipIf, skipUnless + +from pyspark.sql.functions import sum + +from gresearch.spark import _get_jvm, dotnet_ticks_to_timestamp, dotnet_ticks_to_unix_epoch, dotnet_ticks_to_unix_epoch_nanos, \ + timestamp_to_dotnet_ticks, unix_epoch_to_dotnet_ticks, unix_epoch_nanos_to_dotnet_ticks, histogram, job_description, append_description +from gresearch.spark.diff import * +from gresearch.spark.parquet import * +from spark_common import SparkTest + +EXPECTED_UNSUPPORTED_MESSAGE = "This feature is not supported for Spark Connect. Please use a classic Spark client. " \ + "https://github.com/G-Research/spark-extension#spark-connect-server" + + +class PackageTest(SparkTest): + df = None + + @classmethod + def setUpClass(cls): + super(PackageTest, cls).setUpClass() + cls.df = cls.spark.createDataFrame([(1, "one"), (2, "two"), (3, "three")], ["id", "value"]) + + @skipIf(SparkTest.is_spark_connect, "Spark classic client tests") + def test_get_jvm_classic(self): + for obj in [self.spark, self.spark.sparkContext, self.df, self.spark.read]: + with self.subTest(type(obj).__name__): + self.assertIsNotNone(_get_jvm(obj)) + + with self.subTest("Unsupported"): + with self.assertRaises(RuntimeError) as e: + _get_jvm(object()) + self.assertEqual(("Unsupported class: ", ), e.exception.args) + + @skipUnless(SparkTest.is_spark_connect, "Spark connect client tests") + def test_get_jvm_connect(self): + for obj in [self.spark, self.df, self.spark.read]: + with self.subTest(type(obj).__name__): + with self.assertRaises(RuntimeError) as e: + _get_jvm(obj) + self.assertEqual((EXPECTED_UNSUPPORTED_MESSAGE, ), e.exception.args) + + with self.subTest("Unsupported"): + with self.assertRaises(RuntimeError) as e: + _get_jvm(object()) + self.assertEqual(("Unsupported class: ", ), e.exception.args) + + @skipUnless(SparkTest.is_spark_connect, "Spark connect client tests") + def test_diff(self): + for label, func in { + 'diff': lambda: self.df.diff(self.df), + 'diff_with_options': lambda: self.df.diff_with_options(self.df, DiffOptions()), + 'diffwith': lambda: self.df.diffwith(self.df), + 'diffwith_with_options': lambda: self.df.diffwith_with_options(self.df, DiffOptions()), + }.items(): + with self.subTest(label): + with self.assertRaises(RuntimeError) as e: + func() + self.assertEqual((EXPECTED_UNSUPPORTED_MESSAGE, ), e.exception.args) + + @skipUnless(SparkTest.is_spark_connect, "Spark connect client tests") + def test_dotnet_ticks(self): + for label, func in { + 'dotnet_ticks_to_timestamp': dotnet_ticks_to_timestamp, + 'dotnet_ticks_to_unix_epoch': dotnet_ticks_to_unix_epoch, + 'dotnet_ticks_to_unix_epoch_nanos': dotnet_ticks_to_unix_epoch_nanos, + 'timestamp_to_dotnet_ticks': timestamp_to_dotnet_ticks, + 'unix_epoch_to_dotnet_ticks': unix_epoch_to_dotnet_ticks, + 'unix_epoch_nanos_to_dotnet_ticks': unix_epoch_nanos_to_dotnet_ticks, + }.items(): + with self.subTest(label): + with self.assertRaises(RuntimeError) as e: + func("id") + self.assertEqual(("This method must be called inside an active Spark session", ), e.exception.args) + + @skipUnless(SparkTest.is_spark_connect, "Spark connect client tests") + def test_histogram(self): + with self.assertRaises(RuntimeError) as e: + self.df.histogram([1, 10, 100], "bin", sum) + self.assertEqual((EXPECTED_UNSUPPORTED_MESSAGE, ), e.exception.args) + + @skipUnless(SparkTest.is_spark_connect, "Spark connect client tests") + def test_with_row_numbers(self): + with self.assertRaises(RuntimeError) as e: + self.df.with_row_numbers() + self.assertEqual((EXPECTED_UNSUPPORTED_MESSAGE, ), e.exception.args) + + @skipUnless(SparkTest.is_spark_connect, "Spark connect client tests") + def test_job_description(self): + with self.assertRaises(RuntimeError) as e: + with job_description("job description"): + pass + self.assertEqual(("This method must be called inside an active Spark session", ), e.exception.args) + + with self.assertRaises(RuntimeError) as e: + with append_description("job description"): + pass + self.assertEqual(("This method must be called inside an active Spark session", ), e.exception.args) + + @skipUnless(SparkTest.is_spark_connect, "Spark connect client tests") + def test_create_temp_dir(self): + with self.assertRaises(RuntimeError) as e: + self.spark.create_temporary_dir("prefix") + self.assertEqual((EXPECTED_UNSUPPORTED_MESSAGE, ), e.exception.args) + + @skipUnless(SparkTest.is_spark_connect, "Spark connect client tests") + def test_install_pip_package(self): + with self.assertRaises(RuntimeError) as e: + self.spark.install_pip_package("pytest") + self.assertEqual((EXPECTED_UNSUPPORTED_MESSAGE, ), e.exception.args) + + @skipUnless(SparkTest.is_spark_connect, "Spark connect client tests") + def test_install_poetry_project(self): + with self.assertRaises(RuntimeError) as e: + self.spark.install_poetry_project("./poetry-project") + self.assertEqual((EXPECTED_UNSUPPORTED_MESSAGE, ), e.exception.args) + + @skipUnless(SparkTest.is_spark_connect, "Spark connect client tests") + def test_parquet(self): + for label, func in { + 'parquet_metadata': lambda dr: dr.parquet_metadata("file.parquet"), + 'parquet_schema': lambda dr: dr.parquet_schema("file.parquet"), + 'parquet_blocks': lambda dr: dr.parquet_blocks("file.parquet"), + 'parquet_block_columns': lambda dr: dr.parquet_block_columns("file.parquet"), + 'parquet_partitions': lambda dr: dr.parquet_partitions("file.parquet"), + }.items(): + with self.subTest(label): + with self.assertRaises(RuntimeError) as e: + func(self.spark.read) + self.assertEqual((EXPECTED_UNSUPPORTED_MESSAGE, ), e.exception.args) diff --git a/python/test/test_package.py b/python/test/test_package.py index 2a496fd3..96d46c06 100644 --- a/python/test/test_package.py +++ b/python/test/test_package.py @@ -105,6 +105,7 @@ def compare_dfs(self, expected, actual): [row.asDict() for row in expected.collect()] ) + @skipIf(SparkTest.is_spark_connect, "Spark Connect does not provide access to the JVM, required by dotnet ticks") def test_dotnet_ticks_to_timestamp(self): for column in ["tick", self.ticks.tick]: with self.subTest(column=column): @@ -112,6 +113,7 @@ def test_dotnet_ticks_to_timestamp(self): expected = self.ticks.join(self.timestamps, "id").orderBy('id') self.compare_dfs(expected, timestamps) + @skipIf(SparkTest.is_spark_connect, "Spark Connect does not provide access to the JVM, required by dotnet ticks") def test_dotnet_ticks_to_unix_epoch(self): for column in ["tick", self.ticks.tick]: with self.subTest(column=column): @@ -119,6 +121,7 @@ def test_dotnet_ticks_to_unix_epoch(self): expected = self.ticks.join(self.unix, "id").orderBy('id') self.compare_dfs(expected, timestamps) + @skipIf(SparkTest.is_spark_connect, "Spark Connect does not provide access to the JVM, required by dotnet ticks") def test_dotnet_ticks_to_unix_epoch_nanos(self): self.maxDiff = None for column in ["tick", self.ticks.tick]: @@ -127,6 +130,7 @@ def test_dotnet_ticks_to_unix_epoch_nanos(self): expected = self.ticks.join(self.unix_nanos, "id").orderBy('id') self.compare_dfs(expected, timestamps) + @skipIf(SparkTest.is_spark_connect, "Spark Connect does not provide access to the JVM, required by dotnet ticks") def test_timestamp_to_dotnet_ticks(self): if self.spark.version.startswith('3.0.'): self.skipTest('timestamp_to_dotnet_ticks not supported by Spark 3.0') @@ -136,6 +140,7 @@ def test_timestamp_to_dotnet_ticks(self): expected = self.timestamps.join(self.ticks_from_timestamp, "id").orderBy('id') self.compare_dfs(expected, timestamps) + @skipIf(SparkTest.is_spark_connect, "Spark Connect does not provide access to the JVM, required by dotnet ticks") def test_unix_epoch_dotnet_ticks(self): for column in ["unix", self.unix.unix]: with self.subTest(column=column): @@ -143,6 +148,7 @@ def test_unix_epoch_dotnet_ticks(self): expected = self.unix.join(self.ticks, "id").orderBy('id') self.compare_dfs(expected, timestamps) + @skipIf(SparkTest.is_spark_connect, "Spark Connect does not provide access to the JVM, required by dotnet ticks") def test_unix_epoch_nanos_to_dotnet_ticks(self): for column in ["unix_nanos", self.unix_nanos.unix_nanos]: with self.subTest(column=column): @@ -159,6 +165,7 @@ def test_count_null(self): ).collect() self.assertEqual([Row(ids=7, nanos=6, null_ids=0, null_nanos=1)], actual) + @skipIf(SparkTest.is_spark_connect, "Spark Connect does not provide access to the JVM, required by create_temp_dir") def test_create_temp_dir(self): from pyspark import SparkFiles @@ -166,6 +173,7 @@ def test_create_temp_dir(self): self.assertTrue(dir.startswith(SparkFiles.getRootDirectory())) @skipIf(__version__.startswith('3.0.'), 'install_pip_package not supported for Spark 3.0') + @skipIf(SparkTest.is_spark_connect, "Spark Connect does not provide access to the JVM, required by install_pip_package") def test_install_pip_package(self): self.spark.sparkContext.setLogLevel("INFO") with self.assertRaises(ImportError): @@ -189,16 +197,19 @@ def test_install_pip_package(self): self.assertEqual(expected, actual) @skipIf(__version__.startswith('3.0.'), 'install_pip_package not supported for Spark 3.0') + @skipIf(SparkTest.is_spark_connect, "Spark Connect does not provide access to the JVM, required by install_pip_package") def test_install_pip_package_unknown_argument(self): with self.assertRaises(CalledProcessError): self.spark.install_pip_package("--unknown", "argument") @skipIf(__version__.startswith('3.0.'), 'install_pip_package not supported for Spark 3.0') + @skipIf(SparkTest.is_spark_connect, "Spark Connect does not provide access to the JVM, required by install_pip_package") def test_install_pip_package_package_not_found(self): with self.assertRaises(CalledProcessError): self.spark.install_pip_package("pyspark-extension==abc") @skipUnless(__version__.startswith('3.0.'), 'install_pip_package not supported for Spark 3.0') + @skipIf(SparkTest.is_spark_connect, "Spark Connect does not provide access to the JVM, required by install_pip_package") def test_install_pip_package_not_supported(self): with self.assertRaises(NotImplementedError): self.spark.install_pip_package("emoji") @@ -209,6 +220,7 @@ def test_install_pip_package_not_supported(self): f'virtual env python with poetry required') @skipIf(RICH_SOURCES_ENV not in os.environ, f'Environment variable {RICH_SOURCES_ENV} pointing to ' f'rich project sources required') + @skipIf(SparkTest.is_spark_connect, "Spark Connect does not provide access to the JVM, required by install_poetry_project") def test_install_poetry_project(self): self.spark.sparkContext.setLogLevel("INFO") with self.assertRaises(ImportError): @@ -244,6 +256,7 @@ def test_install_poetry_project(self): f'virtual env python with poetry required') @skipIf(RICH_SOURCES_ENV not in os.environ, f'Environment variable {RICH_SOURCES_ENV} pointing to ' f'rich project sources required') + @skipIf(SparkTest.is_spark_connect, "Spark Connect does not provide access to the JVM, required by install_poetry_project") def test_install_poetry_project_wrong_arguments(self): rich_path = os.environ[RICH_SOURCES_ENV] poetry_python = os.environ[POETRY_PYTHON_ENV] @@ -254,6 +267,7 @@ def test_install_poetry_project_wrong_arguments(self): self.spark.install_poetry_project(rich_path, poetry_python="non-existing-python") @skipUnless(__version__.startswith('3.0.'), 'install_poetry_project not supported for Spark 3.0') + @skipIf(SparkTest.is_spark_connect, "Spark Connect does not provide access to the JVM, required by install_poetry_project") def test_install_poetry_project_not_supported(self): with self.assertRaises(NotImplementedError): self.spark.install_poetry_project("./rich") diff --git a/python/test/test_parquet.py b/python/test/test_parquet.py index f81f6215..5e12acaa 100644 --- a/python/test/test_parquet.py +++ b/python/test/test_parquet.py @@ -13,11 +13,13 @@ # limitations under the License. from pathlib import Path +from unittest import skipIf from spark_common import SparkTest import gresearch.spark.parquet +@skipIf(SparkTest.is_spark_connect, "Spark Connect does not provide access to the JVM, required by Parquet") class ParquetTest(SparkTest): test_file = str((Path(__file__).parent.parent.parent / "src" / "test" / "files" / "test.parquet").resolve()) diff --git a/python/test/test_row_number.py b/python/test/test_row_number.py index 8a99f937..16d620a0 100644 --- a/python/test/test_row_number.py +++ b/python/test/test_row_number.py @@ -12,12 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +from unittest import skipIf + from pyspark.storagelevel import StorageLevel from spark_common import SparkTest import gresearch.spark +@skipIf(SparkTest.is_spark_connect, "Spark Connect does not provide access to the JVM, required by RowNumber") class RowNumberTest(SparkTest): @classmethod