diff --git a/setup.py b/setup.py index c545061..e584d54 100644 --- a/setup.py +++ b/setup.py @@ -15,7 +15,7 @@ PYMONGO_REQUIRES = ("pymongo>=3.11.0",) -PYMSSQL_REQUIRES = ("cython>=0.29.21", "pymssql==2.1.5") +PYMSSQL_REQUIRES = ("cython>=0.29.21", "pymssql>=2.1.4") PSYCOPG2_REQUIRES = ("psycopg2-binary>=2.8.6",) diff --git a/src/dsdk/mssql.py b/src/dsdk/mssql.py index f198164..9b25536 100644 --- a/src/dsdk/mssql.py +++ b/src/dsdk/mssql.py @@ -101,13 +101,6 @@ def connect(self) -> Generator[Any, None, None]: con.close() logger.info(self.CLOSE) - @contextmanager - def cursor(self, con) -> Generator[Any, None, None]: - """Yield cursor that provides dicts.""" - # Replace return type with ContextManager[Any] when mypy is fixed. - with con.cursor(as_dict=True) as cur: - yield cur - class Mixin(BaseMixin): """Mixin.""" diff --git a/src/dsdk/persistor.py b/src/dsdk/persistor.py index 4ed87b1..5d0bf65 100644 --- a/src/dsdk/persistor.py +++ b/src/dsdk/persistor.py @@ -56,8 +56,14 @@ def df_from_query( if parameters is None: parameters = {} cur.execute(query, parameters) + columns = (each[0] for each in cur.description) rows = cur.fetchall() - return DataFrame(rows) + if rows: + df = DataFrame(rows) + df.columns = columns + else: + df = DataFrame(columns=columns) + return df @classmethod def df_from_query_by_ids( # pylint: disable=too-many-arguments @@ -72,11 +78,19 @@ def df_from_query_by_ids( # pylint: disable=too-many-arguments if parameters is None: parameters = {} dfs = [] + columns = None + chunk = None for chunk in chunks(ids, size): cur.execute(query, {"ids": chunk, **parameters}) rows = cur.fetchall() - dfs.append(DataFrame(rows)) - return concat(dfs, ignore_index=True) + if rows: + dfs.append(DataFrame(rows)) + if chunk is None: + raise ValueError("Parameter ids must not be empty") + columns = (each[0] for each in cur.description) + df = concat(dfs, ignore_index=True) + df.columns = columns + return df def __init__(self, sql: Namespace, tables: Tuple[str, ...]): """__init__.""" @@ -93,7 +107,7 @@ def check(self, cur, exceptions): logger.info(self.EXTANT, statement) cur.execute(statement) for row in cur: - n = row["n"] + n, *_ = row assert n == 1 continue except exceptions: @@ -125,10 +139,11 @@ def connect(self) -> Generator[Any, None, None]: raise NotImplementedError() @contextmanager - def cursor(self, con): + def cursor(self, con): # pylint: disable=no-self-use """Yield a cursor that provides dicts.""" # Replace return type with ContextManager[Any] when mypy is fixed. - raise NotImplementedError() + with con.cursor() as cursor: + yield cursor def extant(self, table: str) -> str: """Return extant table sql.""" diff --git a/src/dsdk/postgres.py b/src/dsdk/postgres.py index 63886f4..66f75f7 100644 --- a/src/dsdk/postgres.py +++ b/src/dsdk/postgres.py @@ -29,10 +29,7 @@ OperationalError, connect, ) - from psycopg2.extras import ( - RealDictCursor, - execute_batch, - ) + from psycopg2.extras import execute_batch from psycopg2.extensions import ( register_adapter, ISQLQuote, @@ -138,13 +135,6 @@ def connect(self) -> Generator[Any, None, None]: con.close() logger.info(self.CLOSE) - @contextmanager - def cursor(self, con) -> Generator[Any, None, None]: - """Yield a cursor that provides dicts.""" - # Replace return type with ContextManager[Any] when mypy is fixed. - with con.cursor(cursor_factory=RealDictCursor) as cur: - yield cur - @contextmanager def open_run(self, parent: Any) -> Generator[Run, None, None]: """Open batch.""" @@ -155,15 +145,21 @@ def open_run(self, parent: Any) -> Generator[Run, None, None]: cur.execute(sql.schema) cur.execute(sql.runs.open, columns) for row in cur: - run = Run( - row["id"], row["microservice_id"], row["model_id"], parent, - ) - parent.as_of = row["as_of"] - duration = row["duration"] + ( + id_, + microservice_id, + model_id, + duration, + as_of, + time_zone, + *_, + ) = row + run = Run(id_, microservice_id, model_id, parent,) + parent.as_of = as_of parent.duration = Interval( on=duration.lower, end=duration.upper ) - parent.time_zone = row["time_zone"] + parent.time_zone = time_zone break yield run @@ -181,7 +177,7 @@ def open_run(self, parent: Any) -> Generator[Run, None, None]: ) cur.execute(sql.runs.close, {"id": run.id}) for row in cur: - duration = row["duration"] + _, _, _, duration, _, _, *_ = row run.duration = Interval(on=duration.lower, end=duration.upper) break diff --git a/src/dsdk/utils.py b/src/dsdk/utils.py index 1dbdb79..a7bdbe2 100644 --- a/src/dsdk/utils.py +++ b/src/dsdk/utils.py @@ -3,6 +3,7 @@ from __future__ import annotations +from contextlib import contextmanager from functools import wraps from json import dump as json_dump from json import load as json_load @@ -10,8 +11,9 @@ from pickle import dump as pickle_dump from pickle import load as pickle_load from sys import stdout +from time import perf_counter_ns from time import sleep as default_sleep -from typing import Any, Callable, Sequence +from typing import Any, Callable, Generator, Sequence logger = getLogger(__name__) @@ -69,6 +71,19 @@ def load_pickle_file(path: str) -> object: return pickle_load(fin) +@contextmanager +def profile(key: str) -> Generator[Any, None, None]: + """Profile.""" + # Replace return type with ContextManager[Any] when mypy is fixed. + begin = perf_counter_ns() + logger.info('{"key": "%s.begin", "ns": "%s"}', key, begin) + yield + end = perf_counter_ns() + logger.info( + '{"key": "%s.end", "ns": "%s", "elapsed": "%s"}', key, end, end - begin + ) + + def retry( exceptions: Sequence[Exception], retries: int = 60, diff --git a/test/test_postgres.py b/test/test_postgres.py index d1794be..23f7cbc 100644 --- a/test/test_postgres.py +++ b/test/test_postgres.py @@ -94,7 +94,8 @@ def test_cursor(): with persistor.rollback() as cur: cur.execute("""select 1 as n""") for row in cur.fetchall(): - assert row["n"] == 1 + n, *_ = row + assert n == 1 def test_open_run(