Skip to content

Commit

Permalink
Fix dotnet and job description methods / tests
Browse files Browse the repository at this point in the history
  • Loading branch information
EnricoMi committed Aug 15, 2024
1 parent 082cc7b commit 4cbd9b0
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
6 changes: 6 additions & 0 deletions python/gresearch/spark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@


def _get_jvm(obj: Any) -> JVMView:
if obj is None:
if SparkContext._active_spark_context is None:
raise RuntimeError("This method must be called inside an active Spark session")
else:
raise ValueError("Cannot provide access to JVM from None")

# helper method to assert the JVM is accessible and provide a useful error message
if has_connect and isinstance(obj, (ConnectDataFrame, ConnectDataFrameReader, ConnectSparkSession)):
raise RuntimeError('This feature is not supported for Spark Connect. Please use a classic Spark client. https://github.com/G-Research/spark-extension#spark-connect-server')
Expand Down
6 changes: 3 additions & 3 deletions python/test/test_jvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def test_dotnet_ticks(self):
with self.subTest(label):
with self.assertRaises(RuntimeError) as e:
func("id")
self.assertEqual((EXPECTED_UNSUPPORTED_MESSAGE, ), e.exception.args)
self.assertEqual(("This method must be called inside an active Spark session", ), e.exception.args)

@skipUnless(SparkTest.is_spark_connect, "Spark connect client tests")
def test_histogram(self):
Expand All @@ -103,12 +103,12 @@ def test_job_description(self):
with self.assertRaises(RuntimeError) as e:
with job_description("job description"):
pass
self.assertEqual((EXPECTED_UNSUPPORTED_MESSAGE, ), e.exception.args)
self.assertEqual(("This method must be called inside an active Spark session", ), e.exception.args)

with self.assertRaises(RuntimeError) as e:
with append_description("job description"):
pass
self.assertEqual((EXPECTED_UNSUPPORTED_MESSAGE, ), e.exception.args)
self.assertEqual(("This method must be called inside an active Spark session", ), e.exception.args)

@skipUnless(SparkTest.is_spark_connect, "Spark connect client tests")
def test_create_temp_dir(self):
Expand Down

0 comments on commit 4cbd9b0

Please sign in to comment.