Skip to content

Commit

Permalink
Fix imports for Spark <3.4
Browse files Browse the repository at this point in the history
  • Loading branch information
EnricoMi committed Aug 15, 2024
1 parent 3379b55 commit 50f395d
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 25 deletions.
39 changes: 25 additions & 14 deletions python/gresearch/spark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,22 +29,27 @@
from pyspark.files import SparkFiles
from pyspark.sql import DataFrame, DataFrameReader, SQLContext
from pyspark.sql.column import Column, _to_java_column
from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame
from pyspark.sql.connect.readwriter import DataFrameReader as ConnectDataFrameReader
from pyspark.sql.connect.session import SparkContext as ConnectSparkContext
from pyspark.sql.connect.session import SparkSession as ConnectSparkSession
from pyspark.sql.context import SQLContext
from pyspark.sql.functions import col, count, lit, when
from pyspark.sql.session import SparkSession
from pyspark.storagelevel import StorageLevel

try:
from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame
from pyspark.sql.connect.readwriter import DataFrameReader as ConnectDataFrameReader
from pyspark.sql.connect.session import SparkContext as ConnectSparkContext
from pyspark.sql.connect.session import SparkSession as ConnectSparkSession
has_connect = True
except ModuleNotFoundError:
has_connect = False

if TYPE_CHECKING:
from pyspark.sql._typing import ColumnOrName


