Skip to content

Commit

Permalink
Showing 3 changed files with 42 additions and 50 deletions.
40 changes: 17 additions & 23 deletions src/dsdk/mssql.py
Original file line number Diff line number Diff line change
@@ -101,23 +101,10 @@ def check(self, cur, exceptions=(DatabaseError, InterfaceError)):
raise RuntimeError(self.ERRORS, errors)
logger.info(self.END)

@contextmanager
def commit(self) -> Generator[Any, None, None]:
"""Commit."""
with self.connect() as con:
try:
with con.cursor(as_dict=True) as cur:
yield cur
con.commit()
logger.info(self.COMMIT)
except BaseException:
con.rollback()
logger.info(self.ROLLBACK)
raise

@contextmanager
def connect(self) -> Generator[Any, None, None]:
"""Connect."""
# Replace return type with ContextManager[Any] when mypy is fixed.
con = connect(
server=self.host,
user=self.username,
@@ -134,15 +121,11 @@ def connect(self) -> Generator[Any, None, None]:
logger.info(self.CLOSE)

@contextmanager
def rollback(self) -> Generator[Any, None, None]:
"""Rollback."""
with self.connect() as con:
try:
with con.cursor(as_dict=True) as cur:
yield cur
finally:
con.rollback()
logger.info(self.ROLLBACK)
def cursor(self, con) -> Generator[Any, None, None]:
"""Yield cursor that provides dicts."""
# Replace return type with ContextManager[Any] when mypy is fixed.
with con.cursor(as_dict=True) as cur:
yield cur


class AlchemyPersistor(Messages, BaseAbstractPersistor):
@@ -154,6 +137,7 @@ def configure(
cls, service: Service, parser
) -> Generator[None, None, None]:
"""Dependencies."""
# Replace return type with ContextManager[None] when mypy is fixed.
kwargs: Dict[str, Any] = {}

for key, help_, inject in (
@@ -230,6 +214,7 @@ def check(

@contextmanager
def connect(self) -> Generator[Any, None, None]:
# Replace return type with ContextManager[Any] when mypy is fixed.
"""Connect."""
con = self.engine.connect()
logger.info(self.OPEN)
@@ -239,6 +224,13 @@ def connect(self) -> Generator[Any, None, None]:
con.close()
logger.info(self.CLOSE)

@contextmanager
def cursor(self, con) -> Generator[Any, None, None]:
# Replace return type with ContextManager[Any] when mypy is fixed.
"""Yield a cursor that provides dicts."""
with con.cursor() as cur:
yield cur


class Mixin(BaseMixin):
"""Mixin."""
@@ -254,6 +246,7 @@ def inject_arguments(
self, parser: ArgumentParser
) -> Generator[None, None, None]:
"""Inject arguments."""
# Replace return type with ContextManager[Any] when mypy is fixed.
with self.mssql_cls.configure(self, parser):
with super().inject_arguments(parser):
yield
@@ -274,6 +267,7 @@ def __init__(
def inject_arguments(
self, parser: ArgumentParser
) -> Generator[None, None, None]:
# Replace return type with ContextManager[None] when mypy is fixed.
"""Inject arguments."""
with self.mssql_cls.configure(self, parser):
with super().inject_arguments(parser):
15 changes: 13 additions & 2 deletions src/dsdk/persistor.py
Original file line number Diff line number Diff line change
@@ -106,9 +106,10 @@ def check(self, cur, exceptions):
@contextmanager
def commit(self) -> Generator[Any, None, None]:
"""Commit."""
# Replace return type with ContextManager[Any] when mypy is fixed.
with self.connect() as con:
try:
with con.cursor() as cur:
with self.cursor(con) as cur:
yield cur
con.commit()
logger.info(self.COMMIT)
@@ -120,6 +121,13 @@ def commit(self) -> Generator[Any, None, None]:
@contextmanager
def connect(self) -> Generator[Any, None, None]:
"""Connect."""
# Replace return type with ContextManager[Any] when mypy is fixed.
raise NotImplementedError()

@contextmanager
def cursor(self, con):
"""Yield a cursor that provides dicts."""
# Replace return type with ContextManager[Any] when mypy is fixed.
raise NotImplementedError()

def extant(self, table: str) -> str:
@@ -131,9 +139,10 @@ def extant(self, table: str) -> str:
@contextmanager
def rollback(self) -> Generator[Any, None, None]:
"""Rollback."""
# Replace return type with ContextManager[Any] when mypy is fixed.
with self.connect() as con:
try:
with con.cursor() as cur:
with self.cursor(con) as cur:
yield cur
finally:
con.rollback()
@@ -149,6 +158,7 @@ def configure(
cls, service: Service, parser
) -> Generator[None, None, None]:
"""Configure."""
# Replace return type with ContextManager[None] when mypy is fixed.
kwargs: Dict[str, Any] = {}

for key, help_, inject in (
@@ -197,4 +207,5 @@ def __init__( # pylint: disable=too-many-arguments
@contextmanager
def connect(self) -> Generator[Any, None, None]:
"""Connect."""
# Replace return type with ContextManager[Any] when mypy is fixed.
raise NotImplementedError()
37 changes: 12 additions & 25 deletions src/dsdk/postgres.py
Original file line number Diff line number Diff line change
@@ -71,23 +71,10 @@ def check(self, cur, exceptions=(DatabaseError, InterfaceError)):
"""Check."""
super().check(cur, exceptions)

@contextmanager
def commit(self) -> Generator[Any, None, None]:
"""Commit."""
with self.connect() as con:
try:
with con.cursor(cursor_factory=DictCursor) as cur:
yield cur
con.commit()
logger.info(self.COMMIT)
except BaseException:
con.rollback()
logger.info(self.ROLLBACK)
raise

@contextmanager
def connect(self) -> Generator[Any, None, None]:
"""Connect."""
# Replace return type with ContextManager[Any] when mypy is fixed.
# The `with ... as con:` formulation does not close the connection:
# https://www.psycopg.org/docs/usage.html#with-statement
con = self.retry_connect()
@@ -98,6 +85,13 @@ def connect(self) -> Generator[Any, None, None]:
con.close()
logger.info(self.CLOSE)

@contextmanager
def cursor(self, con) -> Generator[Any, None, None]:
"""Yield a cursor that provides dicts."""
# Replace return type with ContextManager[Any] when mypy is fixed.
with con.cursor(cursor_factory=DictCursor) as cur:
yield cur

@retry((OperationalError,))
def retry_connect(self):
"""Retry connect."""
@@ -109,17 +103,6 @@ def retry_connect(self):
dbname=self.database,
)

@contextmanager
def rollback(self) -> Generator[Any, None, None]:
"""Rollback."""
with self.connect() as con:
try:
with con.cursor(cursor_factory=DictCursor) as cur:
yield cur
finally:
con.rollback()
logger.info(self.ROLLBACK)


class Mixin(BaseMixin):
"""Mixin."""
@@ -137,6 +120,7 @@ def inject_arguments(
self, parser: ArgumentParser
) -> Generator[None, None, None]:
"""Inject arguments."""
# Replace return type with ContextManager[None] when mypy is fixed.
with self.postgres_cls.configure(self, parser):
with super().inject_arguments(parser):
yield
@@ -162,6 +146,7 @@ def open_run(
self, microservice_version: str, model_version: str
) -> Generator[Run, None, None]:
"""Open run."""
# Replace return type with ContextManager[Run] when mypy is fixed.
sql = self.sql
with self.commit() as cur:
cur.execute(sql.schema)
@@ -180,7 +165,9 @@ def open_run(
row["duration"],
)
break

yield run

with self.commit() as cur:
cur.execute(sql.schema)
if run.predictions is not None:

0 comments on commit d9b9de8

Please sign in to comment.