Skip to content

Commit

Permalink
method transformed to class property
Browse files Browse the repository at this point in the history
  • Loading branch information
pseusys committed Feb 13, 2025
1 parent fdb3db9 commit 2213fe8
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 16 deletions.
14 changes: 8 additions & 6 deletions chatsky/context_storages/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import Any, Awaitable, Callable, Dict, List, Literal, Optional, Tuple, Union

from chatsky.core.ctx_utils import ContextMainInfo
from chatsky.utils.decorations import classproperty
from chatsky.utils.logging import collapse_num_list
from .protocol import PROTOCOLS

Expand Down Expand Up @@ -46,13 +47,14 @@ class NameConfig:
_requests_field: Literal["requests"] = "requests"
_responses_field: Literal["responses"] = "responses"

def get_context_main_fields(self) -> List[str]:
@classproperty
def get_context_main_fields(cls) -> List[str]:
return [
NameConfig._current_turn_id_column,
NameConfig._created_at_column,
NameConfig._updated_at_column,
NameConfig._misc_column,
NameConfig._framework_data_column,
cls._current_turn_id_column,
cls._created_at_column,
cls._updated_at_column,
cls._misc_column,
cls._framework_data_column,
]


Expand Down
6 changes: 3 additions & 3 deletions chatsky/context_storages/mongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,10 @@ async def _connect(self):
async def _load_main_info(self, ctx_id: str) -> Optional[ContextMainInfo]:
result = await self.main_table.find_one(
{NameConfig._id_column: ctx_id},
NameConfig.get_context_main_fields(),
NameConfig.get_context_main_fields,
)
return (
ContextMainInfo.model_validate({f: result[f] for f in NameConfig.get_context_main_fields()})
ContextMainInfo.model_validate({f: result[f] for f in NameConfig.get_context_main_fields})
if result is not None
else None
)
Expand All @@ -101,7 +101,7 @@ async def _inner_update_context(
"$set": {
NameConfig._id_column: ctx_id,
} | {
f: ctx_info_dump[f] for f in NameConfig.get_context_main_fields()
f: ctx_info_dump[f] for f in NameConfig.get_context_main_fields
}
},
upsert=True,
Expand Down
6 changes: 3 additions & 3 deletions chatsky/context_storages/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,10 @@ def _bytes_to_keys(keys: List[bytes]) -> List[int]:
async def _load_main_info(self, ctx_id: str) -> Optional[ContextMainInfo]:
if await self.database.exists(f"{self._main_key}:{ctx_id}"):
retrieved_fields = await gather(
*[self.database.hget(f"{self._main_key}:{ctx_id}", f) for f in NameConfig.get_context_main_fields()]
*[self.database.hget(f"{self._main_key}:{ctx_id}", f) for f in NameConfig.get_context_main_fields]
)
return ContextMainInfo.model_validate(
{f: v for f, v in zip(NameConfig.get_context_main_fields(), retrieved_fields)}
{f: v for f, v in zip(NameConfig.get_context_main_fields, retrieved_fields)}
)
else:
return None
Expand All @@ -96,7 +96,7 @@ async def _update_context(
update_main, update_values, delete_keys = list(), list(), list()
if ctx_info is not None:
ctx_info_dump = ctx_info.model_dump(mode="python")
update_main = [(f, ctx_info_dump[f] if isinstance(ctx_info_dump[f], bytes) else str(ctx_info_dump[f])) for f in NameConfig.get_context_main_fields()]
update_main = [(f, ctx_info_dump[f] if isinstance(ctx_info_dump[f], bytes) else str(ctx_info_dump[f])) for f in NameConfig.get_context_main_fields]
for field_name, items in field_info:
new_delete_keys = list()
for k, v in items:
Expand Down
4 changes: 2 additions & 2 deletions chatsky/context_storages/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ async def _load_main_info(self, ctx_id: str) -> Optional[ContextMainInfo]:
None
if result is None
else ContextMainInfo.model_validate(
{f: result[i + 1] for i, f in enumerate(NameConfig.get_context_main_fields())}
{f: result[i + 1] for i, f in enumerate(NameConfig.get_context_main_fields)}
)
)

Expand All @@ -228,7 +228,7 @@ async def _update_context(
{
NameConfig._id_column: ctx_id,
} | {
f: ctx_info_dump[f] for f in NameConfig.get_context_main_fields()
f: ctx_info_dump[f] for f in NameConfig.get_context_main_fields
}
)
main_update_stmt = _get_upsert_stmt(
Expand Down
4 changes: 2 additions & 2 deletions chatsky/context_storages/ydb.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ async def callee(session: Session) -> Optional[ContextMainInfo]:
)
return (
ContextMainInfo.model_validate(
{f: result_sets[0].rows[0][f] for f in NameConfig.get_context_main_fields()}
{f: result_sets[0].rows[0][f] for f in NameConfig.get_context_main_fields}
)
if len(result_sets[0].rows) > 0
else None
Expand Down Expand Up @@ -186,7 +186,7 @@ async def callee(session: Session) -> None:
{
f"${NameConfig._id_column}": ctx_id,
} | {
f"${f}": ctx_info_dump[f] for f in NameConfig.get_context_main_fields()
f"${f}": ctx_info_dump[f] for f in NameConfig.get_context_main_fields
},
commit_tx=True,
) as _:
Expand Down
6 changes: 6 additions & 0 deletions chatsky/utils/decorations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
class classproperty:
def __init__(self, f):
self.f = f

def __get__(self, _, owner):
return self.f(owner)

0 comments on commit 2213fe8

Please sign in to comment.