@@ -229,16 +229,34 @@ def current_catalog(self) -> str:
229
229
return catalog
230
230
231
231
@contextlib .contextmanager
232
- def _active_database (self , name : str | None ):
233
- if name is None :
232
+ def _active_catalog_database (self , catalog : str | None , db : str | None ):
233
+ if catalog is None and db is None :
234
234
yield
235
235
return
236
- current = self .current_database
236
+ if catalog is not None and PYSPARK_LT_34 :
237
+ raise com .UnsupportedArgumentError (
238
+ "Catalogs are not supported in pyspark < 3.4"
239
+ )
240
+ current_catalog = self .current_catalog
241
+ current_db = self .current_database
242
+
243
+ # This little horrible bit of work is to avoid trying to set
244
+ # the `CurrentDatabase` inside of a catalog where we don't have permission
245
+ # to do so. We can't have the catalog and database context managers work
246
+ # separately because we need to:
247
+ # 1. set catalog
248
+ # 2. set database
249
+ # 3. set catalog to previous
250
+ # 4. set database to previous
237
251
try :
238
- self ._session .catalog .setCurrentDatabase (name )
252
+ if catalog is not None :
253
+ self ._session .catalog .setCurrentCatalog (catalog )
254
+ self ._session .catalog .setCurrentDatabase (db )
239
255
yield
240
256
finally :
241
- self ._session .catalog .setCurrentDatabase (current )
257
+ if catalog is not None :
258
+ self ._session .catalog .setCurrentCatalog (current_catalog )
259
+ self ._session .catalog .setCurrentDatabase (current_db )
242
260
243
261
@contextlib .contextmanager
244
262
def _active_catalog (self , name : str | None ):
@@ -438,7 +456,7 @@ def get_schema(
438
456
439
457
table_loc = self ._to_sqlglot_table ((catalog , database ))
440
458
catalog , db = self ._to_catalog_db_tuple (table_loc )
441
- with self ._active_catalog (catalog ), self . _active_database ( db ):
459
+ with self ._active_catalog_database (catalog , db ):
442
460
df = self ._session .table (table_name )
443
461
struct = PySparkType .to_ibis (df .schema )
444
462
@@ -500,18 +518,18 @@ def create_table(
500
518
table = obj if isinstance (obj , ir .Expr ) else ibis .memtable (obj )
501
519
query = self .compile (table )
502
520
mode = "overwrite" if overwrite else "error"
503
- with self ._active_catalog (catalog ), self . _active_database ( db ):
521
+ with self ._active_catalog_database (catalog , db ):
504
522
self ._run_pre_execute_hooks (table )
505
523
df = self ._session .sql (query )
506
524
df .write .saveAsTable (name , format = format , mode = mode )
507
525
elif schema is not None :
508
526
schema = PySparkSchema .from_ibis (schema )
509
- with self ._active_catalog (catalog ), self . _active_database ( db ):
527
+ with self ._active_catalog_database (catalog , db ):
510
528
self ._session .catalog .createTable (name , schema = schema , format = format )
511
529
else :
512
530
raise com .IbisError ("The schema or obj parameter is required" )
513
531
514
- return self .table (name , database = db )
532
+ return self .table (name , database = ( catalog , db ) )
515
533
516
534
def create_view (
517
535
self ,
0 commit comments