diff --git a/chatsky/context_storages/database.py b/chatsky/context_storages/database.py index fecd0f668..6a223a08f 100644 --- a/chatsky/context_storages/database.py +++ b/chatsky/context_storages/database.py @@ -18,6 +18,7 @@ from typing import Any, Awaitable, Callable, Dict, List, Literal, Optional, Tuple, Union from chatsky.core.ctx_utils import ContextMainInfo +from chatsky.utils.decorations import classproperty from chatsky.utils.logging import collapse_num_list from .protocol import PROTOCOLS @@ -46,13 +47,14 @@ class NameConfig: _requests_field: Literal["requests"] = "requests" _responses_field: Literal["responses"] = "responses" - def get_context_main_fields(self) -> List[str]: + @classproperty + def get_context_main_fields(cls) -> List[str]: return [ - NameConfig._current_turn_id_column, - NameConfig._created_at_column, - NameConfig._updated_at_column, - NameConfig._misc_column, - NameConfig._framework_data_column, + cls._current_turn_id_column, + cls._created_at_column, + cls._updated_at_column, + cls._misc_column, + cls._framework_data_column, ] diff --git a/chatsky/context_storages/mongo.py b/chatsky/context_storages/mongo.py index 9ecd799c9..82824b0a6 100644 --- a/chatsky/context_storages/mongo.py +++ b/chatsky/context_storages/mongo.py @@ -79,10 +79,10 @@ async def _connect(self): async def _load_main_info(self, ctx_id: str) -> Optional[ContextMainInfo]: result = await self.main_table.find_one( {NameConfig._id_column: ctx_id}, - NameConfig.get_context_main_fields(), + NameConfig.get_context_main_fields, ) return ( - ContextMainInfo.model_validate({f: result[f] for f in NameConfig.get_context_main_fields()}) + ContextMainInfo.model_validate({f: result[f] for f in NameConfig.get_context_main_fields}) if result is not None else None ) @@ -101,7 +101,7 @@ async def _inner_update_context( "$set": { NameConfig._id_column: ctx_id, } | { - f: ctx_info_dump[f] for f in NameConfig.get_context_main_fields() + f: ctx_info_dump[f] for f in NameConfig.get_context_main_fields } }, upsert=True, diff --git a/chatsky/context_storages/redis.py b/chatsky/context_storages/redis.py index f32b2e818..f49ccc529 100644 --- a/chatsky/context_storages/redis.py +++ b/chatsky/context_storages/redis.py @@ -82,10 +82,10 @@ def _bytes_to_keys(keys: List[bytes]) -> List[int]: async def _load_main_info(self, ctx_id: str) -> Optional[ContextMainInfo]: if await self.database.exists(f"{self._main_key}:{ctx_id}"): retrieved_fields = await gather( - *[self.database.hget(f"{self._main_key}:{ctx_id}", f) for f in NameConfig.get_context_main_fields()] + *[self.database.hget(f"{self._main_key}:{ctx_id}", f) for f in NameConfig.get_context_main_fields] ) return ContextMainInfo.model_validate( - {f: v for f, v in zip(NameConfig.get_context_main_fields(), retrieved_fields)} + {f: v for f, v in zip(NameConfig.get_context_main_fields, retrieved_fields)} ) else: return None @@ -96,7 +96,7 @@ async def _update_context( update_main, update_values, delete_keys = list(), list(), list() if ctx_info is not None: ctx_info_dump = ctx_info.model_dump(mode="python") - update_main = [(f, ctx_info_dump[f] if isinstance(ctx_info_dump[f], bytes) else str(ctx_info_dump[f])) for f in NameConfig.get_context_main_fields()] + update_main = [(f, ctx_info_dump[f] if isinstance(ctx_info_dump[f], bytes) else str(ctx_info_dump[f])) for f in NameConfig.get_context_main_fields] for field_name, items in field_info: new_delete_keys = list() for k, v in items: diff --git a/chatsky/context_storages/sql.py b/chatsky/context_storages/sql.py index b060d1b17..b5cb6f2f8 100644 --- a/chatsky/context_storages/sql.py +++ b/chatsky/context_storages/sql.py @@ -214,7 +214,7 @@ async def _load_main_info(self, ctx_id: str) -> Optional[ContextMainInfo]: None if result is None else ContextMainInfo.model_validate( - {f: result[i + 1] for i, f in enumerate(NameConfig.get_context_main_fields())} + {f: result[i + 1] for i, f in enumerate(NameConfig.get_context_main_fields)} ) ) @@ -228,7 +228,7 @@ async def _update_context( { NameConfig._id_column: ctx_id, } | { - f: ctx_info_dump[f] for f in NameConfig.get_context_main_fields() + f: ctx_info_dump[f] for f in NameConfig.get_context_main_fields } ) main_update_stmt = _get_upsert_stmt( diff --git a/chatsky/context_storages/ydb.py b/chatsky/context_storages/ydb.py index 30b1b5f49..18880721e 100644 --- a/chatsky/context_storages/ydb.py +++ b/chatsky/context_storages/ydb.py @@ -155,7 +155,7 @@ async def callee(session: Session) -> Optional[ContextMainInfo]: ) return ( ContextMainInfo.model_validate( - {f: result_sets[0].rows[0][f] for f in NameConfig.get_context_main_fields()} + {f: result_sets[0].rows[0][f] for f in NameConfig.get_context_main_fields} ) if len(result_sets[0].rows) > 0 else None @@ -186,7 +186,7 @@ async def callee(session: Session) -> None: { f"${NameConfig._id_column}": ctx_id, } | { - f"${f}": ctx_info_dump[f] for f in NameConfig.get_context_main_fields() + f"${f}": ctx_info_dump[f] for f in NameConfig.get_context_main_fields }, commit_tx=True, ) as _: diff --git a/chatsky/utils/decorations.py b/chatsky/utils/decorations.py new file mode 100644 index 000000000..caa7a025e --- /dev/null +++ b/chatsky/utils/decorations.py @@ -0,0 +1,6 @@ +class classproperty: + def __init__(self, f): + self.f = f + + def __get__(self, _, owner): + return self.f(owner)