diff --git a/exhibit/core/sql.py b/exhibit/core/sql.py index 5a8d657..6ae8556 100644 --- a/exhibit/core/sql.py +++ b/exhibit/core/sql.py @@ -64,17 +64,16 @@ def query_exhibit_database( # define fully qualified table name, including schema if provided table_full_name = table_name if db_schema is None else ".".join([db_schema, table_name]) - # create engine and connection - engine = create_engine(db_url) - conn = engine.connect() - # column can come in as a string or as an empty list or as ["string"] if column and isinstance(column, list): column = column[0] + # create engine + engine = create_engine(db_url) + # create metadata object of the DB, reflecting the required table - metadata = MetaData(bind=engine, schema=db_schema) - metadata.reflect(only=[table_name]) + metadata = MetaData(schema=db_schema) + metadata.reflect(bind=engine, only=[table_name]) # get the table class from metadata table = metadata.tables[table_full_name] @@ -92,7 +91,8 @@ def query_exhibit_database( stmt = stmt.order_by(text(order)) # get result object - result = conn.execute(stmt) + with engine.connect() as conn: + result = conn.execute(stmt) # build a Pandas dataframe column_names = [col[0] for col in result.cursor.description] @@ -150,15 +150,8 @@ def create_temp_table(table_name, col_names, data, return_table=False, db_path=N db_schema = os.environ.get("EXHIBIT_DB_SCHEMA", None) db_path = db_path if db_path else package_dir(EXHIBIT_DB_LOCAL) - if db_url is None: - db_url = make_url("sqlite:///" + db_path + "?mode=rw") - - # create engine and connection - engine = create_engine(db_url) - conn = engine.connect() - # tables are created from a metadata object - metadata = MetaData(bind=engine, schema=db_schema) + metadata = MetaData(schema=db_schema) # to help with managing the data, convert tuples to a dataframe data_df = pd.DataFrame(data) if not isinstance(data, pd.DataFrame) else data @@ -185,26 +178,38 @@ def create_temp_table(table_name, col_names, data, return_table=False, db_path=N # ensure the correct column data type for table creation table = Table( table_name, metadata, *[Column(c, t) for c, t in zip(col_names, data_types)]) + + # connection block + if db_url is None: + db_url = make_url("sqlite:///" + db_path + "?mode=rw") - # drop the table from DB if it exists and then create - conn.execute(DropTable(table, if_exists=True)) - metadata.create_all(engine) - - # insert the values (assuming tuples in the data follow the col_names order) - # sqlite has a limit on how many records can be inserted into a table at one time - # see: https://www.sqlite.org/limits.html #9 - chunk = 32_000 - if (engine.dialect.name == "sqlite") and (num_records:=len(data)) > chunk: #pragma: no cover - for i, _ in enumerate(range(0, num_records, chunk)): - from_i = i * chunk - to_i = (i+1) * chunk - conn.execute(table.insert().values(data[from_i:to_i])) - else: - conn.execute(table.insert().values(data)) + # create engine and connection + engine = create_engine(db_url) - # save the table in case it's required - if return_table: - result = conn.execute(select(table)).fetchall() + with engine.connect() as conn: + + with conn.begin(): + + # drop the table from DB if it exists and then create + conn.execute(DropTable(table, if_exists=True)) + metadata.create_all(engine) + + # insert the values (assuming tuples in the data follow the col_names order) + # sqlite has a limit on how many records can be inserted into a table at one time + # see: https://www.sqlite.org/limits.html #9 + chunk = 32_000 + #pragma: no cover + if (engine.dialect.name == "sqlite") and (num_records:=len(data)) > chunk: + for i, _ in enumerate(range(0, num_records, chunk)): + from_i = i * chunk + to_i = (i+1) * chunk + conn.execute(table.insert().values(data[from_i:to_i])) + else: + conn.execute(table.insert().values(data)) + + # save the table in case it's required + if return_table: + result = conn.execute(select(table)).fetchall() # shut down the engine which closes all associated connections engine.dispose() @@ -247,17 +252,17 @@ def get_number_of_table_rows(table_name, column=None, db_path=None): table_full_name = table_name if db_schema is None else ".".join([db_schema, table_name]) # create metadata object of the DB, reflecting the required table - metadata = MetaData(bind=engine, schema=db_schema) - metadata.reflect(only=[table_name]) + metadata = MetaData(schema=db_schema) + metadata.reflect(bind=engine, only=[table_name]) # get the table class from metadata table = metadata.tables[table_full_name] # either count distinct values of a given column, or just the size of the table if column: - stmt = select(func.count(table.c[column].distinct())) + stmt = select(func.count(table.c[column].distinct())).select_from(table) # pylint: disable=E1102 else: - stmt = select([func.count()]).select_from(table) + stmt = select(func.count()).select_from(table) # pylint: disable=E1102 # get the count result = conn.execute(stmt).fetchall()[0][0] @@ -293,8 +298,8 @@ def get_number_of_table_columns(table_name, db_path=None): table_full_name = table_name if db_schema is None else ".".join([db_schema, table_name]) # create metadata object of the DB, reflecting the required table - metadata = MetaData(bind=engine, schema=db_schema) - metadata.reflect(only=[table_name]) + metadata = MetaData(schema=db_schema) + metadata.reflect(bind=engine, only=[table_name]) # get the table class from metadata table = metadata.tables[table_full_name] @@ -329,7 +334,14 @@ def check_table_exists(table_name, db_path=None): def execute_sql(sql, db_path=None): ''' - Doc string + Connect to a database and execute SQL passed in as text. + + Parameters + ---------- + sql : str + SQL query + db_path : str + Optional. Mainly used for testing when creating temporary database. ''' db_url = os.environ.get("EXHIBIT_DB_URL", None) @@ -342,6 +354,6 @@ def execute_sql(sql, db_path=None): engine = create_engine(db_url) with engine.connect() as conn: - result = conn.execute(sql).fetchall() + result = conn.execute(text(sql)).fetchall() return result