diff --git a/python/gresearch/spark/__init__.py b/python/gresearch/spark/__init__.py index 21f0d3a7..75eeb996 100644 --- a/python/gresearch/spark/__init__.py +++ b/python/gresearch/spark/__init__.py @@ -47,6 +47,12 @@ def _get_jvm(obj: Any) -> JVMView: + if obj is None: + if SparkContext._active_spark_context is None: + raise RuntimeError("This method must be called inside an active Spark session") + else: + raise ValueError("Cannot provide access to JVM from None") + # helper method to assert the JVM is accessible and provide a useful error message if has_connect and isinstance(obj, (ConnectDataFrame, ConnectDataFrameReader, ConnectSparkSession)): raise RuntimeError('This feature is not supported for Spark Connect. Please use a classic Spark client. https://github.com/G-Research/spark-extension#spark-connect-server') diff --git a/python/test/test_jvm.py b/python/test/test_jvm.py index e58778ad..45ddc968 100644 --- a/python/test/test_jvm.py +++ b/python/test/test_jvm.py @@ -84,7 +84,7 @@ def test_dotnet_ticks(self): with self.subTest(label): with self.assertRaises(RuntimeError) as e: func("id") - self.assertEqual((EXPECTED_UNSUPPORTED_MESSAGE, ), e.exception.args) + self.assertEqual(("This method must be called inside an active Spark session", ), e.exception.args) @skipUnless(SparkTest.is_spark_connect, "Spark connect client tests") def test_histogram(self): @@ -103,12 +103,12 @@ def test_job_description(self): with self.assertRaises(RuntimeError) as e: with job_description("job description"): pass - self.assertEqual((EXPECTED_UNSUPPORTED_MESSAGE, ), e.exception.args) + self.assertEqual(("This method must be called inside an active Spark session", ), e.exception.args) with self.assertRaises(RuntimeError) as e: with append_description("job description"): pass - self.assertEqual((EXPECTED_UNSUPPORTED_MESSAGE, ), e.exception.args) + self.assertEqual(("This method must be called inside an active Spark session", ), e.exception.args) @skipUnless(SparkTest.is_spark_connect, "Spark connect client tests") def test_create_temp_dir(self):