From ccc53532615a6ba03c91e183f170fa5923fbe217 Mon Sep 17 00:00:00 2001 From: Dmitriy Lunev Date: Mon, 8 Jul 2024 14:46:36 +0300 Subject: [PATCH] v3.0.0: add RepositoryConfig - dataclass to separate configuration and main login of repository classes. --- pyproject.toml | 2 +- sqlrepo/__init__.py | 1 + sqlrepo/config.py | 153 ++++++++++++++++++++++ sqlrepo/queries.py | 43 +++++-- sqlrepo/repositories.py | 212 ++++++------------------------- sqlrepo/wrappers.py | 4 +- tests/test_async_repositories.py | 44 ++++++- tests/test_base_repositories.py | 13 +- tests/test_sync_repositories.py | 43 ++++++- 9 files changed, 306 insertions(+), 209 deletions(-) create mode 100644 sqlrepo/config.py diff --git a/pyproject.toml b/pyproject.toml index 558ecfa..3598fda 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -159,7 +159,7 @@ dev = [ [project] name = "sqlrepo" -version = "2.0.0" +version = "3.0.0" description = "sqlalchemy repositories with crud operations and other utils for it." authors = [{ name = "Dmitriy Lunev", email = "dima.lunev14@gmail.com" }] requires-python = ">=3.11" diff --git a/sqlrepo/__init__.py b/sqlrepo/__init__.py index a1ccc1d..941c48a 100644 --- a/sqlrepo/__init__.py +++ b/sqlrepo/__init__.py @@ -1,3 +1,4 @@ +from .config import RepositoryConfig as RepositoryConfig from .queries import BaseAsyncQuery as BaseAsyncQuery from .queries import BaseQuery as BaseQuery from .queries import BaseSyncQuery as BaseSyncQuery diff --git a/sqlrepo/config.py b/sqlrepo/config.py new file mode 100644 index 0000000..03ae1f2 --- /dev/null +++ b/sqlrepo/config.py @@ -0,0 +1,153 @@ +import datetime +from collections.abc import Callable +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Final, Literal, TypeAlias + +from dev_utils.sqlalchemy.filters.converters import ( + AdvancedOperatorFilterConverter, + BaseFilterConverter, + DjangoLikeFilterConverter, + SimpleFilterConverter, +) +from dev_utils.sqlalchemy.filters.types import FilterConverterStrategiesLiteral +from sqlalchemy.orm import selectinload + +StrField: TypeAlias = str + + +if TYPE_CHECKING: + from sqlalchemy.orm.attributes import InstrumentedAttribute + from sqlalchemy.orm.strategy_options import _AbstractLoad # type: ignore + + +filter_convert_classes: Final[dict[FilterConverterStrategiesLiteral, type[BaseFilterConverter]]] = { + "simple": SimpleFilterConverter, + "advanced": AdvancedOperatorFilterConverter, + "django": DjangoLikeFilterConverter, +} +"""Final convert class filters mapping.""" + + +@dataclass(slots=True) +class RepositoryConfig: + """Repository config as dataclass.""" + + # TODO: add specific_column_mapping to filters, joins and loads. + specific_column_mapping: "dict[str, InstrumentedAttribute[Any]]" = field(default_factory=dict) + """ + Warning! Current version of sqlrepo doesn't support this mapping for filters, joins and loads. + + Uses as mapping for some attributes, that you need to alias or need to specify column + from other models. + + Warning: if you specify column from other model, it may cause errors. For example, update + doesn't use it for filters, because joins are not presents in update. + """ + use_flush: bool = field(default=True) + """ + Uses as flag of flush method in SQLAlchemy session. + + By default, True, because repository has (mostly) multiple methods evaluate use. For example, + generally, you want to create some model instances, create some other (for example, log table) + and then receive other model instance in one use (for example, in Unit of work pattern). + + If you will work with repositories as single methods uses, switch to use_flush=False. It will + make queries commit any changes. + """ + update_set_none: bool = field(default=False) + """ + Uses as flag of set None option in ``update_instance`` method. + + If True, allow to force ``update_instance`` instance columns with None value. Works together + with ``update_allowed_none_fields``. + + By default False, because it's not safe to set column to None - current version if sqlrepo + not able to check optional type. Will be added in next versions, and ``then update_set_none`` + will be not necessary. + """ + update_allowed_none_fields: 'Literal["*"] | set[StrField]' = field(default="*") + """ + Set of strings, which represents columns of model. + + Uses as include or exclude for given data in ``update_instance`` method. + + By default allow any fields. Not dangerous, because ``update_set_none`` by default set to False, + and there will be no affect on ``update_instance`` method + """ + allow_disable_filter_by_value: bool = field(default=True) + """ + Uses as flag of filtering in disable method. + + If True, make additional filter, which will exclude items, which already disabled. + Logic of disable depends on type of disable column. See ``disable_field`` docstring for more + information. + + By default True, because it will make more efficient query to not override disable column. In + some cases (like datetime disable field) it may be better to turn off this flag to save disable + with new context (repeat disable, if your domain supports repeat disable and it make sense). + """ + disable_field_type: type[datetime.datetime] | type[bool] | None = field(default=None) + """ + Uses as choice of type of disable field. + + By default, None. Needs to be set manually, because this option depends on user custom + implementation of disable_field. If None and ``disable`` method was evaluated, there will be + RepositoryAttributeError exception raised by Repository class. + """ + disable_field: "InstrumentedAttribute[Any] | StrField | None" = field(default=None) + """ + Uses as choice of used defined disable field. + + By default, None. Needs to be set manually, because this option depends on user custom + implementation of disable_field. If None and ``disable`` method was evaluated, there will be + RepositoryAttributeError exception raised by Repository class. + """ + disable_id_field: "InstrumentedAttribute[Any] |StrField | None" = field(default=None) + """ + Uses as choice of used defined id field in model, which supports disable. + + By default, None. Needs to be set manually, because this option depends on user custom + implementation of disable_field. If None and ``disable`` method was evaluated, there will be + RepositoryAttributeError exception raised by Repository class. + """ + unique_list_items: bool = field(default=True) + """ + Warning! Ambiguous option! + ========================== + + Current version of ``sqlrepo`` works with load strategies with user configured option + ``load_strategy``. In order to make ``list`` method works stable, this option is used. + If you don't work with relationships in your model or you don't need unique (for example, + if you use selectinload), set this option to False. Otherwise keep it in True state. + """ + filter_convert_strategy: "FilterConverterStrategiesLiteral" = field(default="simple") + """ + Uses as choice of filter convert. + + By default "simple", so you able to pass filters with ``key-value`` structure. You still can + pass raw filters (just list of SQLAlchemy filters), but if you pass dict, it will be converted + to SQLAlchemy filters with passed strategy. + + Currently, supported converters: + + ``simple`` - ``key-value`` dict. + + ``advanced`` - dict with ``field``, ``value`` and ``operator`` keys. + List of operators: + + ``=, >, <, >=, <=, is, is_not, between, contains`` + + ``django-like`` - ``key-value`` dict with django-like lookups system. See django docs for + more info. + """ + # FIXME: remove it. Will cause many errors. Just pass _AbstractLoad instances itself. Not str + default_load_strategy: Callable[..., "_AbstractLoad"] = field(default=selectinload) + """ + Uses as choice of SQLAlchemy load strategies. + + By default selectinload, because it makes less errors. + """ + + def get_filter_convert_class(self) -> type[BaseFilterConverter]: + """Get filter convert class from passed strategy.""" + return filter_convert_classes[self.filter_convert_strategy] diff --git a/sqlrepo/queries.py b/sqlrepo/queries.py index 3c7781a..958e036 100644 --- a/sqlrepo/queries.py +++ b/sqlrepo/queries.py @@ -8,9 +8,8 @@ from dev_utils.core.utils import get_utc_now from dev_utils.sqlalchemy.filters.converters import BaseFilterConverter from dev_utils.sqlalchemy.utils import apply_joins, apply_loads, get_sqlalchemy_attribute -from sqlalchemy import CursorResult, and_, delete +from sqlalchemy import CursorResult, and_, delete, func, insert, or_, select, text, update from sqlalchemy import exc as sqlalchemy_exc -from sqlalchemy import func, insert, or_, select, text, update from sqlalchemy.orm import joinedload from sqlrepo.exc import QueryError @@ -604,10 +603,10 @@ def delete_item( # pragma: no coverage self.session.commit() except sqlalchemy_exc.SQLAlchemyError as exc: self.session.rollback() - msg = f"Delete from database error: {exc}" # noqa: S608 + msg = f"Error delete db_item: {exc}" # noqa: S608 self.logger.warning(msg) return False - msg = f"Delete from database success. Item: {item_repr}" # noqa: S608 + msg = f"Success delete db_item. Item: {item_repr}" # noqa: S608 self.logger.debug(msg) return True @@ -616,8 +615,8 @@ def disable_items( *, model: type["BaseSQLAlchemyModel"], ids_to_disable: set[Any], - id_field: "StrField", - disable_field: "StrField", + id_field: "InstrumentedAttribute[Any] | StrField", + disable_field: "InstrumentedAttribute[Any] | StrField", field_type: type[datetime.datetime] | type[bool] = datetime.datetime, allow_filter_by_value: bool = True, extra_filters: "Filter | None" = None, @@ -627,8 +626,16 @@ def disable_items( stmt = self._disable_items_stmt( model=model, ids_to_disable=ids_to_disable, - id_field=get_sqlalchemy_attribute(model, id_field, only_columns=True), - disable_field=get_sqlalchemy_attribute(model, disable_field, only_columns=True), + id_field=( + get_sqlalchemy_attribute(model, id_field, only_columns=True) + if isinstance(id_field, str) + else id_field + ), + disable_field=( + get_sqlalchemy_attribute(model, disable_field, only_columns=True) + if isinstance(disable_field, str) + else disable_field + ), field_type=field_type, allow_filter_by_value=allow_filter_by_value, extra_filters=extra_filters, @@ -905,10 +912,10 @@ async def delete_item( # pragma: no coverage await self.session.commit() except sqlalchemy_exc.SQLAlchemyError as exc: await self.session.rollback() - msg = f"Delete from database error: {exc}" # noqa: S608 + msg = f"Error delete db_item: {exc}" # noqa: S608 self.logger.warning(msg) return False - msg = f"Delete from database success. Item: {item_repr}" # noqa: S608 + msg = f"Success delete db_item. Item: {item_repr}" # noqa: S608 self.logger.debug(msg) return True @@ -917,8 +924,8 @@ async def disable_items( *, model: type["BaseSQLAlchemyModel"], ids_to_disable: set[Any], - id_field: "StrField", - disable_field: "StrField", + id_field: "InstrumentedAttribute[Any] | StrField", + disable_field: "InstrumentedAttribute[Any] | StrField", field_type: type[datetime.datetime] | type[bool] = datetime.datetime, allow_filter_by_value: bool = True, extra_filters: "Filter | None" = None, @@ -928,8 +935,16 @@ async def disable_items( stmt = self._disable_items_stmt( model=model, ids_to_disable=ids_to_disable, - id_field=get_sqlalchemy_attribute(model, id_field, only_columns=True), - disable_field=get_sqlalchemy_attribute(model, disable_field, only_columns=True), + id_field=( + get_sqlalchemy_attribute(model, id_field, only_columns=True) + if isinstance(id_field, str) + else id_field + ), + disable_field=( + get_sqlalchemy_attribute(model, disable_field, only_columns=True) + if isinstance(disable_field, str) + else disable_field + ), field_type=field_type, allow_filter_by_value=allow_filter_by_value, extra_filters=extra_filters, diff --git a/sqlrepo/repositories.py b/sqlrepo/repositories.py index f96bb51..a23a36f 100644 --- a/sqlrepo/repositories.py +++ b/sqlrepo/repositories.py @@ -1,17 +1,12 @@ """Main implementations for sqlrepo project.""" -import datetime import importlib import warnings -from collections.abc import Callable from typing import ( TYPE_CHECKING, Any, - ClassVar, - Final, ForwardRef, Generic, - Literal, NotRequired, TypeAlias, TypedDict, @@ -20,17 +15,10 @@ overload, ) -from dev_utils.sqlalchemy.filters.converters import ( - AdvancedOperatorFilterConverter, - BaseFilterConverter, - DjangoLikeFilterConverter, - SimpleFilterConverter, -) -from dev_utils.sqlalchemy.filters.types import FilterConverterStrategiesLiteral from sqlalchemy.orm import DeclarativeBase as Base -from sqlalchemy.orm import selectinload from sqlrepo import exc as sqlrepo_exc +from sqlrepo.config import RepositoryConfig from sqlrepo.logging import logger as default_logger from sqlrepo.queries import BaseAsyncQuery, BaseSyncQuery from sqlrepo.wrappers import wrap_any_exception_manager @@ -40,7 +28,7 @@ from logging import Logger from sqlalchemy.ext.asyncio import AsyncSession - from sqlalchemy.orm.attributes import InstrumentedAttribute, QueryableAttribute + from sqlalchemy.orm.attributes import QueryableAttribute from sqlalchemy.orm.session import Session from sqlalchemy.orm.strategy_options import _AbstractLoad # type: ignore from sqlalchemy.sql._typing import _ColumnExpressionOrStrLabelArgument # type: ignore @@ -109,140 +97,19 @@ class AdminRepository(BaseSyncRepository[Admin]): ``` """ - # TODO: add specific_column_mapping to filters, joins and loads. - specific_column_mapping: ClassVar["dict[str, InstrumentedAttribute[Any]]"] = {} - """ - Warning! Current version of sqlrepo doesn't support this mapping for filters, joins and loads. - - Uses as mapping for some attributes, that you need to alias or need to specify column - from other models. - - Warning: if you specify column from other model, it may cause errors. For example, update - doesn't use it for filters, because joins are not presents in update. - """ - use_flush: ClassVar[bool] = True - """ - Uses as flag of flush method in SQLAlchemy session. - - By default, True, because repository has (mostly) multiple methods evaluate use. For example, - generally, you want to create some model instances, create some other (for example, log table) - and then receive other model instance in one use (for example, in Unit of work pattern). - - If you will work with repositories as single methods uses, switch to use_flush=False. It will - make queries commit any changes. - """ - update_set_none: ClassVar[bool] = False - """ - Uses as flag of set None option in ``update_instance`` method. - - If True, allow to force ``update_instance`` instance columns with None value. Works together - with ``update_allowed_none_fields``. - - By default False, because it's not safe to set column to None - current version if sqlrepo - not able to check optional type. Will be added in next versions, and ``then update_set_none`` - will be not necessary. - """ - update_allowed_none_fields: ClassVar['Literal["*"] | set[StrField]'] = "*" - """ - Set of strings, which represents columns of model. - - Uses as include or exclude for given data in ``update_instance`` method. - - By default allow any fields. Not dangerous, because ``update_set_none`` by default set to False, - and there will be no affect on ``update_instance`` method - """ - allow_disable_filter_by_value: ClassVar[bool] = True - """ - Uses as flag of filtering in disable method. - - If True, make additional filter, which will exclude items, which already disabled. - Logic of disable depends on type of disable column. See ``disable_field`` docstring for more - information. - - By default True, because it will make more efficient query to not override disable column. In - some cases (like datetime disable field) it may be better to turn off this flag to save disable - with new context (repeat disable, if your domain supports repeat disable and it make sense). - """ - disable_field_type: ClassVar[type[datetime.datetime] | type[bool] | None] = None - """ - Uses as choice of type of disable field. - - By default, None. Needs to be set manually, because this option depends on user custom - implementation of disable_field. If None and ``disable`` method was evaluated, there will be - RepositoryAttributeError exception raised by Repository class. - """ - disable_field: ClassVar["StrField | None"] = None - """ - Uses as choice of used defined disable field. - - By default, None. Needs to be set manually, because this option depends on user custom - implementation of disable_field. If None and ``disable`` method was evaluated, there will be - RepositoryAttributeError exception raised by Repository class. - """ - disable_id_field: ClassVar["StrField | None"] = None - """ - Uses as choice of used defined id field in model, which supports disable. - - By default, None. Needs to be set manually, because this option depends on user custom - implementation of disable_field. If None and ``disable`` method was evaluated, there will be - RepositoryAttributeError exception raised by Repository class. - """ - unique_list_items: ClassVar[bool] = True - """ - Warning! Ambiguous option! - ========================== - - Current version of ``sqlrepo`` works with load strategies with user configured option - ``load_strategy``. In order to make ``list`` method works stable, this option is used. - If you don't work with relationships in your model or you don't need unique (for example, - if you use selectinload), set this option to False. Otherwise keep it in True state. - """ - filter_convert_strategy: ClassVar[FilterConverterStrategiesLiteral] = "simple" - """ - Uses as choice of filter convert. - - By default "simple", so you able to pass filters with ``key-value`` structure. You still can - pass raw filters (just list of SQLAlchemy filters), but if you pass dict, it will be converted - to SQLAlchemy filters with passed strategy. - - Currently, supported converters: - - ``simple`` - ``key-value`` dict. - - ``advanced`` - dict with ``field``, ``value`` and ``operator`` keys. - List of operators: - - ``=, >, <, >=, <=, is, is_not, between, contains`` - - ``django-like`` - ``key-value`` dict with django-like lookups system. See django docs for - more info. - """ - load_strategy: ClassVar[Callable[..., "_AbstractLoad"]] = selectinload - """ - Uses as choice of SQLAlchemy load strategies. - - By default selectinload, because it makes less errors. - """ - - _filter_convert_classes: Final[ - dict[FilterConverterStrategiesLiteral, type[BaseFilterConverter]] - ] = { - "simple": SimpleFilterConverter, - "advanced": AdvancedOperatorFilterConverter, - "django": DjangoLikeFilterConverter, - } - """ - Final convert class filters mapping. + config = RepositoryConfig() + """Repository config, that contains all settings. - Don't override it, because it can makes unexpected errors. + To add your own settings, inherit RepositoryConfig and add your own fields, then init it in + your Repository class as class variable. """ @classmethod def _validate_disable_attributes(cls) -> None: if ( - cls.disable_id_field is None - or cls.disable_field is None - or cls.disable_field_type is None + cls.config.disable_id_field is None + or cls.config.disable_field is None + or cls.config.disable_field_type is None ): msg = ( 'Attribute "disable_id_field" or "disable_field" or "disable_field_type" not ' @@ -304,11 +171,6 @@ def __init_subclass__(cls) -> None: # noqa: D105 return cls.model_class = model # type: ignore - @classmethod - def get_filter_convert_class(cls) -> type[BaseFilterConverter]: - """Get filter convert class from passed strategy.""" - return cls._filter_convert_classes[cls.filter_convert_strategy] - class BaseAsyncRepository(BaseRepository[BaseSQLAlchemyModel]): """Base repository class with async interface. @@ -329,9 +191,9 @@ def __init__( self.logger = logger self.queries = self.query_class( session, - self.get_filter_convert_class(), - self.specific_column_mapping, - self.load_strategy, + self.config.get_filter_convert_class(), + self.config.specific_column_mapping, + self.config.default_load_strategy, logger, ) @@ -391,7 +253,7 @@ async def list( # noqa: A003 order_by=order_by, limit=limit, offset=offset, - unique_items=self.unique_list_items, + unique_items=self.config.unique_list_items, ) return result @@ -434,7 +296,7 @@ async def update( model=self.model_class, data=data, filters=filters, - use_flush=self.use_flush, + use_flush=self.config.use_flush, ) return result @@ -452,9 +314,9 @@ async def update_instance( result = await self.queries.change_item( data=data, item=instance, - set_none=self.update_set_none, - allowed_none_fields=self.update_allowed_none_fields, - use_flush=self.use_flush, + set_none=self.config.update_set_none, + allowed_none_fields=self.config.update_allowed_none_fields, + use_flush=self.config.use_flush, ) return result @@ -468,7 +330,7 @@ async def delete( result = await self.queries.db_delete( model=self.model_class, filters=filters, - use_flush=self.use_flush, + use_flush=self.config.use_flush, ) return result @@ -484,12 +346,12 @@ async def disable( result = await self.queries.disable_items( model=self.model_class, ids_to_disable=ids_to_disable, - id_field=self.disable_id_field, # type: ignore - disable_field=self.disable_field, # type: ignore - field_type=self.disable_field_type, # type: ignore - allow_filter_by_value=self.allow_disable_filter_by_value, + id_field=self.config.disable_id_field, # type: ignore + disable_field=self.config.disable_field, # type: ignore + field_type=self.config.disable_field_type, # type: ignore + allow_filter_by_value=self.config.allow_disable_filter_by_value, extra_filters=extra_filters, - use_flush=self.use_flush, + use_flush=self.config.use_flush, ) return result @@ -512,9 +374,9 @@ def __init__( self.session = session self.queries = self.query_class( session, - self.get_filter_convert_class(), - self.specific_column_mapping, - self.load_strategy, + self.config.get_filter_convert_class(), + self.config.specific_column_mapping, + self.config.default_load_strategy, logger, ) @@ -574,7 +436,7 @@ def list( # noqa: A003 order_by=order_by, limit=limit, offset=offset, - unique_items=self.unique_list_items, + unique_items=self.config.unique_list_items, ) return result @@ -617,7 +479,7 @@ def update( model=self.model_class, data=data, filters=filters, - use_flush=self.use_flush, + use_flush=self.config.use_flush, ) return result @@ -635,9 +497,9 @@ def update_instance( result = self.queries.change_item( data=data, item=instance, - set_none=self.update_set_none, - allowed_none_fields=self.update_allowed_none_fields, - use_flush=self.use_flush, + set_none=self.config.update_set_none, + allowed_none_fields=self.config.update_allowed_none_fields, + use_flush=self.config.use_flush, ) return result @@ -651,7 +513,7 @@ def delete( result = self.queries.db_delete( model=self.model_class, filters=filters, - use_flush=self.use_flush, + use_flush=self.config.use_flush, ) return result @@ -667,11 +529,11 @@ def disable( result = self.queries.disable_items( model=self.model_class, ids_to_disable=ids_to_disable, - id_field=self.disable_id_field, # type: ignore - disable_field=self.disable_field, # type: ignore - field_type=self.disable_field_type, # type: ignore - allow_filter_by_value=self.allow_disable_filter_by_value, + id_field=self.config.disable_id_field, # type: ignore + disable_field=self.config.disable_field, # type: ignore + field_type=self.config.disable_field_type, # type: ignore + allow_filter_by_value=self.config.allow_disable_filter_by_value, extra_filters=extra_filters, - use_flush=self.use_flush, + use_flush=self.config.use_flush, ) return result diff --git a/sqlrepo/wrappers.py b/sqlrepo/wrappers.py index 9e2f6a8..e8ab111 100644 --- a/sqlrepo/wrappers.py +++ b/sqlrepo/wrappers.py @@ -29,7 +29,7 @@ def wrap_any_exception_manager() -> "Generator[None, None, Any]": raise QueryError(msg) from exc except BaseDevError as exc: msg = "error on python-dev-utils package level." - raise RepositoryError from exc + raise RepositoryError(msg) from exc except (AttributeError, TypeError, ValueError) as exc: msg = "error on python level." - raise BaseSQLRepoError from exc + raise BaseSQLRepoError(msg) from exc diff --git a/tests/test_async_repositories.py b/tests/test_async_repositories.py index b133c67..740a68c 100644 --- a/tests/test_async_repositories.py +++ b/tests/test_async_repositories.py @@ -4,6 +4,7 @@ from mimesis import Datetime, Locale, Text from sqlalchemy import func, select +from sqlrepo.config import RepositoryConfig from sqlrepo.exc import RepositoryAttributeError from sqlrepo.repositories import BaseAsyncRepository from tests.utils import ( @@ -31,10 +32,21 @@ class EmptyMyModelRepo(BaseAsyncRepository[MyModel]): # noqa: D101 class MyModelRepo(BaseAsyncRepository[MyModel]): # noqa: D101 - specific_column_mapping = {"some_specific_column": MyModel.name} - disable_id_field = "id" - disable_field = "bl" - disable_field_type = bool + config = RepositoryConfig( + specific_column_mapping={"some_specific_column": MyModel.name}, + disable_id_field="id", + disable_field="bl", + disable_field_type=bool, + ) + + +class MyModelRepoWithInstrumentedAttributes(BaseAsyncRepository[MyModel]): # noqa: D101 + config = RepositoryConfig( + specific_column_mapping={"some_specific_column": MyModel.name}, + disable_id_field=MyModel.id, + disable_field=MyModel.bl, + disable_field_type=bool, + ) @pytest.mark.asyncio() @@ -288,8 +300,8 @@ async def test_change_item_none_check( ) -> None: item = await mymodel_async_factory(db_async_session) repo = MyModelRepo(db_async_session) - repo.update_set_none = set_none # type: ignore - repo.update_allowed_none_fields = allowed_none_fields # type: ignore + repo.config.update_set_none = set_none # type: ignore + repo.config.update_allowed_none_fields = allowed_none_fields # type: ignore updated, db_item = await repo.update_instance(instance=item, data=update_data) if expected_updated_flag is not updated: pytest.skip("update flag check failed. Test needs to be changed.") @@ -354,3 +366,23 @@ async def test_disable_items_direct_value( ) == 0 ) + + +@pytest.mark.asyncio() +async def test_disable_items_direct_value_with_instrumented_attributes( + db_async_session: "AsyncSession", + mymodel_async_factory: "AsyncFactoryFunctionProtocol[MyModel]", +) -> None: + item = await mymodel_async_factory(db_async_session, bl=False) + repo = MyModelRepoWithInstrumentedAttributes(db_async_session) + disable_count = await repo.disable( + ids_to_disable={item.id}, + extra_filters={"id": item.id}, + ) + assert disable_count == 1 + assert ( + await db_async_session.scalar( + select(func.count()).select_from(MyModel).where(MyModel.bl.is_(False)), + ) + == 0 + ) diff --git a/tests/test_base_repositories.py b/tests/test_base_repositories.py index 8cbe851..e378c44 100644 --- a/tests/test_base_repositories.py +++ b/tests/test_base_repositories.py @@ -7,6 +7,7 @@ SimpleFilterConverter, ) +from sqlrepo.config import RepositoryConfig from sqlrepo.exc import RepositoryAttributeError from sqlrepo.repositories import BaseRepository, RepositoryModelClassIncorrectUseWarning from tests.utils import MyModel @@ -73,9 +74,11 @@ class CorrectRepo(BaseRepository[MyModel]): ... def test_validate_disable_attributes() -> None: class CorrectRepo(BaseRepository[MyModel]): - disable_id_field = "id" - disable_field = "bl" - disable_field_type = bool + config = RepositoryConfig( + disable_id_field="id", + disable_field="bl", + disable_field_type=bool, + ) CorrectRepo._validate_disable_attributes() # type: ignore @@ -97,6 +100,6 @@ class CorrectRepo(BaseRepository[MyModel]): ... ) def test_get_filter_convert_class(strategy: str, expected_class: Any) -> None: # noqa: ANN401 class CorrectRepo(BaseRepository[MyModel]): - filter_convert_strategy = strategy # type: ignore + config = RepositoryConfig(filter_convert_strategy=strategy) # type: ignore - assert CorrectRepo.get_filter_convert_class() == expected_class + assert CorrectRepo.config.get_filter_convert_class() == expected_class diff --git a/tests/test_sync_repositories.py b/tests/test_sync_repositories.py index 6cc6cc1..12ed06a 100644 --- a/tests/test_sync_repositories.py +++ b/tests/test_sync_repositories.py @@ -4,6 +4,7 @@ from mimesis import Datetime, Locale, Text from sqlalchemy import func, select +from sqlrepo.config import RepositoryConfig from sqlrepo.exc import RepositoryAttributeError from sqlrepo.repositories import BaseSyncRepository from tests.utils import ( @@ -31,10 +32,21 @@ class EmptyMyModelRepo(BaseSyncRepository[MyModel]): # noqa: D101 class MyModelRepo(BaseSyncRepository[MyModel]): # noqa: D101 - specific_column_mapping = {"some_specific_column": MyModel.name} - disable_id_field = "id" - disable_field = "bl" - disable_field_type = bool + config = RepositoryConfig( + specific_column_mapping={"some_specific_column": MyModel.name}, + disable_id_field="id", + disable_field="bl", + disable_field_type=bool, + ) + + +class MyModelRepoWithInstrumentedAttributes(BaseSyncRepository[MyModel]): # noqa: D101 + config = RepositoryConfig( + specific_column_mapping={"some_specific_column": MyModel.name}, + disable_id_field=MyModel.id, + disable_field=MyModel.bl, + disable_field_type=bool, + ) def test_get_item( # noqa: D103 @@ -278,8 +290,8 @@ def test_change_item_none_check( ) -> None: item = mymodel_sync_factory(db_sync_session) repo = MyModelRepo(db_sync_session) - repo.update_set_none = set_none # type: ignore - repo.update_allowed_none_fields = allowed_none_fields # type: ignore + repo.config.update_set_none = set_none # type: ignore + repo.config.update_allowed_none_fields = allowed_none_fields # type: ignore updated, db_item = repo.update_instance(instance=item, data=update_data) if expected_updated_flag is not updated: pytest.skip("update flag check failed. Test needs to be changed.") @@ -340,3 +352,22 @@ def test_disable_items_direct_value( ) == 0 ) + + +def test_disable_items_direct_value_with_instrumented_attributes( + db_sync_session: "Session", + mymodel_sync_factory: "SyncFactoryFunctionProtocol[MyModel]", +) -> None: + item = mymodel_sync_factory(db_sync_session, bl=False) + repo = MyModelRepoWithInstrumentedAttributes(db_sync_session) + disable_count = repo.disable( + ids_to_disable={item.id}, + extra_filters={"id": item.id}, + ) + assert disable_count == 1 + assert ( + db_sync_session.scalar( + select(func.count()).select_from(MyModel).where(MyModel.bl.is_(False)), + ) + == 0 + )