Skip to content

Commit

Permalink
Replace cursor dicts with tuples
Browse files Browse the repository at this point in the history
  • Loading branch information
jlubken committed Mar 17, 2021
1 parent aea38d7 commit a552961
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 34 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

PYMONGO_REQUIRES = ("pymongo>=3.11.0",)

PYMSSQL_REQUIRES = ("cython>=0.29.21", "pymssql==2.1.5")
PYMSSQL_REQUIRES = ("cython>=0.29.21", "pymssql>=2.1.4")

PSYCOPG2_REQUIRES = ("psycopg2-binary>=2.8.6",)

Expand Down
7 changes: 0 additions & 7 deletions src/dsdk/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,6 @@ def connect(self) -> Generator[Any, None, None]:
con.close()
logger.info(self.CLOSE)

@contextmanager
def cursor(self, con) -> Generator[Any, None, None]:
"""Yield cursor that provides dicts."""
# Replace return type with ContextManager[Any] when mypy is fixed.
with con.cursor(as_dict=True) as cur:
yield cur


class Mixin(BaseMixin):
"""Mixin."""
Expand Down
27 changes: 21 additions & 6 deletions src/dsdk/persistor.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,14 @@ def df_from_query(
if parameters is None:
parameters = {}
cur.execute(query, parameters)
columns = (each[0] for each in cur.description)
rows = cur.fetchall()
return DataFrame(rows)
if rows:
df = DataFrame(rows)
df.columns = columns
else:
df = DataFrame(columns=columns)
return df

@classmethod
def df_from_query_by_ids( # pylint: disable=too-many-arguments
Expand All @@ -72,11 +78,19 @@ def df_from_query_by_ids( # pylint: disable=too-many-arguments
if parameters is None:
parameters = {}
dfs = []
columns = None
chunk = None
for chunk in chunks(ids, size):
cur.execute(query, {"ids": chunk, **parameters})
rows = cur.fetchall()
dfs.append(DataFrame(rows))
return concat(dfs, ignore_index=True)
if rows:
dfs.append(DataFrame(rows))
if chunk is None:
raise ValueError("Parameter ids must not be empty")
columns = (each[0] for each in cur.description)
df = concat(dfs, ignore_index=True)
df.columns = columns
return df

def __init__(self, sql: Namespace, tables: Tuple[str, ...]):
"""__init__."""
Expand All @@ -93,7 +107,7 @@ def check(self, cur, exceptions):
logger.info(self.EXTANT, statement)
cur.execute(statement)
for row in cur:
n = row["n"]
n, *_ = row
assert n == 1
continue
except exceptions:
Expand Down Expand Up @@ -125,10 +139,11 @@ def connect(self) -> Generator[Any, None, None]:
raise NotImplementedError()

@contextmanager
def cursor(self, con):
def cursor(self, con): # pylint: disable=no-self-use
"""Yield a cursor that provides dicts."""
# Replace return type with ContextManager[Any] when mypy is fixed.
raise NotImplementedError()
with con.cursor() as cursor:
yield cursor

def extant(self, table: str) -> str:
"""Return extant table sql."""
Expand Down
32 changes: 14 additions & 18 deletions src/dsdk/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,7 @@
OperationalError,
connect,
)
from psycopg2.extras import (
RealDictCursor,
execute_batch,
)
from psycopg2.extras import execute_batch
from psycopg2.extensions import (
register_adapter,
ISQLQuote,
Expand Down Expand Up @@ -138,13 +135,6 @@ def connect(self) -> Generator[Any, None, None]:
con.close()
logger.info(self.CLOSE)

@contextmanager
def cursor(self, con) -> Generator[Any, None, None]:
"""Yield a cursor that provides dicts."""
# Replace return type with ContextManager[Any] when mypy is fixed.
with con.cursor(cursor_factory=RealDictCursor) as cur:
yield cur

@contextmanager
def open_run(self, parent: Any) -> Generator[Run, None, None]:
"""Open batch."""
Expand All @@ -155,15 +145,21 @@ def open_run(self, parent: Any) -> Generator[Run, None, None]:
cur.execute(sql.schema)
cur.execute(sql.runs.open, columns)
for row in cur:
run = Run(
row["id"], row["microservice_id"], row["model_id"], parent,
)
parent.as_of = row["as_of"]
duration = row["duration"]
(
id_,
microservice_id,
model_id,
duration,
as_of,
time_zone,
*_,
) = row
run = Run(id_, microservice_id, model_id, parent,)
parent.as_of = as_of
parent.duration = Interval(
on=duration.lower, end=duration.upper
)
parent.time_zone = row["time_zone"]
parent.time_zone = time_zone
break

yield run
Expand All @@ -181,7 +177,7 @@ def open_run(self, parent: Any) -> Generator[Run, None, None]:
)
cur.execute(sql.runs.close, {"id": run.id})
for row in cur:
duration = row["duration"]
_, _, _, duration, _, _, *_ = row
run.duration = Interval(on=duration.lower, end=duration.upper)
break

Expand Down
17 changes: 16 additions & 1 deletion src/dsdk/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,17 @@

from __future__ import annotations

from contextlib import contextmanager
from functools import wraps
from json import dump as json_dump
from json import load as json_load
from logging import INFO, Formatter, StreamHandler, getLogger
from pickle import dump as pickle_dump
from pickle import load as pickle_load
from sys import stdout
from time import perf_counter_ns
from time import sleep as default_sleep
from typing import Any, Callable, Sequence
from typing import Any, Callable, Generator, Sequence

logger = getLogger(__name__)

Expand Down Expand Up @@ -69,6 +71,19 @@ def load_pickle_file(path: str) -> object:
return pickle_load(fin)


@contextmanager
def profile(key: str) -> Generator[Any, None, None]:
"""Profile."""
# Replace return type with ContextManager[Any] when mypy is fixed.
begin = perf_counter_ns()
logger.info('{"key": "%s.begin", "ns": "%s"}', key, begin)
yield
end = perf_counter_ns()
logger.info(
'{"key": "%s.end", "ns": "%s", "elapsed": "%s"}', key, end, end - begin
)


def retry(
exceptions: Sequence[Exception],
retries: int = 60,
Expand Down
3 changes: 2 additions & 1 deletion test/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ def test_cursor():
with persistor.rollback() as cur:
cur.execute("""select 1 as n""")
for row in cur.fetchall():
assert row["n"] == 1
n, *_ = row
assert n == 1


def test_open_run(
Expand Down

0 comments on commit a552961

Please sign in to comment.