From c7ea505f7143b8bb2238f3fba4f820f4939a7708 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sat, 1 Feb 2025 07:31:14 -0500 Subject: [PATCH] fix(postgres): do not use schema when renaming a table for overwrite purposes (#10771) --- ibis/backends/postgres/__init__.py | 24 ++++++++++----------- ibis/backends/postgres/tests/test_client.py | 14 ++++++++++++ 2 files changed, 25 insertions(+), 13 deletions(-) diff --git a/ibis/backends/postgres/__init__.py b/ibis/backends/postgres/__init__.py index 483bf99c359f..e8f779c2cf1b 100644 --- a/ibis/backends/postgres/__init__.py +++ b/ibis/backends/postgres/__init__.py @@ -622,7 +622,6 @@ def create_table( overwrite If `True`, replace the table if it already exists, otherwise fail if the table exists - """ if obj is None and schema is None: raise ValueError("Either `obj` or `schema` must be specified") @@ -654,10 +653,11 @@ def create_table( if not schema: schema = table.schema() - table_expr = sg.table(temp_name, db=database, quoted=self.compiler.quoted) - target = sge.Schema( - this=table_expr, expressions=schema.to_sqlglot(self.dialect) - ) + quoted = self.compiler.quoted + dialect = self.dialect + + table_expr = sg.table(temp_name, db=database, quoted=quoted) + target = sge.Schema(this=table_expr, expressions=schema.to_sqlglot(dialect)) create_stmt = sge.Create( kind="TABLE", @@ -665,20 +665,18 @@ def create_table( properties=sge.Properties(expressions=properties), ) - this = sg.table(name, catalog=database, quoted=self.compiler.quoted) + this = sg.table(name, catalog=database, quoted=quoted) + this_no_catalog = sg.table(name, quoted=quoted) + with self._safe_raw_sql(create_stmt) as cur: if query is not None: - insert_stmt = sge.Insert(this=table_expr, expression=query).sql( - self.dialect - ) + insert_stmt = sge.Insert(this=table_expr, expression=query).sql(dialect) cur.execute(insert_stmt) if overwrite: + cur.execute(sge.Drop(kind="TABLE", this=this, exists=True).sql(dialect)) cur.execute( - sge.Drop(kind="TABLE", this=this, exists=True).sql(self.dialect) - ) - cur.execute( - f"ALTER TABLE IF EXISTS {table_expr.sql(self.dialect)} RENAME TO {this.sql(self.dialect)}" + f"ALTER TABLE IF EXISTS {table_expr.sql(dialect)} RENAME TO {this_no_catalog.sql(dialect)}" ) if schema is None: diff --git a/ibis/backends/postgres/tests/test_client.py b/ibis/backends/postgres/tests/test_client.py index 2f2cd9c6d8e2..1fb4a4f6a3a3 100644 --- a/ibis/backends/postgres/tests/test_client.py +++ b/ibis/backends/postgres/tests/test_client.py @@ -438,3 +438,17 @@ def test_parsing_oid_dtype(con): # Load a table that uses the OID type and check that we map it to Int64 t = con.table("pg_class", database="pg_catalog") assert t.oid.type() == ibis.dtype("int64") + + +@pytest.fixture +def tmp_db(con): + name = gen_name("tmp_db") + con.create_database(name) + yield name + con.drop_database(name, cascade=True) + + +def test_create_table_overwrite(con, tmp_db): + name = gen_name("overwrite_test") + t = con.create_table(name, schema={"id": "int32"}, database=tmp_db, overwrite=True) + assert t.schema() == ibis.schema({"id": dt.int32})