def _get_jvm(obj: Any) -> JVMView:
# helper method to assert the JVM is accessible and provide a useful error message
if isinstance(obj, (ConnectDataFrame, ConnectDataFrameReader, ConnectSparkSession, ConnectSparkContext)):
if has_connect and isinstance(obj, (ConnectDataFrame, ConnectDataFrameReader, ConnectSparkSession, ConnectSparkContext)):
raise RuntimeError('This feature is not supported for Spark Connect. Please use a classic Spark client. https://github.com/G-Research/spark-extension#spark-connect-server')
if isinstance(obj, DataFrame):
return _get_jvm(obj._sc)
Expand Down Expand Up @@ -293,7 +298,8 @@ def histogram(self: DataFrame,


DataFrame.histogram = histogram
ConnectDataFrame.histogram = histogram
if has_connect:
ConnectDataFrame.histogram = histogram


class UnpersistHandle:
Expand Down Expand Up @@ -341,15 +347,17 @@ def with_row_numbers(self: DataFrame,


DataFrame.with_row_numbers = with_row_numbers
ConnectDataFrame.with_row_numbers = with_row_numbers
if has_connect:
ConnectDataFrame.with_row_numbers = with_row_numbers


def session_or_ctx(self: DataFrame) -> Union[SparkSession, SQLContext]:
return self.sparkSession if hasattr(self, 'sparkSession') else self.sql_ctx


DataFrame.session_or_ctx = session_or_ctx
ConnectDataFrame.session_or_ctx = session_or_ctx
if has_connect:
ConnectDataFrame.session_or_ctx = session_or_ctx


def set_description(description: str, if_not_set: bool = False):
Expand Down Expand Up @@ -438,8 +446,9 @@ def create_temporary_dir(spark: Union[SparkSession, SparkContext], prefix: str)
SparkSession.create_temporary_dir = create_temporary_dir
SparkContext.create_temporary_dir = create_temporary_dir

ConnectSparkSession.create_temporary_dir = create_temporary_dir
ConnectSparkContext.create_temporary_dir = create_temporary_dir
if has_connect:
ConnectSparkSession.create_temporary_dir = create_temporary_dir
ConnectSparkContext.create_temporary_dir = create_temporary_dir


def install_pip_package(spark: Union[SparkSession, SparkContext], *package_or_pip_option: str) -> None:
Expand Down Expand Up @@ -475,8 +484,9 @@ def install_pip_package(spark: Union[SparkSession, SparkContext], *package_or_pi
SparkSession.install_pip_package = install_pip_package
SparkContext.install_pip_package = install_pip_package

ConnectSparkSession.install_pip_package = install_pip_package
ConnectSparkContext.install_pip_package = install_pip_package
if has_connect:
ConnectSparkSession.install_pip_package = install_pip_package
ConnectSparkContext.install_pip_package = install_pip_package


def install_poetry_project(spark: Union[SparkSession, SparkContext],
Expand Down Expand Up @@ -555,5 +565,6 @@ def build_wheel(project: Path) -> Path:
SparkSession.install_poetry_project = install_poetry_project
SparkContext.install_poetry_project = install_poetry_project

ConnectSparkSession.install_poetry_project = install_poetry_project
ConnectSparkContext.install_poetry_project = install_poetry_project
if has_connect:
ConnectSparkSession.install_poetry_project = install_poetry_project
ConnectSparkContext.install_poetry_project = install_poetry_project
16 changes: 11 additions & 5 deletions python/gresearch/spark/diff/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,17 @@

from py4j.java_gateway import JavaObject, JVMView
from pyspark.sql import DataFrame
from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame
from pyspark.sql.types import DataType

from gresearch.spark import _get_jvm, _to_seq, _to_map
from gresearch.spark.diff.comparator import DiffComparator, DiffComparators, DefaultDiffComparator

try:
from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame
has_connect = True
except ModuleNotFoundError:
has_connect = False


class DiffMode(Enum):
ColumnByColumn = "ColumnByColumn"
Expand Down Expand Up @@ -493,7 +498,8 @@ def diffwith_with_options(self: DataFrame, other: DataFrame, options: DiffOption
DataFrame.diff_with_options = diff_with_options
DataFrame.diffwith_with_options = diffwith_with_options

ConnectDataFrame.diff = diff
ConnectDataFrame.diffwith = diffwith
ConnectDataFrame.diff_with_options = diff_with_options
ConnectDataFrame.diffwith_with_options = diffwith_with_options
if has_connect:
ConnectDataFrame.diff = diff
ConnectDataFrame.diffwith = diffwith
ConnectDataFrame.diff_with_options = diff_with_options
ConnectDataFrame.diffwith_with_options = diffwith_with_options
18 changes: 12 additions & 6 deletions python/gresearch/spark/parquet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,15 @@

from py4j.java_gateway import JavaObject
from pyspark.sql import DataFrameReader, DataFrame
from pyspark.sql.connect.readwriter import DataFrameReader as ConnectDataFrameReader

from gresearch.spark import _get_jvm, _to_seq

try:
from pyspark.sql.connect.readwriter import DataFrameReader as ConnectDataFrameReader
has_connect = True
except ModuleNotFoundError:
has_connect = False


def _jreader(reader: DataFrameReader) -> JavaObject:
jvm = _get_jvm(reader)
Expand Down Expand Up @@ -206,8 +211,9 @@ def parquet_partitions(self: DataFrameReader, *paths: str, parallelism: Optional
DataFrameReader.parquet_block_columns = parquet_block_columns
DataFrameReader.parquet_partitions = parquet_partitions

ConnectDataFrameReader.parquet_metadata = parquet_metadata
ConnectDataFrameReader.parquet_schema = parquet_schema
ConnectDataFrameReader.parquet_blocks = parquet_blocks
ConnectDataFrameReader.parquet_block_columns = parquet_block_columns
ConnectDataFrameReader.parquet_partitions = parquet_partitions
if has_connect:
ConnectDataFrameReader.parquet_metadata = parquet_metadata
ConnectDataFrameReader.parquet_schema = parquet_schema
ConnectDataFrameReader.parquet_blocks = parquet_blocks
ConnectDataFrameReader.parquet_block_columns = parquet_block_columns
ConnectDataFrameReader.parquet_partitions = parquet_partitions

0 comments on commit 50f395d

Please sign in to comment.