Skip to content

Commit 962ee00

Browse files
authored
fix(pyspark): unwind catalog/database settings in same order they were set (#9067)
1 parent fd35b66 commit 962ee00

File tree

2 files changed

+46
-9
lines changed

2 files changed

+46
-9
lines changed

ibis/backends/pyspark/__init__.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -229,16 +229,34 @@ def current_catalog(self) -> str:
229229
return catalog
230230

231231
@contextlib.contextmanager
232-
def _active_database(self, name: str | None):
233-
if name is None:
232+
def _active_catalog_database(self, catalog: str | None, db: str | None):
233+
if catalog is None and db is None:
234234
yield
235235
return
236-
current = self.current_database
236+
if catalog is not None and PYSPARK_LT_34:
237+
raise com.UnsupportedArgumentError(
238+
"Catalogs are not supported in pyspark < 3.4"
239+
)
240+
current_catalog = self.current_catalog
241+
current_db = self.current_database
242+
243+
# This little horrible bit of work is to avoid trying to set
244+
# the `CurrentDatabase` inside of a catalog where we don't have permission
245+
# to do so. We can't have the catalog and database context managers work
246+
# separately because we need to:
247+
# 1. set catalog
248+
# 2. set database
249+
# 3. set catalog to previous
250+
# 4. set database to previous
237251
try:
238-
self._session.catalog.setCurrentDatabase(name)
252+
if catalog is not None:
253+
self._session.catalog.setCurrentCatalog(catalog)
254+
self._session.catalog.setCurrentDatabase(db)
239255
yield
240256
finally:
241-
self._session.catalog.setCurrentDatabase(current)
257+
if catalog is not None:
258+
self._session.catalog.setCurrentCatalog(current_catalog)
259+
self._session.catalog.setCurrentDatabase(current_db)
242260

243261
@contextlib.contextmanager
244262
def _active_catalog(self, name: str | None):
@@ -438,7 +456,7 @@ def get_schema(
438456

439457
table_loc = self._to_sqlglot_table((catalog, database))
440458
catalog, db = self._to_catalog_db_tuple(table_loc)
441-
with self._active_catalog(catalog), self._active_database(db):
459+
with self._active_catalog_database(catalog, db):
442460
df = self._session.table(table_name)
443461
struct = PySparkType.to_ibis(df.schema)
444462

@@ -500,18 +518,18 @@ def create_table(
500518
table = obj if isinstance(obj, ir.Expr) else ibis.memtable(obj)
501519
query = self.compile(table)
502520
mode = "overwrite" if overwrite else "error"
503-
with self._active_catalog(catalog), self._active_database(db):
521+
with self._active_catalog_database(catalog, db):
504522
self._run_pre_execute_hooks(table)
505523
df = self._session.sql(query)
506524
df.write.saveAsTable(name, format=format, mode=mode)
507525
elif schema is not None:
508526
schema = PySparkSchema.from_ibis(schema)
509-
with self._active_catalog(catalog), self._active_database(db):
527+
with self._active_catalog_database(catalog, db):
510528
self._session.catalog.createTable(name, schema=schema, format=format)
511529
else:
512530
raise com.IbisError("The schema or obj parameter is required")
513531

514-
return self.table(name, database=db)
532+
return self.table(name, database=(catalog, db))
515533

516534
def create_view(
517535
self,

ibis/backends/pyspark/tests/test_client.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
from __future__ import annotations
22

3+
import pytest
4+
35
import ibis
46

57

8+
@pytest.mark.xfail_version(pyspark=["pyspark<3.4"], reason="no catalog support")
69
def test_catalog_db_args(con, monkeypatch):
710
monkeypatch.setattr(ibis.options, "default_backend", con)
811
t = ibis.memtable({"epoch": [1712848119, 1712848121, 1712848155]})
@@ -20,3 +23,19 @@ def test_catalog_db_args(con, monkeypatch):
2023
con.drop_table("t2", database="spark_catalog.default")
2124

2225
assert "t2" not in con.list_tables(database="default")
26+
27+
28+
def test_create_table_no_catalog(con, monkeypatch):
29+
monkeypatch.setattr(ibis.options, "default_backend", con)
30+
t = ibis.memtable({"epoch": [1712848119, 1712848121, 1712848155]})
31+
32+
# create a table in specified catalog and db
33+
con.create_table("t2", database=("default"), obj=t, overwrite=True)
34+
35+
assert "t2" not in con.list_tables()
36+
assert "t2" in con.list_tables(database="default")
37+
assert "t2" in con.list_tables(database=("default"))
38+
39+
con.drop_table("t2", database="default")
40+
41+
assert "t2" not in con.list_tables(database="default")

0 commit comments

Comments
 (0)