Skip to content

Commit

Permalink
refactor, fix scoped session
Browse files Browse the repository at this point in the history
  • Loading branch information
Kareem Zidane committed Apr 9, 2021
1 parent 599d968 commit 1b671e2
Show file tree
Hide file tree
Showing 8 changed files with 684 additions and 638 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@
package_dir={"": "src"},
packages=["cs50"],
url="https://github.com/cs50/python-cs50",
version="6.0.4"
version="7.0.0"
)
20 changes: 3 additions & 17 deletions src/cs50/__init__.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,6 @@
import logging
import os
import sys
from ._logger import _setup_logger
_setup_logger()


# Disable cs50 logger by default
logging.getLogger("cs50").disabled = True

# Import cs50_*
from .cs50 import get_char, get_float, get_int, get_string
try:
from .cs50 import get_long
except ImportError:
pass

# Hook into flask importing
from .cs50 import get_float, get_int, get_string
from . import flask

# Wrap SQLAlchemy
from .sql import SQL
48 changes: 48 additions & 0 deletions src/cs50/_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import logging
import os.path
import re
import sys
import traceback

import termcolor


def _setup_logger():
_logger = logging.getLogger("cs50")
_logger.disabled = True
_logger.setLevel(logging.DEBUG)

# Log messages once
_logger.propagate = False

handler = logging.StreamHandler()
handler.setLevel(logging.DEBUG)

formatter = logging.Formatter("%(levelname)s: %(message)s")
formatter.formatException = lambda exc_info: _formatException(*exc_info)
handler.setFormatter(formatter)
_logger.addHandler(handler)


def _formatException(type, value, tb):
"""
Format traceback, darkening entries from global site-packages directories
and user-specific site-packages directory.
https://stackoverflow.com/a/46071447/5156190
"""

# Absolute paths to site-packages
packages = tuple(os.path.join(os.path.abspath(p), "") for p in sys.path[1:])

# Highlight lines not referring to files in site-packages
lines = []
for line in traceback.format_exception(type, value, tb):
matches = re.search(r"^ File \"([^\"]+)\", line \d+, in .+", line)
if matches and matches.group(1).startswith(packages):
lines += line
else:
matches = re.search(r"^(\s*)(.*?)(\s*)$", line, re.DOTALL)
lines.append(matches.group(1) + termcolor.colored(matches.group(2), "yellow") + matches.group(3))
return "".join(lines).rstrip()


80 changes: 80 additions & 0 deletions src/cs50/_session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import os

import sqlalchemy
import sqlalchemy.orm
import sqlite3

class Session:
def __init__(self, url, **engine_kwargs):
self._url = url
if _is_sqlite_url(self._url):
_assert_sqlite_file_exists(self._url)

self._engine = _create_engine(self._url, **engine_kwargs)
self._is_postgres = self._engine.url.get_backend_name() in {"postgres", "postgresql"}
_setup_on_connect(self._engine)
self._session = _create_scoped_session(self._engine)


def is_postgres(self):
return self._is_postgres


def execute(self, statement):
return self._session.execute(sqlalchemy.text(str(statement)))


def __getattr__(self, attr):
return getattr(self._session, attr)


def _is_sqlite_url(url):
return url.startswith("sqlite:///")


def _assert_sqlite_file_exists(url):
path = url[len("sqlite:///"):]
if not os.path.exists(path):
raise RuntimeError(f"does not exist: {path}")
if not os.path.isfile(path):
raise RuntimeError(f"not a file: {path}")


def _create_engine(url, **kwargs):
try:
engine = sqlalchemy.create_engine(url, **kwargs)
except sqlalchemy.exc.ArgumentError:
raise RuntimeError(f"invalid URL: {url}") from None

engine.execution_options(autocommit=False)
return engine


def _setup_on_connect(engine):
def connect(dbapi_connection, _):
_disable_auto_begin_commit(dbapi_connection)
if _is_sqlite_connection(dbapi_connection):
_enable_sqlite_foreign_key_constraints(dbapi_connection)

sqlalchemy.event.listen(engine, "connect", connect)


def _create_scoped_session(engine):
session_factory = sqlalchemy.orm.sessionmaker(bind=engine)
return sqlalchemy.orm.scoping.scoped_session(session_factory)


def _disable_auto_begin_commit(dbapi_connection):
# Disable underlying API's own emitting of BEGIN and COMMIT so we can ourselves
# https://docs.sqlalchemy.org/en/13/dialects/sqlite.html#serializable-isolation-savepoints-transactional-ddl
dbapi_connection.isolation_level = None


def _is_sqlite_connection(dbapi_connection):
return isinstance(dbapi_connection, sqlite3.Connection)


def _enable_sqlite_foreign_key_constraints(dbapi_connection):
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA foreign_keys=ON")
cursor.close()
Loading

0 comments on commit 1b671e2

Please sign in to comment.