diff --git a/.github/actions/test-python/action.yml b/.github/actions/test-python/action.yml
index 62436136..1a1a4c36 100644
--- a/.github/actions/test-python/action.yml
+++ b/.github/actions/test-python/action.yml
@@ -15,6 +15,9 @@ inputs:
spark-compat-version:
description: Spark compatibility version, e.g. 3.4
required: true
+ hadoop-version:
+ description: Hadoop version, e.g. 2.7 or 2
+ required: true
scala-compat-version:
description: Scala compatibility version, e.g. 2.12
required: true
@@ -40,6 +43,26 @@ runs:
name: Binaries-${{ inputs.spark-compat-version }}-${{ inputs.scala-compat-version }}
path: .
+ - name: Cache Spark Binaries
+ uses: actions/cache@v4
+ if: inputs.scala-compat-version == '2.12' && ! contains(inputs.spark-version, '-SNAPSHOT')
+ with:
+ path: ~/spark
+ key: ${{ runner.os }}-spark-binaries-${{ inputs.spark-version }}-${{ inputs.scala-compat-version }}
+
+ - name: Setup Spark Binaries
+ if: inputs.scala-compat-version == '2.12' && ! contains(inputs.spark-version, '-SNAPSHOT')
+ env:
+ SPARK_PACKAGE: spark-${{ inputs.spark-version }}/spark-${{ inputs.spark-version }}-bin-hadoop${{ inputs.hadoop-version }}${{ inputs.scala-compat-version == '2.13' && '-scala2.13' || '' }}.tgz
+ run: |
+ if [[ ! -e ~/spark ]]
+ then
+ wget --progress=dot:giga "https://www.apache.org/dyn/closer.lua/spark/${SPARK_PACKAGE}?action=download" -O - | tar -xzC "${{ runner.temp }}"
+ archive=$(basename "${SPARK_PACKAGE}") bash -c "mv -v "${{ runner.temp }}/\${archive/%.tgz/}" ~/spark"
+ fi
+ echo "SPARK_BIN_HOME=$(cd ~/spark; pwd)" >> $GITHUB_ENV
+ shell: bash
+
- name: Cache Maven packages
if: github.event_name != 'merge_group'
uses: actions/cache@v4
@@ -105,6 +128,34 @@ runs:
run: mvn --batch-mode --update-snapshots install -Dspotless.check.skip -DskipTests -Dmaven.test.skip=true -Dgpg.skip
shell: bash
+ - name: Start Spark Connect
+ id: spark-connect
+ if: (inputs.spark-compat-version == '3.4' || inputs.spark-compat-version == '3.5' || startsWith('4.', inputs.spark-compat-version)) && inputs.scala-compat-version == '2.12' && ! contains(inputs.spark-version, '-SNAPSHOT')
+ run: |
+ $SPARK_BIN_HOME/sbin/start-connect-server.sh --packages org.apache.spark:spark-connect_${{ inputs.scala-compat-version }}:${{ inputs.spark-version }}
+ shell: bash
+
+ - name: Python Unit Tests (Spark Connect)
+ if: steps.spark-connect.outcome == 'success'
+ env:
+ PYTHONPATH: python:python/test
+ TEST_SPARK_CONNECT_SERVER: sc://localhost:15002
+ run: |
+ pip install pyspark[connect]
+ python -m pytest python/test --junit-xml test-results-connect/pytest-$(date +%s.%N)-$RANDOM.xml
+ shell: bash
+
+ - name: Stop Spark Connect
+ if: always() && steps.spark-connect.outcome == 'success'
+ run: |
+ $SPARK_BIN_HOME/sbin/stop-connect-server.sh
+ echo "::group::Spark Connect server log"
+ # thoughs started in $SPARK_BIN_HOME/sbin, logs go to $SPARK_HOME/logs
+ ls -lah $SPARK_HOME/logs || true
+ cat $SPARK_HOME/logs/spark-*-org.apache.spark.sql.connect.service.SparkConnectServer-*.out || true
+ echo "::endgroup::"
+ shell: bash
+
- name: Python Integration Tests
env:
PYTHONPATH: python:python/test
@@ -112,7 +163,7 @@ runs:
find python/test -name 'test*.py' > tests
while read test
do
- if ! $SPARK_HOME/bin/spark-submit --master "local[2]" --packages uk.co.gresearch.spark:spark-extension_${{ inputs.scala-compat-version }}:$SPARK_EXTENSION_VERSION "$test" test-results
+ if ! $SPARK_HOME/bin/spark-submit --master "local[2]" --packages uk.co.gresearch.spark:spark-extension_${{ inputs.scala-compat-version }}:$SPARK_EXTENSION_VERSION "$test" test-results-submit
then
state="fail"
fi
@@ -135,7 +186,10 @@ runs:
uses: actions/upload-artifact@v4
with:
name: Python Test Results (Spark ${{ inputs.spark-version }} Scala ${{ inputs.scala-version }} Python ${{ inputs.python-version }})
- path: test-results/*.xml
+ path: |
+ test-results/*.xml
+ test-results-submit/*.xml
+ test-results-connect/*.xml
branding:
icon: 'check-circle'
diff --git a/.github/workflows/test-python.yml b/.github/workflows/test-python.yml
index 444e80f5..a83d5599 100644
--- a/.github/workflows/test-python.yml
+++ b/.github/workflows/test-python.yml
@@ -19,28 +19,34 @@ jobs:
include:
- spark-compat-version: '3.0'
spark-version: '3.0.3'
+ hadoop-version: '2.7'
scala-compat-version: '2.12'
scala-version: '2.12.10'
python-version: '3.8'
- spark-compat-version: '3.1'
spark-version: '3.1.3'
+ hadoop-version: '2.7'
scala-compat-version: '2.12'
scala-version: '2.12.10'
python-version: '3.8'
- spark-compat-version: '3.2'
spark-version: '3.2.4'
+ hadoop-version: '2.7'
scala-compat-version: '2.12'
scala-version: '2.12.15'
- spark-compat-version: '3.3'
spark-version: '3.3.4'
+ hadoop-version: '3'
scala-compat-version: '2.12'
scala-version: '2.12.15'
- spark-compat-version: '3.4'
spark-version: '3.4.2'
+ hadoop-version: '3'
scala-compat-version: '2.12'
scala-version: '2.12.17'
- spark-compat-version: '3.5'
spark-version: '3.5.1'
+ hadoop-version: '3'
scala-compat-version: '2.12'
scala-version: '2.12.18'
@@ -55,4 +61,5 @@ jobs:
scala-version: ${{ matrix.scala-version }}
spark-compat-version: ${{ matrix.spark-compat-version }}
scala-compat-version: ${{ matrix.scala-compat-version }}
+ hadoop-version: ${{ matrix.hadoop-version }}
python-version: ${{ matrix.python-version }}
diff --git a/DIFF.md b/DIFF.md
index c5674c76..5fd04fd3 100644
--- a/DIFF.md
+++ b/DIFF.md
@@ -404,6 +404,8 @@ The latter variant is prefixed with `_with_options`.
* `def diff(self: DataFrame, other: DataFrame, *id_columns: str) -> DataFrame`
* `def diffwith(self: DataFrame, other: DataFrame, *id_columns: str) -> DataFrame:`
+Note that this feature is not supported in Python when connected with a [Spark Connect server](README.md#spark-connect-server).
+
## Diff Spark application
There is also a Spark application that can be used to create a diff DataFrame. The application reads two DataFrames
diff --git a/HISTOGRAM.md b/HISTOGRAM.md
index b8f2948d..8a383498 100644
--- a/HISTOGRAM.md
+++ b/HISTOGRAM.md
@@ -55,3 +55,5 @@ In Python, call:
import gresearch.spark
df.histogram([100, 200], 'user').orderBy('user')
+
+Note that this feature is not supported in Python when connected with a [Spark Connect server](README.md#spark-connect-server).
diff --git a/PARQUET.md b/PARQUET.md
index 75f5b0df..aad88310 100644
--- a/PARQUET.md
+++ b/PARQUET.md
@@ -254,3 +254,7 @@ spark.read.parquet_blocks("/path/to/parquet", parallelism=100)
spark.read.parquet_block_columns("/path/to/parquet", parallelism=100)
spark.read.parquet_partitions("/path/to/parquet", parallelism=100)
```
+
+## Known Issues
+
+Note that this feature is not supported in Python when connected with a [Spark Connect server](README.md#spark-connect-server).
diff --git a/PYSPARK-DEPS.md b/PYSPARK-DEPS.md
index 28b227fc..bfc3fb6a 100644
--- a/PYSPARK-DEPS.md
+++ b/PYSPARK-DEPS.md
@@ -129,3 +129,7 @@ Finally, shutdown the example cluster:
```shell
docker compose -f docker-compose.yml down
```
+
+## Known Issues
+
+Note that this feature is not supported in Python when connected with a [Spark Connect server](README.md#spark-connect-server).
diff --git a/README.md b/README.md
index 27f24aef..8e0052f3 100644
--- a/README.md
+++ b/README.md
@@ -2,27 +2,27 @@
This project provides extensions to the [Apache Spark project](https://spark.apache.org/) in Scala and Python:
-**[Diff](DIFF.md):** A `diff` transformation and application for `Dataset`s that computes the differences between
+**[Diff](DIFF.md) [[*]](#spark-connect-server):** A `diff` transformation and application for `Dataset`s that computes the differences between
two datasets, i.e. which rows to _add_, _delete_ or _change_ to get from one dataset to the other.
**[SortedGroups](GROUPS.md):** A `groupByKey` transformation that groups rows by a key while providing
a **sorted** iterator for each group. Similar to `Dataset.groupByKey.flatMapGroups`, but with order guarantees
for the iterator.
-**[Histogram](HISTOGRAM.md):** A `histogram` transformation that computes the histogram DataFrame for a value column.
+**[Histogram](HISTOGRAM.md) [[*]](#spark-connect-server):** A `histogram` transformation that computes the histogram DataFrame for a value column.
-**[Global Row Number](ROW_NUMBER.md):** A `withRowNumbers` transformation that provides the global row number w.r.t.
+**[Global Row Number](ROW_NUMBER.md) [[*]](#spark-connect-server):** A `withRowNumbers` transformation that provides the global row number w.r.t.
the current order of the Dataset, or any given order. In contrast to the existing SQL function `row_number`, which
requires a window spec, this transformation provides the row number across the entire Dataset without scaling problems.
**[Partitioned Writing](PARTITIONING.md):** The `writePartitionedBy` action writes your `Dataset` partitioned and
efficiently laid out with a single operation.
-**[Inspect Parquet files](PARQUET.md):** The structure of Parquet files (the metadata, not the data stored in Parquet) can be inspected similar to [parquet-tools](https://pypi.org/project/parquet-tools/)
+**[Inspect Parquet files](PARQUET.md) [[*]](#spark-connect-server):** The structure of Parquet files (the metadata, not the data stored in Parquet) can be inspected similar to [parquet-tools](https://pypi.org/project/parquet-tools/)
or [parquet-cli](https://pypi.org/project/parquet-cli/) by reading from a simple Spark data source.
This simplifies identifying why some Parquet files cannot be split by Spark into scalable partitions.
-**[Install Python packages into PySpark job](PYSPARK-DEPS.md):** Install Python dependencies via PIP or Poetry programatically into your running PySpark job (PySpark ≥ 3.1.0):
+**[Install Python packages into PySpark job](PYSPARK-DEPS.md) [[*]](#spark-connect-server):** Install Python dependencies via PIP or Poetry programatically into your running PySpark job (PySpark ≥ 3.1.0):
```python
# noinspection PyUnresolvedReferences
@@ -84,7 +84,7 @@ This is a handy way to ensure column names with special characters like dots (`.
**Count null values:** `count_null(e: Column)`: an aggregation function like `count` that counts null values in column `e`.
This is equivalent to calling `count(when(e.isNull, lit(1)))`.
-**.Net DateTime.Ticks:** Convert .Net (C#, F#, Visual Basic) `DateTime.Ticks` into Spark timestamps, seconds and nanoseconds.
+**.Net DateTime.Ticks[[*]](#spark-connect-server):** Convert .Net (C#, F#, Visual Basic) `DateTime.Ticks` into Spark timestamps, seconds and nanoseconds.
Available methods:
@@ -117,7 +117,7 @@ unix_epoch_nanos_to_dotnet_ticks(column_or_name)
```
-**Spark temporary directory**: Create a temporary directory that will be removed on Spark application shutdown.
+**Spark temporary directory[[*]](#spark-connect-server)**: Create a temporary directory that will be removed on Spark application shutdown.
Examples:
@@ -138,7 +138,7 @@ dir = spark.create_temporary_dir("prefix")
```
-**Spark job description:** Set Spark job description for all Spark jobs within a context.
+**Spark job description[[*]](#spark-connect-server):** Set Spark job description for all Spark jobs within a context.
Examples:
@@ -306,6 +306,18 @@ on a filesystem where it is accessible by the notebook, and reference that jar f
Check the documentation of your favorite notebook to learn how to add jars to your Spark environment.
+## Known issues
+### Spark Connect Server
+
+Most features are not supported **in Python** in conjunction with a [Spark Connect server](https://spark.apache.org/docs/latest/spark-connect-overview.html).
+This also holds for Databricks Runtime environment 13.x and above. Details can be found [in this blog](https://semyonsinchenko.github.io/ssinchenko/post/how-databricks-14x-breaks-3dparty-compatibility/).
+
+Calling any of those features when connected to a Spark Connect server will raise this error:
+
+ This feature is not supported for Spark Connect.
+
+Use a classic connection to a Spark cluster instead.
+
## Build
You can build this project against different versions of Spark and Scala.
diff --git a/ROW_NUMBER.md b/ROW_NUMBER.md
index 02832367..2bd31f47 100644
--- a/ROW_NUMBER.md
+++ b/ROW_NUMBER.md
@@ -216,3 +216,7 @@ WindowExec: No Partition Defined for Window operation! Moving all data to a sing
```
This warning is unavoidable, because `withRowNumbers` has to pull information about the initial partitions into a single partition.
Fortunately, there are only 12 Bytes per input partition required, so this amount of data usually fits into a single partition and the warning can safely be ignored.
+
+## Known issues
+
+Note that this feature is not supported in Python when connected with a [Spark Connect server](README.md#spark-connect-server).
diff --git a/python/gresearch/spark/__init__.py b/python/gresearch/spark/__init__.py
index 15cb4bed..6ff299c5 100644
--- a/python/gresearch/spark/__init__.py
+++ b/python/gresearch/spark/__init__.py
@@ -27,17 +27,46 @@
from pyspark import __version__
from pyspark.context import SparkContext
from pyspark.files import SparkFiles
-from pyspark.sql import DataFrame
+from pyspark.sql import DataFrame, DataFrameReader, SQLContext
from pyspark.sql.column import Column, _to_java_column
from pyspark.sql.context import SQLContext
from pyspark.sql.functions import col, count, lit, when
from pyspark.sql.session import SparkSession
from pyspark.storagelevel import StorageLevel
+try:
+ from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame
+ from pyspark.sql.connect.readwriter import DataFrameReader as ConnectDataFrameReader
+ from pyspark.sql.connect.session import SparkSession as ConnectSparkSession
+ has_connect = True
+except ImportError:
+ has_connect = False
+
if TYPE_CHECKING:
from pyspark.sql._typing import ColumnOrName
+def _get_jvm(obj: Any) -> JVMView:
+ if obj is None:
+ if SparkContext._active_spark_context is None:
+ raise RuntimeError("This method must be called inside an active Spark session")
+ else:
+ raise ValueError("Cannot provide access to JVM from None")
+
+ # helper method to assert the JVM is accessible and provide a useful error message
+ if has_connect and isinstance(obj, (ConnectDataFrame, ConnectDataFrameReader, ConnectSparkSession)):
+ raise RuntimeError('This feature is not supported for Spark Connect. Please use a classic Spark client. https://github.com/G-Research/spark-extension#spark-connect-server')
+ if isinstance(obj, DataFrame):
+ return _get_jvm(obj._sc)
+ if isinstance(obj, DataFrameReader):
+ return _get_jvm(obj._spark)
+ if isinstance(obj, SparkSession):
+ return _get_jvm(obj.sparkContext)
+ if isinstance(obj, (SparkContext, SQLContext)):
+ return obj._jvm
+ raise RuntimeError(f'Unsupported class: {type(obj)}')
+
+
def _to_seq(jvm: JVMView, list: List[Any]) -> JavaObject:
array = jvm.java.util.ArrayList(list)
return jvm.scala.collection.JavaConverters.asScalaIteratorConverter(array.iterator()).asScala().toSeq()
@@ -72,11 +101,8 @@ def dotnet_ticks_to_timestamp(tick_column: Union[str, Column]) -> Column:
if not isinstance(tick_column, (str, Column)):
raise ValueError(f"Given column must be a column name (str) or column instance (Column): {type(tick_column)}")
- sc = SparkContext._active_spark_context
- if sc is None or sc._jvm is None:
- raise RuntimeError("This method must be called inside an active Spark session")
-
- func = sc._jvm.uk.co.gresearch.spark.__getattr__("package$").__getattr__("MODULE$").dotNetTicksToTimestamp
+ jvm = _get_jvm(SparkContext._active_spark_context)
+ func = jvm.uk.co.gresearch.spark.__getattr__("package$").__getattr__("MODULE$").dotNetTicksToTimestamp
return Column(func(_to_java_column(tick_column)))
@@ -105,11 +131,8 @@ def dotnet_ticks_to_unix_epoch(tick_column: Union[str, Column]) -> Column:
if not isinstance(tick_column, (str, Column)):
raise ValueError(f"Given column must be a column name (str) or column instance (Column): {type(tick_column)}")
- sc = SparkContext._active_spark_context
- if sc is None or sc._jvm is None:
- raise RuntimeError("This method must be called inside an active Spark session")
-
- func = sc._jvm.uk.co.gresearch.spark.__getattr__("package$").__getattr__("MODULE$").dotNetTicksToUnixEpoch
+ jvm = _get_jvm(SparkContext._active_spark_context)
+ func = jvm.uk.co.gresearch.spark.__getattr__("package$").__getattr__("MODULE$").dotNetTicksToUnixEpoch
return Column(func(_to_java_column(tick_column)))
@@ -138,11 +161,8 @@ def dotnet_ticks_to_unix_epoch_nanos(tick_column: Union[str, Column]) -> Column:
if not isinstance(tick_column, (str, Column)):
raise ValueError(f"Given column must be a column name (str) or column instance (Column): {type(tick_column)}")
- sc = SparkContext._active_spark_context
- if sc is None or sc._jvm is None:
- raise RuntimeError("This method must be called inside an active Spark session")
-
- func = sc._jvm.uk.co.gresearch.spark.__getattr__("package$").__getattr__("MODULE$").dotNetTicksToUnixEpochNanos
+ jvm = _get_jvm(SparkContext._active_spark_context)
+ func = jvm.uk.co.gresearch.spark.__getattr__("package$").__getattr__("MODULE$").dotNetTicksToUnixEpochNanos
return Column(func(_to_java_column(tick_column)))
@@ -170,11 +190,8 @@ def timestamp_to_dotnet_ticks(timestamp_column: Union[str, Column]) -> Column:
if not isinstance(timestamp_column, (str, Column)):
raise ValueError(f"Given column must be a column name (str) or column instance (Column): {type(timestamp_column)}")
- sc = SparkContext._active_spark_context
- if sc is None or sc._jvm is None:
- raise RuntimeError("This method must be called inside an active Spark session")
-
- func = sc._jvm.uk.co.gresearch.spark.__getattr__("package$").__getattr__("MODULE$").timestampToDotNetTicks
+ jvm = _get_jvm(SparkContext._active_spark_context)
+ func = jvm.uk.co.gresearch.spark.__getattr__("package$").__getattr__("MODULE$").timestampToDotNetTicks
return Column(func(_to_java_column(timestamp_column)))
@@ -204,11 +221,8 @@ def unix_epoch_to_dotnet_ticks(unix_column: Union[str, Column]) -> Column:
if not isinstance(unix_column, (str, Column)):
raise ValueError(f"Given column must be a column name (str) or column instance (Column): {type(unix_column)}")
- sc = SparkContext._active_spark_context
- if sc is None or sc._jvm is None:
- raise RuntimeError("This method must be called inside an active Spark session")
-
- func = sc._jvm.uk.co.gresearch.spark.__getattr__("package$").__getattr__("MODULE$").unixEpochToDotNetTicks
+ jvm = _get_jvm(SparkContext._active_spark_context)
+ func = jvm.uk.co.gresearch.spark.__getattr__("package$").__getattr__("MODULE$").unixEpochToDotNetTicks
return Column(func(_to_java_column(unix_column)))
@@ -239,11 +253,8 @@ def unix_epoch_nanos_to_dotnet_ticks(unix_column: Union[str, Column]) -> Column:
if not isinstance(unix_column, (str, Column)):
raise ValueError(f"Given column must be a column name (str) or column instance (Column): {type(unix_column)}")
- sc = SparkContext._active_spark_context
- if sc is None or sc._jvm is None:
- raise RuntimeError("This method must be called inside an active Spark session")
-
- func = sc._jvm.uk.co.gresearch.spark.__getattr__("package$").__getattr__("MODULE$").unixEpochNanosToDotNetTicks
+ jvm = _get_jvm(SparkContext._active_spark_context)
+ func = jvm.uk.co.gresearch.spark.__getattr__("package$").__getattr__("MODULE$").unixEpochNanosToDotNetTicks
return Column(func(_to_java_column(unix_column)))
@@ -281,7 +292,7 @@ def histogram(self: DataFrame,
else:
raise ValueError('thresholds must be int or floats: {}'.format(t))
- jvm = self._sc._jvm
+ jvm = _get_jvm(self)
col = jvm.org.apache.spark.sql.functions.col
value_column = col(value_column)
aggregate_columns = [col(column) for column in aggregate_columns]
@@ -292,6 +303,8 @@ def histogram(self: DataFrame,
DataFrame.histogram = histogram
+if has_connect:
+ ConnectDataFrame.histogram = histogram
class UnpersistHandle:
@@ -307,7 +320,7 @@ def __call__(self, blocking: Optional[bool] = None):
def unpersist_handle(self: SparkSession) -> UnpersistHandle:
- jvm = self._sc._jvm
+ jvm = _get_jvm(self)
handle = jvm.uk.co.gresearch.spark.UnpersistHandle()
return UnpersistHandle(handle)
@@ -321,7 +334,7 @@ def with_row_numbers(self: DataFrame,
row_number_column_name: str = "row_number",
order: Union[str, Column, List[Union[str, Column]]] = [],
ascending: Union[bool, List[bool]] = True) -> DataFrame:
- jvm = self._sc._jvm
+ jvm = _get_jvm(self)
jsl = self._sc._getJavaStorageLevel(storage_level)
juho = jvm.uk.co.gresearch.spark.UnpersistHandle
juh = unpersist_handle._handle if unpersist_handle else juho.Noop()
@@ -338,17 +351,23 @@ def with_row_numbers(self: DataFrame,
return DataFrame(jdf, self.session_or_ctx())
+DataFrame.with_row_numbers = with_row_numbers
+if has_connect:
+ ConnectDataFrame.with_row_numbers = with_row_numbers
+
+
def session_or_ctx(self: DataFrame) -> Union[SparkSession, SQLContext]:
return self.sparkSession if hasattr(self, 'sparkSession') else self.sql_ctx
-DataFrame.with_row_numbers = with_row_numbers
DataFrame.session_or_ctx = session_or_ctx
+if has_connect:
+ ConnectDataFrame.session_or_ctx = session_or_ctx
def set_description(description: str, if_not_set: bool = False):
context = SparkContext._active_spark_context
- jvm = context._jvm
+ jvm = _get_jvm(context)
spark_package = jvm.uk.co.gresearch.spark.__getattr__("package$").__getattr__("MODULE$")
return spark_package.setJobDescription(description, if_not_set, context._jsc.sc())
@@ -384,7 +403,7 @@ def job_description(description: str, if_not_set: bool = False):
def append_description(extra_description: str, separator: str = " - "):
context = SparkContext._active_spark_context
- jvm = context._jvm
+ jvm = _get_jvm(context)
spark_package = jvm.uk.co.gresearch.spark.__getattr__("package$").__getattr__("MODULE$")
return spark_package.appendJobDescription(extra_description, separator, context._jsc.sc())
@@ -424,21 +443,25 @@ def create_temporary_dir(spark: Union[SparkSession, SparkContext], prefix: str)
:param prefix: prefix string of temporary directory name
:return: absolute path of temporary directory
"""
- if isinstance(spark, SparkSession):
- spark = spark.sparkContext
-
- root_dir = spark._jvm.org.apache.spark.SparkFiles.getRootDirectory()
+ jvm = _get_jvm(spark)
+ root_dir = jvm.org.apache.spark.SparkFiles.getRootDirectory()
return tempfile.mkdtemp(prefix=prefix, dir=root_dir)
SparkSession.create_temporary_dir = create_temporary_dir
SparkContext.create_temporary_dir = create_temporary_dir
+if has_connect:
+ ConnectSparkSession.create_temporary_dir = create_temporary_dir
+
def install_pip_package(spark: Union[SparkSession, SparkContext], *package_or_pip_option: str) -> None:
if __version__.startswith('2.') or __version__.startswith('3.0.'):
raise NotImplementedError(f'Not supported for PySpark __version__')
+ # just here to assert JVM is accessible
+ _get_jvm(spark)
+
if isinstance(spark, SparkSession):
spark = spark.sparkContext
@@ -465,6 +488,9 @@ def install_pip_package(spark: Union[SparkSession, SparkContext], *package_or_pi
SparkSession.install_pip_package = install_pip_package
SparkContext.install_pip_package = install_pip_package
+if has_connect:
+ ConnectSparkSession.install_pip_package = install_pip_package
+
def install_poetry_project(spark: Union[SparkSession, SparkContext],
*project: str,
@@ -478,6 +504,9 @@ def install_poetry_project(spark: Union[SparkSession, SparkContext],
if __version__.startswith('2.') or __version__.startswith('3.0.'):
raise NotImplementedError(f'Not supported for PySpark __version__')
+ # just here to assert JVM is accessible
+ _get_jvm(spark)
+
if isinstance(spark, SparkSession):
spark = spark.sparkContext
if poetry_python is None:
@@ -538,3 +567,6 @@ def build_wheel(project: Path) -> Path:
SparkSession.install_poetry_project = install_poetry_project
SparkContext.install_poetry_project = install_poetry_project
+
+if has_connect:
+ ConnectSparkSession.install_poetry_project = install_poetry_project
diff --git a/python/gresearch/spark/diff/__init__.py b/python/gresearch/spark/diff/__init__.py
index 7776a743..351dc587 100644
--- a/python/gresearch/spark/diff/__init__.py
+++ b/python/gresearch/spark/diff/__init__.py
@@ -20,9 +20,15 @@
from pyspark.sql import DataFrame
from pyspark.sql.types import DataType
-from gresearch.spark import _to_seq, _to_map
+from gresearch.spark import _get_jvm, _to_seq, _to_map
from gresearch.spark.diff.comparator import DiffComparator, DiffComparators, DefaultDiffComparator
+try:
+ from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame
+ has_connect = True
+except ImportError:
+ has_connect = False
+
class DiffMode(Enum):
ColumnByColumn = "ColumnByColumn"
@@ -333,7 +339,7 @@ def diff(self, left: DataFrame, right: DataFrame, *id_columns: str) -> DataFrame
:return: the diff DataFrame
:rtype DataFrame
"""
- jvm = left._sc._jvm
+ jvm = _get_jvm(left)
jdiffer = self._to_java(jvm)
jdf = jdiffer.diff(left._jdf, right._jdf, _to_seq(jvm, list(id_columns)))
return DataFrame(jdf, left.session_or_ctx())
@@ -354,7 +360,7 @@ def diffwith(self, left: DataFrame, right: DataFrame, *id_columns: str) -> DataF
:return: the diff DataFrame
:rtype DataFrame
"""
- jvm = left._sc._jvm
+ jvm = _get_jvm(left)
jdiffer = self._to_java(jvm)
jdf = jdiffer.diffWith(left._jdf, right._jdf, _to_seq(jvm, list(id_columns)))
df = DataFrame(jdf, left.sql_ctx)
@@ -489,6 +495,11 @@ def diffwith_with_options(self: DataFrame, other: DataFrame, options: DiffOption
DataFrame.diff = diff
DataFrame.diffwith = diffwith
-
DataFrame.diff_with_options = diff_with_options
DataFrame.diffwith_with_options = diffwith_with_options
+
+if has_connect:
+ ConnectDataFrame.diff = diff
+ ConnectDataFrame.diffwith = diffwith
+ ConnectDataFrame.diff_with_options = diff_with_options
+ ConnectDataFrame.diffwith_with_options = diffwith_with_options
diff --git a/python/gresearch/spark/parquet/__init__.py b/python/gresearch/spark/parquet/__init__.py
index f6dd137f..890f22f0 100644
--- a/python/gresearch/spark/parquet/__init__.py
+++ b/python/gresearch/spark/parquet/__init__.py
@@ -17,11 +17,17 @@
from py4j.java_gateway import JavaObject
from pyspark.sql import DataFrameReader, DataFrame
-from gresearch.spark import _to_seq
+from gresearch.spark import _get_jvm, _to_seq
+
+try:
+ from pyspark.sql.connect.readwriter import DataFrameReader as ConnectDataFrameReader
+ has_connect = True
+except ImportError:
+ has_connect = False
def _jreader(reader: DataFrameReader) -> JavaObject:
- jvm = reader._spark._jvm
+ jvm = _get_jvm(reader)
return jvm.uk.co.gresearch.spark.parquet.__getattr__("package$").__getattr__("MODULE$").ExtendedDataFrameReader(reader._jreader)
@@ -52,7 +58,7 @@ def parquet_metadata(self: DataFrameReader, *paths: str, parallelism: Optional[i
:param parallelism: number of partitions of returned DataFrame
:return: dataframe with Parquet metadata
"""
- jvm = self._spark._jvm
+ jvm = _get_jvm(self)
if parallelism is None:
jdf = _jreader(self).parquetMetadata(_to_seq(jvm, list(paths)))
else:
@@ -87,7 +93,7 @@ def parquet_schema(self: DataFrameReader, *paths: str, parallelism: Optional[int
:param parallelism: number of partitions of returned DataFrame
:return: dataframe with Parquet metadata
"""
- jvm = self._spark._jvm
+ jvm = _get_jvm(self)
if parallelism is None:
jdf = _jreader(self).parquetSchema(_to_seq(jvm, list(paths)))
else:
@@ -119,7 +125,7 @@ def parquet_blocks(self: DataFrameReader, *paths: str, parallelism: Optional[int
:param parallelism: number of partitions of returned DataFrame
:return: dataframe with Parquet metadata
"""
- jvm = self._spark._jvm
+ jvm = _get_jvm(self)
if parallelism is None:
jdf = _jreader(self).parquetBlocks(_to_seq(jvm, list(paths)))
else:
@@ -155,7 +161,7 @@ def parquet_block_columns(self: DataFrameReader, *paths: str, parallelism: Optio
:param parallelism: number of partitions of returned DataFrame
:return: dataframe with Parquet metadata
"""
- jvm = self._spark._jvm
+ jvm = _get_jvm(self)
if parallelism is None:
jdf = _jreader(self).parquetBlockColumns(_to_seq(jvm, list(paths)))
else:
@@ -191,7 +197,7 @@ def parquet_partitions(self: DataFrameReader, *paths: str, parallelism: Optional
:param parallelism: number of partitions of returned DataFrame
:return: dataframe with Parquet metadata
"""
- jvm = self._spark._jvm
+ jvm = _get_jvm(self)
if parallelism is None:
jdf = _jreader(self).parquetPartitions(_to_seq(jvm, list(paths)))
else:
@@ -204,3 +210,10 @@ def parquet_partitions(self: DataFrameReader, *paths: str, parallelism: Optional
DataFrameReader.parquet_blocks = parquet_blocks
DataFrameReader.parquet_block_columns = parquet_block_columns
DataFrameReader.parquet_partitions = parquet_partitions
+
+if has_connect:
+ ConnectDataFrameReader.parquet_metadata = parquet_metadata
+ ConnectDataFrameReader.parquet_schema = parquet_schema
+ ConnectDataFrameReader.parquet_blocks = parquet_blocks
+ ConnectDataFrameReader.parquet_block_columns = parquet_block_columns
+ ConnectDataFrameReader.parquet_partitions = parquet_partitions
diff --git a/python/test/requirements.txt b/python/test/requirements.txt
index 7056a990..cefc8f91 100644
--- a/python/test/requirements.txt
+++ b/python/test/requirements.txt
@@ -1,4 +1,5 @@
-pandas
-pyarrow
+grpcio>=1.48.1
+pandas>=1.0.5
+pyarrow>=4.0.0
pytest
unittest-xml-reporting
diff --git a/python/test/spark_common.py b/python/test/spark_common.py
index 019cce45..76735beb 100644
--- a/python/test/spark_common.py
+++ b/python/test/spark_common.py
@@ -80,7 +80,9 @@ def get_spark_config(path) -> SparkConf:
def get_spark_session(cls) -> SparkSession:
builder = SparkSession.builder
- if 'PYSPARK_GATEWAY_PORT' in os.environ:
+ if 'TEST_SPARK_CONNECT_SERVER' in os.environ:
+ builder.remote(os.environ['TEST_SPARK_CONNECT_SERVER'])
+ elif 'PYSPARK_GATEWAY_PORT' in os.environ:
logging.info('Running inside existing Spark environment')
else:
logging.info('Setting up Spark environment')
@@ -91,6 +93,7 @@ def get_spark_session(cls) -> SparkSession:
return builder.getOrCreate()
spark: SparkSession = None
+ is_spark_connect: bool = 'TEST_SPARK_CONNECT_SERVER' in os.environ
@classmethod
def setUpClass(cls):
diff --git a/python/test/test_diff.py b/python/test/test_diff.py
index f34607a9..328612e8 100644
--- a/python/test/test_diff.py
+++ b/python/test/test_diff.py
@@ -18,11 +18,13 @@
from pyspark.sql import Row
from pyspark.sql.functions import col, when
from pyspark.sql.types import IntegerType, LongType, StringType, DateType
+from unittest import skipIf
from gresearch.spark.diff import Differ, DiffOptions, DiffMode, DiffComparators
from spark_common import SparkTest
+@skipIf(SparkTest.is_spark_connect, "Spark Connect does not provide access to the JVM, required by Diff")
class DiffTest(SparkTest):
expected_diff = None
diff --git a/python/test/test_histogram.py b/python/test/test_histogram.py
index 0f58d848..01f92196 100644
--- a/python/test/test_histogram.py
+++ b/python/test/test_histogram.py
@@ -12,10 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from unittest import skipIf
+
from spark_common import SparkTest
import gresearch.spark
+@skipIf(SparkTest.is_spark_connect, "Spark Connect does not provide access to the JVM, required by Historgam")
class HistogramTest(SparkTest):
@classmethod
diff --git a/python/test/test_job_description.py b/python/test/test_job_description.py
index 15630a43..df72faee 100644
--- a/python/test/test_job_description.py
+++ b/python/test/test_job_description.py
@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from unittest import skipIf
+
from pyspark import TaskContext, SparkContext
from typing import Optional
@@ -19,6 +21,7 @@
from gresearch.spark import job_description, append_job_description
+@skipIf(SparkTest.is_spark_connect, "Spark Connect does not provide access to the JVM, required by JobDescription")
class JobDescriptionTest(SparkTest):
def _assert_job_description(self, expected: Optional[str]):
diff --git a/python/test/test_jvm.py b/python/test/test_jvm.py
new file mode 100644
index 00000000..45ddc968
--- /dev/null
+++ b/python/test/test_jvm.py
@@ -0,0 +1,143 @@
+# Copyright 2024 G-Research
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from unittest import skipIf, skipUnless
+
+from pyspark.sql.functions import sum
+
+from gresearch.spark import _get_jvm, dotnet_ticks_to_timestamp, dotnet_ticks_to_unix_epoch, dotnet_ticks_to_unix_epoch_nanos, \
+ timestamp_to_dotnet_ticks, unix_epoch_to_dotnet_ticks, unix_epoch_nanos_to_dotnet_ticks, histogram, job_description, append_description
+from gresearch.spark.diff import *
+from gresearch.spark.parquet import *
+from spark_common import SparkTest
+
+EXPECTED_UNSUPPORTED_MESSAGE = "This feature is not supported for Spark Connect. Please use a classic Spark client. " \
+ "https://github.com/G-Research/spark-extension#spark-connect-server"
+
+
+class PackageTest(SparkTest):
+ df = None
+
+ @classmethod
+ def setUpClass(cls):
+ super(PackageTest, cls).setUpClass()
+ cls.df = cls.spark.createDataFrame([(1, "one"), (2, "two"), (3, "three")], ["id", "value"])
+
+ @skipIf(SparkTest.is_spark_connect, "Spark classic client tests")
+ def test_get_jvm_classic(self):
+ for obj in [self.spark, self.spark.sparkContext, self.df, self.spark.read]:
+ with self.subTest(type(obj).__name__):
+ self.assertIsNotNone(_get_jvm(obj))
+
+ with self.subTest("Unsupported"):
+ with self.assertRaises(RuntimeError) as e:
+ _get_jvm(object())
+ self.assertEqual(("Unsupported class: ", ), e.exception.args)
+
+ @skipUnless(SparkTest.is_spark_connect, "Spark connect client tests")
+ def test_get_jvm_connect(self):
+ for obj in [self.spark, self.df, self.spark.read]:
+ with self.subTest(type(obj).__name__):
+ with self.assertRaises(RuntimeError) as e:
+ _get_jvm(obj)
+ self.assertEqual((EXPECTED_UNSUPPORTED_MESSAGE, ), e.exception.args)
+
+ with self.subTest("Unsupported"):
+ with self.assertRaises(RuntimeError) as e:
+ _get_jvm(object())
+ self.assertEqual(("Unsupported class: ", ), e.exception.args)
+
+ @skipUnless(SparkTest.is_spark_connect, "Spark connect client tests")
+ def test_diff(self):
+ for label, func in {
+ 'diff': lambda: self.df.diff(self.df),
+ 'diff_with_options': lambda: self.df.diff_with_options(self.df, DiffOptions()),
+ 'diffwith': lambda: self.df.diffwith(self.df),
+ 'diffwith_with_options': lambda: self.df.diffwith_with_options(self.df, DiffOptions()),
+ }.items():
+ with self.subTest(label):
+ with self.assertRaises(RuntimeError) as e:
+ func()
+ self.assertEqual((EXPECTED_UNSUPPORTED_MESSAGE, ), e.exception.args)
+
+ @skipUnless(SparkTest.is_spark_connect, "Spark connect client tests")
+ def test_dotnet_ticks(self):
+ for label, func in {
+ 'dotnet_ticks_to_timestamp': dotnet_ticks_to_timestamp,
+ 'dotnet_ticks_to_unix_epoch': dotnet_ticks_to_unix_epoch,
+ 'dotnet_ticks_to_unix_epoch_nanos': dotnet_ticks_to_unix_epoch_nanos,
+ 'timestamp_to_dotnet_ticks': timestamp_to_dotnet_ticks,
+ 'unix_epoch_to_dotnet_ticks': unix_epoch_to_dotnet_ticks,
+ 'unix_epoch_nanos_to_dotnet_ticks': unix_epoch_nanos_to_dotnet_ticks,
+ }.items():
+ with self.subTest(label):
+ with self.assertRaises(RuntimeError) as e:
+ func("id")
+ self.assertEqual(("This method must be called inside an active Spark session", ), e.exception.args)
+
+ @skipUnless(SparkTest.is_spark_connect, "Spark connect client tests")
+ def test_histogram(self):
+ with self.assertRaises(RuntimeError) as e:
+ self.df.histogram([1, 10, 100], "bin", sum)
+ self.assertEqual((EXPECTED_UNSUPPORTED_MESSAGE, ), e.exception.args)
+
+ @skipUnless(SparkTest.is_spark_connect, "Spark connect client tests")
+ def test_with_row_numbers(self):
+ with self.assertRaises(RuntimeError) as e:
+ self.df.with_row_numbers()
+ self.assertEqual((EXPECTED_UNSUPPORTED_MESSAGE, ), e.exception.args)
+
+ @skipUnless(SparkTest.is_spark_connect, "Spark connect client tests")
+ def test_job_description(self):
+ with self.assertRaises(RuntimeError) as e:
+ with job_description("job description"):
+ pass
+ self.assertEqual(("This method must be called inside an active Spark session", ), e.exception.args)
+
+ with self.assertRaises(RuntimeError) as e:
+ with append_description("job description"):
+ pass
+ self.assertEqual(("This method must be called inside an active Spark session", ), e.exception.args)
+
+ @skipUnless(SparkTest.is_spark_connect, "Spark connect client tests")
+ def test_create_temp_dir(self):
+ with self.assertRaises(RuntimeError) as e:
+ self.spark.create_temporary_dir("prefix")
+ self.assertEqual((EXPECTED_UNSUPPORTED_MESSAGE, ), e.exception.args)
+
+ @skipUnless(SparkTest.is_spark_connect, "Spark connect client tests")
+ def test_install_pip_package(self):
+ with self.assertRaises(RuntimeError) as e:
+ self.spark.install_pip_package("pytest")
+ self.assertEqual((EXPECTED_UNSUPPORTED_MESSAGE, ), e.exception.args)
+
+ @skipUnless(SparkTest.is_spark_connect, "Spark connect client tests")
+ def test_install_poetry_project(self):
+ with self.assertRaises(RuntimeError) as e:
+ self.spark.install_poetry_project("./poetry-project")
+ self.assertEqual((EXPECTED_UNSUPPORTED_MESSAGE, ), e.exception.args)
+
+ @skipUnless(SparkTest.is_spark_connect, "Spark connect client tests")
+ def test_parquet(self):
+ for label, func in {
+ 'parquet_metadata': lambda dr: dr.parquet_metadata("file.parquet"),
+ 'parquet_schema': lambda dr: dr.parquet_schema("file.parquet"),
+ 'parquet_blocks': lambda dr: dr.parquet_blocks("file.parquet"),
+ 'parquet_block_columns': lambda dr: dr.parquet_block_columns("file.parquet"),
+ 'parquet_partitions': lambda dr: dr.parquet_partitions("file.parquet"),
+ }.items():
+ with self.subTest(label):
+ with self.assertRaises(RuntimeError) as e:
+ func(self.spark.read)
+ self.assertEqual((EXPECTED_UNSUPPORTED_MESSAGE, ), e.exception.args)
diff --git a/python/test/test_package.py b/python/test/test_package.py
index 2a496fd3..96d46c06 100644
--- a/python/test/test_package.py
+++ b/python/test/test_package.py
@@ -105,6 +105,7 @@ def compare_dfs(self, expected, actual):
[row.asDict() for row in expected.collect()]
)
+ @skipIf(SparkTest.is_spark_connect, "Spark Connect does not provide access to the JVM, required by dotnet ticks")
def test_dotnet_ticks_to_timestamp(self):
for column in ["tick", self.ticks.tick]:
with self.subTest(column=column):
@@ -112,6 +113,7 @@ def test_dotnet_ticks_to_timestamp(self):
expected = self.ticks.join(self.timestamps, "id").orderBy('id')
self.compare_dfs(expected, timestamps)
+ @skipIf(SparkTest.is_spark_connect, "Spark Connect does not provide access to the JVM, required by dotnet ticks")
def test_dotnet_ticks_to_unix_epoch(self):
for column in ["tick", self.ticks.tick]:
with self.subTest(column=column):
@@ -119,6 +121,7 @@ def test_dotnet_ticks_to_unix_epoch(self):
expected = self.ticks.join(self.unix, "id").orderBy('id')
self.compare_dfs(expected, timestamps)
+ @skipIf(SparkTest.is_spark_connect, "Spark Connect does not provide access to the JVM, required by dotnet ticks")
def test_dotnet_ticks_to_unix_epoch_nanos(self):
self.maxDiff = None
for column in ["tick", self.ticks.tick]:
@@ -127,6 +130,7 @@ def test_dotnet_ticks_to_unix_epoch_nanos(self):
expected = self.ticks.join(self.unix_nanos, "id").orderBy('id')
self.compare_dfs(expected, timestamps)
+ @skipIf(SparkTest.is_spark_connect, "Spark Connect does not provide access to the JVM, required by dotnet ticks")
def test_timestamp_to_dotnet_ticks(self):
if self.spark.version.startswith('3.0.'):
self.skipTest('timestamp_to_dotnet_ticks not supported by Spark 3.0')
@@ -136,6 +140,7 @@ def test_timestamp_to_dotnet_ticks(self):
expected = self.timestamps.join(self.ticks_from_timestamp, "id").orderBy('id')
self.compare_dfs(expected, timestamps)
+ @skipIf(SparkTest.is_spark_connect, "Spark Connect does not provide access to the JVM, required by dotnet ticks")
def test_unix_epoch_dotnet_ticks(self):
for column in ["unix", self.unix.unix]:
with self.subTest(column=column):
@@ -143,6 +148,7 @@ def test_unix_epoch_dotnet_ticks(self):
expected = self.unix.join(self.ticks, "id").orderBy('id')
self.compare_dfs(expected, timestamps)
+ @skipIf(SparkTest.is_spark_connect, "Spark Connect does not provide access to the JVM, required by dotnet ticks")
def test_unix_epoch_nanos_to_dotnet_ticks(self):
for column in ["unix_nanos", self.unix_nanos.unix_nanos]:
with self.subTest(column=column):
@@ -159,6 +165,7 @@ def test_count_null(self):
).collect()
self.assertEqual([Row(ids=7, nanos=6, null_ids=0, null_nanos=1)], actual)
+ @skipIf(SparkTest.is_spark_connect, "Spark Connect does not provide access to the JVM, required by create_temp_dir")
def test_create_temp_dir(self):
from pyspark import SparkFiles
@@ -166,6 +173,7 @@ def test_create_temp_dir(self):
self.assertTrue(dir.startswith(SparkFiles.getRootDirectory()))
@skipIf(__version__.startswith('3.0.'), 'install_pip_package not supported for Spark 3.0')
+ @skipIf(SparkTest.is_spark_connect, "Spark Connect does not provide access to the JVM, required by install_pip_package")
def test_install_pip_package(self):
self.spark.sparkContext.setLogLevel("INFO")
with self.assertRaises(ImportError):
@@ -189,16 +197,19 @@ def test_install_pip_package(self):
self.assertEqual(expected, actual)
@skipIf(__version__.startswith('3.0.'), 'install_pip_package not supported for Spark 3.0')
+ @skipIf(SparkTest.is_spark_connect, "Spark Connect does not provide access to the JVM, required by install_pip_package")
def test_install_pip_package_unknown_argument(self):
with self.assertRaises(CalledProcessError):
self.spark.install_pip_package("--unknown", "argument")
@skipIf(__version__.startswith('3.0.'), 'install_pip_package not supported for Spark 3.0')
+ @skipIf(SparkTest.is_spark_connect, "Spark Connect does not provide access to the JVM, required by install_pip_package")
def test_install_pip_package_package_not_found(self):
with self.assertRaises(CalledProcessError):
self.spark.install_pip_package("pyspark-extension==abc")
@skipUnless(__version__.startswith('3.0.'), 'install_pip_package not supported for Spark 3.0')
+ @skipIf(SparkTest.is_spark_connect, "Spark Connect does not provide access to the JVM, required by install_pip_package")
def test_install_pip_package_not_supported(self):
with self.assertRaises(NotImplementedError):
self.spark.install_pip_package("emoji")
@@ -209,6 +220,7 @@ def test_install_pip_package_not_supported(self):
f'virtual env python with poetry required')
@skipIf(RICH_SOURCES_ENV not in os.environ, f'Environment variable {RICH_SOURCES_ENV} pointing to '
f'rich project sources required')
+ @skipIf(SparkTest.is_spark_connect, "Spark Connect does not provide access to the JVM, required by install_poetry_project")
def test_install_poetry_project(self):
self.spark.sparkContext.setLogLevel("INFO")
with self.assertRaises(ImportError):
@@ -244,6 +256,7 @@ def test_install_poetry_project(self):
f'virtual env python with poetry required')
@skipIf(RICH_SOURCES_ENV not in os.environ, f'Environment variable {RICH_SOURCES_ENV} pointing to '
f'rich project sources required')
+ @skipIf(SparkTest.is_spark_connect, "Spark Connect does not provide access to the JVM, required by install_poetry_project")
def test_install_poetry_project_wrong_arguments(self):
rich_path = os.environ[RICH_SOURCES_ENV]
poetry_python = os.environ[POETRY_PYTHON_ENV]
@@ -254,6 +267,7 @@ def test_install_poetry_project_wrong_arguments(self):
self.spark.install_poetry_project(rich_path, poetry_python="non-existing-python")
@skipUnless(__version__.startswith('3.0.'), 'install_poetry_project not supported for Spark 3.0')
+ @skipIf(SparkTest.is_spark_connect, "Spark Connect does not provide access to the JVM, required by install_poetry_project")
def test_install_poetry_project_not_supported(self):
with self.assertRaises(NotImplementedError):
self.spark.install_poetry_project("./rich")
diff --git a/python/test/test_parquet.py b/python/test/test_parquet.py
index f81f6215..5e12acaa 100644
--- a/python/test/test_parquet.py
+++ b/python/test/test_parquet.py
@@ -13,11 +13,13 @@
# limitations under the License.
from pathlib import Path
+from unittest import skipIf
from spark_common import SparkTest
import gresearch.spark.parquet
+@skipIf(SparkTest.is_spark_connect, "Spark Connect does not provide access to the JVM, required by Parquet")
class ParquetTest(SparkTest):
test_file = str((Path(__file__).parent.parent.parent / "src" / "test" / "files" / "test.parquet").resolve())
diff --git a/python/test/test_row_number.py b/python/test/test_row_number.py
index 8a99f937..16d620a0 100644
--- a/python/test/test_row_number.py
+++ b/python/test/test_row_number.py
@@ -12,12 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from unittest import skipIf
+
from pyspark.storagelevel import StorageLevel
from spark_common import SparkTest
import gresearch.spark
+@skipIf(SparkTest.is_spark_connect, "Spark Connect does not provide access to the JVM, required by RowNumber")
class RowNumberTest(SparkTest):
@classmethod