From 5863f3db6696da59bbf65096426e9ea8e0505efa Mon Sep 17 00:00:00 2001 From: Katherine Bargar Date: Thu, 11 Sep 2025 14:05:13 +0000 Subject: [PATCH 1/3] Refactor out new abstraction creation --- src/ldlite/__init__.py | 23 +++++++++++++++++------ src/ldlite/_database.py | 29 +++++++++++++---------------- src/ldlite/_sqlx.py | 6 +++--- 3 files changed, 33 insertions(+), 25 deletions(-) diff --git a/src/ldlite/__init__.py b/src/ldlite/__init__.py index 61c6081..c41485d 100644 --- a/src/ldlite/__init__.py +++ b/src/ldlite/__init__.py @@ -82,6 +82,7 @@ def __init__(self) -> None: self._quiet = False self.dbtype: DBType = DBType.UNDEFINED self.db: dbapi.DBAPIConnection | None = None + self._db: DBTypeDatabase | None = None self._folio: FolioClient | None = None self.page_size = 1000 self._okapi_timeout = 60 @@ -132,6 +133,11 @@ def _connect_db_duckdb( fn = filename if filename is not None else ":memory:" db = duckdb.connect(database=fn) self.db = cast("dbapi.DBAPIConnection", db.cursor()) + self._db = DBTypeDatabase( + DBType.DUCKDB, + lambda: cast("dbapi.DBAPIConnection", db.cursor()), + ) + return db.cursor() def connect_db_postgresql(self, dsn: str) -> psycopg2.extensions.connection: @@ -150,7 +156,10 @@ def connect_db_postgresql(self, dsn: str) -> psycopg2.extensions.connection: self.dbtype = DBType.POSTGRES db = psycopg.connect(dsn) self.db = cast("dbapi.DBAPIConnection", db) - autocommit(self.db, self.dbtype, True) + self._db = DBTypeDatabase( + DBType.POSTGRES, + lambda: cast("dbapi.DBAPIConnection", psycopg.connect(dsn)), + ) ret_db = psycopg2.connect(dsn) ret_db.rollback() @@ -180,6 +189,10 @@ def experimental_connect_db_sqlite( self.dbtype = DBType.SQLITE fn = filename if filename is not None else "file::memory:?cache=shared" self.db = sqlite3.connect(fn) + self._db = DBTypeDatabase( + DBType.SQLITE, + lambda: cast("dbapi.DBAPIConnection", sqlite3.connect(fn)), + ) db = sqlite3.connect(fn) autocommit(db, self.dbtype, True) @@ -338,7 +351,7 @@ def query( # noqa: C901, PLR0912, PLR0913, PLR0915 if self._folio is None: self._check_folio() return [] - if self.db is None: + if self.db is None or self._db is None: self._check_db() return [] if len(schema_table) == 2 and self.dbtype == DBType.SQLITE: @@ -346,6 +359,7 @@ def query( # noqa: C901, PLR0912, PLR0913, PLR0915 schema_table = [table] if not self._quiet: print("ldlite: querying: " + path, file=sys.stderr) + autocommit(self.db, self.dbtype, True) drop_json_tables(self.db, table) autocommit(self.db, self.dbtype, False) try: @@ -381,10 +395,7 @@ def on_processed() -> bool: p = next(processed) return limit is None or p >= limit - cur = self.db.cursor() - db = DBTypeDatabase(self.dbtype, self.db) - db.ingest_records(self.db, Prefix(table), on_processed, records) - self.db.commit() + self._db.ingest_records(Prefix(table), on_processed, records) if pbar is not None: pbar.close() diff --git a/src/ldlite/_database.py b/src/ldlite/_database.py index f1988de..a9bbdc5 100644 --- a/src/ldlite/_database.py +++ b/src/ldlite/_database.py @@ -76,23 +76,20 @@ def _prepare_raw_table( def ingest_records( self, - conn: DB, prefix: Prefix, on_processed: Callable[[], bool], records: Iterator[tuple[int, str | bytes]], ) -> None: - # the only implementation right now is a hack - # the db connection is managed outside of the factory - # for now it's taken as a parameter - # with self._conn_factory() as conn: - self._prepare_raw_table(conn, prefix) - with closing(conn.cursor()) as cur: - for pkey, d in records: - cur.execute( - self._insert_record_sql.format( - table=prefix.raw_table_name, - ).as_string(), - [pkey, d if isinstance(d, str) else d.decode("utf-8")], - ) - if not on_processed(): - return + with closing(self._conn_factory()) as conn: + self._prepare_raw_table(conn, prefix) + with closing(conn.cursor()) as cur: + for pkey, d in records: + cur.execute( + self._insert_record_sql.format( + table=prefix.raw_table_name, + ).as_string(), + [pkey, d if isinstance(d, str) else d.decode("utf-8")], + ) + if not on_processed(): + return + conn.commit() diff --git a/src/ldlite/_sqlx.py b/src/ldlite/_sqlx.py index 6277fda..fe23819 100644 --- a/src/ldlite/_sqlx.py +++ b/src/ldlite/_sqlx.py @@ -2,7 +2,7 @@ import secrets from enum import Enum -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING, Callable, cast from psycopg import sql @@ -26,9 +26,9 @@ class DBType(Enum): class DBTypeDatabase(Database["dbapi.DBAPIConnection"]): - def __init__(self, dbtype: DBType, db: dbapi.DBAPIConnection): + def __init__(self, dbtype: DBType, factory: Callable[[], dbapi.DBAPIConnection]): self._dbtype = dbtype - super().__init__(lambda: db) + super().__init__(factory) @property def _create_raw_table_sql(self) -> sql.SQL: From 7b8cdcc7ceb9f81c5c952d4f40094b583c3889b1 Mon Sep 17 00:00:00 2001 From: Katherine Bargar Date: Thu, 11 Sep 2025 15:48:42 +0000 Subject: [PATCH 2/3] Refactor dropping tables into shared class --- src/ldlite/__init__.py | 36 ++++------- src/ldlite/_database.py | 129 ++++++++++++++++++++++++++++++++++------ src/ldlite/_jsonx.py | 83 -------------------------- src/ldlite/_sqlx.py | 21 +++++-- 4 files changed, 141 insertions(+), 128 deletions(-) diff --git a/src/ldlite/__init__.py b/src/ldlite/__init__.py index c41485d..93180a8 100644 --- a/src/ldlite/__init__.py +++ b/src/ldlite/__init__.py @@ -50,7 +50,7 @@ from ._csv import to_csv from ._database import Prefix from ._folio import FolioClient -from ._jsonx import Attr, drop_json_tables, transform_json +from ._jsonx import Attr, transform_json from ._select import select from ._sqlx import ( DBType, @@ -236,22 +236,16 @@ def drop_tables(self, table: str) -> None: ld.drop_tables('g') """ - if self.db is None: + if self.db is None or self._db is None: self._check_db() return - autocommit(self.db, self.dbtype, True) schema_table = table.strip().split(".") - if len(schema_table) < 1 or len(schema_table) > 2: + if len(schema_table) != 1 and len(schema_table) != 2: raise ValueError("invalid table name: " + table) - self._check_db() - cur = self.db.cursor() - try: - cur.execute("DROP TABLE IF EXISTS " + sqlid(table)) - except (RuntimeError, psycopg2.Error): - pass - finally: - cur.close() - drop_json_tables(self.db, table) + if len(schema_table) == 2 and self.dbtype == DBType.SQLITE: + table = schema_table[0] + "_" + schema_table[1] + prefix = Prefix(table) + self._db.drop_prefix(prefix) def set_folio_max_retries(self, max_retries: int) -> None: """Sets the maximum number of retries for FOLIO requests. @@ -356,12 +350,9 @@ def query( # noqa: C901, PLR0912, PLR0913, PLR0915 return [] if len(schema_table) == 2 and self.dbtype == DBType.SQLITE: table = schema_table[0] + "_" + schema_table[1] - schema_table = [table] + prefix = Prefix(table) if not self._quiet: print("ldlite: querying: " + path, file=sys.stderr) - autocommit(self.db, self.dbtype, True) - drop_json_tables(self.db, table) - autocommit(self.db, self.dbtype, False) try: # First get total number of records records = self._folio.iterate_records( @@ -395,13 +386,15 @@ def on_processed() -> bool: p = next(processed) return limit is None or p >= limit - self._db.ingest_records(Prefix(table), on_processed, records) + self._db.ingest_records(prefix, on_processed, records) if pbar is not None: pbar.close() + self._db.drop_extracted_tables(prefix) newtables = [table] newattrs = {} if json_depth > 0: + autocommit(self.db, self.dbtype, False) jsontables, jsonattrs = transform_json( self.db, self.dbtype, @@ -417,12 +410,7 @@ def on_processed() -> bool: newattrs[table] = {"__id": Attr("__id", "bigint")} if not keep_raw: - cur = self.db.cursor() - try: - cur.execute("DROP TABLE " + sqlid(table)) - self.db.commit() - finally: - cur.close() + self._db.drop_raw_table(prefix) finally: autocommit(self.db, self.dbtype, True) diff --git a/src/ldlite/_database.py b/src/ldlite/_database.py index a9bbdc5..0024e4f 100644 --- a/src/ldlite/_database.py +++ b/src/ldlite/_database.py @@ -2,12 +2,12 @@ from abc import ABC, abstractmethod from contextlib import closing -from typing import TYPE_CHECKING, Callable, Generic, TypeVar +from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, cast from psycopg import sql if TYPE_CHECKING: - from collections.abc import Iterator + from collections.abc import Iterator, Sequence from _typeshed import dbapi @@ -27,29 +27,125 @@ def __init__(self, table: str): def schema_name(self) -> sql.Identifier | None: return None if self._schema is None else sql.Identifier(self._schema) + def identifier(self, table: str) -> sql.Identifier: + if self._schema is None: + return sql.Identifier(table) + return sql.Identifier(self._schema, table) + @property def raw_table_name(self) -> sql.Identifier: - return ( - sql.Identifier(self._schema, self._prefix) - if self._schema is not None - else sql.Identifier(self._prefix) - ) + return self.identifier(self._prefix) + + @property + def catalog_table_name(self) -> sql.Identifier: + return self.identifier(f"{self._prefix}__tcatalog") + + @property + def legacy_jtable(self) -> sql.Identifier: + return self.identifier(f"{self._prefix}_jtable") class Database(ABC, Generic[DB]): def __init__(self, conn_factory: Callable[[], DB]): self._conn_factory = conn_factory + @abstractmethod + def _rollback(self, conn: DB) -> None: ... + + def drop_prefix( + self, + prefix: Prefix, + ) -> None: + with closing(self._conn_factory()) as conn: + self._drop_extracted_tables(conn, prefix) + self._drop_raw_table(conn, prefix) + conn.commit() + + def drop_raw_table( + self, + prefix: Prefix, + ) -> None: + with closing(self._conn_factory()) as conn: + self._drop_raw_table(conn, prefix) + conn.commit() + + def _drop_raw_table( + self, + conn: DB, + prefix: Prefix, + ) -> None: + with closing(conn.cursor()) as cur: + cur.execute( + sql.SQL("DROP TABLE IF EXISTS {table};") + .format(table=prefix.raw_table_name) + .as_string(), + ) + + def drop_extracted_tables( + self, + prefix: Prefix, + ) -> None: + with closing(self._conn_factory()) as conn: + self._drop_extracted_tables(conn, prefix) + conn.commit() + @property @abstractmethod - def _truncate_raw_table_sql(self) -> sql.SQL: ... + def _missing_table_error(self) -> tuple[type[Exception], ...]: ... + def _drop_extracted_tables( + self, + conn: DB, + prefix: Prefix, + ) -> None: + tables: list[Sequence[Sequence[Any]]] = [] + with closing(conn.cursor()) as cur: + try: + cur.execute( + sql.SQL("SELECT table_name FROM {catalog};") + .format(catalog=prefix.catalog_table_name) + .as_string(), + ) + except self._missing_table_error: + self._rollback(conn) + else: + tables.extend(cur.fetchall()) + + with closing(conn.cursor()) as cur: + try: + cur.execute( + sql.SQL("SELECT table_name FROM {catalog};") + .format(catalog=prefix.legacy_jtable) + .as_string(), + ) + except self._missing_table_error: + self._rollback(conn) + else: + tables.extend(cur.fetchall()) + + with closing(conn.cursor()) as cur: + for (et,) in tables: + cur.execute( + sql.SQL("DROP TABLE IF EXISTS {table};") + .format(table=sql.Identifier(cast("str", et))) + .as_string(), + ) + cur.execute( + sql.SQL("DROP TABLE IF EXISTS {catalog};") + .format(catalog=prefix.catalog_table_name) + .as_string(), + ) + cur.execute( + sql.SQL("DROP TABLE IF EXISTS {catalog};") + .format(catalog=prefix.legacy_jtable) + .as_string(), + ) + @property @abstractmethod - def _create_raw_table_sql(self) -> sql.SQL: ... + def _truncate_raw_table_sql(self) -> sql.SQL: ... @property @abstractmethod - def _insert_record_sql(self) -> sql.SQL: ... - + def _create_raw_table_sql(self) -> sql.SQL: ... def _prepare_raw_table( self, conn: DB, @@ -62,18 +158,17 @@ def _prepare_raw_table( .format(schema=prefix.schema_name) .as_string(), ) - + self._drop_raw_table(conn, prefix) + with closing(conn.cursor()) as cur: cur.execute( self._create_raw_table_sql.format( table=prefix.raw_table_name, ).as_string(), ) - cur.execute( - self._truncate_raw_table_sql.format( - table=prefix.raw_table_name, - ).as_string(), - ) + @property + @abstractmethod + def _insert_record_sql(self) -> sql.SQL: ... def ingest_records( self, prefix: Prefix, diff --git a/src/ldlite/_jsonx.py b/src/ldlite/_jsonx.py index efce0d3..4983e94 100644 --- a/src/ldlite/_jsonx.py +++ b/src/ldlite/_jsonx.py @@ -86,93 +86,10 @@ def __repr__(self) -> str: ) -def _old_jtable(table: str) -> str: - return table + "_jtable" - - def _tcatalog(table: str) -> str: return table + "__tcatalog" -# noinspection DuplicatedCode -def _old_drop_json_tables(db: dbapi.DBAPIConnection, table: str) -> None: - jtable_sql = sqlid(_old_jtable(table)) - cur = db.cursor() - try: - cur.execute("SELECT table_name FROM " + jtable_sql) - rows = list(cur.fetchall()) - for row in rows: - t = row[0] - cur2 = db.cursor() - try: - cur2.execute("DROP TABLE " + sqlid(t)) - except (psycopg.Error, duckdb.CatalogException, sqlite3.OperationalError): - continue - finally: - cur2.close() - except ( - psycopg.errors.UndefinedTable, - sqlite3.OperationalError, - duckdb.CatalogException, - ): - pass - finally: - cur.close() - cur = db.cursor() - try: - cur.execute("DROP TABLE " + jtable_sql) - except ( - psycopg.errors.UndefinedTable, - duckdb.CatalogException, - sqlite3.OperationalError, - ): - pass - finally: - cur.close() - - -# noinspection DuplicatedCode -def drop_json_tables(db: dbapi.DBAPIConnection, table: str) -> None: - tcatalog_sql = sqlid(_tcatalog(table)) - cur = db.cursor() - try: - cur.execute("SELECT table_name FROM " + tcatalog_sql) - rows = list(cur.fetchall()) - for row in rows: - t = row[0] - cur2 = db.cursor() - try: - cur2.execute("DROP TABLE " + sqlid(t)) - except ( - psycopg.errors.UndefinedTable, - duckdb.CatalogException, - sqlite3.OperationalError, - ): - continue - finally: - cur2.close() - except ( - psycopg.errors.UndefinedTable, - duckdb.CatalogException, - sqlite3.OperationalError, - ): - pass - finally: - cur.close() - cur = db.cursor() - try: - cur.execute("DROP TABLE " + tcatalog_sql) - except ( - psycopg.errors.UndefinedTable, - duckdb.CatalogException, - sqlite3.OperationalError, - ): - pass - finally: - cur.close() - _old_drop_json_tables(db, table) - - def _table_name(parents: list[tuple[int, str]]) -> str: j = len(parents) while j > 0 and parents[j - 1][0] == 0: diff --git a/src/ldlite/_sqlx.py b/src/ldlite/_sqlx.py index fe23819..e87a38f 100644 --- a/src/ldlite/_sqlx.py +++ b/src/ldlite/_sqlx.py @@ -1,18 +1,17 @@ from __future__ import annotations import secrets +import sqlite3 from enum import Enum from typing import TYPE_CHECKING, Callable, cast +import duckdb +import psycopg from psycopg import sql from ._database import Database if TYPE_CHECKING: - import sqlite3 - - import duckdb - import psycopg from _typeshed import dbapi from ._jsonx import JsonValue @@ -30,6 +29,20 @@ def __init__(self, dbtype: DBType, factory: Callable[[], dbapi.DBAPIConnection]) self._dbtype = dbtype super().__init__(factory) + @property + def _missing_table_error(self) -> tuple[type[Exception], ...]: + return ( + psycopg.errors.UndefinedTable, + sqlite3.OperationalError, + duckdb.CatalogException, + ) + + def _rollback(self, conn: dbapi.DBAPIConnection) -> None: + if sql3db := as_sqlite(conn, self._dbtype): + sql3db.rollback() + if pgdb := as_postgres(conn, self._dbtype): + pgdb.rollback() + @property def _create_raw_table_sql(self) -> sql.SQL: create_sql = "CREATE TABLE IF NOT EXISTS {table} (__id integer, jsonb text);" From 365601883ad96a758624181892c20af91e2c4677 Mon Sep 17 00:00:00 2001 From: Katherine Bargar Date: Thu, 11 Sep 2025 17:48:10 +0000 Subject: [PATCH 3/3] Refactor on_progress to do less work in a tight loop --- src/ldlite/__init__.py | 45 ++++++++++++++++++++++++++++-------------- 1 file changed, 30 insertions(+), 15 deletions(-) diff --git a/src/ldlite/__init__.py b/src/ldlite/__init__.py index 93180a8..063bea9 100644 --- a/src/ldlite/__init__.py +++ b/src/ldlite/__init__.py @@ -367,8 +367,9 @@ def query( # noqa: C901, PLR0912, PLR0913, PLR0915 if self._verbose: print("ldlite: estimated row count: " + str(total), file=sys.stderr) - processed = count(0) - pbar = None + p_count = count(0) + processed = 0 + pbar: tqdm | PbarNoop # type:ignore[type-arg] if not self._quiet: pbar = tqdm( desc="reading", @@ -379,16 +380,32 @@ def query( # noqa: C901, PLR0912, PLR0913, PLR0915 colour="#A9A9A9", bar_format="{desc} {bar}{postfix}", ) + else: - def on_processed() -> bool: - if pbar is not None: - pbar.update(1) - p = next(processed) - return limit is None or p >= limit + class PbarNoop: + def update(self, _: int) -> None: ... + def close(self) -> None: ... + + pbar = PbarNoop() - self._db.ingest_records(prefix, on_processed, records) - if pbar is not None: - pbar.close() + def on_processed() -> bool: + pbar.update(1) + nonlocal processed + processed = next(p_count) + return True + + def on_processed_limit() -> bool: + pbar.update(1) + nonlocal processed + processed = next(p_count) + return limit is None or processed >= limit + + self._db.ingest_records( + prefix, + on_processed_limit if limit is not None else on_processed, + records, + ) + pbar.close() self._db.drop_extracted_tables(prefix) newtables = [table] @@ -399,7 +416,7 @@ def on_processed() -> bool: self.db, self.dbtype, table, - next(processed) - 1, + processed, self._quiet, json_depth, ) @@ -445,10 +462,8 @@ def on_processed() -> bool: pass finally: cur.close() - if pbar is not None: - pbar.update(1) - if pbar is not None: - pbar.close() + pbar.update(1) + pbar.close() # Return table names if not self._quiet: print("ldlite: created tables: " + ", ".join(newtables), file=sys.stderr)