Skip to content

Commit

Permalink
Updated to use SQLAlchemy 2.x
Browse files Browse the repository at this point in the history
  • Loading branch information
gherka committed Aug 27, 2024
1 parent 2ca4c62 commit 5f1b147
Showing 1 changed file with 53 additions and 41 deletions.
94 changes: 53 additions & 41 deletions exhibit/core/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,17 +64,16 @@ def query_exhibit_database(
# define fully qualified table name, including schema if provided
table_full_name = table_name if db_schema is None else ".".join([db_schema, table_name])

# create engine and connection
engine = create_engine(db_url)
conn = engine.connect()

# column can come in as a string or as an empty list or as ["string"]
if column and isinstance(column, list):
column = column[0]

# create engine
engine = create_engine(db_url)

# create metadata object of the DB, reflecting the required table
metadata = MetaData(bind=engine, schema=db_schema)
metadata.reflect(only=[table_name])
metadata = MetaData(schema=db_schema)
metadata.reflect(bind=engine, only=[table_name])

# get the table class from metadata
table = metadata.tables[table_full_name]
Expand All @@ -92,7 +91,8 @@ def query_exhibit_database(
stmt = stmt.order_by(text(order))

# get result object
result = conn.execute(stmt)
with engine.connect() as conn:
result = conn.execute(stmt)

# build a Pandas dataframe
column_names = [col[0] for col in result.cursor.description]
Expand Down Expand Up @@ -150,15 +150,8 @@ def create_temp_table(table_name, col_names, data, return_table=False, db_path=N
db_schema = os.environ.get("EXHIBIT_DB_SCHEMA", None)
db_path = db_path if db_path else package_dir(EXHIBIT_DB_LOCAL)

if db_url is None:
db_url = make_url("sqlite:///" + db_path + "?mode=rw")

# create engine and connection
engine = create_engine(db_url)
conn = engine.connect()

# tables are created from a metadata object
metadata = MetaData(bind=engine, schema=db_schema)
metadata = MetaData(schema=db_schema)

# to help with managing the data, convert tuples to a dataframe
data_df = pd.DataFrame(data) if not isinstance(data, pd.DataFrame) else data
Expand All @@ -185,26 +178,38 @@ def create_temp_table(table_name, col_names, data, return_table=False, db_path=N
# ensure the correct column data type for table creation
table = Table(
table_name, metadata, *[Column(c, t) for c, t in zip(col_names, data_types)])

# connection block
if db_url is None:
db_url = make_url("sqlite:///" + db_path + "?mode=rw")

# drop the table from DB if it exists and then create
conn.execute(DropTable(table, if_exists=True))
metadata.create_all(engine)

# insert the values (assuming tuples in the data follow the col_names order)
# sqlite has a limit on how many records can be inserted into a table at one time
# see: https://www.sqlite.org/limits.html #9
chunk = 32_000
if (engine.dialect.name == "sqlite") and (num_records:=len(data)) > chunk: #pragma: no cover
for i, _ in enumerate(range(0, num_records, chunk)):
from_i = i * chunk
to_i = (i+1) * chunk
conn.execute(table.insert().values(data[from_i:to_i]))
else:
conn.execute(table.insert().values(data))
# create engine and connection
engine = create_engine(db_url)

# save the table in case it's required
if return_table:
result = conn.execute(select(table)).fetchall()
with engine.connect() as conn:

with conn.begin():

# drop the table from DB if it exists and then create
conn.execute(DropTable(table, if_exists=True))
metadata.create_all(engine)

# insert the values (assuming tuples in the data follow the col_names order)
# sqlite has a limit on how many records can be inserted into a table at one time
# see: https://www.sqlite.org/limits.html #9
chunk = 32_000
#pragma: no cover
if (engine.dialect.name == "sqlite") and (num_records:=len(data)) > chunk:
for i, _ in enumerate(range(0, num_records, chunk)):
from_i = i * chunk
to_i = (i+1) * chunk
conn.execute(table.insert().values(data[from_i:to_i]))
else:
conn.execute(table.insert().values(data))

# save the table in case it's required
if return_table:
result = conn.execute(select(table)).fetchall()

# shut down the engine which closes all associated connections
engine.dispose()
Expand Down Expand Up @@ -247,17 +252,17 @@ def get_number_of_table_rows(table_name, column=None, db_path=None):
table_full_name = table_name if db_schema is None else ".".join([db_schema, table_name])

# create metadata object of the DB, reflecting the required table
metadata = MetaData(bind=engine, schema=db_schema)
metadata.reflect(only=[table_name])
metadata = MetaData(schema=db_schema)
metadata.reflect(bind=engine, only=[table_name])

# get the table class from metadata
table = metadata.tables[table_full_name]

# either count distinct values of a given column, or just the size of the table
if column:
stmt = select(func.count(table.c[column].distinct()))
stmt = select(func.count(table.c[column].distinct())).select_from(table) # pylint: disable=E1102
else:
stmt = select([func.count()]).select_from(table)
stmt = select(func.count()).select_from(table) # pylint: disable=E1102

# get the count
result = conn.execute(stmt).fetchall()[0][0]
Expand Down Expand Up @@ -293,8 +298,8 @@ def get_number_of_table_columns(table_name, db_path=None):
table_full_name = table_name if db_schema is None else ".".join([db_schema, table_name])

# create metadata object of the DB, reflecting the required table
metadata = MetaData(bind=engine, schema=db_schema)
metadata.reflect(only=[table_name])
metadata = MetaData(schema=db_schema)
metadata.reflect(bind=engine, only=[table_name])

# get the table class from metadata
table = metadata.tables[table_full_name]
Expand Down Expand Up @@ -329,7 +334,14 @@ def check_table_exists(table_name, db_path=None):

def execute_sql(sql, db_path=None):
'''
Doc string
Connect to a database and execute SQL passed in as text.
Parameters
----------
sql : str
SQL query
db_path : str
Optional. Mainly used for testing when creating temporary database.
'''

db_url = os.environ.get("EXHIBIT_DB_URL", None)
Expand All @@ -342,6 +354,6 @@ def execute_sql(sql, db_path=None):
engine = create_engine(db_url)

with engine.connect() as conn:
result = conn.execute(sql).fetchall()
result = conn.execute(text(sql)).fetchall()

return result

0 comments on commit 5f1b147

Please sign in to comment.