diff --git a/python/gresearch/spark/__init__.py b/python/gresearch/spark/__init__.py index 6ff299c5..45fa6d6d 100644 --- a/python/gresearch/spark/__init__.py +++ b/python/gresearch/spark/__init__.py @@ -45,26 +45,56 @@ if TYPE_CHECKING: from pyspark.sql._typing import ColumnOrName +_java_pkg_is_installed: Optional[bool] = None + + +def _check_java_pkg_is_installed(jvm: JVMView) -> bool: + """Check that the Java / Scala package is installed.""" + try: + jvm.uk.co.gresearch.spark.__getattr__("package$").__getattr__("MODULE$").VersionString() + return True + except TypeError as e: + print(e.args) + return False + except: + # any other exception indicate some problem, be safe and do not fail fast here + return True + def _get_jvm(obj: Any) -> JVMView: + """ + Provides easy access to the JVMView provided by Spark, and raises meaningful error message if that is not available. + Also checks that the Java / Scala package is accessible via this JVMView. + """ if obj is None: if SparkContext._active_spark_context is None: raise RuntimeError("This method must be called inside an active Spark session") else: raise ValueError("Cannot provide access to JVM from None") - # helper method to assert the JVM is accessible and provide a useful error message if has_connect and isinstance(obj, (ConnectDataFrame, ConnectDataFrameReader, ConnectSparkSession)): - raise RuntimeError('This feature is not supported for Spark Connect. Please use a classic Spark client. https://github.com/G-Research/spark-extension#spark-connect-server') + raise RuntimeError('This feature is not supported for Spark Connect. Please use a classic Spark client. ' + 'https://github.com/G-Research/spark-extension#spark-connect-server') + if isinstance(obj, DataFrame): - return _get_jvm(obj._sc) - if isinstance(obj, DataFrameReader): - return _get_jvm(obj._spark) - if isinstance(obj, SparkSession): - return _get_jvm(obj.sparkContext) - if isinstance(obj, (SparkContext, SQLContext)): - return obj._jvm - raise RuntimeError(f'Unsupported class: {type(obj)}') + jvm = _get_jvm(obj._sc) + elif isinstance(obj, DataFrameReader): + jvm = _get_jvm(obj._spark) + elif isinstance(obj, SparkSession): + jvm = _get_jvm(obj.sparkContext) + elif isinstance(obj, (SparkContext, SQLContext)): + jvm = obj._jvm + else: + raise RuntimeError(f'Unsupported class: {type(obj)}') + + global _java_pkg_is_installed + if _java_pkg_is_installed is None: + _java_pkg_is_installed = _check_java_pkg_is_installed(jvm) + if not _java_pkg_is_installed: + raise RuntimeError("Java / Scala package not found! You need to add the Maven spark-extension package " + "to your PySpark environment: https://github.com/G-Research/spark-extension#python") + + return jvm def _to_seq(jvm: JVMView, list: List[Any]) -> JavaObject: diff --git a/python/test/test_jvm.py b/python/test/test_jvm.py index 45ddc968..5ea02957 100644 --- a/python/test/test_jvm.py +++ b/python/test/test_jvm.py @@ -16,8 +16,10 @@ from pyspark.sql.functions import sum -from gresearch.spark import _get_jvm, dotnet_ticks_to_timestamp, dotnet_ticks_to_unix_epoch, dotnet_ticks_to_unix_epoch_nanos, \ - timestamp_to_dotnet_ticks, unix_epoch_to_dotnet_ticks, unix_epoch_nanos_to_dotnet_ticks, histogram, job_description, append_description +from gresearch.spark import _get_jvm, \ + dotnet_ticks_to_timestamp, dotnet_ticks_to_unix_epoch, dotnet_ticks_to_unix_epoch_nanos, \ + timestamp_to_dotnet_ticks, unix_epoch_to_dotnet_ticks, unix_epoch_nanos_to_dotnet_ticks, \ + histogram, job_description, append_description from gresearch.spark.diff import * from gresearch.spark.parquet import * from spark_common import SparkTest @@ -58,6 +60,21 @@ def test_get_jvm_connect(self): _get_jvm(object()) self.assertEqual(("Unsupported class: ", ), e.exception.args) + @skipIf(SparkTest.is_spark_connect, "Spark classic client tests") + def test_get_jvm_check_java_pkg_is_installed(self): + from gresearch import spark + + is_installed = spark._java_pkg_is_installed + + try: + spark._java_pkg_is_installed = False + with self.assertRaises(RuntimeError) as e: + _get_jvm(self.spark) + self.assertEqual(("Java / Scala package not found! You need to add the Maven spark-extension package " + "to your PySpark environment: https://github.com/G-Research/spark-extension#python", ), e.exception.args) + finally: + spark._java_pkg_is_installed = is_installed + @skipUnless(SparkTest.is_spark_connect, "Spark connect client tests") def test_diff(self): for label, func in {