Skip to content

Commit

Permalink
Update module
Browse files Browse the repository at this point in the history
  • Loading branch information
Jason Lubken authored and jlubken committed Aug 24, 2020
1 parent 5640fb8 commit e966aa2
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 61 deletions.
4 changes: 0 additions & 4 deletions src/dsdk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
from .utils import (
chunks,
configure_logger,
df_from_query,
df_from_query_by_ids,
dump_json_file,
dump_pickle_file,
load_json_file,
Expand Down Expand Up @@ -51,8 +49,6 @@
"Task",
"chunks",
"configure_logger",
"df_from_query_by_ids",
"df_from_query",
"dump_json_file",
"dump_pickle_file",
"load_json_file",
Expand Down
29 changes: 27 additions & 2 deletions src/dsdk/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def check(self, cur, exceptions=(DatabaseError, InterfaceError)):
logger.info(self.EXTANT, statement)
cur.execute(statement)
for row in cur:
(n,) = row
n = row["n"]
assert n == 1
continue
# pylint: disable=catching-non-exception
Expand All @@ -101,6 +101,20 @@ def check(self, cur, exceptions=(DatabaseError, InterfaceError)):
raise RuntimeError(self.ERRORS, errors)
logger.info(self.END)

@contextmanager
def commit(self) -> Generator[Any, None, None]:
"""Commit."""
with self.connect() as con:
try:
with con.cursor(as_dict=True) as cur:
yield cur
con.commit()
logger.info(self.COMMIT)
except BaseException:
con.rollback()
logger.info(self.ROLLBACK)
raise

@contextmanager
def connect(self) -> Generator[Any, None, None]:
"""Connect."""
Expand All @@ -119,6 +133,17 @@ def connect(self) -> Generator[Any, None, None]:
con.close()
logger.info(self.CLOSE)

@contextmanager
def rollback(self) -> Generator[Any, None, None]:
"""Rollback."""
with self.connect() as con:
try:
with con.cursor(as_dict=True) as cur:
yield cur
finally:
con.rollback()
logger.info(self.ROLLBACK)


class AlchemyPersistor(Messages, BaseAbstractPersistor):
"""AlchemyPersistor."""
Expand Down Expand Up @@ -188,7 +213,7 @@ def check(
logger.info(self.EXTANT, statement)
cur.execute(statement)
for row in cur:
(n,) = row
n = row["n"]
assert n == 1
continue
except exceptions as error:
Expand Down
47 changes: 40 additions & 7 deletions src/dsdk/persistor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
from json import dumps
from logging import getLogger
from re import compile as re_compile
from typing import Any, Dict, Generator, Tuple
from typing import Any, Dict, Generator, Optional, Sequence, Tuple

from pandas import DataFrame, concat

from .dependency import (
inject_int,
Expand All @@ -17,6 +19,7 @@
inject_str_tuple,
)
from .service import Service
from .utils import chunks

logger = getLogger(__name__)

Expand Down Expand Up @@ -45,6 +48,36 @@ def configure(cls, service: Service, parser):
"""Configure."""
raise NotImplementedError()

@classmethod
def df_from_query(
cls, cur, query: str, parameters: Optional[Dict[str, Any]],
) -> DataFrame:
"""Return DataFrame from query."""
if parameters is None:
parameters = {}
cur.execute(query, parameters)
rows = cur.fetchall()
return DataFrame(rows)

@classmethod
def df_from_query_by_ids( # pylint: disable=too-many-arguments
cls,
cur,
query: str,
ids: Sequence[Any],
parameters: Optional[Dict[str, Any]] = None,
size: int = 10000,
) -> DataFrame:
"""Return DataFrame from query by ids."""
if parameters is None:
parameters = {}
dfs = []
for chunk in chunks(ids, size):
cur.execute(query, {"ids": chunk, **parameters})
rows = cur.fetchall()
dfs.append(DataFrame(rows))
return concat(dfs, ignore_index=True)

def __init__(self, sql: Namespace, tables: Tuple[str, ...]):
"""__init__."""
self.sql = sql
Expand All @@ -60,7 +93,7 @@ def check(self, cur, exceptions):
logger.info(self.EXTANT, statement)
cur.execute(statement)
for row in cur:
(n,) = row
n = row["n"]
assert n == 1
continue
except exceptions:
Expand All @@ -70,11 +103,6 @@ def check(self, cur, exceptions):
raise RuntimeError(self.ERRORS, errors)
logger.info(self.END)

@contextmanager
def connect(self) -> Generator[Any, None, None]:
"""Connect."""
raise NotImplementedError()

@contextmanager
def commit(self) -> Generator[Any, None, None]:
"""Commit."""
Expand All @@ -89,6 +117,11 @@ def commit(self) -> Generator[Any, None, None]:
logger.info(self.ROLLBACK)
raise

@contextmanager
def connect(self) -> Generator[Any, None, None]:
"""Connect."""
raise NotImplementedError()

def extant(self, table: str) -> str:
"""Return extant table sql."""
if not ALPHA_NUMERIC_DOT.match(table):
Expand Down
57 changes: 43 additions & 14 deletions src/dsdk/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

logger = getLogger(__name__)


try:
# Not everyone will be using postgres
from psycopg2 import (
Expand All @@ -27,7 +28,10 @@
OperationalError,
connect,
)
from psycopg2.extras import execute_batch
from psycopg2.extras import (
DictCursor,
execute_batch,
)
except ImportError as import_error:
logger.warning(import_error)

Expand Down Expand Up @@ -63,16 +67,23 @@ class Messages: # pylint: disable=too-few-public-methods
class Persistor(Messages, BasePersistor):
"""Persistor."""

@retry((OperationalError,))
def retry_connect(self):
"""Retry connect."""
return connect(
user=self.username,
password=self.password,
host=self.host,
port=self.port,
dbname=self.database,
)
def check(self, cur, exceptions=(DatabaseError, InterfaceError)):
"""Check."""
super().check(cur, exceptions)

@contextmanager
def commit(self) -> Generator[Any, None, None]:
"""Commit."""
with self.connect() as con:
try:
with con.cursor(cursor_factory=DictCursor) as cur:
yield cur
con.commit()
logger.info(self.COMMIT)
except BaseException:
con.rollback()
logger.info(self.ROLLBACK)
raise

@contextmanager
def connect(self) -> Generator[Any, None, None]:
Expand All @@ -87,9 +98,27 @@ def connect(self) -> Generator[Any, None, None]:
con.close()
logger.info(self.CLOSE)

def check(self, cur, exceptions=(DatabaseError, InterfaceError)):
"""Check."""
super().check(cur, exceptions)
@retry((OperationalError,))
def retry_connect(self):
"""Retry connect."""
return connect(
user=self.username,
password=self.password,
host=self.host,
port=self.port,
dbname=self.database,
)

@contextmanager
def rollback(self) -> Generator[Any, None, None]:
"""Rollback."""
with self.connect() as con:
try:
with con.cursor(cursor_factory=DictCursor) as cur:
yield cur
finally:
con.rollback()
logger.info(self.ROLLBACK)


class Mixin(BaseMixin):
Expand Down
35 changes: 1 addition & 34 deletions src/dsdk/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@
from pickle import load as pickle_load
from sys import stdout
from time import sleep as default_sleep
from typing import Any, Callable, Dict, Optional, Sequence

from pandas import DataFrame, concat
from typing import Any, Callable, Sequence

logger = getLogger(__name__)

Expand Down Expand Up @@ -71,37 +69,6 @@ def load_pickle_file(path: str) -> object:
return pickle_load(fin)


def df_from_query_by_ids(
cur,
query: str,
ids: Sequence[Any],
parameters: Optional[Dict[str, Any]] = None,
size: int = 10000,
) -> DataFrame:
"""Return DataFrame from query by ids."""
if parameters is None:
parameters = {}
dfs = []
for chunk in chunks(ids, size):
cur.execute(query, {"ids": chunk, **parameters})
columns = [i[0] for i in cur.description]
rows = cur.fetchall()
dfs.append(DataFrame.from_records(rows, columns=columns))
return concat(dfs, ignore_index=True)


def df_from_query(
cur, query: str, parameters: Optional[Dict[str, Any]],
) -> DataFrame:
"""Return DataFrame from query."""
if parameters is None:
parameters = {}
cur.execute(query, parameters)
columns = [i[0] for i in cur.description]
rows = cur.fetchall()
return DataFrame.from_records(rows, columns=columns)


def retry(
exceptions: Sequence[Exception],
retries: int = 60,
Expand Down

0 comments on commit e966aa2

Please sign in to comment.