Skip to content

Commit

Permalink
Add tqdm to sql db
Browse files Browse the repository at this point in the history
  • Loading branch information
moobeck committed Feb 12, 2025
1 parent 3a698b4 commit d127ec9
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 9 deletions.
106 changes: 106 additions & 0 deletions inventory_foundation_sdk/db_mgmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,109 @@ def insert_multi_rows(
conn.commit() # Commit all changes after processing

return None


# %% ../nbs/10_db_mgmt.ipynb 6
from tqdm import tqdm
import psycopg2
from psycopg2.extras import execute_values

class SQLDatabase:
def __init__(self, autocommit=True):
self._credentials = get_db_credentials()["con"]
self.connection = None
self.autocommit = autocommit

def connect(self):
if not self.connection:
self.connection = psycopg2.connect(self._credentials)
self.connection.autocommit = self.autocommit

def close(self):
if self.connection:
self.connection.close()
self.connection = None

def __enter__(self):
self.connect()
return self

def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type:
self.connection.rollback()
elif not self.autocommit:
self.connection.commit()
self.close()

def execute_query(self, query: str, params: tuple = None, fetchall: bool = False, fetchone: bool = False, commit: bool = False):
if fetchall and fetchone:
raise ValueError("Both fetchall and fetchone cannot be True")
if not self.connection:
self.connect()
with self.connection.cursor() as cur:
cur.execute(query, params)
result = cur.fetchall() if fetchall else cur.fetchone() if fetchone else None
if commit and self.autocommit:
self.connection.commit()
return result

def execute_multiple_queries(self, queries: list | str, params: list = None, fetchrows: bool = False, commit: bool = False):
if not self.connection:
self.connect()
results = []
with self.connection.cursor() as cur:
if fetchrows:
if isinstance(queries, str):
queries = [queries] * len(params)
for query, par in tqdm(zip(queries, params), total=len(params), desc="Executing queries"):
cur.execute(query, par)
results.append(cur.fetchone())
else:
if not isinstance(queries, str):
raise ValueError("For batch execution use a single query with multiple params (set fetchrows=True otherwise)")
cur.executemany(queries, tqdm(params, desc="Executing batch queries"))
if commit and self.autocommit:
self.connection.commit()
return results if fetchrows else None

def fetch_ids_bulk(self, table_name: str, id_column, column_names: list, rows: list[tuple]) -> list:
"""
Retrieve IDs in one bulk query using the VALUES construct.
'id_column' can be a string or a list/tuple of column names.
"""
if not rows:
return []
columns_str = ", ".join(column_names)
join_clause = " AND ".join([f"t.{col} = v.{col}" for col in column_names])

# Build the SELECT part based on whether id_column is a single column or multiple.
if isinstance(id_column, (list, tuple)):
id_columns_str = ", ".join([f"t.{col}" for col in id_column])
else:
id_columns_str = f"t.{id_column}"

query = f"""
SELECT {id_columns_str}
FROM {table_name} t
JOIN (
VALUES %s
) AS v({columns_str})
ON {join_clause}
"""
all_ids = []
chunk_size = 100
if not self.connection:
self.connect()
with self.connection.cursor() as cur:
for i in tqdm(range(0, len(rows), chunk_size), desc="Fetching IDs", unit="chunk"):
chunk = rows[i:i + chunk_size]
execute_values(cur, query, chunk, page_size=len(chunk))
results = cur.fetchall()
if isinstance(id_column, (list, tuple)):
# Each row is a tuple of id values; convert each value to int.
for row in results:
all_ids.append(tuple(int(x) for x in row))
else:
all_ids.extend(int(row[0]) for row in results)
return all_ids

33 changes: 24 additions & 9 deletions nbs/10_db_mgmt.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">[02/11/25 13:46:50] </span><span style=\"color: #000080; text-decoration-color: #000080\">INFO </span> Using <a href=\"file:///Users/moritzbeckmail.de/miniconda3/envs/if_sdk/lib/python3.11/site-packages/kedro/framework/project/__init__.py\" target=\"_blank\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">__init__.py</span></a><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">:</span><a href=\"file:///Users/moritzbeckmail.de/miniconda3/envs/if_sdk/lib/python3.11/site-packages/kedro/framework/project/__init__.py#270\" target=\"_blank\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">270</span></a>\n",
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">[02/12/25 08:10:44] </span><span style=\"color: #000080; text-decoration-color: #000080\">INFO </span> Using <a href=\"file:///Users/moritzbeckmail.de/miniconda3/envs/if_sdk/lib/python3.11/site-packages/kedro/framework/project/__init__.py\" target=\"_blank\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">__init__.py</span></a><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">:</span><a href=\"file:///Users/moritzbeckmail.de/miniconda3/envs/if_sdk/lib/python3.11/site-packages/kedro/framework/project/__init__.py#270\" target=\"_blank\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">270</span></a>\n",
"<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> <span style=\"color: #008000; text-decoration-color: #008000\">'/Users/moritzbeckmail.de/miniconda3/envs/if_sdk/lib/python3.11/site-p</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
"<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> <span style=\"color: #008000; text-decoration-color: #008000\">ackages/kedro/framework/project/rich_logging.yml'</span> as logging <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
"<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> configuration. <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[2;36m[02/11/25 13:46:50]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Using \u001b]8;id=850623;file:///Users/moritzbeckmail.de/miniconda3/envs/if_sdk/lib/python3.11/site-packages/kedro/framework/project/__init__.py\u001b\\\u001b[2m__init__.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=932813;file:///Users/moritzbeckmail.de/miniconda3/envs/if_sdk/lib/python3.11/site-packages/kedro/framework/project/__init__.py#270\u001b\\\u001b[2m270\u001b[0m\u001b]8;;\u001b\\\n",
"\u001b[2;36m[02/12/25 08:10:44]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Using \u001b]8;id=164157;file:///Users/moritzbeckmail.de/miniconda3/envs/if_sdk/lib/python3.11/site-packages/kedro/framework/project/__init__.py\u001b\\\u001b[2m__init__.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=546820;file:///Users/moritzbeckmail.de/miniconda3/envs/if_sdk/lib/python3.11/site-packages/kedro/framework/project/__init__.py#270\u001b\\\u001b[2m270\u001b[0m\u001b]8;;\u001b\\\n",
"\u001b[2;36m \u001b[0m \u001b[32m'/Users/moritzbeckmail.de/miniconda3/envs/if_sdk/lib/python3.11/site-p\u001b[0m \u001b[2m \u001b[0m\n",
"\u001b[2;36m \u001b[0m \u001b[32mackages/kedro/framework/project/rich_logging.yml'\u001b[0m as logging \u001b[2m \u001b[0m\n",
"\u001b[2;36m \u001b[0m configuration. \u001b[2m \u001b[0m\n"
Expand Down Expand Up @@ -214,7 +214,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -269,24 +269,35 @@
" if fetchrows:\n",
" if isinstance(queries, str):\n",
" queries = [queries] * len(params)\n",
" for query, par in zip(queries, params):\n",
" for query, par in tqdm(zip(queries, params), total=len(params), desc=\"Executing queries\"):\n",
" cur.execute(query, par)\n",
" results.append(cur.fetchone())\n",
" else:\n",
" if not isinstance(queries, str):\n",
" raise ValueError(\"For batch execution use a single query with multiple params (set fetchrows=True otherwise)\")\n",
" cur.executemany(queries, params)\n",
" cur.executemany(queries, tqdm(params, desc=\"Executing batch queries\"))\n",
" if commit and self.autocommit:\n",
" self.connection.commit()\n",
" return results if fetchrows else None\n",
"\n",
" def fetch_ids_bulk(self, table_name: str, id_column: str, column_names: list, rows: list[tuple]) -> list:\n",
" def fetch_ids_bulk(self, table_name: str, id_column, column_names: list, rows: list[tuple]) -> list:\n",
" \"\"\"\n",
" Retrieve IDs in one bulk query using the VALUES construct.\n",
" 'id_column' can be a string or a list/tuple of column names.\n",
" \"\"\"\n",
" if not rows:\n",
" return []\n",
" columns_str = \", \".join(column_names)\n",
" join_clause = \" AND \".join([f\"t.{col} = v.{col}\" for col in column_names])\n",
" \n",
" # Build the SELECT part based on whether id_column is a single column or multiple.\n",
" if isinstance(id_column, (list, tuple)):\n",
" id_columns_str = \", \".join([f\"t.{col}\" for col in id_column])\n",
" else:\n",
" id_columns_str = f\"t.{id_column}\"\n",
" \n",
" query = f\"\"\"\n",
" SELECT t.{id_column}\n",
" SELECT {id_columns_str}\n",
" FROM {table_name} t\n",
" JOIN (\n",
" VALUES %s\n",
Expand All @@ -298,12 +309,16 @@
" if not self.connection:\n",
" self.connect()\n",
" with self.connection.cursor() as cur:\n",
" # Use tqdm to track progress through chunks.\n",
" for i in tqdm(range(0, len(rows), chunk_size), desc=\"Fetching IDs\", unit=\"chunk\"):\n",
" chunk = rows[i:i + chunk_size]\n",
" execute_values(cur, query, chunk, page_size=len(chunk))\n",
" results = cur.fetchall()\n",
" all_ids.extend(int(row[0]) for row in results)\n",
" if isinstance(id_column, (list, tuple)):\n",
" # Each row is a tuple of id values; convert each value to int.\n",
" for row in results:\n",
" all_ids.append(tuple(int(x) for x in row))\n",
" else:\n",
" all_ids.extend(int(row[0]) for row in results)\n",
" return all_ids\n",
"\n"
]
Expand Down

0 comments on commit d127ec9

Please sign in to comment.