diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index aba5c36..fc6b61f 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -14,16 +14,4 @@ jobs: uses: ./.github/workflows/test.yaml with: coverage: true - python-version: ${{ matrix.python-version }} - quality: - name: "Code quality check (${{ matrix.python-version }}" - needs: - - test - strategy: - fail-fast: true - matrix: - python-version: ["3.12"] - poetry-version: ["1.8"] - uses: ./.github/workflows/lint.yaml - with: python-version: ${{ matrix.python-version }} \ No newline at end of file diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 2909836..fcfa081 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -64,4 +64,24 @@ jobs: run: make install - name: Run the automated tests with coverage if: ${{ inputs.coverage }} - run: make test \ No newline at end of file + run: make test + - name: Coverage Badge + uses: tj-actions/coverage-badge-py@v2 + - name: Verify Changed files + uses: tj-actions/verify-changed-files@v16 + id: verify-changed-files + with: + files: coverage.svg + - name: Commit files + if: steps.verify-changed-files.outputs.files_changed == 'true' + run: | + git config --local user.email "github-actions[bot]@users.noreply.github.com" + git config --local user.name "github-actions[bot]" + git add coverage.svg + git commit -m "Updated coverage.svg" + - name: Push changes + if: steps.verify-changed-files.outputs.files_changed == 'true' + uses: ad-m/github-push-action@master + with: + github_token: ${{ secrets.github_token }} + branch: ${{ github.ref }} \ No newline at end of file diff --git a/Makefile b/Makefile index 0a7c3c7..6d41c34 100644 --- a/Makefile +++ b/Makefile @@ -56,7 +56,7 @@ format: .PHONY: test test: @if [ -z $(PDM) ]; then echo "Poetry could not be found. See https://python-poetry.org/docs/"; exit 2; fi - $(PDM) run pytest ./tests --cov-report xml --cov-fail-under 60 --cov ./$(NAME) -v + $(PDM) run pytest ./tests --cov-report xml --cov-fail-under 60 --cov ./$(NAME) -vv .PHONY: test_docker diff --git a/README.md b/README.md index e0692a1..c26bfcf 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,7 @@ # sqlrepo +![coverage](./coverage.svg) + >SQLAlchemy repository pattern. ## Install @@ -86,6 +88,23 @@ from other models. Warning: if you specify column from other model, it may cause errors. For example, update doesn't use it for filters, because joins are not presents in update. +Current implementation use these option in search_by and order_by params, if you pass them as +strings. + +```python +from my_package.models import Admin + +class AdminRepository(BaseSyncRepository[Admin]): + specific_column_mapping = {"custom_field": Admin.id, "other_field": Admin.name} + + +admins = AdminRepository(session).list( + search='abc', + search_by="other_field", + order_by='custom_field', +) +``` + ### `use_flush` Uses as flag of `flush` method in SQLAlchemy session. @@ -217,7 +236,7 @@ class YourUnitOfWork(BaseAsyncUnitOfWork): # Your custom method, that works with your repositories and do business-logic. async def work_with_repo_together(self, model_id: int): - your_model_instance = await self.your_model_repo.get({'id': model_id}) + your_model_instance = await self.your_model_repo.get(filters={'id': model_id}) your_other_model_instance = await self.your_model_repo.list( filters={'your_model_id': model_id}, ) diff --git a/pdm.lock b/pdm.lock index 63b41ed..cc22225 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "dev"] strategy = ["cross_platform", "inherit_metadata"] lock_version = "4.4.1" -content_hash = "sha256:ad17db767f9e2d85181ad44384eb756f5acb428a92203c942abcdde54d895710" +content_hash = "sha256:4e924ec2336aa857207238251b2574073bad596578816c7aebb946e68461a4c5" [[package]] name = "anyio" @@ -722,29 +722,29 @@ files = [ [[package]] name = "python-dev-utils" -version = "1.2.3" +version = "1.3.0" requires_python = ">=3.11" summary = "My project utils package, that I use in my projects." groups = ["default"] files = [ - {file = "python_dev_utils-1.2.3-py3-none-any.whl", hash = "sha256:7ae9d15ef89606f9cacf676935394cd59a32474b1bc202ca916314586a89233d"}, - {file = "python_dev_utils-1.2.3.tar.gz", hash = "sha256:e93dd230bc79af16485d6ca1c121b648bb6b2c772ece9a2a425b375942fb42cc"}, + {file = "python_dev_utils-1.3.0-py3-none-any.whl", hash = "sha256:1d2163b503f60ae44cd8da234680f72f8dd1390668b07ec73c46f9bbf28d72e5"}, + {file = "python_dev_utils-1.3.0.tar.gz", hash = "sha256:5fc87d8d4e2fe9b91be48e80850b3582747a5f1e393892b037b1c37a1350bee2"}, ] [[package]] name = "python-dev-utils" -version = "1.2.3" +version = "1.3.0" extras = ["sqlalchemy_filters"] requires_python = ">=3.11" summary = "My project utils package, that I use in my projects." groups = ["default"] dependencies = [ - "python-dev-utils==1.2.3", + "python-dev-utils==1.3.0", "sqlalchemy>=2.0.28", ] files = [ - {file = "python_dev_utils-1.2.3-py3-none-any.whl", hash = "sha256:7ae9d15ef89606f9cacf676935394cd59a32474b1bc202ca916314586a89233d"}, - {file = "python_dev_utils-1.2.3.tar.gz", hash = "sha256:e93dd230bc79af16485d6ca1c121b648bb6b2c772ece9a2a425b375942fb42cc"}, + {file = "python_dev_utils-1.3.0-py3-none-any.whl", hash = "sha256:1d2163b503f60ae44cd8da234680f72f8dd1390668b07ec73c46f9bbf28d72e5"}, + {file = "python_dev_utils-1.3.0.tar.gz", hash = "sha256:5fc87d8d4e2fe9b91be48e80850b3582747a5f1e393892b037b1c37a1350bee2"}, ] [[package]] diff --git a/pyproject.toml b/pyproject.toml index e542ef3..11bccbe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,7 @@ lint.ignore = [ "PT016", "ANN101", "ANN102", + "PLR0913", ] [tool.ruff.lint.pydocstyle] @@ -158,7 +159,7 @@ dev = [ [project] name = "sqlrepo" -version = "1.0.2" +version = "1.2.0" description = "sqlalchemy repositories with crud operations and other utils for it." authors = [{ name = "Dmitriy Lunev", email = "dima.lunev14@gmail.com" }] requires-python = ">=3.11" @@ -166,7 +167,7 @@ readme = "README.md" license = { text = "MIT" } dependencies = [ "sqlalchemy>=2.0.29", - "python-dev-utils[sqlalchemy_filters]>=1.2.3", + "python-dev-utils[sqlalchemy_filters]>=1.3.0", ] diff --git a/sqlrepo/queries.py b/sqlrepo/queries.py index 75811fa..e5aac39 100644 --- a/sqlrepo/queries.py +++ b/sqlrepo/queries.py @@ -3,19 +3,17 @@ import datetime import re from collections.abc import Callable -from typing import TYPE_CHECKING, Any, Literal, NotRequired, TypedDict, TypeVar +from typing import TYPE_CHECKING, Any, Literal, NotRequired, TypedDict, TypeVar, overload +from dev_utils.core.utils import get_utc_now from dev_utils.sqlalchemy.filters.converters import BaseFilterConverter -from dev_utils.sqlalchemy.utils import ( - apply_joins, - apply_loads, - get_sqlalchemy_attribute, - get_utc_now, -) -from sqlalchemy import CursorResult, and_, delete, func, or_, select, text, update +from dev_utils.sqlalchemy.utils import apply_joins, apply_loads, get_sqlalchemy_attribute +from sqlalchemy import CursorResult, and_, delete from sqlalchemy import exc as sqlalchemy_exc +from sqlalchemy import func, insert, or_, select, text, update from sqlalchemy.orm import joinedload +from sqlrepo.exc import QueryError from sqlrepo.logging import logger as default_logger @@ -37,7 +35,7 @@ class JoinKwargs(TypedDict): from sqlalchemy.orm.session import Session from sqlalchemy.orm.strategy_options import _AbstractLoad # type: ignore from sqlalchemy.sql._typing import _ColumnExpressionOrStrLabelArgument # type: ignore - from sqlalchemy.sql.dml import Delete, ReturningUpdate, Update + from sqlalchemy.sql.dml import Delete, ReturningInsert, ReturningUpdate, Update from sqlalchemy.sql.elements import ColumnElement from sqlalchemy.sql.selectable import Select @@ -100,6 +98,12 @@ def _resolve_and_apply_joins( joins: "Sequence[Join]", ) -> "Select[tuple[T]]": """Resolve joins from strings.""" + # FIXME: may cause situation, when user passed Join as tuple may cause error. + # (Model, Model.id == OtherModel.model_id) # noqa: ERA001 + # or + # (Model, Model.id == OtherModel.model_id, {"isouter": True}) # noqa: ERA001 + if isinstance(joins, str): + joins = [joins] for join in joins: if isinstance(join, tuple | list): target, clause, *kw_list = join @@ -117,6 +121,8 @@ def _resolve_and_apply_loads( stmt: "Select[tuple[T]]", loads: "Sequence[Load]", ) -> "Select[tuple[T]]": + if isinstance(loads, str): + loads = [loads] for load in loads: stmt = ( apply_loads(stmt, load, load_strategy=self.load_strategy) @@ -240,6 +246,35 @@ def _get_item_list_stmt( stmt = stmt.offset(offset) return stmt + def _db_insert_stmt( + self, + *, + model: type["BaseSQLAlchemyModel"], + data: "DataDict | Sequence[DataDict] | None" = None, + ) -> "ReturningInsert[tuple[BaseSQLAlchemyModel]]": + """Generate SQLAlchemy stmt to insert data.""" + stmt = insert(model) + stmt = stmt.values() if data is None else stmt.values(data) + stmt = stmt.returning(model) + return stmt + + def _prepare_create_items( + self, + *, + model: type["BaseSQLAlchemyModel"], + data: "DataDict | Sequence[DataDict | None] | None" = None, + ) -> "Sequence[BaseSQLAlchemyModel]": + """Prepare items to create. + + Initialize model instances by given data. + """ + if isinstance(data, dict) or data is None: + data = [data] + items: list["BaseSQLAlchemyModel"] = [] + for data_ele in data: + items.append(model() if data_ele is None else model(**data_ele)) + return items + def _db_update_stmt( self, *, @@ -390,28 +425,88 @@ def get_item_list( return result.unique().all() return result.all() + @overload + def db_create( + self, + *, + model: type["BaseSQLAlchemyModel"], + data: "DataDict | None", + use_flush: bool = False, + ) -> "BaseSQLAlchemyModel": ... + + @overload + def db_create( + self, + *, + model: type["BaseSQLAlchemyModel"], + data: "Sequence[DataDict]", + use_flush: bool = False, + ) -> "Sequence[BaseSQLAlchemyModel]": ... + + def db_create( + self, + *, + model: type["BaseSQLAlchemyModel"], + data: "DataDict | Sequence[DataDict] | None" = None, + use_flush: bool = False, + ) -> "BaseSQLAlchemyModel | Sequence[BaseSQLAlchemyModel]": + """Insert data to given model by given data.""" + stmt = self._db_insert_stmt(model=model, data=data) + if isinstance(data, dict) or data is None: + result = self.session.scalar(stmt) + else: + result = self.session.scalars(stmt) + result = result.unique().all() + if use_flush: + self.session.flush() + else: + self.session.commit() + if not result: # pragma: no coverage + msg = f'No data was insert for model "{model}" and data {data}.' + raise QueryError(msg) + return result + + @overload def create_item( self, *, model: type["BaseSQLAlchemyModel"], - # TODO: add sequence of data to make more than one object at the same time. - data: "DataDict | None" = None, + data: "DataDict | None", use_flush: bool = False, - ) -> "BaseSQLAlchemyModel": + ) -> "BaseSQLAlchemyModel": ... + + @overload + def create_item( + self, + *, + model: type["BaseSQLAlchemyModel"], + data: "Sequence[DataDict | None]", + use_flush: bool = False, + ) -> "Sequence[BaseSQLAlchemyModel]": ... + + def create_item( + self, + *, + model: type["BaseSQLAlchemyModel"], + data: "DataDict | Sequence[DataDict | None] | None" = None, + use_flush: bool = False, + ) -> "BaseSQLAlchemyModel | Sequence[BaseSQLAlchemyModel]": """Create model instance from given data.""" - item = model() if data is None else model(**data) - self.session.add(item) + items = self._prepare_create_items(model=model, data=data) + self.session.add_all(items) if use_flush: self.session.flush() else: self.session.commit() msg = ( - f"Create row in database. Item: {item}. " + f"Create row in database. Item: {items}. " f"{'Flush used.' if use_flush else 'Commit used.'}." ) self.logger.debug(msg) - return item + if len(items) == 1: + return items[0] + return items def db_update( self, @@ -438,7 +533,6 @@ def change_item( self, *, data: "DataDict", - # TODO: add sequence items to make more than one object update at the same time. item: "BaseSQLAlchemyModel", set_none: bool = False, allowed_none_fields: 'Literal["*"] | set[str]' = "*", @@ -489,11 +583,11 @@ def db_delete( self.session.flush() else: self.session.commit() - if isinstance(result, CursorResult): # type: ignore + if isinstance(result, CursorResult): # type: ignore # pragma: no coverage return result.rowcount - return 0 + return 0 # pragma: no coverage - def delete_item( + def delete_item( # pragma: no coverage self, *, item: "Base", @@ -538,16 +632,16 @@ def disable_items( allow_filter_by_value=allow_filter_by_value, extra_filters=extra_filters, ) - if isinstance(stmt, int): + if isinstance(stmt, int): # pragma: no coverage return stmt result = self.session.execute(stmt) if use_flush: self.session.flush() else: self.session.commit() - if isinstance(result, CursorResult): # type: ignore + if isinstance(result, CursorResult): # type: ignore # pragma: no coverage return result.rowcount - return 0 + return 0 # pragma: no coverage class BaseAsyncQuery(BaseQuery): @@ -632,27 +726,88 @@ async def get_item_list( return result.unique().all() return result.all() + @overload + async def db_create( + self, + *, + model: type["BaseSQLAlchemyModel"], + data: "DataDict | None", + use_flush: bool = False, + ) -> "BaseSQLAlchemyModel": ... + + @overload + async def db_create( + self, + *, + model: type["BaseSQLAlchemyModel"], + data: "Sequence[DataDict]", + use_flush: bool = False, + ) -> "Sequence[BaseSQLAlchemyModel]": ... + + async def db_create( + self, + *, + model: type["BaseSQLAlchemyModel"], + data: "DataDict | Sequence[DataDict] | None" = None, + use_flush: bool = False, + ) -> "BaseSQLAlchemyModel | Sequence[BaseSQLAlchemyModel]": + """Insert data to given model by given data.""" + stmt = self._db_insert_stmt(model=model, data=data) + if isinstance(data, dict) or data is None: + result = await self.session.scalar(stmt) + else: + result = await self.session.scalars(stmt) + result = result.unique().all() + if use_flush: + await self.session.flush() + else: + await self.session.commit() + if not result: # pragma: no coverage + msg = f'No data was insert for model "{model}" and data {data}.' + raise QueryError(msg) + return result + + @overload async def create_item( self, *, model: type["BaseSQLAlchemyModel"], - data: "DataDict | None" = None, + data: "DataDict | None", use_flush: bool = False, - ) -> "BaseSQLAlchemyModel": + ) -> "BaseSQLAlchemyModel": ... + + @overload + async def create_item( + self, + *, + model: type["BaseSQLAlchemyModel"], + data: "Sequence[DataDict | None]", + use_flush: bool = False, + ) -> "Sequence[BaseSQLAlchemyModel]": ... + + async def create_item( + self, + *, + model: type["BaseSQLAlchemyModel"], + data: "DataDict | Sequence[DataDict | None] | None" = None, + use_flush: bool = False, + ) -> "BaseSQLAlchemyModel | Sequence[BaseSQLAlchemyModel]": """Create model instance from given data.""" - item = model() if data is None else model(**data) - self.session.add(item) + items = self._prepare_create_items(model=model, data=data) + self.session.add_all(items) if use_flush: await self.session.flush() else: await self.session.commit() msg = ( - f"Create row in database. Item: {item}. " + f"Create row in database. Items: {items}. " f"{'Flush used.' if use_flush else 'Commit used.'}." ) self.logger.debug(msg) - return item + if len(items) == 1: + return items[0] + return items async def db_update( self, @@ -661,7 +816,7 @@ async def db_update( data: "DataDict", filters: "Filter | None" = None, use_flush: bool = False, - ) -> "Sequence[BaseSQLAlchemyModel] | None": + ) -> "Sequence[BaseSQLAlchemyModel]": """Update model from given data.""" stmt = self._db_update_stmt( model=model, @@ -729,11 +884,11 @@ async def db_delete( await self.session.flush() else: await self.session.commit() - if isinstance(result, CursorResult): # type: ignore + if isinstance(result, CursorResult): # type: ignore # pragma: no coverage return result.rowcount - return 0 + return 0 # pragma: no coverage - async def delete_item( + async def delete_item( # pragma: no coverage self, *, item: "Base", @@ -778,13 +933,13 @@ async def disable_items( allow_filter_by_value=allow_filter_by_value, extra_filters=extra_filters, ) - if isinstance(stmt, int): + if isinstance(stmt, int): # pragma: no coverage return stmt result = await self.session.execute(stmt) if use_flush: await self.session.flush() else: await self.session.commit() - if isinstance(result, CursorResult): # type: ignore + if isinstance(result, CursorResult): # type: ignore # pragma: no coverage return result.rowcount - return 0 + return 0 # pragma: no coverage diff --git a/sqlrepo/repositories.py b/sqlrepo/repositories.py index fc87604..513b0e0 100644 --- a/sqlrepo/repositories.py +++ b/sqlrepo/repositories.py @@ -1,6 +1,7 @@ """Main implementations for sqlrepo project.""" import datetime +import importlib import warnings from collections.abc import Callable from typing import ( @@ -15,6 +16,7 @@ TypedDict, TypeVar, get_args, + overload, ) from dev_utils.sqlalchemy.filters.converters import ( @@ -66,6 +68,10 @@ class JoinKwargs(TypedDict): BaseSQLAlchemyModel = TypeVar("BaseSQLAlchemyModel", bound=Base) +class RepositoryModelClassIncorrectUseWarning(Warning): + """Warning about Repository model_class attribute incorrect usage.""" + + class BaseRepository(Generic[BaseSQLAlchemyModel]): """Base repository class. @@ -242,13 +248,14 @@ def _validate_disable_attributes(cls) -> None: raise sqlrepo_exc.RepositoryAttributeError(msg) def __init_subclass__(cls) -> None: # noqa: D105 + super().__init_subclass__() if hasattr(cls, "model_class"): msg = ( "Don't change model_class attribute to class. Use generic syntax instead. " "See PEP 646 (https://peps.python.org/pep-0646/). Repository will automatically " "add model_class attribute by extracting it from Generic type." ) - warnings.warn(msg, stacklevel=2) + warnings.warn(msg, RepositoryModelClassIncorrectUseWarning, stacklevel=2) return if cls.__inheritance_check_model_class__ is False: cls.__inheritance_check_model_class__ = True @@ -258,33 +265,46 @@ def __init_subclass__(cls) -> None: # noqa: D105 # NOTE: this code is needed for getting type from generic: Generic[int] -> int type # get_args get params from __orig_bases__, that contains Generic passed types. model, *_ = get_args(cls.__orig_bases__[0]) # type: ignore - except Exception as exc: - msg = f"Error during getting information about Generic types for {cls.__name__}." - raise sqlrepo_exc.RepositoryAttributeError(msg) from exc + except Exception as exc: # pragma: no coverage + msg = ( + f"Error during getting information about Generic types for {cls.__name__}. " + f"Original exception: {str(exc)}" + ) + warnings.warn(msg, RepositoryModelClassIncorrectUseWarning, stacklevel=2) + return if isinstance(model, ForwardRef): try: - model = eval(model.__forward_arg__) # noqa: S307 + repo_module = vars(cls).get("__module__") + if not repo_module: + msg = ( + f"No attribute __module__ in {cls}. Can't import global context for " + "ForwardRef resolving." + ) + raise TypeError(msg) # noqa: TRY301 + model_globals = vars(importlib.import_module(repo_module)) + model = eval(model.__forward_arg__, model_globals) # noqa: S307 except Exception as exc: msg = ( "Can't evaluate ForwardRef of generic type. " "Don't use type in generic with quotes. " f"Original exception: {str(exc)}" ) - warnings.warn(msg, stacklevel=2) + warnings.warn(msg, RepositoryModelClassIncorrectUseWarning, stacklevel=2) return if isinstance(model, TypeVar): msg = "GenericType was not passed for SQLAlchemy model declarative class." - warnings.warn(msg, stacklevel=2) + warnings.warn(msg, RepositoryModelClassIncorrectUseWarning, stacklevel=2) return if not issubclass(model, Base): msg = "Passed GenericType is not SQLAlchemy model declarative class." - warnings.warn(msg, stacklevel=2) + warnings.warn(msg, RepositoryModelClassIncorrectUseWarning, stacklevel=2) return cls.model_class = model # type: ignore - def get_filter_convert_class(self) -> type[BaseFilterConverter]: + @classmethod + def get_filter_convert_class(cls) -> type[BaseFilterConverter]: """Get filter convert class from passed strategy.""" - return self._filter_convert_classes[self.filter_convert_strategy] + return cls._filter_convert_classes[cls.filter_convert_strategy] class BaseAsyncRepository(BaseRepository[BaseSQLAlchemyModel]): @@ -344,7 +364,6 @@ async def count( async def list( # noqa: A003 self, *, - # TODO: улучшить интерфейс, чтобы можно было принимать как 1 элемент, так и несколько filters: "Filter | None" = None, joins: "Sequence[Join] | None" = None, loads: "Sequence[Load] | None" = None, @@ -369,17 +388,29 @@ async def list( # noqa: A003 ) return result - # TODO: def create - insert stmt execute + @overload + async def create( + self, + *, + data: "DataDict | None", + ) -> "BaseSQLAlchemyModel": ... + + @overload + async def create( + self, + *, + data: "Sequence[DataDict]", + ) -> "Sequence[BaseSQLAlchemyModel]": ... - async def create_instance( + async def create( self, - data: "DataDict | None" = None, - ) -> "BaseSQLAlchemyModel": + *, + data: "DataDict | Sequence[DataDict] | None", + ) -> "BaseSQLAlchemyModel | Sequence[BaseSQLAlchemyModel]": """Create model_class instance from given data.""" - result = await self.queries.create_item( + result = await self.queries.db_create( model=self.model_class, data=data, - use_flush=self.use_flush, ) return result @@ -419,6 +450,7 @@ async def update_instance( async def delete( self, + *, filters: "Filter | None" = None, ) -> "Count": """Delete model_class in db by given filters.""" @@ -429,18 +461,6 @@ async def delete( ) return result - async def delete_item( - self, - *, - instance: "BaseSQLAlchemyModel", - ) -> "Deleted": - """Delete model_class instance.""" - result = await self.queries.delete_item( - item=instance, - use_flush=self.use_flush, - ) - return result - async def disable( self, *, @@ -486,8 +506,8 @@ def __init__( def get( self, - filters: "Filter", *, + filters: "Filter", joins: "Sequence[Join] | None" = None, loads: "Sequence[Load] | None" = None, ) -> "BaseSQLAlchemyModel | None": @@ -514,10 +534,9 @@ def count( ) return result - async def list( # noqa: A003 + def list( # noqa: A003 self, *, - # TODO: улучшить интерфейс, чтобы можно было принимать как 1 элемент, так и несколько joins: "Sequence[Join] | None" = None, loads: "Sequence[Load] | None" = None, filters: "Filter | None" = None, @@ -542,22 +561,33 @@ async def list( # noqa: A003 ) return result - # TODO: async def create - insert stmt execute + @overload + def create( + self, + *, + data: "DataDict | None", + ) -> "BaseSQLAlchemyModel": ... + + @overload + def create( + self, + *, + data: "Sequence[DataDict]", + ) -> "Sequence[BaseSQLAlchemyModel]": ... - async def create_instance( + def create( self, *, - data: "DataDict | None" = None, - ) -> "BaseSQLAlchemyModel": + data: "DataDict | Sequence[DataDict] | None", + ) -> "BaseSQLAlchemyModel | Sequence[BaseSQLAlchemyModel]": """Create model_class instance from given data.""" - result = self.queries.create_item( + result = self.queries.db_create( model=self.model_class, data=data, - use_flush=self.use_flush, ) return result - async def update( + def update( self, *, data: "DataDict", @@ -593,6 +623,7 @@ def update_instance( def delete( self, + *, filters: "Filter | None" = None, ) -> "Count": """Delete model_class in db by given filters.""" @@ -603,18 +634,6 @@ def delete( ) return result - def delete_item( - self, - *, - instance: "BaseSQLAlchemyModel", - ) -> "Deleted": - """Delete model_class instance.""" - result = self.queries.delete_item( - item=instance, - use_flush=self.use_flush, - ) - return result - def disable( self, *, diff --git a/tests/conftest.py b/tests/conftest.py index bccbe70..40cd01d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,12 +1,12 @@ import asyncio import os -from contextlib import suppress -from typing import TYPE_CHECKING, Any, Generator +from typing import TYPE_CHECKING, Any import pytest +import pytest_asyncio from mimesis import Datetime, Locale, Text from sqlalchemy import create_engine -from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.ext.asyncio import async_scoped_session, async_sessionmaker, create_async_engine from sqlalchemy.orm import scoped_session, sessionmaker from tests.utils import ( @@ -21,8 +21,10 @@ ) if TYPE_CHECKING: + from collections.abc import AsyncGenerator, Generator + from sqlalchemy import Engine - from sqlalchemy.ext.asyncio import AsyncSession + from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession from sqlalchemy.orm import Session from tests.types import AsyncFactoryFunctionProtocol, SyncFactoryFunctionProtocol @@ -90,21 +92,39 @@ def db_async_url(db_domain: str) -> str: return f"postgresql+asyncpg://{db_domain}" -@pytest.fixture(scope="module") +@pytest.fixture(scope="session") +def postgresql_url_with_template_db(db_domain_with_template_db: str) -> str: + """URL for test db (will be created in db_engine): async driver.""" + return f"postgresql://{db_domain_with_template_db}" + + +@pytest.fixture(scope="session") def db_sync_engine(db_sync_url: str) -> "Generator[Engine, None, None]": """SQLAlchemy engine session-based fixture.""" - with suppress(SQLAlchemyError): - create_db(db_sync_url) + create_db(db_sync_url) engine = create_engine(db_sync_url, echo=False, pool_pre_ping=True) try: yield engine finally: engine.dispose() - with suppress(SQLAlchemyError): - destroy_db(db_sync_url) + destroy_db(db_sync_url) + + +@pytest_asyncio.fixture(scope="session") # type: ignore +async def db_async_engine( + db_async_url: str, + db_name: str, + db_sync_engine: "Engine", +) -> "AsyncGenerator[AsyncEngine, None]": # type: ignore + """SQLAlchemy engine session-based fixture.""" + engine = create_async_engine(db_async_url, echo=True, pool_pre_ping=True) + try: + yield engine + finally: + await engine.dispose() -@pytest.fixture(scope="module") +@pytest.fixture(scope="session") def db_sync_session_factory(db_sync_engine: "Engine") -> "scoped_session[Session]": """SQLAlchemy session factory session-based fixture.""" return scoped_session( @@ -116,6 +136,21 @@ def db_sync_session_factory(db_sync_engine: "Engine") -> "scoped_session[Session ) +@pytest.fixture(scope="session") # type: ignore +def db_async_session_factory( + db_async_engine: "AsyncEngine", +) -> "async_scoped_session[AsyncSession]": + """SQLAlchemy session factory session-based fixture.""" + return async_scoped_session( + async_sessionmaker( + bind=db_async_engine, + autoflush=False, + expire_on_commit=False, + ), + asyncio.current_task, + ) + + @pytest.fixture() def db_sync_session( db_sync_engine: "Engine", @@ -128,6 +163,19 @@ def db_sync_session( Base.metadata.drop_all(db_sync_engine) +@pytest_asyncio.fixture() # type: ignore +async def db_async_session( + db_async_engine: "AsyncEngine", + db_async_session_factory: "async_scoped_session[AsyncSession]", +) -> "AsyncGenerator[AsyncSession, None]": + """SQLAlchemy session fixture.""" + async with db_async_engine.begin() as conn: + await conn.run_sync(Base.metadata.drop_all) + await conn.run_sync(Base.metadata.create_all) + async with db_async_session_factory() as session: + yield session + + @pytest.fixture() def mymodel_sync_factory( dt_faker: Datetime, diff --git a/tests/test_async_queries.py b/tests/test_async_queries.py new file mode 100644 index 0000000..f9f5d32 --- /dev/null +++ b/tests/test_async_queries.py @@ -0,0 +1,448 @@ +from typing import TYPE_CHECKING, Any + +import pytest +from dev_utils.sqlalchemy.filters.converters import SimpleFilterConverter # type: ignore +from mimesis import Datetime, Locale, Text +from sqlalchemy import func, select + +from sqlrepo.queries import BaseAsyncQuery +from tests.utils import ( + MyModel, + OtherModel, + assert_compare_db_item_list, + assert_compare_db_item_list_with_dict, + assert_compare_db_item_none_fields, + assert_compare_db_item_with_dict, + assert_compare_db_items, + coin_flip, +) + +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncSession + + from tests.types import AsyncFactoryFunctionProtocol + + +text_faker = Text(locale=Locale.EN) +dt_faker = Datetime(locale=Locale.EN) + + +@pytest.mark.asyncio() +async def test_get_item( # noqa: D103 + db_async_session: "AsyncSession", + mymodel_async_factory: "AsyncFactoryFunctionProtocol[MyModel]", +) -> None: + item = await mymodel_async_factory(db_async_session, commit=True) + query_obj = BaseAsyncQuery(db_async_session, SimpleFilterConverter) + db_item = await query_obj.get_item(model=MyModel, filters=dict(id=item.id)) + assert db_item is not None, f"MyModel with id {item.id} not found in db." + assert_compare_db_items(item, db_item) + + +@pytest.mark.asyncio() +async def test_get_item_not_found( # noqa: D103 + db_async_session: "AsyncSession", + mymodel_async_factory: "AsyncFactoryFunctionProtocol[MyModel]", +) -> None: + item = await mymodel_async_factory(db_async_session, commit=True) + query_obj = BaseAsyncQuery(db_async_session, SimpleFilterConverter) + db_item = await query_obj.get_item(model=MyModel, filters=dict(id=item.id + 1)) + assert db_item is None, f"MyModel with id {item.id + 1} was found in db (but it shouldn't)." + + +@pytest.mark.asyncio() +async def test_get_items_count( # noqa: D103 + db_async_session: "AsyncSession", + mymodel_async_factory: "AsyncFactoryFunctionProtocol[MyModel]", +) -> None: + create_count = 3 + for _ in range(create_count): + await mymodel_async_factory(db_async_session, commit=False) + await db_async_session.commit() + query_obj = BaseAsyncQuery(db_async_session, SimpleFilterConverter) + count = await query_obj.get_items_count(model=MyModel) + assert count == create_count + + +@pytest.mark.asyncio() +async def test_get_items_count_with_filter( # noqa: D103 + db_async_session: "AsyncSession", + mymodel_async_factory: "AsyncFactoryFunctionProtocol[MyModel]", +) -> None: + item = await mymodel_async_factory(db_async_session, commit=False) + for _ in range(2): + await mymodel_async_factory(db_async_session, commit=False) + await db_async_session.commit() + query_obj = BaseAsyncQuery(db_async_session, SimpleFilterConverter) + count = await query_obj.get_items_count(model=MyModel, filters=dict(id=item.id)) + assert count == 1 + + +@pytest.mark.asyncio() +async def test_get_items_list( # noqa: D103 + db_async_session: "AsyncSession", + mymodel_async_factory: "AsyncFactoryFunctionProtocol[MyModel]", +) -> None: + items = [await mymodel_async_factory(db_async_session, commit=False) for _ in range(3)] + await db_async_session.commit() + query_obj = BaseAsyncQuery(db_async_session, SimpleFilterConverter) + db_items = list(await query_obj.get_item_list(model=MyModel)) + assert_compare_db_item_list(items, db_items) + + +# TODO: fix test. Now it just clone of test_get_items_list (previous test). Needs to check unique +@pytest.mark.asyncio() +async def test_get_items_list_with_unique( # noqa: D103 + db_async_session: "AsyncSession", + mymodel_async_factory: "AsyncFactoryFunctionProtocol[MyModel]", +) -> None: + items = [await mymodel_async_factory(db_async_session, commit=False) for _ in range(3)] + await db_async_session.commit() + query_obj = BaseAsyncQuery(db_async_session, SimpleFilterConverter) + db_items = list(await query_obj.get_item_list(model=MyModel, unique_items=True)) + assert_compare_db_item_list(items, db_items) + + +@pytest.mark.asyncio() +@pytest.mark.parametrize( + ("create_data", "use_flush"), + [ + ( + { + "name": text_faker.sentence(), + "other_name": text_faker.sentence(), + "dt": dt_faker.datetime(), + "bl": coin_flip(), + }, + True, + ), + ( + [ + { + "name": text_faker.sentence(), + "other_name": text_faker.sentence(), + "dt": dt_faker.datetime(), + "bl": coin_flip(), + }, + { + "name": text_faker.sentence(), + "other_name": text_faker.sentence(), + "dt": dt_faker.datetime(), + "bl": coin_flip(), + }, + ], + True, + ), + ( + { + "name": text_faker.sentence(), + "other_name": text_faker.sentence(), + "dt": dt_faker.datetime(), + "bl": coin_flip(), + }, + False, + ), + ], +) +async def test_db_create( + db_async_session: "AsyncSession", + create_data: dict[str, Any], + use_flush: bool, # noqa: FBT001 +) -> None: + query_obj = BaseAsyncQuery(db_async_session, SimpleFilterConverter) + db_item = await query_obj.db_create(model=MyModel, data=create_data, use_flush=use_flush) + if not isinstance(db_item, MyModel): # type: ignore + pytest.skip("No compare functions") + assert_compare_db_item_with_dict(db_item, create_data, skip_keys_check=True) + + +@pytest.mark.asyncio() +@pytest.mark.parametrize( + ("create_data", "use_flush"), + [ + ( + { + "name": text_faker.sentence(), + "other_name": text_faker.sentence(), + "dt": dt_faker.datetime(), + "bl": coin_flip(), + }, + True, + ), + ( + [ + { + "name": text_faker.sentence(), + "other_name": text_faker.sentence(), + "dt": dt_faker.datetime(), + "bl": coin_flip(), + }, + { + "name": text_faker.sentence(), + "other_name": text_faker.sentence(), + "dt": dt_faker.datetime(), + "bl": coin_flip(), + }, + ], + True, + ), + ( + { + "name": text_faker.sentence(), + "other_name": text_faker.sentence(), + "dt": dt_faker.datetime(), + "bl": coin_flip(), + }, + False, + ), + ( + { + "name": text_faker.sentence(), + "other_name": text_faker.sentence(), + "dt": dt_faker.datetime(), + "bl": coin_flip(), + "other_models": [ + OtherModel( + name=text_faker.sentence(), + other_name=text_faker.sentence(), + ), + OtherModel( + name=text_faker.sentence(), + other_name=text_faker.sentence(), + ), + ], + }, + False, + ), + ], +) +async def test_create_item( + db_async_session: "AsyncSession", + create_data: dict[str, Any], + use_flush: bool, # noqa: FBT001 +) -> None: + query_obj = BaseAsyncQuery(db_async_session, SimpleFilterConverter) + db_item = await query_obj.create_item(model=MyModel, data=create_data, use_flush=use_flush) + if not isinstance(db_item, MyModel): # type: ignore + pytest.skip("No compare functions") + assert_compare_db_item_with_dict(db_item, create_data, skip_keys_check=True) + + +@pytest.mark.asyncio() +@pytest.mark.parametrize( + ("update_data", "use_flush", "items_count"), + [ + ( + { + "name": text_faker.sentence(), + "other_name": text_faker.sentence(), + "dt": dt_faker.datetime(), + "bl": coin_flip(), + }, + True, + 1, + ), + ( + { + "name": text_faker.sentence(), + "other_name": text_faker.sentence(), + "dt": dt_faker.datetime(), + "bl": coin_flip(), + }, + False, + 1, + ), + ( + { + "name": text_faker.sentence(), + "other_name": text_faker.sentence(), + "dt": dt_faker.datetime(), + "bl": coin_flip(), + }, + False, + 3, + ), + ], +) +async def test_db_update( + db_async_session: "AsyncSession", + mymodel_async_factory: "AsyncFactoryFunctionProtocol[MyModel]", + update_data: Any, # noqa: ANN401 + use_flush: bool, # noqa: FBT001 + items_count: int, +) -> None: + for _ in range(items_count): + await mymodel_async_factory(db_async_session, commit=True) + query_obj = BaseAsyncQuery(db_async_session, SimpleFilterConverter) + db_item = await query_obj.db_update(model=MyModel, data=update_data, use_flush=use_flush) + assert len(db_item) == items_count + assert_compare_db_item_list_with_dict(db_item, update_data, skip_keys_check=True) + + +@pytest.mark.asyncio() +@pytest.mark.parametrize( + ("update_data", "use_flush", "expected_updated_flag"), + [ + ( + { + "name": text_faker.sentence(), + "other_name": text_faker.sentence(), + "dt": dt_faker.datetime(), + "bl": coin_flip(), + }, + True, + True, + ), + ( + { + "name": text_faker.sentence(), + "other_name": text_faker.sentence(), + "dt": dt_faker.datetime(), + "bl": coin_flip(), + }, + False, + True, + ), + ( + {}, + False, + False, + ), + ], +) +async def test_change_item( + db_async_session: "AsyncSession", + mymodel_async_factory: "AsyncFactoryFunctionProtocol[MyModel]", + update_data: dict[str, Any], + use_flush: bool, # noqa: FBT001 + expected_updated_flag: bool, # noqa: FBT001 +) -> None: + item = await mymodel_async_factory(db_async_session) + query_obj = BaseAsyncQuery(db_async_session, SimpleFilterConverter) + updated, db_item = await query_obj.change_item(data=update_data, item=item, use_flush=use_flush) + assert expected_updated_flag is updated + assert_compare_db_item_with_dict(db_item, update_data, skip_keys_check=True) + + +@pytest.mark.asyncio() +@pytest.mark.parametrize( + ("update_data", "expected_updated_flag", "set_none", "allowed_none_fields", "none_set_fields"), + [ + ( + {}, + False, + False, + {}, + {}, + ), + ( + {"name": text_faker.sentence()}, + True, + True, + "*", + {}, + ), + ( + {"name": text_faker.sentence(), "other_name": None, "dt": None, "bl": None}, + True, + True, + "*", + {"other_name", "dt", "bl"}, + ), + ( + {"name": text_faker.sentence(), "other_name": None, "dt": None, "bl": None}, + True, + True, + {"other_name"}, + {"other_name"}, + ), + ], +) +async def test_change_item_none_check( + db_async_session: "AsyncSession", + mymodel_async_factory: "AsyncFactoryFunctionProtocol[MyModel]", + update_data: dict[str, Any], + expected_updated_flag: bool, # noqa: FBT001 + set_none: bool, # noqa: FBT001 + allowed_none_fields: Any, # noqa: FBT001, ANN401 + none_set_fields: set[str], +) -> None: + item = await mymodel_async_factory(db_async_session) + query_obj = BaseAsyncQuery(db_async_session, SimpleFilterConverter) + updated, db_item = await query_obj.change_item( + data=update_data, + item=item, + set_none=set_none, + allowed_none_fields=allowed_none_fields, + ) + if expected_updated_flag is not updated: + pytest.skip("update flag check failed. Test needs to be changed.") + assert_compare_db_item_none_fields(db_item, none_set_fields) + + +@pytest.mark.asyncio() +@pytest.mark.parametrize("use_flush", [True, False]) +async def test_db_delete_direct_value( + use_flush: bool, # noqa: FBT001 + db_async_session: "AsyncSession", + mymodel_async_factory: "AsyncFactoryFunctionProtocol[MyModel]", +) -> None: + item = await mymodel_async_factory(db_async_session) + query_obj = BaseAsyncQuery(db_async_session, SimpleFilterConverter) + delete_count = await query_obj.db_delete( + model=MyModel, + filters={"id": item.id}, + use_flush=use_flush, + ) + assert delete_count == 1 + assert await db_async_session.scalar(select(func.count()).select_from(MyModel)) == 0 + + +@pytest.mark.asyncio() +@pytest.mark.parametrize("use_flush", [True, False]) +async def test_db_delete_multiple_values( + use_flush: bool, # noqa: FBT001 + db_async_session: "AsyncSession", + mymodel_async_factory: "AsyncFactoryFunctionProtocol[MyModel]", +) -> None: + to_delete_count = 3 + for _ in range(to_delete_count): + await mymodel_async_factory(db_async_session) + query_obj = BaseAsyncQuery(db_async_session, SimpleFilterConverter) + delete_count = await query_obj.db_delete( + model=MyModel, + use_flush=use_flush, + ) + assert delete_count == to_delete_count + assert await db_async_session.scalar(select(func.count()).select_from(MyModel)) == 0 + + +@pytest.mark.asyncio() +@pytest.mark.parametrize( + "use_flush", + [True, False], +) +async def test_disable_items_direct_value( + use_flush: bool, # noqa: FBT001 + db_async_session: "AsyncSession", + mymodel_async_factory: "AsyncFactoryFunctionProtocol[MyModel]", +) -> None: + item = await mymodel_async_factory(db_async_session, bl=False) + query_obj = BaseAsyncQuery(db_async_session, SimpleFilterConverter) + disable_count = await query_obj.disable_items( + model=MyModel, + ids_to_disable={item.id}, + id_field=MyModel.id, + disable_field=MyModel.bl, + field_type=bool, + allow_filter_by_value=False, + extra_filters={"id": item.id}, + use_flush=use_flush, + ) + assert disable_count == 1 + assert ( + await db_async_session.scalar( + select(func.count()).select_from(MyModel).where(MyModel.bl.is_(False)), + ) + == 0 + ) diff --git a/tests/test_base_queries.py b/tests/test_base_queries.py index 48c69c5..111f5d7 100644 --- a/tests/test_base_queries.py +++ b/tests/test_base_queries.py @@ -6,7 +6,7 @@ from dev_utils.core.exc import NoModelRelationshipError # type: ignore from dev_utils.sqlalchemy.filters.converters import SimpleFilterConverter # type: ignore from freezegun import freeze_time -from sqlalchemy import ColumnElement, and_, delete, func, or_, select, text, update +from sqlalchemy import ColumnElement, and_, delete, func, insert, or_, select, text, update from sqlalchemy.orm import joinedload, selectinload from sqlrepo.queries import BaseQuery @@ -73,6 +73,11 @@ def test_resolve_specific_columns( # noqa [(OtherModel, MyModel.id == OtherModel.my_model_id, {"full": True})], select(MyModel).join(OtherModel, full=True), ), + ( + select(MyModel), + "other_models", + select(MyModel).join(OtherModel), + ), ], ) def test_resolve_and_apply_joins( # noqa @@ -100,6 +105,12 @@ def test_resolve_and_apply_joins( # noqa ["other_models"], select(MyModel).options(selectinload(MyModel.other_models)), ), + ( + select(MyModel), + selectinload, + "other_models", + select(MyModel).options(selectinload(MyModel.other_models)), + ), ], ) def test_resolve_and_apply_loads( # noqa @@ -434,6 +445,37 @@ def test_get_item_list_stmt( # noqa assert str(get_item_list_stmt) == str(expected_result) +@pytest.mark.parametrize( + ("data", "expected_result"), + [ + (None, insert(MyModel).values().returning(MyModel)), + ({"id": 1}, insert(MyModel).values({"id": 1}).returning(MyModel)), + ([{"id": 1}], insert(MyModel).values([{"id": 1}]).returning(MyModel)), + ], +) +def test_db_insert_stmt(data: Any, expected_result: Any) -> None: # noqa: ANN401 + query = BaseQuery(SimpleFilterConverter) + db_insert_stmt = query._db_insert_stmt(model=MyModel, data=data) # type: ignore + assert str(db_insert_stmt) == str(expected_result) + + +@pytest.mark.parametrize( + ("data", "expected_result"), + [ + (None, [MyModel()]), + ({"id": 1}, [MyModel(id=1)]), + ([{"id": 1}, None], [MyModel(id=1), MyModel()]), + ([{"id": 1}, {"id": 2}], [MyModel(id=1), MyModel(id=2)]), + ], +) +def test_prepare_create_items(data: Any, expected_result: Any) -> None: # noqa: ANN401 + query = BaseQuery(SimpleFilterConverter) + prepared_values = query._prepare_create_items(model=MyModel, data=data) # type: ignore + assert len(prepared_values) == len(expected_result) + for prepared, expected in zip(prepared_values, prepared_values, strict=True): + assert prepared.__dict__ == expected.__dict__ + + @pytest.mark.parametrize( ( "data", diff --git a/tests/test_base_repositories.py b/tests/test_base_repositories.py new file mode 100644 index 0000000..9d0d745 --- /dev/null +++ b/tests/test_base_repositories.py @@ -0,0 +1,95 @@ +from typing import TYPE_CHECKING, Any + +import pytest +from dev_utils.sqlalchemy.filters.converters import ( + AdvancedOperatorFilterConverter, + DjangoLikeFilterConverter, + SimpleFilterConverter, +) + +from sqlrepo.exc import RepositoryAttributeError +from sqlrepo.repositories import BaseRepository, RepositoryModelClassIncorrectUseWarning +from tests.utils import MyModel + +if TYPE_CHECKING: + from tests.utils import OtherModel # type: ignore # noqa: F401 + + +def test_inherit_skip() -> None: + assert BaseRepository.__inheritance_check_model_class__ is True + + class MyRepo(BaseRepository): # type: ignore + __inheritance_check_model_class__ = False + + assert MyRepo.__inheritance_check_model_class__ is True + + +def test_already_set_model_class_warn() -> None: + with pytest.warns(RepositoryModelClassIncorrectUseWarning): + + class MyRepo(BaseRepository[MyModel]): # type: ignore + model_class = MyModel + + +def test_cant_eval_forward_ref() -> None: + with pytest.warns(RepositoryModelClassIncorrectUseWarning): + + class MyRepo(BaseRepository["OtherModel"]): # type: ignore + ... + + +def test_generic_incorrect_type() -> None: + with pytest.warns( + RepositoryModelClassIncorrectUseWarning, + match="Passed GenericType is not SQLAlchemy model declarative class.", + ): + + class MyRepo(BaseRepository[int]): # type: ignore + ... + + +def test_no_generic() -> None: + with pytest.warns( + RepositoryModelClassIncorrectUseWarning, + match="GenericType was not passed for SQLAlchemy model declarative class.", + ): + + class MyRepo(BaseRepository): # type: ignore + ... + + +def test_correct_use() -> None: + class CorrectRepo(BaseRepository[MyModel]): ... + + assert CorrectRepo.model_class == MyModel # type: ignore + + +def test_validate_disable_attributes() -> None: + class CorrectRepo(BaseRepository[MyModel]): + disable_id_field = MyModel.id + disable_field = MyModel.bl + disable_field_type = bool + + CorrectRepo._validate_disable_attributes() # type: ignore + + +def test_validate_disable_attributes_raise_error() -> None: + class CorrectRepo(BaseRepository[MyModel]): ... + + with pytest.raises(RepositoryAttributeError): + CorrectRepo._validate_disable_attributes() # type: ignore + + +@pytest.mark.parametrize( + ("strategy", "expected_class"), + [ + ("simple", SimpleFilterConverter), + ("advanced", AdvancedOperatorFilterConverter), + ("django", DjangoLikeFilterConverter), + ], +) +def test_get_filter_convert_class(strategy: str, expected_class: Any) -> None: # noqa: ANN401 + class CorrectRepo(BaseRepository[MyModel]): + filter_convert_strategy = strategy # type: ignore + + assert CorrectRepo.get_filter_convert_class() == expected_class diff --git a/tests/test_sync_queries.py b/tests/test_sync_queries.py index 92e1efc..d21a615 100644 --- a/tests/test_sync_queries.py +++ b/tests/test_sync_queries.py @@ -3,6 +3,7 @@ import pytest from dev_utils.sqlalchemy.filters.converters import SimpleFilterConverter # type: ignore from mimesis import Datetime, Locale, Text +from sqlalchemy import func, select from sqlrepo.queries import BaseSyncQuery from tests.utils import ( @@ -51,12 +52,13 @@ def test_get_items_count( # noqa: D103 db_sync_session: "Session", mymodel_sync_factory: "SyncFactoryFunctionProtocol[MyModel]", ) -> None: - for _ in range(3): + create_count = 3 + for _ in range(create_count): mymodel_sync_factory(db_sync_session, commit=False) db_sync_session.commit() query_obj = BaseSyncQuery(db_sync_session, SimpleFilterConverter) count = query_obj.get_items_count(model=MyModel) - assert count == 3 + assert count == create_count def test_get_items_count_with_filter( # noqa: D103 @@ -107,6 +109,75 @@ def test_get_items_list_with_unique( # noqa: D103 }, True, ), + ( + [ + { + "name": text_faker.sentence(), + "other_name": text_faker.sentence(), + "dt": dt_faker.datetime(), + "bl": coin_flip(), + }, + { + "name": text_faker.sentence(), + "other_name": text_faker.sentence(), + "dt": dt_faker.datetime(), + "bl": coin_flip(), + }, + ], + True, + ), + ( + { + "name": text_faker.sentence(), + "other_name": text_faker.sentence(), + "dt": dt_faker.datetime(), + "bl": coin_flip(), + }, + False, + ), + ], +) +def test_db_create( + db_sync_session: "Session", + create_data: dict[str, Any], + use_flush: bool, # noqa: FBT001 +) -> None: + query_obj = BaseSyncQuery(db_sync_session, SimpleFilterConverter) + db_item = query_obj.db_create(model=MyModel, data=create_data, use_flush=use_flush) + if not isinstance(db_item, MyModel): # type: ignore + pytest.skip("No compare functions") + assert_compare_db_item_with_dict(db_item, create_data, skip_keys_check=True) + + +@pytest.mark.parametrize( + ("create_data", "use_flush"), + [ + ( + { + "name": text_faker.sentence(), + "other_name": text_faker.sentence(), + "dt": dt_faker.datetime(), + "bl": coin_flip(), + }, + True, + ), + ( + [ + { + "name": text_faker.sentence(), + "other_name": text_faker.sentence(), + "dt": dt_faker.datetime(), + "bl": coin_flip(), + }, + { + "name": text_faker.sentence(), + "other_name": text_faker.sentence(), + "dt": dt_faker.datetime(), + "bl": coin_flip(), + }, + ], + True, + ), ( { "name": text_faker.sentence(), @@ -144,6 +215,8 @@ def test_create_item( ) -> None: query_obj = BaseSyncQuery(db_sync_session, SimpleFilterConverter) db_item = query_obj.create_item(model=MyModel, data=create_data, use_flush=use_flush) + if not isinstance(db_item, MyModel): # type: ignore + pytest.skip("No compare functions") assert_compare_db_item_with_dict(db_item, create_data, skip_keys_check=True) @@ -291,5 +364,71 @@ def test_change_item_none_check( set_none=set_none, allowed_none_fields=allowed_none_fields, ) - assert expected_updated_flag is updated + if expected_updated_flag is not updated: + pytest.skip("update flag check failed. Test needs to be changed.") assert_compare_db_item_none_fields(db_item, none_set_fields) + + +@pytest.mark.parametrize("use_flush", [True, False]) +def test_db_delete_direct_value( + use_flush: bool, # noqa: FBT001 + db_sync_session: "Session", + mymodel_sync_factory: "SyncFactoryFunctionProtocol[MyModel]", +) -> None: + item = mymodel_sync_factory(db_sync_session) + query_obj = BaseSyncQuery(db_sync_session, SimpleFilterConverter) + delete_count = query_obj.db_delete( + model=MyModel, + filters={"id": item.id}, + use_flush=use_flush, + ) + assert delete_count == 1 + assert db_sync_session.scalar(select(func.count()).select_from(MyModel)) == 0 + + +@pytest.mark.parametrize("use_flush", [True, False]) +def test_db_delete_multiple_values( + use_flush: bool, # noqa: FBT001 + db_sync_session: "Session", + mymodel_sync_factory: "SyncFactoryFunctionProtocol[MyModel]", +) -> None: + to_delete_count = 3 + for _ in range(to_delete_count): + mymodel_sync_factory(db_sync_session) + query_obj = BaseSyncQuery(db_sync_session, SimpleFilterConverter) + delete_count = query_obj.db_delete( + model=MyModel, + use_flush=use_flush, + ) + assert delete_count == to_delete_count + assert db_sync_session.scalar(select(func.count()).select_from(MyModel)) == 0 + + +@pytest.mark.parametrize( + "use_flush", + [True, False], +) +def test_disable_items_direct_value( + use_flush: bool, # noqa: FBT001 + db_sync_session: "Session", + mymodel_sync_factory: "SyncFactoryFunctionProtocol[MyModel]", +) -> None: + item = mymodel_sync_factory(db_sync_session, bl=False) + query_obj = BaseSyncQuery(db_sync_session, SimpleFilterConverter) + disable_count = query_obj.disable_items( + model=MyModel, + ids_to_disable={item.id}, + id_field=MyModel.id, + disable_field=MyModel.bl, + field_type=bool, + allow_filter_by_value=False, + extra_filters={"id": item.id}, + use_flush=use_flush, + ) + assert disable_count == 1 + assert ( + db_sync_session.scalar( + select(func.count()).select_from(MyModel).where(MyModel.bl.is_(False)), + ) + == 0 + )