From 63abcad995ac87278608c6fe2f05cee8af7f1376 Mon Sep 17 00:00:00 2001 From: Dmitry Lunev Date: Mon, 5 Aug 2024 16:41:17 +0300 Subject: [PATCH] v1.0.0: separate package from python-dev-utils (refactor) - test coverage is not enough (78%). --- .github/workflows/ci.yaml | 17 ++ .github/workflows/lint.yaml | 41 ++++ .github/workflows/test.yaml | 87 ++++++++ .gitignore | 167 +++++++++++++++ Makefile | 65 ++++++ README.md | 37 +++- pyproject.toml | 140 +++++++++++++ sqlalchemy_profiler/__init__.py | 7 + sqlalchemy_profiler/containers.py | 52 +++++ sqlalchemy_profiler/ext/__init__.py | 0 sqlalchemy_profiler/ext/fastapi.py | 83 ++++++++ sqlalchemy_profiler/profilers.py | 304 ++++++++++++++++++++++++++++ sqlalchemy_profiler/types.py | 86 ++++++++ sqlalchemy_profiler/utils.py | 31 +++ tests/__init__.py | 0 tests/conftest.py | 266 ++++++++++++++++++++++++ tests/test_fastapi.py | 14 ++ tests/test_profilers.py | 103 ++++++++++ tests/types.py | 32 +++ tests/utils.py | 217 ++++++++++++++++++++ 20 files changed, 1748 insertions(+), 1 deletion(-) create mode 100644 .github/workflows/ci.yaml create mode 100644 .github/workflows/lint.yaml create mode 100644 .github/workflows/test.yaml create mode 100644 .gitignore create mode 100644 Makefile create mode 100644 pyproject.toml create mode 100644 sqlalchemy_profiler/__init__.py create mode 100644 sqlalchemy_profiler/containers.py create mode 100644 sqlalchemy_profiler/ext/__init__.py create mode 100644 sqlalchemy_profiler/ext/fastapi.py create mode 100644 sqlalchemy_profiler/profilers.py create mode 100644 sqlalchemy_profiler/types.py create mode 100644 sqlalchemy_profiler/utils.py create mode 100644 tests/__init__.py create mode 100644 tests/conftest.py create mode 100644 tests/test_fastapi.py create mode 100644 tests/test_profilers.py create mode 100644 tests/types.py create mode 100644 tests/utils.py diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml new file mode 100644 index 0000000..31e860e --- /dev/null +++ b/.github/workflows/ci.yaml @@ -0,0 +1,17 @@ +name: Tests And Linting + +on: + pull_request: + push: + +jobs: + test: + name: "Test (${{ matrix.python-version }}" + strategy: + fail-fast: true + matrix: + python-version: ["3.11"] + uses: ./.github/workflows/test.yaml + with: + coverage: true + python-version: ${{ matrix.python-version }} \ No newline at end of file diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml new file mode 100644 index 0000000..4ae604a --- /dev/null +++ b/.github/workflows/lint.yaml @@ -0,0 +1,41 @@ +name: Python package + +on: + workflow_call: + inputs: + python-version: + required: true + type: string + os: + required: false + type: string + default: "ubuntu-latest" + timeout: + required: false + type: number + default: 60 + + +jobs: + lint: + timeout-minutes: ${{ inputs.timeout }} + runs-on: ${{ inputs.os }} + steps: + - uses: actions/checkout@v4 + - name: Set up Python ${{ inputs.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ inputs.python-version }} + - name: Set up PDM + uses: pdm-project/setup-pdm@v3 + with: + python-version: ${{ inputs.python-version }} + - uses: actions/cache@v3 + name: Define a cache for the virtual environment based on the dependencies lock file + with: + path: ./.venv + key: venv-${{ hashFiles('pdm.lock') }} + - name: Install the project dependencies + run: make install + - name: Run linting + run: make lint \ No newline at end of file diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml new file mode 100644 index 0000000..fcfa081 --- /dev/null +++ b/.github/workflows/test.yaml @@ -0,0 +1,87 @@ +name: Python package + +on: + workflow_call: + inputs: + python-version: + required: true + type: string + poetry-version: + required: false + type: string + default: "1.8" + coverage: + required: false + type: boolean + default: false + os: + required: false + type: string + default: "ubuntu-latest" + timeout: + required: false + type: number + default: 60 + + +jobs: + test: + timeout-minutes: ${{ inputs.timeout }} + runs-on: ${{ inputs.os }} + defaults: + run: + shell: bash + services: + postgres: + image: postgres + env: + POSTGRES_USER: postgres + POSTGRES_PASSWORD: postgres + POSTGRES_DB: test_db + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + ports: + - 5432:5432 + steps: + - uses: actions/checkout@v4 + - name: Set up Python ${{ inputs.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ inputs.python-version }} + - name: Set up PDM + uses: pdm-project/setup-pdm@v3 + with: + python-version: ${{ inputs.python-version }} + - uses: actions/cache@v3 + name: Define a cache for the virtual environment based on the dependencies lock file + with: + path: ./.venv + key: venv-${{ hashFiles('pdm.lock') }} + - name: Install the project dependencies + run: make install + - name: Run the automated tests with coverage + if: ${{ inputs.coverage }} + run: make test + - name: Coverage Badge + uses: tj-actions/coverage-badge-py@v2 + - name: Verify Changed files + uses: tj-actions/verify-changed-files@v16 + id: verify-changed-files + with: + files: coverage.svg + - name: Commit files + if: steps.verify-changed-files.outputs.files_changed == 'true' + run: | + git config --local user.email "github-actions[bot]@users.noreply.github.com" + git config --local user.name "github-actions[bot]" + git add coverage.svg + git commit -m "Updated coverage.svg" + - name: Push changes + if: steps.verify-changed-files.outputs.files_changed == 'true' + uses: ad-m/github-push-action@master + with: + github_token: ${{ secrets.github_token }} + branch: ${{ github.ref }} \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..3cb1e13 --- /dev/null +++ b/.gitignore @@ -0,0 +1,167 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm-python +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + + +# VSCode +.vscode +test.db +report.txt \ No newline at end of file diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..21e3e52 --- /dev/null +++ b/Makefile @@ -0,0 +1,65 @@ +NAME := sqlalchemy_profiler +PDM := $(shell command -v pdm 2> /dev/null) + +.DEFAULT_GOAL := help + +.PHONY: help +help: + @echo -e "Please, use \033[0;33m'make '\033[0m where is one the following commands:" + @echo "" + @echo -e " \033[0;33minstall\033[0m run installation for all dependencies" + @echo -e " \033[0;33mshell\033[0m run ipython shell" + @echo -e " \033[0;33mclean\033[0m run delete all not needed files" + @echo -e " \033[0;33mlint\033[0m run project code checking without formatting" + @echo -e " \033[0;33mformat\033[0m run project code formatting" + @echo -e " \033[0;33mtest\033[0m run all tests" + @echo -e " \033[0;33mtest_docker\033[0m run all tests in docker" + + @echo "" + @echo -e "Check \033[0;33mMakefile\033[0m to get full context of commands." + + +.PHONY: install +install: + @if [ -z $(PDM) ]; then echo "PDM could not be found."; exit 2; fi + $(PDM) install -G:all --no-self + + +.PHONY: shell +shell: + @if [ -z $(PDM) ]; then echo "Poetry could not be found. See https://python-poetry.org/docs/"; exit 2; fi + $(ENV_VARS_PREFIX) $(PDM) run ipython --no-confirm-exit --no-banner --quick \ + --InteractiveShellApp.extensions="autoreload" \ + --InteractiveShellApp.exec_lines="%autoreload 2" + +.PHONY: clean +clean: + find . -type d -name "__pycache__" | xargs rm -rf {}; + +.PHONY: lint +lint: + @if [ -z $(PDM) ]; then echo "Poetry could not be found. See https://python-poetry.org/docs/"; exit 2; fi + $(PDM) run pyright $(NAME) + $(PDM) run isort --settings-path ./pyproject.toml --check-only $(NAME) + $(PDM) run black --config ./pyproject.toml --check $(NAME) --diff + $(PDM) run ruff check $(NAME) + $(PDM) run vulture $(NAME) --min-confidence 100 --exclude "migration_numbering.py" + $(PDM) run bandit --configfile ./pyproject.toml -r ./$(NAME) + +.PHONY: format +format: + @if [ -z $(PDM) ]; then echo "Poetry could not be found. See https://python-poetry.org/docs/"; exit 2; fi + $(PDM) run isort --settings-path ./pyproject.toml $(NAME) + $(PDM) run black --config ./pyproject.toml $(NAME) + +.PHONY: test +test: + @if [ -z $(PDM) ]; then echo "Poetry could not be found. See https://python-poetry.org/docs/"; exit 2; fi + $(PDM) run pytest ./tests --cov-report xml --cov-fail-under 95 --cov ./$(NAME) -vv + + +.PHONY: test_docker +test_docker: + @if [ -z $(PDM) ]; then echo "Poetry could not be found. See https://python-poetry.org/docs/"; exit 2; fi + $(ENV_VARS_PREFIX) docker-compose -f docker/docker-compose-test.yaml up --build + $(ENV_VARS_PREFIX) docker-compose -f docker/docker-compose-test.yaml down \ No newline at end of file diff --git a/README.md b/README.md index 7710f1a..16d59a5 100644 --- a/README.md +++ b/README.md @@ -1 +1,36 @@ -# sqlalchemy_profiler \ No newline at end of file + +# Dev utils + +![coverage](./coverage.svg) + +## For what? + +I made this project to avoid copy-pasting with utils in my projects. I was aiming to simplify +working with sqlalchemy, FastAPI and other libraries. + +## Install + +With pip: + +```bash +pip install sqlalchemy_profiler +``` + +With pdm: + +```bash +pdm install sqlalchemy_profiler +``` + +With poetry: + +```bash +poetry add sqlalchemy_profiler +``` + +## Profiling + +Profiling utils. Now available 2 profilers and 2 middlewares (FastAPI) for such profilers: + +1. SQLAlchemyQueryProfiler - profile entire sqlalchemy query - it text, params, duration. +2. SQLAlchemyQueryCounter - count sqlalchemy queries. diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..9652ac1 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,140 @@ +[tool.ruff] +output-format = "full" +lint.select = ["ALL"] +line-length = 100 +exclude = [ + ".git", + "__pycache__", + ".venv", + ".eggs", + "*.egg", + "dist", + "tests/fixtures/**", + "tests/**/snapshots/**", + "alembic", + "airich", +] +lint.ignore = [ + "D100", + "B008", + "D104", + "Q000", + "S101", + "PT016", + "ANN101", + "ANN102", + "N805", + "UP037", + "PLC0414", +] + +[tool.ruff.lint.pydocstyle] +convention = "numpy" +ignore-decorators = ["typing.overload"] + +[tool.ruff.lint.mccabe] +max-complexity = 11 + +[tool.ruff.lint.flake8-bugbear] +extend-immutable-calls = ["fastapi.Depends", "fastapi.Query"] + +[tool.ruff.lint.extend-per-file-ignores] +"typings/*" = ["ANN401"] +"__init__.py" = ["F401"] +"manage.py" = ["E402"] +"tests/*" = ["D103"] + + +[tool.black] +line-length = 100 +skip-string-normalization = true +include = '\.pyi?$' +exclude = ''' +/( + \.git + | \.hg + | \.mypy_cache + | \.tox + | \.venv + | _build + | buck-out + | build + | dist +)/ +''' + + +[tool.isort] +profile = "black" +line_length = 100 +multi_line_output = 3 + + +[tool.coverage] +[tool.coverage.run] +source = ["sqlrepo"] +branch = true + +[tool.coverage.report] +exclude_lines = [ + "enum.Enum", + "(Protocol):", + "(typing.Protocol):", + "pragma: no cover", + "pragma: no coverage", + "raise NotImplementedError", + "if TYPE_CHECKING:", + "if typing.TYPE_CHECKING:", + "@overload", +] + +[tool.pytest] +testpath = "tests" + + +[tool.bandit] +exclude = ['tests'] + +[tool.bandit.assert_used] +skips = ['*_test.py', '*/test_*.py'] + +[tool.pdm.dev-dependencies] +dev = [ + "ruff>=0.5.5", + "vulture>=2.11", + "pytest>=8.1.1", + "black>=24.3.0", + "bandit>=1.7.8", + "coverage>=7.4.4", + "pytest-cov>=4.1.0", + "isort>=5.13.2", + "pyright>=1.1.355", + "freezegun>=1.4.0", + "mimesis>=15.1.0", + "ipython>=8.22.2", + "sqlalchemy-utils>=0.41.2", + "psycopg2-binary>=2.9.9", + "asyncpg>=0.29.0", + "pytest-asyncio>=0.23.6", + "httpx>=0.27.0", + "pyment>=0.3.3", + "ipython>=8.19.0", +] + + +[project] +name = "sqlalchemy_profiler" +version = "1.0.0" +description = "SQLAlchemy profiler classes and extentions for other packages." +authors = [{ name = "Dmitriy Lunev", email = "dima.lunev14@gmail.com" }] +requires-python = ">=3.11" +readme = "README.md" +license = { text = "MIT" } +dependencies = ["sqlalchemy>=2.0", "python-dev-utils==7.0.1"] + +[project.optional-dependencies] +fastapi = ["fastapi"] + +[build-system] +requires = ["pdm-backend"] +build-backend = "pdm.backend" diff --git a/sqlalchemy_profiler/__init__.py b/sqlalchemy_profiler/__init__.py new file mode 100644 index 0000000..75ed6db --- /dev/null +++ b/sqlalchemy_profiler/__init__.py @@ -0,0 +1,7 @@ +"""Profiling utils. Now available 2 profilers and 2 middlewares (FastAPI) for such profilers.""" + +from sqlalchemy_profiler.containers import QueryInfo as QueryInfo +from sqlalchemy_profiler.profilers import BaseSQLAlchemyProfiler as BaseSQLAlchemyProfiler +from sqlalchemy_profiler.profilers import SQLAlchemyQueryCounter as SQLAlchemyQueryCounter +from sqlalchemy_profiler.profilers import SQLAlchemyQueryProfiler as SQLAlchemyQueryProfiler +from sqlalchemy_profiler.utils import pretty_query_info as pretty_query_info diff --git a/sqlalchemy_profiler/containers.py b/sqlalchemy_profiler/containers.py new file mode 100644 index 0000000..b651e4b --- /dev/null +++ b/sqlalchemy_profiler/containers.py @@ -0,0 +1,52 @@ +import traceback +from typing import TYPE_CHECKING, Any, ClassVar, TypeVar + +from dev_utils.common import get_object_class_absolute_name, trim_and_plain_text + +if TYPE_CHECKING: + from sqlalchemy.engine.cursor import CursorResult + from sqlalchemy.sql import ClauseElement + from sqlalchemy.sql.compiler import Compiled + +T = TypeVar("T") + + +class QueryInfo: + """Data class (not actually has @dataclass decorator) for profiling results. + + Contains full info about query itself, but not any additional context. + """ + + repr_full_query_text: ClassVar[bool] = False + repr_template: ClassVar[str] = ( + "" + ) + + def __init__( # noqa: PLR0913 + self, + *, + text: "ClauseElement | Compiled", + stack: list[traceback.FrameSummary], + start_time: float, + end_time: float, + params_dict: dict[Any, Any], + results: "CursorResult[Any]", + ) -> None: + self.text = trim_and_plain_text(str(text)) + self.params = params_dict + self.stack = self.stack_text = stack + self.start_time = start_time + self.end_time = end_time + self.duration = end_time - start_time + # BUG: results.rowcount is always -1. Remove or fix it. + self.rowcount = results.rowcount + + def __repr__(self) -> str: # noqa: D105 + return self.repr_template.format( + cls_path=get_object_class_absolute_name(self.__class__), + text=str(self.text)[:40] if self.repr_full_query_text else str(self.text), + params=self.params, + duration=self.duration, + rowcount=self.rowcount, + ) diff --git a/sqlalchemy_profiler/ext/__init__.py b/sqlalchemy_profiler/ext/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sqlalchemy_profiler/ext/fastapi.py b/sqlalchemy_profiler/ext/fastapi.py new file mode 100644 index 0000000..85e3568 --- /dev/null +++ b/sqlalchemy_profiler/ext/fastapi.py @@ -0,0 +1,83 @@ +import uuid + +from fastapi import FastAPI, Request +from sqlalchemy import Engine +from starlette.middleware.base import RequestResponseEndpoint +from starlette.responses import Response + +from sqlalchemy_profiler.profilers import SQLAlchemyQueryCounter, SQLAlchemyQueryProfiler +from sqlalchemy_profiler.types import LogFunctionProtocol, ReportPath + + +def add_query_profiling_middleware( # noqa: PLR0913 + app: FastAPI, + engine: Engine | type[Engine] = Engine, + *, + request_id: str | uuid.UUID | None = None, + log_function: LogFunctionProtocol = print, + report_to: "ReportPath | None" = None, + log_query_stats: bool = False, +) -> FastAPI: + """Add query profiling middleware to FastAPI. + + Note: this function also can be used with starlette instance, but only before version 1.0.0, + because `middleware` decorator is deprecated and will be (or, maybe, already removed) in this + version. + """ + + async def _profiling_middleware( + request: Request, + call_next: RequestResponseEndpoint, + ) -> Response: + with SQLAlchemyQueryProfiler( + engine=engine, + request_id=request_id, + log_function=log_function, + log_query_stats=log_query_stats, + ) as profiler: + log_function(f"Profiler {request.url}: start profiling. {request_id=}") + response = await call_next(request) + profiler.report(report_to) + log_function(f"Profiler {profiler.request_id} finished.") + return response + + app.middleware("http")(_profiling_middleware) + return app + + +def add_query_counter_middleware( + app: FastAPI, + engine: Engine | type[Engine] = Engine, + *, + request_id: str | uuid.UUID | None = None, + log_function: LogFunctionProtocol = print, + log_query_stats: bool = False, +) -> FastAPI: + """Add query counting middleware to FastAPI. + + Note: this function also can be used with starlette instance, but only before version 1.0.0, + because `middleware` decorator is deprecated and will be (or, maybe, already removed) in this + version. + """ + + async def _counter_middleware( + request: Request, + call_next: RequestResponseEndpoint, + ) -> Response: + with SQLAlchemyQueryCounter( + engine=engine, + request_id=request_id, + log_function=log_function, + log_query_stats=log_query_stats, + ) as profiler: + log_function( + f"Counter {request.url}: start counting queries.", + ) + response = await call_next(request) + log_function( + f"Counter {request.url}: finish with count {profiler.collect()}. {request_id=}", + ) + return response + + app.middleware("http")(_counter_middleware) + return app diff --git a/sqlalchemy_profiler/profilers.py b/sqlalchemy_profiler/profilers.py new file mode 100644 index 0000000..f2139c9 --- /dev/null +++ b/sqlalchemy_profiler/profilers.py @@ -0,0 +1,304 @@ +import os +import queue +import time +import traceback +import uuid +from abc import ABC, abstractmethod +from collections.abc import Mapping, Sequence +from contextlib import suppress +from pathlib import Path +from types import TracebackType +from typing import TYPE_CHECKING, Any, Generic, Self, TypeVar, final + +from dev_utils.common import trim_and_plain_text +from dev_utils.guards import all_dict_keys_are_str +from sqlalchemy import event +from sqlalchemy.engine import Engine +from sqlalchemy.ext.asyncio import AsyncEngine + +from sqlalchemy_profiler.containers import QueryInfo +from sqlalchemy_profiler.types import NoLog, NoLogStub, ReportPath +from sqlalchemy_profiler.utils import pretty_query_info + +if TYPE_CHECKING: + from sqlalchemy.engine import Connection + from sqlalchemy.engine.cursor import CursorResult + from sqlalchemy.sql import ClauseElement + from sqlalchemy.sql.compiler import SQLCompiler + from sqlalchemy.util import immutabledict + + from sqlalchemy_profiler.types import LogFunctionProtocol + + +T = TypeVar("T") + + +class BaseSQLAlchemyProfiler(ABC, Generic[T]): + """Abstract base sqlalchemy profiling class. + + It is a generic class, so use it with typing. Generic uses in: + + * ``__init__``: self.collector - queue of ```` objects. For example, it could be + queue of QueryInfo objects. + * ``collect``: return value - Sequence of ````. + """ + + def __init__( + self, + engine: "type[Engine] | Engine | AsyncEngine" = Engine, + *, + request_id: str | uuid.UUID | None = None, + log_function: "LogFunctionProtocol | NoLogStub" = NoLog, + log_query_stats: bool = False, + ) -> None: + self.started = False + if isinstance(engine, AsyncEngine): + self.engine = engine.sync_engine + else: + self.engine = engine + self.log_function = log_function + self.log_query_stats = log_query_stats + + self.request_id = str(request_id) if request_id is not None else str(uuid.uuid4()) + + self._result: T | None = None + self.collector: queue.Queue[T] = queue.Queue() + + @abstractmethod + def _before_exec( + self, + conn: "Connection", + clause: "SQLCompiler", + multiparams: "Sequence[Mapping[str, Any]]", + params: "Mapping[str, Any]", + execution_options: "immutabledict[str, Any]", + ) -> None: + """Method, which will be bounded to `before_execute` handler in SQLAlchemy.""" # noqa: D401 + raise NotImplementedError + + @abstractmethod + def _after_exec( # noqa: PLR0913 + self, + conn: "Connection", + clause: "ClauseElement", + multiparams: "Sequence[Mapping[str, Any]]", + params: "Mapping[str, Any]", + execution_options: "immutabledict[str, Any]", + results: "CursorResult[Any]", + ) -> None: + """Method, which will be bounded to `after_execute` handler in SQLAlchemy.""" # noqa: D401 + raise NotImplementedError + + @final + def _extract_parameters_from_results( + self, + query_results: "CursorResult[Any]", + ) -> dict[str, Any]: + """Get parameters from query results object.""" + params_dict: dict[str, Any] = {} + compiled_parameters = getattr(query_results.context, "compiled_parameters", []) + if not compiled_parameters or not isinstance( # pragma: no cover + compiled_parameters, + Sequence, + ): + return {} + for compiled_param_dict in compiled_parameters: + if not isinstance(compiled_param_dict, dict): # pragma: no cover + continue + if not all_dict_keys_are_str(compiled_param_dict): # type: ignore pragma: no cover + continue + params_dict.update(compiled_param_dict) + return params_dict + + def start(self) -> None: + """Start the profiling process. + + Add engine-level handlers from events, which will fill collector with data. + """ + if self.started is False and not isinstance( + self.log_function, + NoLogStub, + ): + msg = f"Profiling session is already started! {self.request_id=}" + self.log_function(msg) + + self.started = True + if not event.contains(self.engine, "before_execute", self._before_exec): + event.listen(self.engine, "before_execute", self._before_exec) + if not event.contains(self.engine, "after_execute", self._after_exec): + event.listen(self.engine, "after_execute", self._after_exec) + + def stop(self) -> None: + """Stop the profiling process. + + Remove engine-level handlers from events - no other data will be put in collector. + """ + if self.started is False and not isinstance( + self.log_function, + NoLogStub, + ): + msg = f"Profiling session is already stopped. {self.request_id=}" + self.log_function(msg) + + self.started = False + if event.contains(self.engine, "before_execute", self._before_exec): + event.remove(self.engine, "before_execute", self._before_exec) + if event.contains(self.engine, "after_execute", self._after_exec): + event.remove(self.engine, "after_execute", self._after_exec) + + def __enter__(self) -> Self: + """Enter of context manager. + + Start the profiler by executing ``self.start()`` method. + """ + self.start() + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, # noqa: F401, F841, RUF100 + exc: BaseException | None, # noqa: F401, F841, RUF100 + traceback: TracebackType | None, # noqa: F401, F841, RUF100 + ) -> None: + """Exit of context manager. + + Stop the profiler by executing ``self.stop()`` method. + """ + self.stop() + + def collect(self) -> Sequence[T]: + """Collect all information from queue. + + Collect means "transform to list". You can override this method, if you want to return + other type. Use self._result as return value and assign this attribute in profiling methods. + """ + queries: list[T] = [] + with suppress(queue.Empty): + while True: + queries.append(self.collector.get(block=False)) + + return queries + + def report( + self, + stdout: "ReportPath | None" = None, # noqa: ARG002, F401, RUF100 + ) -> None: + """Make report about profiling.""" + return # pragma: no coverage + + +class SQLAlchemyQueryProfiler(BaseSQLAlchemyProfiler[QueryInfo]): + """SQLAlchemy query profiler.""" + + def _before_exec( + self, + conn: "Connection", + clause: "SQLCompiler", + multiparams: "Sequence[Mapping[str, Any]]", # noqa: ARG002, F401, F841, RUF100 + params: "Mapping[str, Any]", + execution_options: "immutabledict[str, Any]", # noqa: ARG002, F401, F841, RUF100 + ) -> None: + conn.info.setdefault("query_start_time", []).append(time.time()) + if self.log_query_stats and not isinstance( + self.log_function, + NoLogStub, + ): + msg = ( + f"Query started: {trim_and_plain_text(str(clause))}. Params: {params}. " + f"{self.request_id=}" + ) + self.log_function(msg) + + def _after_exec( # noqa: PLR0913 + self, + conn: "Connection", + clause: "ClauseElement", + multiparams: "Sequence[Mapping[str, Any]]", # noqa: ARG002, F401, F841, RUF100 + params: "Mapping[str, Any]", + execution_options: "immutabledict[str, Any]", # noqa: ARG002, F401, F841, RUF100 + results: "CursorResult[Any]", + ) -> None: + end_time = time.time() + start_time = conn.info["query_start_time"].pop(-1) + if self.log_query_stats and not isinstance( + self.log_function, + NoLogStub, + ): + msg = ( + f'Query "{trim_and_plain_text(str(clause))}" (params: {params}) ' + f"finished in {(end_time - start_time) * 1000} milliseconds. " + f"{self.request_id=}" + ) + self.log_function(msg) + + text = clause + with suppress(AttributeError): + text = clause.compile(dialect=conn.engine.dialect) + + params_dict = self._extract_parameters_from_results(results) + + stack = traceback.extract_stack()[:-1] + query_info = QueryInfo( + text=text, + stack=stack, + start_time=start_time, + end_time=end_time, + params_dict=params_dict, + results=results, + ) + + self.collector.put(query_info) + + def report( # type: ignore[reportIncompatibleMethodOverride] # noqa: D102 + self, + stdout: "ReportPath | None" = None, + ) -> None: + data = pretty_query_info(self.collect()) + if stdout is None: + if isinstance(self.log_function, NoLogStub): + return + self.log_function(data) + return + if isinstance(stdout, str): + stdout = Path(stdout) + if isinstance(stdout, Path | os.PathLike): + Path(stdout).write_text(data) + else: + stdout.write(data) + + +class SQLAlchemyQueryCounter(BaseSQLAlchemyProfiler[int]): + """SQLAlchemy query counter.""" + + def collect(self) -> int: # type: ignore[reportIncompatibleMethodOverride] # noqa: D102 + if self._result is None: # pragma: no cover + return 0 + return self._result + + def _before_exec( + self, + conn: "Connection", # noqa: ARG002, F401, F841, RUF100 + clause: "SQLCompiler", # noqa: ARG002, F401, F841, RUF100 + multiparams: "Sequence[Mapping[str, Any]]", # noqa: ARG002, F401, F841, RUF100 + params: "Mapping[str, Any]", # noqa: ARG002, F401, F841, RUF100 + execution_options: "immutabledict[str, Any]", # noqa: ARG002, F401, F841, RUF100 + ) -> None: + self._result = 0 + + def _after_exec( # noqa: PLR0913 + self, + conn: "Connection", # noqa: ARG002, F401, F841, RUF100 + clause: "ClauseElement", # noqa: ARG002, F401, F841, RUF100 + multiparams: "Sequence[Mapping[str, Any]]", # noqa: ARG002, F401, F841, RUF100 + params: "Mapping[str, Any]", # noqa: ARG002, F401, F841, RUF100 + execution_options: "immutabledict[str, Any]", # noqa: ARG002, F401, F841, RUF100 + results: "CursorResult[Any]", # noqa: ARG002, F401, F841, RUF100 + ) -> None: + if self._result is None: # pragma: no cover + self._result = 0 + self._result += 1 + + def start(self) -> None: # noqa: D102 + if not self.started: # pragma: no cover + self._result = 0 + return super().start() diff --git a/sqlalchemy_profiler/types.py b/sqlalchemy_profiler/types.py new file mode 100644 index 0000000..b18e1c3 --- /dev/null +++ b/sqlalchemy_profiler/types.py @@ -0,0 +1,86 @@ +import os +from collections.abc import Mapping +from pathlib import Path +from types import TracebackType +from typing import Any, Literal, Protocol, TypeAlias, TypeVar, overload + +_T_contra = TypeVar("_T_contra", contravariant=True) +_SysExcInfoType: TypeAlias = ( + tuple[type[BaseException], BaseException, TracebackType | None] | tuple[None, None, None] +) +_ExcInfoType: TypeAlias = bool | _SysExcInfoType | BaseException + + +class SupportsWrite(Protocol[_T_contra]): # noqa: D101 + def write(self, s: _T_contra, /) -> object: ... # noqa: D102 pragma: no cover + + +class SupportsFlush(Protocol): # noqa: D101 + def flush(self) -> object: ... # noqa: D102 pragma: no cover + + +class _SupportsWriteAndFlush(SupportsWrite[_T_contra], SupportsFlush, Protocol[_T_contra]): ... + + +class _PrintProtocol(Protocol): + @overload + @staticmethod + def __call__( # NOTE: print 1 + *values: object, + sep: str | None = " ", + end: str | None = "\n", + file: "SupportsWrite[str] | None" = None, + flush: Literal[False] = False, + ) -> None: ... + + @overload + @staticmethod + def __call__( # NOTE: print 2 + *values: object, + sep: str | None = " ", + end: str | None = "\n", + file: "_SupportsWriteAndFlush[str] | None" = None, + flush: bool, + ) -> None: ... + + @staticmethod + def __call__(*args: Any, **kwargs: Any) -> None: ... # pragma: no cover + + +class _PythonLogProtocol(Protocol): + @staticmethod + def __call__( + msg: object, + *args: object, + exc_info: _ExcInfoType | None = None, + stack_info: bool = False, + stacklevel: int = 1, + extra: Mapping[str, object] | None = None, + ) -> None: ... + + +class _StructLogProtocol(Protocol): + @staticmethod + def __call__( + event: str | None = None, + *args: Any, # noqa: ANN401 + **kw: Any, # noqa: ANN401 + ) -> Any: ... # noqa: ANN401 + + +class NoLogStub: + """Stub class, which instance will be passed to profilers. + + Make profilers not write logs. + """ + + +NoLog = NoLogStub() +"""Stub class instance. + +Use it to prevent profilers or middlewares of making logging. +""" +ReportPath: TypeAlias = str | Path | os.PathLike[str] | SupportsWrite[str] +"""Type alias for path to report path variables.""" +LogFunctionProtocol: TypeAlias = _PrintProtocol | _PythonLogProtocol | _StructLogProtocol +"""Type alias for log function protocols.""" diff --git a/sqlalchemy_profiler/utils.py b/sqlalchemy_profiler/utils.py new file mode 100644 index 0000000..0b9601f --- /dev/null +++ b/sqlalchemy_profiler/utils.py @@ -0,0 +1,31 @@ +from collections.abc import Sequence +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from sqlalchemy_profiler.containers import QueryInfo + + +def pretty_query_info(info: "QueryInfo | Sequence[QueryInfo]") -> str: + """Pretty text from QueryInfo. + + Make string from information to log it. + """ + query_template = ( + "index: {query_index}\n" + "query text: {query_text}\n" + "query params: {query_params}\n" + "query duration: {query_duration}\n" + "query rowcount (may be incorrect): {query_rowcount}\n" + ) + if not isinstance(info, Sequence): + info = [info] + return "\n".join( + query_template.format( + query_index=idx, + query_text=query.text, + query_params=query.params, + query_duration=query.duration, + query_rowcount=query.rowcount, + ) + for idx, query in enumerate(info) + ) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..1427614 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,266 @@ +import asyncio +import os +from contextlib import suppress +from typing import TYPE_CHECKING, Any + +import pytest +import pytest_asyncio +from fastapi import FastAPI +from fastapi.testclient import TestClient +from mimesis import Datetime, Locale, Text +from sqlalchemy import create_engine, select +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.ext.asyncio import async_scoped_session, async_sessionmaker, create_async_engine +from sqlalchemy.orm import scoped_session, sessionmaker + +from sqlalchemy_profiler.ext.fastapi import ( + add_query_counter_middleware, + add_query_profiling_middleware, +) +from tests.utils import ( + Base, + MyModel, + coin_flip, + create_db, + create_db_item_async, + create_db_item_sync, + destroy_db, +) + +if TYPE_CHECKING: + from collections.abc import AsyncGenerator, Generator + + from sqlalchemy import Engine + from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession + from sqlalchemy.orm import Session + + from tests.types import AsyncFactoryFunctionProtocol, SyncFactoryFunctionProtocol + + +true_stmt = {"y", "Y", "yes", "Yes", "t", "true", "True", "1"} +IS_DOCKER_TEST = os.environ.get("IS_DOCKER_TEST", "false") in true_stmt + + +@pytest.fixture(scope="session") +def event_loop() -> "Generator[asyncio.AbstractEventLoop, None, None]": + """Event loop fixture.""" + loop = asyncio.get_event_loop_policy().new_event_loop() + yield loop + loop.close() + + +@pytest.fixture(scope="session") +def db_name() -> str: + """Db name as fixture.""" + return "sqlalchemy_profiler_test_db" + + +@pytest.fixture(scope="session") +def db_user() -> str: + """DB user as fixture.""" + return "postgres" + + +@pytest.fixture(scope="session") +def db_password() -> str: + """DB password as fixture.""" + return "postgres" + + +@pytest.fixture(scope="session") +def db_host() -> str: + """DB host as fixture.""" + return "db" if IS_DOCKER_TEST else "localhost" + + +@pytest.fixture(scope="session") +def db_port() -> int: + """DB port as fixture.""" + return 5432 + + +@pytest.fixture(scope="session") +def db_domain(db_name: str, db_user: str, db_password: str, db_host: str, db_port: int) -> str: + """Domain for test db without specified driver.""" + return f"{db_user}:{db_password}@{db_host}:{db_port}/{db_name}" + + +@pytest.fixture(scope="session") +def db_sync_url(db_domain: str) -> str: + """URL for test db (will be created in db_engine): sync driver.""" + return f"postgresql://{db_domain}" + + +@pytest.fixture(scope="session") +def db_async_url(db_domain: str) -> str: + """URL for test db (will be created in db_engine): async driver.""" + return f"postgresql+asyncpg://{db_domain}" + + +@pytest.fixture(scope="session") +def db_sync_engine(db_sync_url: str) -> "Generator[Engine, None, None]": + """SQLAlchemy engine session-based fixture.""" + with suppress(SQLAlchemyError): + create_db(db_sync_url) + engine = create_engine( + db_sync_url, + echo=False, + pool_pre_ping=True, + ) + try: + yield engine + finally: + engine.dispose() + with suppress(SQLAlchemyError): + destroy_db(db_sync_url) + + +@pytest_asyncio.fixture(scope="session") # type: ignore[reportUntypedFunctionDecorator] +async def db_async_engine(db_async_url: str) -> "AsyncGenerator[AsyncEngine, None]": + """SQLAlchemy engine session-based fixture.""" + engine = create_async_engine( + db_async_url, + echo=True, + pool_pre_ping=True, + ) + try: + yield engine + finally: + await engine.dispose() + + +@pytest.fixture(scope="session") +def db_sync_session_factory(db_sync_engine: "Engine") -> "scoped_session[Session]": + """SQLAlchemy session factory session-based fixture.""" + return scoped_session( + sessionmaker( + bind=db_sync_engine, + autoflush=False, + expire_on_commit=False, + ), + ) + + +@pytest.fixture(scope="session") +def db_async_session_factory( + db_async_engine: "AsyncEngine", +) -> "async_scoped_session[AsyncSession]": + """SQLAlchemy session factory session-based fixture.""" + return async_scoped_session( + async_sessionmaker( + bind=db_async_engine, + autoflush=False, + expire_on_commit=False, + ), + asyncio.current_task, + ) + + +@pytest.fixture() +def db_sync_session( + db_sync_engine: "Engine", + db_sync_session_factory: "scoped_session[Session]", +) -> "Generator[Session, None, None]": + """SQLAlchemy session fixture.""" + Base.metadata.drop_all(db_sync_engine) + Base.metadata.create_all(db_sync_engine) + with db_sync_session_factory() as session: + yield session + + +@pytest_asyncio.fixture() # type: ignore[reportUntypedFunctionDecorator] +async def db_async_session( + db_async_engine: "AsyncEngine", + db_async_session_factory: "async_scoped_session[AsyncSession]", +) -> "AsyncGenerator[AsyncSession, None]": + """SQLAlchemy session fixture.""" + async with db_async_engine.begin() as conn: + await conn.run_sync(Base.metadata.drop_all) + await conn.run_sync(Base.metadata.create_all) + async with db_async_session_factory() as session: + yield session + + +@pytest.fixture() +def mymodel_sync_factory( + dt_faker: Datetime, + text_faker: Text, +) -> "SyncFactoryFunctionProtocol[MyModel]": + """Function-factory, that create MyModel instances.""" + + def _create( + session: "Session", + *, + commit: bool = False, + **kwargs: Any, # noqa: ANN401 + ) -> MyModel: + params: dict[str, Any] = { + "name": text_faker.sentence(), + "other_name": text_faker.sentence(), + "dt": dt_faker.datetime(), + "bl": coin_flip(), + } + params.update(kwargs) + return create_db_item_sync(session, MyModel, params, commit=commit) + + return _create + + +@pytest.fixture() +def mymodel_async_factory( + text_faker: Text, + dt_faker: Datetime, +) -> "AsyncFactoryFunctionProtocol[MyModel]": + """Function-factory, that create MyModel instances.""" + + async def _create( + session: "AsyncSession", + *, + commit: bool = False, + **kwargs: Any, # noqa: ANN401 + ) -> MyModel: + params: dict[str, Any] = { + "name": text_faker.sentence(), + "other_name": text_faker.sentence(), + "dt": dt_faker.datetime(), + "bl": coin_flip(), + } + params.update(kwargs) + return await create_db_item_async(session, MyModel, params, commit=commit) + + return _create + + +@pytest.fixture() +def text_faker() -> Text: + return Text(locale=Locale.EN) + + +@pytest.fixture() +def dt_faker() -> Datetime: + return Datetime(locale=Locale.EN) + + +@pytest.fixture() +def test_sync_app( + db_sync_session: "Session", + db_sync_engine: "Engine", + mymodel_sync_factory: "SyncFactoryFunctionProtocol[MyModel]", +) -> "Generator[TestClient, None, None]": + app = FastAPI() + add_query_profiling_middleware(app, engine=db_sync_engine) + add_query_counter_middleware(app, engine=db_sync_engine) + for _ in range(10): + mymodel_sync_factory(db_sync_session) + + @app.get("/") + def index(): # type: ignore[reportUnusedFunction] # noqa: ANN202 + stmt = select(MyModel) + items = db_sync_session.execute(stmt).scalars().all() + return [{"id": item.id} for item in items] + + with TestClient( + app=app, + base_url="http://test/", + ) as c: + yield c diff --git a/tests/test_fastapi.py b/tests/test_fastapi.py new file mode 100644 index 0000000..7ecb675 --- /dev/null +++ b/tests/test_fastapi.py @@ -0,0 +1,14 @@ +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from fastapi.testclient import TestClient + + +# TODO: tests for logs. # noqa: FIX002, TD002, TD003 + + +# def test_middlewares_correct_install(test_sync_app: "TestClient") -> None: +# result = test_sync_app.get("/") +# jsn = result.json() +# assert isinstance(jsn, list) +# assert len(jsn) == 10 # type: ignore[reportUnknownArgumentType] # noqa: PLR2004 diff --git a/tests/test_profilers.py b/tests/test_profilers.py new file mode 100644 index 0000000..644412b --- /dev/null +++ b/tests/test_profilers.py @@ -0,0 +1,103 @@ +from tempfile import TemporaryFile +from typing import TYPE_CHECKING + +import pytest +from dev_utils.common import trim_and_plain_text +from sqlalchemy import select + +from sqlalchemy_profiler import profilers +from tests.utils import MyModel + +if TYPE_CHECKING: + from sqlalchemy import Engine + from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession + from sqlalchemy.orm import Session + + +def test_sync_sql_alchemy_query_profiler( + db_sync_engine: "Engine", + db_sync_session: "Session", +) -> None: + profiler = profilers.SQLAlchemyQueryProfiler(db_sync_engine) + profiler.start() + stmt = select(MyModel) + db_sync_session.execute(stmt) + profiler.stop() + report = profiler.collect() + assert len(report) == 1 + assert isinstance(report[0], profilers.QueryInfo) + assert report[0].text == trim_and_plain_text(str(stmt)) + + +def test_sync_sql_alchemy_query_profiler_double_start( + db_sync_engine: "Engine", +) -> None: + # BUG: refactor rest. it doesn't check anything. + profiler = profilers.SQLAlchemyQueryProfiler(db_sync_engine) + profiler.start() + profiler.start() + profiler.stop() + + +def test_sync_sql_alchemy_query_profiler_report( + db_sync_engine: "Engine", +) -> None: + # BUG: refactor rest. it doesn't check anything. + profiler = profilers.SQLAlchemyQueryProfiler(db_sync_engine) + profiler.start() + profiler.stop() + file = TemporaryFile('a') + profiler.report(file) + + +def test_sync_sql_alchemy_query_profiler_double_stop( + db_sync_engine: "Engine", +) -> None: + # BUG: refactor rest. it doesn't check anything. + profiler = profilers.SQLAlchemyQueryProfiler(db_sync_engine) + profiler.start() + profiler.stop() + profiler.stop() + + +def test_sync_sql_alchemy_query_profiler_context_manager( + db_sync_engine: "Engine", + db_sync_session: "Session", +) -> None: + with profilers.SQLAlchemyQueryProfiler(db_sync_engine) as profiler: + stmt = select(MyModel) + db_sync_session.execute(stmt) + report = profiler.collect() + assert len(report) == 1 + assert isinstance(report[0], profilers.QueryInfo) + assert report[0].text == trim_and_plain_text(str(stmt)) + + +@pytest.mark.asyncio() +async def test_async_sql_alchemy_query_profiler( + db_async_engine: "AsyncEngine", + db_async_session: "AsyncSession", +) -> None: + profiler = profilers.SQLAlchemyQueryProfiler(db_async_engine) + profiler.start() + stmt = select(MyModel) + await db_async_session.execute(stmt) + profiler.stop() + report = profiler.collect() + assert len(report) == 1 + assert isinstance(report[0], profilers.QueryInfo) + assert report[0].text == trim_and_plain_text(str(stmt)) + + +@pytest.mark.asyncio() +async def test_async_sql_alchemy_query_profiler_context_manager( + db_async_engine: "AsyncEngine", + db_async_session: "AsyncSession", +) -> None: + with profilers.SQLAlchemyQueryProfiler(db_async_engine) as profiler: + stmt = select(MyModel) + await db_async_session.execute(stmt) + report = profiler.collect() + assert len(report) == 1 + assert isinstance(report[0], profilers.QueryInfo) + assert report[0].text == trim_and_plain_text(str(stmt)) diff --git a/tests/types.py b/tests/types.py new file mode 100644 index 0000000..5b5663f --- /dev/null +++ b/tests/types.py @@ -0,0 +1,32 @@ +from typing import TYPE_CHECKING, Any, Protocol, TypeVar + +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncSession + from sqlalchemy.orm import Session + + +T_co = TypeVar("T_co", covariant=True) + + +class SyncFactoryFunctionProtocol(Protocol[T_co]): + """Protocol for Sync functions-factories that create db items.""" + + @staticmethod + def __call__( # noqa: D102 + session: "Session", + *, + commit: bool = False, + **kwargs: Any, # noqa: ANN401 + ) -> T_co: ... + + +class AsyncFactoryFunctionProtocol(Protocol[T_co]): + """Protocol for Sync functions-factories that create db items.""" + + @staticmethod + async def __call__( # noqa: D102 + session: "AsyncSession", + *, + commit: bool = False, + **kwargs: Any, # noqa: ANN401 + ) -> T_co: ... diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 0000000..0d4739b --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,217 @@ +import datetime +import random +from typing import TYPE_CHECKING, Any, TypeVar + +from sqlalchemy import inspect +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.ext.hybrid import hybrid_method, hybrid_property +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column +from sqlalchemy_utils import ( # type: ignore[reportUnknownVariableType] + create_database, # type: ignore[reportUnknownVariableType] + database_exists, # type: ignore[reportUnknownVariableType] + drop_database, # type: ignore[reportUnknownVariableType] +) + +if TYPE_CHECKING: + from collections.abc import Sequence + + from sqlalchemy.ext.asyncio import AsyncSession + from sqlalchemy.orm import Session + + +T = TypeVar("T") + + +def coin_flip() -> bool: + """Coin flip: True or False.""" + return bool(random.getrandbits(1)) + + +def create_db(uri: str) -> None: + """Drop the database at ``uri`` and create a brand new one.""" + destroy_db(uri) + create_database(uri) + + +def destroy_db(uri: str) -> None: + """Destroy the database at ``uri``, if it exists.""" + if database_exists(uri): + drop_database(uri) + + +def generate_datetime_list( + *, + n: int = 10, + tz: Any = None, # noqa: ANN401 +) -> list[datetime.datetime]: + """Generate list of datetimes of given length with or without timezone.""" + now = datetime.datetime.now(tz=tz) + res = [now] + for i in range(1, n): + delta = datetime.timedelta(days=i) + res.append(now + delta) + return res + + +def assert_compare_db_items( + item1: "DeclarativeBase", + item2: "DeclarativeBase", +) -> None: + """Assert if 2 models not compare to each other.""" + if item1 is item2: + return + assert ( + item1.__class__ == item2.__class__ + ), "item1 and item2 has different classes. Cant compare." + item1_fields = set(inspect(item1.__class__).columns.keys()) + item2_fields = set(inspect(item2.__class__).columns.keys()) + assert item1_fields == item2_fields, "" + for field in item1_fields: + assert getattr( + item1, + field, + float("nan"), + ) == getattr( + item2, + field, + float("nan"), + ), f"field {field} is not compared. Different values." + + +def assert_compare_db_item_list( + items1: "Sequence[DeclarativeBase]", + items2: "Sequence[DeclarativeBase]", +) -> None: + """Assert if 2 model lists not compare to each other.""" + assert len(items1) == len(items2), f"Different lists count: {len(items1)} != {len(items2)}" + for item1, item2 in zip( + sorted(items1, key=lambda x: x.id), # type: ignore[reportAttributeAccessIssue] + sorted(items2, key=lambda x: x.id), # type: ignore[reportAttributeAccessIssue] + strict=True, + ): + assert_compare_db_items(item1, item2) + + +def assert_compare_db_item_with_dict( + item: "DeclarativeBase", + data: dict[str, Any], + *, + skip_keys_check: bool = False, +) -> None: + """Assert if model not compare to dict.""" + data_fields = set(data.keys()) + item_fields = set(inspect(item.__class__).columns.keys()) + msg = f"data fields ({data_fields}) are not compare to item fields ({item_fields})." + if not skip_keys_check: + assert set(data_fields).issubset(item_fields), msg + for field, value in data.items(): + item_field_value = getattr(item, field, float("nan")) + msg = ( + f'data ({field=} {value=}) not compare ' + f'to item ({field=} value={getattr(item, field, "")})' + ) + assert item_field_value == value, msg + + +def assert_compare_db_item_list_with_dict( + items: "Sequence[DeclarativeBase]", + data: dict[str, Any], + *, + skip_keys_check: bool = False, +) -> None: + """Assert if list of models not compare to dict.""" + data_fields = set(data.keys()) + for item in items: + item_class = item.__class__ + item_fields = set(inspect(item_class).columns.keys()) + msg = ( + f"data fields ({data_fields}) are not compare to item " + f"({item_class}) fields ({item_fields})." + ) + if not skip_keys_check: + assert set(data_fields).issubset(item_fields), msg + for field, value in data.items(): + item_field_value = getattr(item, field, float("nan")) + msg = ( + f'data ({field=} {value=}) not compare ' + f'to item ({field=} value={getattr(item, field, "")})' + ) + assert item_field_value == value, msg + + +def assert_compare_db_item_none_fields(item: "DeclarativeBase", none_fields: set[str]) -> None: + """Assert compare model instance fields for none value.""" + for field in none_fields: + item_value = getattr(item, field, float("nan")) + msg = f'Field "{field}" is not None.' + assert item_value is None, msg + + +def assert_compare_db_item_list_none_fields( + items: "Sequence[DeclarativeBase]", + none_fields: set[str], +) -> None: + """Assert compare list of model instances fields for none value.""" + for item in items: + for field in none_fields: + item_value = getattr(item, field, float("nan")) + msg = f'Field "{field}" of item {item} is not None.' + assert item_value is None, msg + + +def create_db_item_sync( + session: "Session", + model: type[T], + params: dict[str, Any], + *, + commit: bool = False, +) -> T: + """Create SQLAlchemy model item and add it to DB.""" + item = model(**params) + session.add(item) + try: + session.commit() if commit else session.flush() + except SQLAlchemyError: + session.rollback() + raise + return item + + +async def create_db_item_async( + session: "AsyncSession", + model: type[T], + params: dict[str, Any], + *, + commit: bool = False, +) -> T: + """Create SQLAlchemy model item and add it to DB.""" + item = model(**params) + session.add(item) + try: + await session.commit() if commit else await session.flush() + except SQLAlchemyError: + await session.rollback() + raise + return item + + +class Base(DeclarativeBase): # noqa: D101 + pass + + +class MyModel(Base): # noqa: D101 + __tablename__ = "my_model" + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str | None] + other_name: Mapped[str | None] + dt: Mapped[datetime.datetime | None] + bl: Mapped[bool | None] + + @hybrid_property + def full_name(self): # type: ignore[reportUnknownParameterType] # noqa: ANN201, D102 + return self.name + "" + self.other_name # type: ignore[reportUnknownVariableType] + + @hybrid_method + def get_full_name(self): # type: ignore[reportUnknownParameterType] # noqa: ANN201, D102 + return self.name + "" + self.other_name # type: ignore[reportUnknownVariableType]