diff --git a/python/gresearch/spark/__init__.py b/python/gresearch/spark/__init__.py index 4355109b..093103fe 100644 --- a/python/gresearch/spark/__init__.py +++ b/python/gresearch/spark/__init__.py @@ -29,22 +29,27 @@ from pyspark.files import SparkFiles from pyspark.sql import DataFrame, DataFrameReader, SQLContext from pyspark.sql.column import Column, _to_java_column -from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame -from pyspark.sql.connect.readwriter import DataFrameReader as ConnectDataFrameReader -from pyspark.sql.connect.session import SparkContext as ConnectSparkContext -from pyspark.sql.connect.session import SparkSession as ConnectSparkSession from pyspark.sql.context import SQLContext from pyspark.sql.functions import col, count, lit, when from pyspark.sql.session import SparkSession from pyspark.storagelevel import StorageLevel +try: + from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame + from pyspark.sql.connect.readwriter import DataFrameReader as ConnectDataFrameReader + from pyspark.sql.connect.session import SparkContext as ConnectSparkContext + from pyspark.sql.connect.session import SparkSession as ConnectSparkSession + has_connect = True +except ModuleNotFoundError: + has_connect = False + if TYPE_CHECKING: from pyspark.sql._typing import ColumnOrName def _get_jvm(obj: Any) -> JVMView: # helper method to assert the JVM is accessible and provide a useful error message - if isinstance(obj, (ConnectDataFrame, ConnectDataFrameReader, ConnectSparkSession, ConnectSparkContext)): + if has_connect and isinstance(obj, (ConnectDataFrame, ConnectDataFrameReader, ConnectSparkSession, ConnectSparkContext)): raise RuntimeError('This feature is not supported for Spark Connect. Please use a classic Spark client. https://github.com/G-Research/spark-extension#spark-connect-server') if isinstance(obj, DataFrame): return _get_jvm(obj._sc) @@ -293,7 +298,8 @@ def histogram(self: DataFrame, DataFrame.histogram = histogram -ConnectDataFrame.histogram = histogram +if has_connect: + ConnectDataFrame.histogram = histogram class UnpersistHandle: @@ -341,7 +347,8 @@ def with_row_numbers(self: DataFrame, DataFrame.with_row_numbers = with_row_numbers -ConnectDataFrame.with_row_numbers = with_row_numbers +if has_connect: + ConnectDataFrame.with_row_numbers = with_row_numbers def session_or_ctx(self: DataFrame) -> Union[SparkSession, SQLContext]: @@ -349,7 +356,8 @@ def session_or_ctx(self: DataFrame) -> Union[SparkSession, SQLContext]: DataFrame.session_or_ctx = session_or_ctx -ConnectDataFrame.session_or_ctx = session_or_ctx +if has_connect: + ConnectDataFrame.session_or_ctx = session_or_ctx def set_description(description: str, if_not_set: bool = False): @@ -438,8 +446,9 @@ def create_temporary_dir(spark: Union[SparkSession, SparkContext], prefix: str) SparkSession.create_temporary_dir = create_temporary_dir SparkContext.create_temporary_dir = create_temporary_dir -ConnectSparkSession.create_temporary_dir = create_temporary_dir -ConnectSparkContext.create_temporary_dir = create_temporary_dir +if has_connect: + ConnectSparkSession.create_temporary_dir = create_temporary_dir + ConnectSparkContext.create_temporary_dir = create_temporary_dir def install_pip_package(spark: Union[SparkSession, SparkContext], *package_or_pip_option: str) -> None: @@ -475,8 +484,9 @@ def install_pip_package(spark: Union[SparkSession, SparkContext], *package_or_pi SparkSession.install_pip_package = install_pip_package SparkContext.install_pip_package = install_pip_package -ConnectSparkSession.install_pip_package = install_pip_package -ConnectSparkContext.install_pip_package = install_pip_package +if has_connect: + ConnectSparkSession.install_pip_package = install_pip_package + ConnectSparkContext.install_pip_package = install_pip_package def install_poetry_project(spark: Union[SparkSession, SparkContext], @@ -555,5 +565,6 @@ def build_wheel(project: Path) -> Path: SparkSession.install_poetry_project = install_poetry_project SparkContext.install_poetry_project = install_poetry_project -ConnectSparkSession.install_poetry_project = install_poetry_project -ConnectSparkContext.install_poetry_project = install_poetry_project +if has_connect: + ConnectSparkSession.install_poetry_project = install_poetry_project + ConnectSparkContext.install_poetry_project = install_poetry_project diff --git a/python/gresearch/spark/diff/__init__.py b/python/gresearch/spark/diff/__init__.py index 3bd3c412..f9ca8272 100644 --- a/python/gresearch/spark/diff/__init__.py +++ b/python/gresearch/spark/diff/__init__.py @@ -18,12 +18,17 @@ from py4j.java_gateway import JavaObject, JVMView from pyspark.sql import DataFrame -from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame from pyspark.sql.types import DataType from gresearch.spark import _get_jvm, _to_seq, _to_map from gresearch.spark.diff.comparator import DiffComparator, DiffComparators, DefaultDiffComparator +try: + from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame + has_connect = True +except ModuleNotFoundError: + has_connect = False + class DiffMode(Enum): ColumnByColumn = "ColumnByColumn" @@ -493,7 +498,8 @@ def diffwith_with_options(self: DataFrame, other: DataFrame, options: DiffOption DataFrame.diff_with_options = diff_with_options DataFrame.diffwith_with_options = diffwith_with_options -ConnectDataFrame.diff = diff -ConnectDataFrame.diffwith = diffwith -ConnectDataFrame.diff_with_options = diff_with_options -ConnectDataFrame.diffwith_with_options = diffwith_with_options +if has_connect: + ConnectDataFrame.diff = diff + ConnectDataFrame.diffwith = diffwith + ConnectDataFrame.diff_with_options = diff_with_options + ConnectDataFrame.diffwith_with_options = diffwith_with_options diff --git a/python/gresearch/spark/parquet/__init__.py b/python/gresearch/spark/parquet/__init__.py index c7a0d362..9aaa067b 100644 --- a/python/gresearch/spark/parquet/__init__.py +++ b/python/gresearch/spark/parquet/__init__.py @@ -16,10 +16,15 @@ from py4j.java_gateway import JavaObject from pyspark.sql import DataFrameReader, DataFrame -from pyspark.sql.connect.readwriter import DataFrameReader as ConnectDataFrameReader from gresearch.spark import _get_jvm, _to_seq +try: + from pyspark.sql.connect.readwriter import DataFrameReader as ConnectDataFrameReader + has_connect = True +except ModuleNotFoundError: + has_connect = False + def _jreader(reader: DataFrameReader) -> JavaObject: jvm = _get_jvm(reader) @@ -206,8 +211,9 @@ def parquet_partitions(self: DataFrameReader, *paths: str, parallelism: Optional DataFrameReader.parquet_block_columns = parquet_block_columns DataFrameReader.parquet_partitions = parquet_partitions -ConnectDataFrameReader.parquet_metadata = parquet_metadata -ConnectDataFrameReader.parquet_schema = parquet_schema -ConnectDataFrameReader.parquet_blocks = parquet_blocks -ConnectDataFrameReader.parquet_block_columns = parquet_block_columns -ConnectDataFrameReader.parquet_partitions = parquet_partitions +if has_connect: + ConnectDataFrameReader.parquet_metadata = parquet_metadata + ConnectDataFrameReader.parquet_schema = parquet_schema + ConnectDataFrameReader.parquet_blocks = parquet_blocks + ConnectDataFrameReader.parquet_block_columns = parquet_block_columns + ConnectDataFrameReader.parquet_partitions = parquet_partitions