From 868f5eff457c23865e229f9698bac22817a53e53 Mon Sep 17 00:00:00 2001 From: ZergLev Date: Fri, 22 Nov 2024 20:34:46 +0300 Subject: [PATCH 01/38] added SlotValueEquals, drafted RegexpGroupSlot, formatting a GroupSlot with an f-string, moved derivative slots into a different file --- chatsky/__rebuild_pydantic_models__.py | 5 +- chatsky/conditions/slots.py | 28 +- chatsky/core/context.py | 2 +- chatsky/core/pipeline.py | 2 +- chatsky/processing/slots.py | 2 +- chatsky/slots/__init__.py | 2 +- chatsky/slots/base_slots.py | 254 +++++++++++++++++ chatsky/slots/slots.py | 299 ++++---------------- tests/slots/conftest.py | 2 +- tests/slots/test_slot_functions.py | 4 +- tests/slots/test_slot_manager.py | 2 +- tests/slots/test_slot_partial_extraction.py | 2 +- tests/slots/test_slot_types.py | 2 +- 13 files changed, 348 insertions(+), 258 deletions(-) create mode 100644 chatsky/slots/base_slots.py diff --git a/chatsky/__rebuild_pydantic_models__.py b/chatsky/__rebuild_pydantic_models__.py index 1da4126a9..2bed62871 100644 --- a/chatsky/__rebuild_pydantic_models__.py +++ b/chatsky/__rebuild_pydantic_models__.py @@ -1,10 +1,10 @@ # flake8: noqa: F401 from chatsky.core.service.types import ExtraHandlerRuntimeInfo, ComponentExecutionState -from chatsky.core import Context, Script +from chatsky.core import Context, Message, Script from chatsky.core.script import Node from chatsky.core.pipeline import Pipeline -from chatsky.slots.slots import SlotManager +from chatsky.slots.base_slots import SlotManager, FunctionSlot from chatsky.core.context import FrameworkData, ServiceState from chatsky.core.service import PipelineComponent @@ -15,3 +15,4 @@ ExtraHandlerRuntimeInfo.model_rebuild() FrameworkData.model_rebuild() ServiceState.model_rebuild() +FunctionSlot.model_rebuild() diff --git a/chatsky/conditions/slots.py b/chatsky/conditions/slots.py index eaddd3140..ee21111a5 100644 --- a/chatsky/conditions/slots.py +++ b/chatsky/conditions/slots.py @@ -5,10 +5,10 @@ """ from __future__ import annotations -from typing import Literal, List +from typing import Literal, List, Any from chatsky.core import Context, BaseCondition -from chatsky.slots.slots import SlotName +from chatsky.slots.base_slots import SlotName class SlotsExtracted(BaseCondition): @@ -36,3 +36,27 @@ async def call(self, ctx: Context) -> bool: return all(manager.is_slot_extracted(slot) for slot in self.slots) elif self.mode == "any": return any(manager.is_slot_extracted(slot) for slot in self.slots) + + +class SlotValueEquals(BaseCondition): + """ + Check if :py:attr:`.slot_name`'s extracted value is equal to a given value. + + :raises KeyError: If the slot with the specified name does not exist. + """ + + slot_name: SlotName + """ + Name of the slot that needs to be checked. + """ + value: Any + """ + The value which the slot's extracted value is supposed to be checked against. + """ + + def __init__(self, slot_name: SlotName, value: Any): + super().__init__(slot_name=slot_name, value=value) + + async def call(self, ctx: Context) -> bool: + manager = ctx.framework_data.slot_manager + return manager.get_extracted_slot(self.slot_name).self.value == value diff --git a/chatsky/core/context.py b/chatsky/core/context.py index 59a9899a2..52c70ff5a 100644 --- a/chatsky/core/context.py +++ b/chatsky/core/context.py @@ -25,7 +25,7 @@ from pydantic import BaseModel, Field from chatsky.core.message import Message, MessageInitTypes -from chatsky.slots.slots import SlotManager +from chatsky.slots.base_slots import SlotManager from chatsky.core.node_label import AbsoluteNodeLabel, AbsoluteNodeLabelInitTypes if TYPE_CHECKING: diff --git a/chatsky/core/pipeline.py b/chatsky/core/pipeline.py index 20caa74ea..90480d2aa 100644 --- a/chatsky/core/pipeline.py +++ b/chatsky/core/pipeline.py @@ -21,7 +21,7 @@ from chatsky.messengers.console import CLIMessengerInterface from chatsky.messengers.common import MessengerInterface -from chatsky.slots.slots import GroupSlot +from chatsky.slots.base_slots import GroupSlot from chatsky.core.service.group import ServiceGroup, ServiceGroupInitTypes from chatsky.core.service.extra import ComponentExtraHandlerInitTypes, BeforeHandler, AfterHandler from .service import Service diff --git a/chatsky/processing/slots.py b/chatsky/processing/slots.py index 898195e94..655f3d609 100644 --- a/chatsky/processing/slots.py +++ b/chatsky/processing/slots.py @@ -9,7 +9,7 @@ import logging from typing import List -from chatsky.slots.slots import SlotName +from chatsky.slots.base_slots import SlotName from chatsky.core import Context, BaseProcessing from chatsky.responses.slots import FilledTemplate diff --git a/chatsky/slots/__init__.py b/chatsky/slots/__init__.py index 6c929b9af..cee35a178 100644 --- a/chatsky/slots/__init__.py +++ b/chatsky/slots/__init__.py @@ -1 +1 @@ -from chatsky.slots.slots import GroupSlot, ValueSlot, RegexpSlot, FunctionSlot +from chatsky.slots.base_slots import GroupSlot, ValueSlot, RegexpSlot, FunctionSlot diff --git a/chatsky/slots/base_slots.py b/chatsky/slots/base_slots.py new file mode 100644 index 000000000..2224f3c62 --- /dev/null +++ b/chatsky/slots/base_slots.py @@ -0,0 +1,254 @@ +""" +Base Slots +----- +This module defines base classes for slots. +""" + +from __future__ import annotations + +import asyncio +import re +from abc import ABC, abstractmethod +from typing import Callable, Any, Awaitable, TYPE_CHECKING, Union, Optional, Dict +from typing_extensions import TypeAlias, Annotated +import logging +from functools import reduce +from string import Formatter + +from pydantic import BaseModel, model_validator, Field, field_serializer, field_validator + +from chatsky.utils.devel.async_helpers import wrap_sync_function_in_async +from chatsky.utils.devel.json_serialization import pickle_serializer, pickle_validator + +if TYPE_CHECKING: + from chatsky.core import Context, Message + + +logger = logging.getLogger(__name__) + + +SlotName: TypeAlias = str +""" +A string to identify slots. + +Top-level slots are identified by their key in a :py:class:`~.GroupSlot`. + +E.g. + +.. code:: python + + GroupSlot( + user=RegexpSlot(), + password=FunctionSlot, + ) + +Has two slots with names "user" and "password". + +For nested group slots use dots to separate names: + +.. code:: python + + GroupSlot( + user=GroupSlot( + name=FunctionSlot, + password=FunctionSlot, + ) + ) + +Has two slots with names "user.name" and "user.password". +""" + + +def recursive_getattr(obj, slot_name: SlotName): + def two_arg_getattr(__o, name): + # pydantic handles exception when accessing a non-existing extra-field on its own + # return None by default to avoid that + return getattr(__o, name, None) + + return reduce(two_arg_getattr, [obj, *slot_name.split(".")]) + + +def recursive_setattr(obj, slot_name: SlotName, value): + parent_slot, _, slot = slot_name.rpartition(".") + + if parent_slot: + setattr(recursive_getattr(obj, parent_slot), slot, value) + else: + setattr(obj, slot, value) + + +class KwargOnlyFormatter(Formatter): + def get_value(self, key, args, kwargs): + return super().get_value(str(key), args, kwargs) + + +class SlotNotExtracted(Exception): + """This exception can be returned or raised by slot extractor if slot extraction is unsuccessful.""" + + pass + + +class ExtractedSlot(BaseModel, ABC): + """ + Represents value of an extracted slot. + + Instances of this class are managed by framework and + are stored in :py:attr:`~chatsky.core.context.FrameworkData.slot_manager`. + They can be accessed via the ``ctx.framework_data.slot_manager.get_extracted_slot`` method. + """ + + @property + @abstractmethod + def __slot_extracted__(self) -> bool: + """Whether the slot is extracted.""" + raise NotImplementedError + + def __unset__(self): + """Mark slot as not extracted and clear extracted data (except for default value).""" + raise NotImplementedError + + @abstractmethod + def __str__(self): + """String representation is used to fill templates.""" + raise NotImplementedError + + +class BaseSlot(BaseModel, frozen=True): + """ + BaseSlot is a base class for all slots. + """ + + @abstractmethod + async def get_value(self, ctx: Context) -> ExtractedSlot: + """ + Extract slot value from :py:class:`~.Context` and return an instance of :py:class:`~.ExtractedSlot`. + """ + raise NotImplementedError + + @abstractmethod + def init_value(self) -> ExtractedSlot: + """ + Provide an initial value to fill slot storage with. + """ + raise NotImplementedError + + +class SlotManager(BaseModel): + """ + Provides API for managing slots. + + An instance of this class can be accessed via ``ctx.framework_data.slot_manager``. + """ + + slot_storage: ExtractedGroupSlot = Field(default_factory=ExtractedGroupSlot) + """Slot storage. Stored inside ctx.framework_data.""" + root_slot: GroupSlot = Field(default_factory=GroupSlot, exclude=True) + """Slot configuration passed during pipeline initialization.""" + + def set_root_slot(self, root_slot: GroupSlot): + """ + Set root_slot configuration from pipeline. + Update extracted slots with the new configuration: + + New slots are added with their :py:meth:`~.BaseSlot.init_value`. + Old extracted slot values are preserved only if their configuration did not change. + That is if they are still present in the config and if their fundamental type did not change + (i.e. `GroupSlot` did not turn into a `ValueSlot` or vice versa). + + This method is called by pipeline and is not supposed to be used otherwise. + """ + self.root_slot = root_slot + new_slot_storage = root_slot.init_value() + new_slot_storage.update(self.slot_storage) + self.slot_storage = new_slot_storage + + def get_slot(self, slot_name: SlotName) -> BaseSlot: + """ + Get slot configuration from the slot name. + + :raises KeyError: If the slot with the specified name does not exist. + """ + slot = recursive_getattr(self.root_slot, slot_name) + if isinstance(slot, BaseSlot): + return slot + raise KeyError(f"Could not find slot {slot_name!r}.") + + async def extract_slot(self, slot_name: SlotName, ctx: Context, success_only: bool) -> None: + """ + Extract slot `slot_name` and store extracted value in `slot_storage`. + + :raises KeyError: If the slot with the specified name does not exist. + + :param slot_name: Name of the slot to extract. + :param ctx: Context. + :param success_only: Whether to store the value only if it is successfully extracted. + """ + slot = self.get_slot(slot_name) + value = await slot.get_value(ctx) + + if value.__slot_extracted__ or success_only is False: + recursive_setattr(self.slot_storage, slot_name, value) + + async def extract_all(self, ctx: Context): + """ + Extract all slots from slot configuration `root_slot` and set `slot_storage` to the extracted value. + """ + self.slot_storage = await self.root_slot.get_value(ctx) + + def get_extracted_slot(self, slot_name: SlotName) -> Union[ExtractedValueSlot, ExtractedGroupSlot]: + """ + Retrieve extracted value from `slot_storage`. + + :raises KeyError: If the slot with the specified name does not exist. + """ + slot = recursive_getattr(self.slot_storage, slot_name) + if isinstance(slot, ExtractedSlot): + return slot + raise KeyError(f"Could not find slot {slot_name!r}.") + + def is_slot_extracted(self, slot_name: str) -> bool: + """ + Return if the specified slot is extracted. + + :raises KeyError: If the slot with the specified name does not exist. + """ + return self.get_extracted_slot(slot_name).__slot_extracted__ + + def all_slots_extracted(self) -> bool: + """ + Return if all slots are extracted. + """ + return self.slot_storage.__slot_extracted__ + + def unset_slot(self, slot_name: SlotName) -> None: + """ + Mark specified slot as not extracted and clear extracted value. + + :raises KeyError: If the slot with the specified name does not exist. + """ + self.get_extracted_slot(slot_name).__unset__() + + def unset_all_slots(self) -> None: + """ + Mark all slots as not extracted and clear all extracted values. + """ + self.slot_storage.__unset__() + + def fill_template(self, template: str) -> Optional[str]: + """ + Fill `template` string with extracted slot values and return a formatted string + or None if an exception has occurred while trying to fill template. + + `template` should be a format-string: + + E.g. "Your username is {profile.username}". + + For the example above, if ``profile.username`` slot has value "admin", + it would return the following text: + "Your username is admin". + """ + try: + return KwargOnlyFormatter().format(template, **dict(self.slot_storage.__pydantic_extra__.items())) + except Exception as exc: + logger.exception("An exception occurred during template filling.", exc_info=exc) + return None diff --git a/chatsky/slots/slots.py b/chatsky/slots/slots.py index 3cadd9205..73f9f7737 100644 --- a/chatsky/slots/slots.py +++ b/chatsky/slots/slots.py @@ -1,7 +1,7 @@ """ Slots ----- -This module defines base classes for slots and some concrete implementations of them. +This module defines some concrete implementations of slots. """ from __future__ import annotations @@ -19,6 +19,7 @@ from chatsky.utils.devel.async_helpers import wrap_sync_function_in_async from chatsky.utils.devel.json_serialization import pickle_serializer, pickle_validator +from chatsky.slots.base_slots import ExtractedSlot, BaseSlot, SlotNotExtracted, KwargOnlyFormatter if TYPE_CHECKING: from chatsky.core import Context, Message @@ -27,92 +28,6 @@ logger = logging.getLogger(__name__) -SlotName: TypeAlias = str -""" -A string to identify slots. - -Top-level slots are identified by their key in a :py:class:`~.GroupSlot`. - -E.g. - -.. code:: python - - GroupSlot( - user=RegexpSlot(), - password=FunctionSlot, - ) - -Has two slots with names "user" and "password". - -For nested group slots use dots to separate names: - -.. code:: python - - GroupSlot( - user=GroupSlot( - name=FunctionSlot, - password=FunctionSlot, - ) - ) - -Has two slots with names "user.name" and "user.password". -""" - - -def recursive_getattr(obj, slot_name: SlotName): - def two_arg_getattr(__o, name): - # pydantic handles exception when accessing a non-existing extra-field on its own - # return None by default to avoid that - return getattr(__o, name, None) - - return reduce(two_arg_getattr, [obj, *slot_name.split(".")]) - - -def recursive_setattr(obj, slot_name: SlotName, value): - parent_slot, sep, slot = slot_name.rpartition(".") - - if sep == ".": - parent_obj = recursive_getattr(obj, parent_slot) - else: - parent_obj = obj - - if isinstance(value, ExtractedGroupSlot): - getattr(parent_obj, slot).update(value) - else: - setattr(parent_obj, slot, value) - - -class SlotNotExtracted(Exception): - """This exception can be returned or raised by slot extractor if slot extraction is unsuccessful.""" - - pass - - -class ExtractedSlot(BaseModel, ABC): - """ - Represents value of an extracted slot. - - Instances of this class are managed by framework and - are stored in :py:attr:`~chatsky.core.context.FrameworkData.slot_manager`. - They can be accessed via the ``ctx.framework_data.slot_manager.get_extracted_slot`` method. - """ - - @property - @abstractmethod - def __slot_extracted__(self) -> bool: - """Whether the slot is extracted.""" - raise NotImplementedError - - def __unset__(self): - """Mark slot as not extracted and clear extracted data (except for default value).""" - raise NotImplementedError - - @abstractmethod - def __str__(self): - """String representation is used to fill templates.""" - raise NotImplementedError - - class ExtractedValueSlot(ExtractedSlot): """Value extracted from :py:class:`~.ValueSlot`.""" @@ -159,6 +74,7 @@ def __str__(self): class ExtractedGroupSlot(ExtractedSlot, extra="allow"): + value_format: str = None __pydantic_extra__: Dict[ str, Annotated[Union["ExtractedGroupSlot", "ExtractedValueSlot"], Field(union_mode="left_to_right")] ] @@ -171,8 +87,13 @@ def __unset__(self): for child in self.__pydantic_extra__.values(): child.__unset__() + # fill template here def __str__(self): - return str({key: str(value) for key, value in self.__pydantic_extra__.items()}) + if self.value_format is not None: + # return str({key: Kwargs.format() for key, value in self.__pydantic_extra__.items()}) + return KwargOnlyFormatter().format(self.value_format, **self.__pydantic_extra__) + else: + return str({key: str(value) for key, value in self.__pydantic_extra__.items()}) def update(self, old: "ExtractedGroupSlot"): """ @@ -193,26 +114,6 @@ def update(self, old: "ExtractedGroupSlot"): self.__pydantic_extra__[slot] = old_slot -class BaseSlot(BaseModel, frozen=True): - """ - BaseSlot is a base class for all slots. - """ - - @abstractmethod - async def get_value(self, ctx: Context) -> ExtractedSlot: - """ - Extract slot value from :py:class:`~.Context` and return an instance of :py:class:`~.ExtractedSlot`. - """ - raise NotImplementedError - - @abstractmethod - def init_value(self) -> ExtractedSlot: - """ - Provide an initial value to fill slot storage with. - """ - raise NotImplementedError - - class ValueSlot(BaseSlot, frozen=True): """ Value slot is a base class for all slots that are designed to extract concrete values. @@ -265,12 +166,11 @@ class GroupSlot(BaseSlot, extra="allow", frozen=True): Base class for :py:class:`~.RootSlot` and :py:class:`~.GroupSlot`. """ + value_format: str = None __pydantic_extra__: Dict[str, Annotated[Union["GroupSlot", "ValueSlot"], Field(union_mode="left_to_right")]] - allow_partial_extraction: bool = False - """If True, extraction returns only successfully extracted child slots.""" - def __init__(self, allow_partial_extraction=False, **kwargs): - super().__init__(allow_partial_extraction=allow_partial_extraction, **kwargs) + def __init__(self, **kwargs): # supress unexpected argument warnings + super().__init__(**kwargs) @model_validator(mode="after") def __check_extra_field_names__(self): @@ -286,12 +186,13 @@ def __check_extra_field_names__(self): async def get_value(self, ctx: Context) -> ExtractedGroupSlot: child_values = await asyncio.gather(*(child.get_value(ctx) for child in self.__pydantic_extra__.values())) - extracted_values = {} - for child_value, child_name in zip(child_values, self.__pydantic_extra__.keys()): - if child_value.__slot_extracted__ or not self.allow_partial_extraction: - extracted_values[child_name] = child_value - - return ExtractedGroupSlot(**extracted_values) + return ExtractedGroupSlot( + value_format=self.value_format, + **{ + child_name: child_value + for child_value, child_name in zip(child_values, self.__pydantic_extra__.keys()) + } + ) def init_value(self) -> ExtractedGroupSlot: return ExtractedGroupSlot( @@ -321,141 +222,51 @@ async def extract_value(self, ctx: Context) -> Union[str, SlotNotExtracted]: ) -class FunctionSlot(ValueSlot, frozen=True): - """ - A simpler version of :py:class:`~.ValueSlot`. - - Uses a user-defined `func` to extract slot value from the :py:attr:`~.Context.last_request` Message. +# TODO: Change class and method descriptions. +class RegexpGroupSlot(GroupSlot, frozen=True): """ - - func: Callable[[Message], Union[Awaitable[Union[Any, SlotNotExtracted]], Any, SlotNotExtracted]] - - async def extract_value(self, ctx: Context) -> Union[Any, SlotNotExtracted]: - return await wrap_sync_function_in_async(self.func, ctx.last_request) - - -class SlotManager(BaseModel): - """ - Provides API for managing slots. - - An instance of this class can be accessed via ``ctx.framework_data.slot_manager``. + RegexpGroupSlot is semantically equal to a GroupSlot of RegexpSlots. + You can pass a compiled or a non-compiled pattern to the `regexp` argument. + If you want to extract a particular group, but not the full match, + change the `match_group_idx` parameter. """ - slot_storage: ExtractedGroupSlot = Field(default_factory=ExtractedGroupSlot) - """Slot storage. Stored inside ctx.framework_data.""" - root_slot: GroupSlot = Field(default_factory=GroupSlot, exclude=True) - """Slot configuration passed during pipeline initialization.""" - - def set_root_slot(self, root_slot: GroupSlot): - """ - Set root_slot configuration from pipeline. - Update extracted slots with the new configuration: - - New slots are added with their :py:meth:`~.BaseSlot.init_value`. - Old extracted slot values are preserved only if their configuration did not change. - That is if they are still present in the config and if their fundamental type did not change - (i.e. `GroupSlot` did not turn into a `ValueSlot` or vice versa). - - This method is called by pipeline and is not supposed to be used otherwise. - """ - self.root_slot = root_slot - new_slot_storage = root_slot.init_value() - new_slot_storage.update(self.slot_storage) - self.slot_storage = new_slot_storage - - def get_slot(self, slot_name: SlotName) -> BaseSlot: - """ - Get slot configuration from the slot name. - - :raises KeyError: If the slot with the specified name does not exist. - """ - slot = recursive_getattr(self.root_slot, slot_name) - if isinstance(slot, BaseSlot): - return slot - raise KeyError(f"Could not find slot {slot_name!r}.") - - async def extract_slot(self, slot_name: SlotName, ctx: Context, success_only: bool) -> None: - """ - Extract slot `slot_name` and store extracted value in `slot_storage`. - - Extracted group slots update slot storage instead of overwriting it. - - :raises KeyError: If the slot with the specified name does not exist. - - :param slot_name: Name of the slot to extract. - :param ctx: Context. - :param success_only: Whether to store the value only if it is successfully extracted. - """ - slot = self.get_slot(slot_name) - value = await slot.get_value(ctx) - - if value.__slot_extracted__ or success_only is False: - recursive_setattr(self.slot_storage, slot_name, value) - - async def extract_all(self, ctx: Context): - """ - Extract all slots from slot configuration `root_slot` and set `slot_storage` to the extracted value. - """ - self.slot_storage = await self.root_slot.get_value(ctx) - - def get_extracted_slot(self, slot_name: SlotName) -> Union[ExtractedValueSlot, ExtractedGroupSlot]: - """ - Retrieve extracted value from `slot_storage`. - - :raises KeyError: If the slot with the specified name does not exist. - """ - slot = recursive_getattr(self.slot_storage, slot_name) - if isinstance(slot, ExtractedSlot): - return slot - raise KeyError(f"Could not find slot {slot_name!r}.") - - def is_slot_extracted(self, slot_name: str) -> bool: - """ - Return if the specified slot is extracted. - - :raises KeyError: If the slot with the specified name does not exist. - """ - return self.get_extracted_slot(slot_name).__slot_extracted__ + regexp: str + groups: dict[str, int] + "Index of the group to match." - def all_slots_extracted(self) -> bool: - """ - Return if all slots are extracted. - """ - return self.slot_storage.__slot_extracted__ + def __init__(self, **kwargs): # supress unexpected argument warnings + super().__init__(**kwargs) - def unset_slot(self, slot_name: SlotName) -> None: - """ - Mark specified slot as not extracted and clear extracted value. + async def get_value(self, ctx: Context) -> ExtractedGroupSlot: + child_values = await asyncio.gather( + *( + RegexpSlot(regexp=self.regexp, match_group_id=match_group).get_value(ctx) + for match_group in self.groups.values() + ) + ) + return ExtractedGroupSlot( + **{child_name: child_value for child_value, child_name in zip(child_values, self.groups.keys())} + ) - :raises KeyError: If the slot with the specified name does not exist. - """ - self.get_extracted_slot(slot_name).__unset__() + def init_value(self) -> ExtractedGroupSlot: + return ExtractedGroupSlot( + **{ + child_name: RegexpSlot(regexp=self.regexp, match_group_id=match_group).init_value() + for child_name, match_group in self.groups.items() + } + ) - def unset_all_slots(self) -> None: - """ - Mark all slots as not extracted and clear all extracted values. - """ - self.slot_storage.__unset__() - class KwargOnlyFormatter(Formatter): - def get_value(self, key, args, kwargs): - return super().get_value(str(key), args, kwargs) +class FunctionSlot(ValueSlot, frozen=True): + """ + A simpler version of :py:class:`~.ValueSlot`. - def fill_template(self, template: str) -> Optional[str]: - """ - Fill `template` string with extracted slot values and return a formatted string - or None if an exception has occurred while trying to fill template. + Uses a user-defined `func` to extract slot value from the :py:attr:`~.Context.last_request` Message. + """ - `template` should be a format-string: + func: Callable[[Message], Union[Awaitable[Union[Any, SlotNotExtracted]], Any, SlotNotExtracted]] - E.g. "Your username is {profile.username}". + async def extract_value(self, ctx: Context) -> Union[Any, SlotNotExtracted]: + return await wrap_sync_function_in_async(self.func, ctx.last_request) - For the example above, if ``profile.username`` slot has value "admin", - it would return the following text: - "Your username is admin". - """ - try: - return self.KwargOnlyFormatter().format(template, **dict(self.slot_storage.__pydantic_extra__.items())) - except Exception as exc: - logger.exception("An exception occurred during template filling.", exc_info=exc) - return None diff --git a/tests/slots/conftest.py b/tests/slots/conftest.py index 9cdcc4eec..c19d34519 100644 --- a/tests/slots/conftest.py +++ b/tests/slots/conftest.py @@ -1,7 +1,7 @@ import pytest from chatsky.core import Message, TRANSITIONS, RESPONSE, Context, Pipeline, Transition as Tr -from chatsky.slots.slots import SlotNotExtracted +from chatsky.slots.base_slots import SlotNotExtracted @pytest.fixture(scope="function", autouse=True) diff --git a/tests/slots/test_slot_functions.py b/tests/slots/test_slot_functions.py index 83c6b0171..902b1253c 100644 --- a/tests/slots/test_slot_functions.py +++ b/tests/slots/test_slot_functions.py @@ -6,10 +6,10 @@ from chatsky import Context from chatsky.core import BaseResponse, Node from chatsky.core.message import MessageInitTypes, Message -from chatsky.slots.slots import ValueSlot, SlotNotExtracted, GroupSlot, SlotManager +from chatsky.slots.base_slots import ValueSlot, SlotNotExtracted, GroupSlot, SlotManager from chatsky import conditions as cnd, responses as rsp, processing as proc from chatsky.processing.slots import logger as proc_logger -from chatsky.slots.slots import logger as slot_logger +from chatsky.slots.base_slots import logger as slot_logger from chatsky.responses.slots import logger as rsp_logger diff --git a/tests/slots/test_slot_manager.py b/tests/slots/test_slot_manager.py index 33885b360..9ff50fcd9 100644 --- a/tests/slots/test_slot_manager.py +++ b/tests/slots/test_slot_manager.py @@ -1,6 +1,6 @@ import pytest -from chatsky.slots.slots import ( +from chatsky.slots.base_slots import ( SlotManager, RegexpSlot, GroupSlot, diff --git a/tests/slots/test_slot_partial_extraction.py b/tests/slots/test_slot_partial_extraction.py index 234287c17..88f5ee5fe 100644 --- a/tests/slots/test_slot_partial_extraction.py +++ b/tests/slots/test_slot_partial_extraction.py @@ -1,5 +1,5 @@ from chatsky.slots import RegexpSlot, GroupSlot -from chatsky.slots.slots import SlotManager +from chatsky.slots.base_slots import SlotManager from chatsky.core import Message import pytest diff --git a/tests/slots/test_slot_types.py b/tests/slots/test_slot_types.py index a21cbd896..24f6c01d9 100644 --- a/tests/slots/test_slot_types.py +++ b/tests/slots/test_slot_types.py @@ -2,7 +2,7 @@ from pydantic import ValidationError from chatsky.core import Message -from chatsky.slots.slots import ( +from chatsky.slots.base_slots import ( RegexpSlot, GroupSlot, FunctionSlot, From 80500f6c030aadd666fd1c341ce2d6f7a2219b9c Mon Sep 17 00:00:00 2001 From: ZergLev Date: Fri, 22 Nov 2024 20:47:29 +0300 Subject: [PATCH 02/38] removed checking for funder names in slot names for GroupSlot + lint --- chatsky/conditions/slots.py | 2 +- chatsky/slots/base_slots.py | 134 +----------------------------- chatsky/slots/slots.py | 159 +++++++++++++++++++++++++++++++----- 3 files changed, 144 insertions(+), 151 deletions(-) diff --git a/chatsky/conditions/slots.py b/chatsky/conditions/slots.py index ee21111a5..8319c36af 100644 --- a/chatsky/conditions/slots.py +++ b/chatsky/conditions/slots.py @@ -59,4 +59,4 @@ def __init__(self, slot_name: SlotName, value: Any): async def call(self, ctx: Context) -> bool: manager = ctx.framework_data.slot_manager - return manager.get_extracted_slot(self.slot_name).self.value == value + return manager.get_extracted_slot(self.slot_name).self.value == self.value diff --git a/chatsky/slots/base_slots.py b/chatsky/slots/base_slots.py index 2224f3c62..ed23cce7b 100644 --- a/chatsky/slots/base_slots.py +++ b/chatsky/slots/base_slots.py @@ -6,22 +6,17 @@ from __future__ import annotations -import asyncio -import re from abc import ABC, abstractmethod -from typing import Callable, Any, Awaitable, TYPE_CHECKING, Union, Optional, Dict -from typing_extensions import TypeAlias, Annotated +from typing import TYPE_CHECKING +from typing_extensions import TypeAlias import logging from functools import reduce from string import Formatter -from pydantic import BaseModel, model_validator, Field, field_serializer, field_validator - -from chatsky.utils.devel.async_helpers import wrap_sync_function_in_async -from chatsky.utils.devel.json_serialization import pickle_serializer, pickle_validator +from pydantic import BaseModel if TYPE_CHECKING: - from chatsky.core import Context, Message + from chatsky.core import Context logger = logging.getLogger(__name__) @@ -131,124 +126,3 @@ def init_value(self) -> ExtractedSlot: Provide an initial value to fill slot storage with. """ raise NotImplementedError - - -class SlotManager(BaseModel): - """ - Provides API for managing slots. - - An instance of this class can be accessed via ``ctx.framework_data.slot_manager``. - """ - - slot_storage: ExtractedGroupSlot = Field(default_factory=ExtractedGroupSlot) - """Slot storage. Stored inside ctx.framework_data.""" - root_slot: GroupSlot = Field(default_factory=GroupSlot, exclude=True) - """Slot configuration passed during pipeline initialization.""" - - def set_root_slot(self, root_slot: GroupSlot): - """ - Set root_slot configuration from pipeline. - Update extracted slots with the new configuration: - - New slots are added with their :py:meth:`~.BaseSlot.init_value`. - Old extracted slot values are preserved only if their configuration did not change. - That is if they are still present in the config and if their fundamental type did not change - (i.e. `GroupSlot` did not turn into a `ValueSlot` or vice versa). - - This method is called by pipeline and is not supposed to be used otherwise. - """ - self.root_slot = root_slot - new_slot_storage = root_slot.init_value() - new_slot_storage.update(self.slot_storage) - self.slot_storage = new_slot_storage - - def get_slot(self, slot_name: SlotName) -> BaseSlot: - """ - Get slot configuration from the slot name. - - :raises KeyError: If the slot with the specified name does not exist. - """ - slot = recursive_getattr(self.root_slot, slot_name) - if isinstance(slot, BaseSlot): - return slot - raise KeyError(f"Could not find slot {slot_name!r}.") - - async def extract_slot(self, slot_name: SlotName, ctx: Context, success_only: bool) -> None: - """ - Extract slot `slot_name` and store extracted value in `slot_storage`. - - :raises KeyError: If the slot with the specified name does not exist. - - :param slot_name: Name of the slot to extract. - :param ctx: Context. - :param success_only: Whether to store the value only if it is successfully extracted. - """ - slot = self.get_slot(slot_name) - value = await slot.get_value(ctx) - - if value.__slot_extracted__ or success_only is False: - recursive_setattr(self.slot_storage, slot_name, value) - - async def extract_all(self, ctx: Context): - """ - Extract all slots from slot configuration `root_slot` and set `slot_storage` to the extracted value. - """ - self.slot_storage = await self.root_slot.get_value(ctx) - - def get_extracted_slot(self, slot_name: SlotName) -> Union[ExtractedValueSlot, ExtractedGroupSlot]: - """ - Retrieve extracted value from `slot_storage`. - - :raises KeyError: If the slot with the specified name does not exist. - """ - slot = recursive_getattr(self.slot_storage, slot_name) - if isinstance(slot, ExtractedSlot): - return slot - raise KeyError(f"Could not find slot {slot_name!r}.") - - def is_slot_extracted(self, slot_name: str) -> bool: - """ - Return if the specified slot is extracted. - - :raises KeyError: If the slot with the specified name does not exist. - """ - return self.get_extracted_slot(slot_name).__slot_extracted__ - - def all_slots_extracted(self) -> bool: - """ - Return if all slots are extracted. - """ - return self.slot_storage.__slot_extracted__ - - def unset_slot(self, slot_name: SlotName) -> None: - """ - Mark specified slot as not extracted and clear extracted value. - - :raises KeyError: If the slot with the specified name does not exist. - """ - self.get_extracted_slot(slot_name).__unset__() - - def unset_all_slots(self) -> None: - """ - Mark all slots as not extracted and clear all extracted values. - """ - self.slot_storage.__unset__() - - def fill_template(self, template: str) -> Optional[str]: - """ - Fill `template` string with extracted slot values and return a formatted string - or None if an exception has occurred while trying to fill template. - - `template` should be a format-string: - - E.g. "Your username is {profile.username}". - - For the example above, if ``profile.username`` slot has value "admin", - it would return the following text: - "Your username is admin". - """ - try: - return KwargOnlyFormatter().format(template, **dict(self.slot_storage.__pydantic_extra__.items())) - except Exception as exc: - logger.exception("An exception occurred during template filling.", exc_info=exc) - return None diff --git a/chatsky/slots/slots.py b/chatsky/slots/slots.py index 73f9f7737..0dc797034 100644 --- a/chatsky/slots/slots.py +++ b/chatsky/slots/slots.py @@ -8,18 +8,24 @@ import asyncio import re -from abc import ABC, abstractmethod +from abc import abstractmethod from typing import Callable, Any, Awaitable, TYPE_CHECKING, Union, Optional, Dict -from typing_extensions import TypeAlias, Annotated +from typing_extensions import Annotated import logging -from functools import reduce -from string import Formatter from pydantic import BaseModel, model_validator, Field, field_serializer, field_validator from chatsky.utils.devel.async_helpers import wrap_sync_function_in_async from chatsky.utils.devel.json_serialization import pickle_serializer, pickle_validator -from chatsky.slots.base_slots import ExtractedSlot, BaseSlot, SlotNotExtracted, KwargOnlyFormatter +from chatsky.slots.base_slots import ( + ExtractedSlot, + BaseSlot, + SlotNotExtracted, + KwargOnlyFormatter, + SlotName, + recursive_getattr, + recursive_setattr, +) if TYPE_CHECKING: from chatsky.core import Context, Message @@ -161,43 +167,36 @@ def init_value(self) -> ExtractedValueSlot: ) -class GroupSlot(BaseSlot, extra="allow", frozen=True): +class GroupSlot(BaseSlot, frozen=True): """ Base class for :py:class:`~.RootSlot` and :py:class:`~.GroupSlot`. """ value_format: str = None - __pydantic_extra__: Dict[str, Annotated[Union["GroupSlot", "ValueSlot"], Field(union_mode="left_to_right")]] + slots: Dict[str, Annotated[Union["GroupSlot", "ValueSlot"], Field(union_mode="left_to_right")]] def __init__(self, **kwargs): # supress unexpected argument warnings super().__init__(**kwargs) @model_validator(mode="after") - def __check_extra_field_names__(self): + def __check_slot_names__(self): """ - Extra field names cannot be dunder names or contain dots. + Slot names cannot contain dots. """ - for field in self.__pydantic_extra__.keys(): + for field in self.slots.keys(): if "." in field: raise ValueError(f"Extra field name cannot contain dots: {field!r}") - if field.startswith("__") and field.endswith("__"): - raise ValueError(f"Extra field names cannot be dunder: {field!r}") return self async def get_value(self, ctx: Context) -> ExtractedGroupSlot: - child_values = await asyncio.gather(*(child.get_value(ctx) for child in self.__pydantic_extra__.values())) + child_values = await asyncio.gather(*(child.get_value(ctx) for child in self.slots.values())) return ExtractedGroupSlot( value_format=self.value_format, - **{ - child_name: child_value - for child_value, child_name in zip(child_values, self.__pydantic_extra__.keys()) - } + slots={child_name: child_value for child_value, child_name in zip(child_values, self.slots.keys())}, ) def init_value(self) -> ExtractedGroupSlot: - return ExtractedGroupSlot( - **{child_name: child.init_value() for child_name, child in self.__pydantic_extra__.items()} - ) + return ExtractedGroupSlot(**{child_name: child.init_value() for child_name, child in self.slots.items()}) class RegexpSlot(ValueSlot, frozen=True): @@ -270,3 +269,123 @@ class FunctionSlot(ValueSlot, frozen=True): async def extract_value(self, ctx: Context) -> Union[Any, SlotNotExtracted]: return await wrap_sync_function_in_async(self.func, ctx.last_request) + +class SlotManager(BaseModel): + """ + Provides API for managing slots. + + An instance of this class can be accessed via ``ctx.framework_data.slot_manager``. + """ + + slot_storage: ExtractedGroupSlot = Field(default_factory=ExtractedGroupSlot) + """Slot storage. Stored inside ctx.framework_data.""" + root_slot: GroupSlot = Field(default_factory=GroupSlot, exclude=True) + """Slot configuration passed during pipeline initialization.""" + + def set_root_slot(self, root_slot: GroupSlot): + """ + Set root_slot configuration from pipeline. + Update extracted slots with the new configuration: + + New slots are added with their :py:meth:`~.BaseSlot.init_value`. + Old extracted slot values are preserved only if their configuration did not change. + That is if they are still present in the config and if their fundamental type did not change + (i.e. `GroupSlot` did not turn into a `ValueSlot` or vice versa). + + This method is called by pipeline and is not supposed to be used otherwise. + """ + self.root_slot = root_slot + new_slot_storage = root_slot.init_value() + new_slot_storage.update(self.slot_storage) + self.slot_storage = new_slot_storage + + def get_slot(self, slot_name: SlotName) -> BaseSlot: + """ + Get slot configuration from the slot name. + + :raises KeyError: If the slot with the specified name does not exist. + """ + slot = recursive_getattr(self.root_slot, slot_name) + if isinstance(slot, BaseSlot): + return slot + raise KeyError(f"Could not find slot {slot_name!r}.") + + async def extract_slot(self, slot_name: SlotName, ctx: Context, success_only: bool) -> None: + """ + Extract slot `slot_name` and store extracted value in `slot_storage`. + + :raises KeyError: If the slot with the specified name does not exist. + + :param slot_name: Name of the slot to extract. + :param ctx: Context. + :param success_only: Whether to store the value only if it is successfully extracted. + """ + slot = self.get_slot(slot_name) + value = await slot.get_value(ctx) + + if value.__slot_extracted__ or success_only is False: + recursive_setattr(self.slot_storage, slot_name, value) + + async def extract_all(self, ctx: Context): + """ + Extract all slots from slot configuration `root_slot` and set `slot_storage` to the extracted value. + """ + self.slot_storage = await self.root_slot.get_value(ctx) + + def get_extracted_slot(self, slot_name: SlotName) -> Union[ExtractedValueSlot, ExtractedGroupSlot]: + """ + Retrieve extracted value from `slot_storage`. + + :raises KeyError: If the slot with the specified name does not exist. + """ + slot = recursive_getattr(self.slot_storage, slot_name) + if isinstance(slot, ExtractedSlot): + return slot + raise KeyError(f"Could not find slot {slot_name!r}.") + + def is_slot_extracted(self, slot_name: str) -> bool: + """ + Return if the specified slot is extracted. + + :raises KeyError: If the slot with the specified name does not exist. + """ + return self.get_extracted_slot(slot_name).__slot_extracted__ + + def all_slots_extracted(self) -> bool: + """ + Return if all slots are extracted. + """ + return self.slot_storage.__slot_extracted__ + + def unset_slot(self, slot_name: SlotName) -> None: + """ + Mark specified slot as not extracted and clear extracted value. + + :raises KeyError: If the slot with the specified name does not exist. + """ + self.get_extracted_slot(slot_name).__unset__() + + def unset_all_slots(self) -> None: + """ + Mark all slots as not extracted and clear all extracted values. + """ + self.slot_storage.__unset__() + + def fill_template(self, template: str) -> Optional[str]: + """ + Fill `template` string with extracted slot values and return a formatted string + or None if an exception has occurred while trying to fill template. + + `template` should be a format-string: + + E.g. "Your username is {profile.username}". + + For the example above, if ``profile.username`` slot has value "admin", + it would return the following text: + "Your username is admin". + """ + try: + return KwargOnlyFormatter().format(template, **dict(self.slot_storage.__pydantic_extra__.items())) + except Exception as exc: + logger.exception("An exception occurred during template filling.", exc_info=exc) + return None From 03fed9e27b67016c4637b44fb736b27db946eaf6 Mon Sep 17 00:00:00 2001 From: ZergLev Date: Mon, 25 Nov 2024 18:38:39 +0300 Subject: [PATCH 03/38] moved SlotManager into it's own file, ValueSlot back into base_slots, changed some imports --- chatsky/__rebuild_pydantic_models__.py | 2 +- chatsky/core/context.py | 2 +- chatsky/core/pipeline.py | 2 +- chatsky/slots/__init__.py | 2 +- chatsky/slots/base_slots.py | 109 ++++++------- chatsky/slots/slot_manager.py | 202 +++++++++++++++++++++++ chatsky/slots/slots.py | 214 +++---------------------- 7 files changed, 282 insertions(+), 251 deletions(-) create mode 100644 chatsky/slots/slot_manager.py diff --git a/chatsky/__rebuild_pydantic_models__.py b/chatsky/__rebuild_pydantic_models__.py index 2bed62871..2e89317bd 100644 --- a/chatsky/__rebuild_pydantic_models__.py +++ b/chatsky/__rebuild_pydantic_models__.py @@ -4,7 +4,7 @@ from chatsky.core import Context, Message, Script from chatsky.core.script import Node from chatsky.core.pipeline import Pipeline -from chatsky.slots.base_slots import SlotManager, FunctionSlot +from chatsky.slots.slots import SlotManager, FunctionSlot from chatsky.core.context import FrameworkData, ServiceState from chatsky.core.service import PipelineComponent diff --git a/chatsky/core/context.py b/chatsky/core/context.py index 52c70ff5a..59a9899a2 100644 --- a/chatsky/core/context.py +++ b/chatsky/core/context.py @@ -25,7 +25,7 @@ from pydantic import BaseModel, Field from chatsky.core.message import Message, MessageInitTypes -from chatsky.slots.base_slots import SlotManager +from chatsky.slots.slots import SlotManager from chatsky.core.node_label import AbsoluteNodeLabel, AbsoluteNodeLabelInitTypes if TYPE_CHECKING: diff --git a/chatsky/core/pipeline.py b/chatsky/core/pipeline.py index 90480d2aa..c97dafab1 100644 --- a/chatsky/core/pipeline.py +++ b/chatsky/core/pipeline.py @@ -21,7 +21,7 @@ from chatsky.messengers.console import CLIMessengerInterface from chatsky.messengers.common import MessengerInterface -from chatsky.slots.base_slots import GroupSlot +from chatsky.slots import GroupSlot from chatsky.core.service.group import ServiceGroup, ServiceGroupInitTypes from chatsky.core.service.extra import ComponentExtraHandlerInitTypes, BeforeHandler, AfterHandler from .service import Service diff --git a/chatsky/slots/__init__.py b/chatsky/slots/__init__.py index cee35a178..6c929b9af 100644 --- a/chatsky/slots/__init__.py +++ b/chatsky/slots/__init__.py @@ -1 +1 @@ -from chatsky.slots.base_slots import GroupSlot, ValueSlot, RegexpSlot, FunctionSlot +from chatsky.slots.slots import GroupSlot, ValueSlot, RegexpSlot, FunctionSlot diff --git a/chatsky/slots/base_slots.py b/chatsky/slots/base_slots.py index ed23cce7b..c4bc3db0e 100644 --- a/chatsky/slots/base_slots.py +++ b/chatsky/slots/base_slots.py @@ -7,76 +7,20 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING -from typing_extensions import TypeAlias +from typing import TYPE_CHECKING, Any, Union import logging -from functools import reduce -from string import Formatter from pydantic import BaseModel +from chatsky.utils.devel.async_helpers import wrap_sync_function_in_async if TYPE_CHECKING: from chatsky.core import Context + from chatsky.slots.slots import ExtractedValueSlot logger = logging.getLogger(__name__) -SlotName: TypeAlias = str -""" -A string to identify slots. - -Top-level slots are identified by their key in a :py:class:`~.GroupSlot`. - -E.g. - -.. code:: python - - GroupSlot( - user=RegexpSlot(), - password=FunctionSlot, - ) - -Has two slots with names "user" and "password". - -For nested group slots use dots to separate names: - -.. code:: python - - GroupSlot( - user=GroupSlot( - name=FunctionSlot, - password=FunctionSlot, - ) - ) - -Has two slots with names "user.name" and "user.password". -""" - - -def recursive_getattr(obj, slot_name: SlotName): - def two_arg_getattr(__o, name): - # pydantic handles exception when accessing a non-existing extra-field on its own - # return None by default to avoid that - return getattr(__o, name, None) - - return reduce(two_arg_getattr, [obj, *slot_name.split(".")]) - - -def recursive_setattr(obj, slot_name: SlotName, value): - parent_slot, _, slot = slot_name.rpartition(".") - - if parent_slot: - setattr(recursive_getattr(obj, parent_slot), slot, value) - else: - setattr(obj, slot, value) - - -class KwargOnlyFormatter(Formatter): - def get_value(self, key, args, kwargs): - return super().get_value(str(key), args, kwargs) - - class SlotNotExtracted(Exception): """This exception can be returned or raised by slot extractor if slot extraction is unsuccessful.""" @@ -126,3 +70,50 @@ def init_value(self) -> ExtractedSlot: Provide an initial value to fill slot storage with. """ raise NotImplementedError + + +class ValueSlot(BaseSlot, frozen=True): + """ + Value slot is a base class for all slots that are designed to extract concrete values. + Subclass it, if you want to declare your own slot type. + """ + + default_value: Any = None + + @abstractmethod + async def extract_value(self, ctx: Context) -> Union[Any, SlotNotExtracted]: + """ + Return value extracted from context. + + Return :py:exc:`~.SlotNotExtracted` to mark extraction as unsuccessful. + + Raising exceptions is also allowed and will result in an unsuccessful extraction as well. + """ + raise NotImplementedError + + async def get_value(self, ctx: Context) -> ExtractedValueSlot: + """Wrapper for :py:meth:`~.ValueSlot.extract_value` to handle exceptions.""" + extracted_value = SlotNotExtracted("Caught an exit exception.") + is_slot_extracted = False + + try: + extracted_value = await wrap_sync_function_in_async(self.extract_value, ctx) + is_slot_extracted = not isinstance(extracted_value, SlotNotExtracted) + except Exception as error: + logger.exception(f"Exception occurred during {self.__class__.__name__!r} extraction.", exc_info=error) + extracted_value = error + finally: + if not is_slot_extracted: + logger.debug(f"Slot {self.__class__.__name__!r} was not extracted: {extracted_value}") + return ExtractedValueSlot.model_construct( + is_slot_extracted=is_slot_extracted, + extracted_value=extracted_value, + default_value=self.default_value, + ) + + def init_value(self) -> ExtractedValueSlot: + return ExtractedValueSlot.model_construct( + is_slot_extracted=False, + extracted_value=SlotNotExtracted("Initial slot extraction."), + default_value=self.default_value, + ) diff --git a/chatsky/slots/slot_manager.py b/chatsky/slots/slot_manager.py new file mode 100644 index 000000000..98b4e5b30 --- /dev/null +++ b/chatsky/slots/slot_manager.py @@ -0,0 +1,202 @@ +""" +Slot Manager +----- +This module defines the SlotManager class, facilitating slot management for Pipeline. +""" + +from __future__ import annotations + +from typing_extensions import TypeAlias +from typing import Union, TYPE_CHECKING, Optional +import logging + +from functools import reduce +from string import Formatter + +from pydantic import BaseModel, model_validator, Field, field_serializer, field_validator + +from chatsky.slots.base_slots import ( + ExtractedSlot, + BaseSlot, +) +from chatsky.slots.slots import ExtractedGroupSlot, GroupSlot, ExtractedValueSlot + +if TYPE_CHECKING: + from chatsky.core import Context + +logger = logging.getLogger(__name__) + +SlotName: TypeAlias = str +""" +A string to identify slots. + +Top-level slots are identified by their key in a :py:class:`~.GroupSlot`. + +E.g. + +.. code:: python + + GroupSlot( + user=RegexpSlot(), + password=FunctionSlot, + ) + +Has two slots with names "user" and "password". + +For nested group slots use dots to separate names: + +.. code:: python + + GroupSlot( + user=GroupSlot( + name=FunctionSlot, + password=FunctionSlot, + ) + ) + +Has two slots with names "user.name" and "user.password". +""" + + +def recursive_getattr(obj, slot_name: SlotName): + def two_arg_getattr(__o, name): + # pydantic handles exception when accessing a non-existing extra-field on its own + # return None by default to avoid that + return getattr(__o, name, None) + + return reduce(two_arg_getattr, [obj, *slot_name.split(".")]) + + +def recursive_setattr(obj, slot_name: SlotName, value): + parent_slot, _, slot = slot_name.rpartition(".") + + if parent_slot: + setattr(recursive_getattr(obj, parent_slot), slot, value) + else: + setattr(obj, slot, value) + + +class KwargOnlyFormatter(Formatter): + def get_value(self, key, args, kwargs): + return super().get_value(str(key), args, kwargs) + + +class SlotManager(BaseModel): + """ + Provides API for managing slots. + + An instance of this class can be accessed via ``ctx.framework_data.slot_manager``. + """ + + slot_storage: ExtractedGroupSlot = Field(default_factory=ExtractedGroupSlot) + """Slot storage. Stored inside ctx.framework_data.""" + root_slot: GroupSlot = Field(default_factory=GroupSlot, exclude=True) + """Slot configuration passed during pipeline initialization.""" + + def set_root_slot(self, root_slot: GroupSlot): + """ + Set root_slot configuration from pipeline. + Update extracted slots with the new configuration: + + New slots are added with their :py:meth:`~.BaseSlot.init_value`. + Old extracted slot values are preserved only if their configuration did not change. + That is if they are still present in the config and if their fundamental type did not change + (i.e. `GroupSlot` did not turn into a `ValueSlot` or vice versa). + + This method is called by pipeline and is not supposed to be used otherwise. + """ + self.root_slot = root_slot + new_slot_storage = root_slot.init_value() + new_slot_storage.update(self.slot_storage) + self.slot_storage = new_slot_storage + + def get_slot(self, slot_name: SlotName) -> BaseSlot: + """ + Get slot configuration from the slot name. + + :raises KeyError: If the slot with the specified name does not exist. + """ + slot = recursive_getattr(self.root_slot, slot_name) + if isinstance(slot, BaseSlot): + return slot + raise KeyError(f"Could not find slot {slot_name!r}.") + + async def extract_slot(self, slot_name: SlotName, ctx: Context, success_only: bool) -> None: + """ + Extract slot `slot_name` and store extracted value in `slot_storage`. + + :raises KeyError: If the slot with the specified name does not exist. + + :param slot_name: Name of the slot to extract. + :param ctx: Context. + :param success_only: Whether to store the value only if it is successfully extracted. + """ + slot = self.get_slot(slot_name) + value = await slot.get_value(ctx) + + if value.__slot_extracted__ or success_only is False: + recursive_setattr(self.slot_storage, slot_name, value) + + async def extract_all(self, ctx: Context): + """ + Extract all slots from slot configuration `root_slot` and set `slot_storage` to the extracted value. + """ + self.slot_storage = await self.root_slot.get_value(ctx) + + def get_extracted_slot(self, slot_name: SlotName) -> Union[ExtractedValueSlot, ExtractedGroupSlot]: + """ + Retrieve extracted value from `slot_storage`. + + :raises KeyError: If the slot with the specified name does not exist. + """ + slot = recursive_getattr(self.slot_storage, slot_name) + if isinstance(slot, ExtractedSlot): + return slot + raise KeyError(f"Could not find slot {slot_name!r}.") + + def is_slot_extracted(self, slot_name: str) -> bool: + """ + Return if the specified slot is extracted. + + :raises KeyError: If the slot with the specified name does not exist. + """ + return self.get_extracted_slot(slot_name).__slot_extracted__ + + def all_slots_extracted(self) -> bool: + """ + Return if all slots are extracted. + """ + return self.slot_storage.__slot_extracted__ + + def unset_slot(self, slot_name: SlotName) -> None: + """ + Mark specified slot as not extracted and clear extracted value. + + :raises KeyError: If the slot with the specified name does not exist. + """ + self.get_extracted_slot(slot_name).__unset__() + + def unset_all_slots(self) -> None: + """ + Mark all slots as not extracted and clear all extracted values. + """ + self.slot_storage.__unset__() + + def fill_template(self, template: str) -> Optional[str]: + """ + Fill `template` string with extracted slot values and return a formatted string + or None if an exception has occurred while trying to fill template. + + `template` should be a format-string: + + E.g. "Your username is {profile.username}". + + For the example above, if ``profile.username`` slot has value "admin", + it would return the following text: + "Your username is admin". + """ + try: + return KwargOnlyFormatter().format(template, **dict(self.slot_storage.slots.items())) + except Exception as exc: + logger.exception("An exception occurred during template filling.", exc_info=exc) + return None diff --git a/chatsky/slots/slots.py b/chatsky/slots/slots.py index 0dc797034..af4f3acd6 100644 --- a/chatsky/slots/slots.py +++ b/chatsky/slots/slots.py @@ -8,8 +8,7 @@ import asyncio import re -from abc import abstractmethod -from typing import Callable, Any, Awaitable, TYPE_CHECKING, Union, Optional, Dict +from typing import Callable, Any, Awaitable, TYPE_CHECKING, Union, Dict from typing_extensions import Annotated import logging @@ -20,12 +19,10 @@ from chatsky.slots.base_slots import ( ExtractedSlot, BaseSlot, + ValueSlot, SlotNotExtracted, - KwargOnlyFormatter, - SlotName, - recursive_getattr, - recursive_setattr, ) +from chatsky.slots.slot_manager import KwargOnlyFormatter if TYPE_CHECKING: from chatsky.core import Context, Message @@ -81,25 +78,24 @@ def __str__(self): class ExtractedGroupSlot(ExtractedSlot, extra="allow"): value_format: str = None - __pydantic_extra__: Dict[ + slots: Dict[ str, Annotated[Union["ExtractedGroupSlot", "ExtractedValueSlot"], Field(union_mode="left_to_right")] ] @property def __slot_extracted__(self) -> bool: - return all([slot.__slot_extracted__ for slot in self.__pydantic_extra__.values()]) + return all([slot.__slot_extracted__ for slot in self.slots.values()]) def __unset__(self): - for child in self.__pydantic_extra__.values(): + for child in self.slots.values(): child.__unset__() # fill template here def __str__(self): if self.value_format is not None: - # return str({key: Kwargs.format() for key, value in self.__pydantic_extra__.items()}) - return KwargOnlyFormatter().format(self.value_format, **self.__pydantic_extra__) + return KwargOnlyFormatter().format(self.value_format, **self.slots) else: - return str({key: str(value) for key, value in self.__pydantic_extra__.items()}) + return str({key: str(value) for key, value in self.slots.items()}) def update(self, old: "ExtractedGroupSlot"): """ @@ -110,61 +106,14 @@ def update(self, old: "ExtractedGroupSlot"): :param old: An instance of :py:class:`~.ExtractedGroupSlot` stored in-context. Extracted values will be transferred to this object. """ - for slot in old.__pydantic_extra__: - if slot in self.__pydantic_extra__: - new_slot = self.__pydantic_extra__[slot] - old_slot = old.__pydantic_extra__[slot] + for slot in old.slots: + if slot in self.slots: + new_slot = self.slots[slot] + old_slot = old.slots[slot] if isinstance(new_slot, ExtractedGroupSlot) and isinstance(old_slot, ExtractedGroupSlot): new_slot.update(old_slot) if isinstance(new_slot, ExtractedValueSlot) and isinstance(old_slot, ExtractedValueSlot): - self.__pydantic_extra__[slot] = old_slot - - -class ValueSlot(BaseSlot, frozen=True): - """ - Value slot is a base class for all slots that are designed to extract concrete values. - Subclass it, if you want to declare your own slot type. - """ - - default_value: Any = None - - @abstractmethod - async def extract_value(self, ctx: Context) -> Union[Any, SlotNotExtracted]: - """ - Return value extracted from context. - - Return :py:exc:`~.SlotNotExtracted` to mark extraction as unsuccessful. - - Raising exceptions is also allowed and will result in an unsuccessful extraction as well. - """ - raise NotImplementedError - - async def get_value(self, ctx: Context) -> ExtractedValueSlot: - """Wrapper for :py:meth:`~.ValueSlot.extract_value` to handle exceptions.""" - extracted_value = SlotNotExtracted("Caught an exit exception.") - is_slot_extracted = False - - try: - extracted_value = await wrap_sync_function_in_async(self.extract_value, ctx) - is_slot_extracted = not isinstance(extracted_value, SlotNotExtracted) - except Exception as error: - logger.exception(f"Exception occurred during {self.__class__.__name__!r} extraction.", exc_info=error) - extracted_value = error - finally: - if not is_slot_extracted: - logger.debug(f"Slot {self.__class__.__name__!r} was not extracted: {extracted_value}") - return ExtractedValueSlot.model_construct( - is_slot_extracted=is_slot_extracted, - extracted_value=extracted_value, - default_value=self.default_value, - ) - - def init_value(self) -> ExtractedValueSlot: - return ExtractedValueSlot.model_construct( - is_slot_extracted=False, - extracted_value=SlotNotExtracted("Initial slot extraction."), - default_value=self.default_value, - ) + self.slots[slot] = old_slot class GroupSlot(BaseSlot, frozen=True): @@ -188,6 +137,16 @@ def __check_slot_names__(self): raise ValueError(f"Extra field name cannot contain dots: {field!r}") return self + def _flatten_group_slot(self, slot, parent_key=""): + items = {} + for key, value in slot.__pydantic_extra__.items(): + new_key = f"{parent_key}.{key}" if parent_key else key + if isinstance(value, GroupSlot): + items.update(self.__flatten_llm_group_slot(value, new_key)) + else: + items[new_key] = value + return items + async def get_value(self, ctx: Context) -> ExtractedGroupSlot: child_values = await asyncio.gather(*(child.get_value(ctx) for child in self.slots.values())) return ExtractedGroupSlot( @@ -196,7 +155,7 @@ async def get_value(self, ctx: Context) -> ExtractedGroupSlot: ) def init_value(self) -> ExtractedGroupSlot: - return ExtractedGroupSlot(**{child_name: child.init_value() for child_name, child in self.slots.items()}) + return ExtractedGroupSlot(slots={child_name: child.init_value() for child_name, child in self.slots.items()}) class RegexpSlot(ValueSlot, frozen=True): @@ -245,12 +204,12 @@ async def get_value(self, ctx: Context) -> ExtractedGroupSlot: ) ) return ExtractedGroupSlot( - **{child_name: child_value for child_value, child_name in zip(child_values, self.groups.keys())} + slots={child_name: child_value for child_value, child_name in zip(child_values, self.groups.keys())} ) def init_value(self) -> ExtractedGroupSlot: return ExtractedGroupSlot( - **{ + slots={ child_name: RegexpSlot(regexp=self.regexp, match_group_id=match_group).init_value() for child_name, match_group in self.groups.items() } @@ -268,124 +227,3 @@ class FunctionSlot(ValueSlot, frozen=True): async def extract_value(self, ctx: Context) -> Union[Any, SlotNotExtracted]: return await wrap_sync_function_in_async(self.func, ctx.last_request) - - -class SlotManager(BaseModel): - """ - Provides API for managing slots. - - An instance of this class can be accessed via ``ctx.framework_data.slot_manager``. - """ - - slot_storage: ExtractedGroupSlot = Field(default_factory=ExtractedGroupSlot) - """Slot storage. Stored inside ctx.framework_data.""" - root_slot: GroupSlot = Field(default_factory=GroupSlot, exclude=True) - """Slot configuration passed during pipeline initialization.""" - - def set_root_slot(self, root_slot: GroupSlot): - """ - Set root_slot configuration from pipeline. - Update extracted slots with the new configuration: - - New slots are added with their :py:meth:`~.BaseSlot.init_value`. - Old extracted slot values are preserved only if their configuration did not change. - That is if they are still present in the config and if their fundamental type did not change - (i.e. `GroupSlot` did not turn into a `ValueSlot` or vice versa). - - This method is called by pipeline and is not supposed to be used otherwise. - """ - self.root_slot = root_slot - new_slot_storage = root_slot.init_value() - new_slot_storage.update(self.slot_storage) - self.slot_storage = new_slot_storage - - def get_slot(self, slot_name: SlotName) -> BaseSlot: - """ - Get slot configuration from the slot name. - - :raises KeyError: If the slot with the specified name does not exist. - """ - slot = recursive_getattr(self.root_slot, slot_name) - if isinstance(slot, BaseSlot): - return slot - raise KeyError(f"Could not find slot {slot_name!r}.") - - async def extract_slot(self, slot_name: SlotName, ctx: Context, success_only: bool) -> None: - """ - Extract slot `slot_name` and store extracted value in `slot_storage`. - - :raises KeyError: If the slot with the specified name does not exist. - - :param slot_name: Name of the slot to extract. - :param ctx: Context. - :param success_only: Whether to store the value only if it is successfully extracted. - """ - slot = self.get_slot(slot_name) - value = await slot.get_value(ctx) - - if value.__slot_extracted__ or success_only is False: - recursive_setattr(self.slot_storage, slot_name, value) - - async def extract_all(self, ctx: Context): - """ - Extract all slots from slot configuration `root_slot` and set `slot_storage` to the extracted value. - """ - self.slot_storage = await self.root_slot.get_value(ctx) - - def get_extracted_slot(self, slot_name: SlotName) -> Union[ExtractedValueSlot, ExtractedGroupSlot]: - """ - Retrieve extracted value from `slot_storage`. - - :raises KeyError: If the slot with the specified name does not exist. - """ - slot = recursive_getattr(self.slot_storage, slot_name) - if isinstance(slot, ExtractedSlot): - return slot - raise KeyError(f"Could not find slot {slot_name!r}.") - - def is_slot_extracted(self, slot_name: str) -> bool: - """ - Return if the specified slot is extracted. - - :raises KeyError: If the slot with the specified name does not exist. - """ - return self.get_extracted_slot(slot_name).__slot_extracted__ - - def all_slots_extracted(self) -> bool: - """ - Return if all slots are extracted. - """ - return self.slot_storage.__slot_extracted__ - - def unset_slot(self, slot_name: SlotName) -> None: - """ - Mark specified slot as not extracted and clear extracted value. - - :raises KeyError: If the slot with the specified name does not exist. - """ - self.get_extracted_slot(slot_name).__unset__() - - def unset_all_slots(self) -> None: - """ - Mark all slots as not extracted and clear all extracted values. - """ - self.slot_storage.__unset__() - - def fill_template(self, template: str) -> Optional[str]: - """ - Fill `template` string with extracted slot values and return a formatted string - or None if an exception has occurred while trying to fill template. - - `template` should be a format-string: - - E.g. "Your username is {profile.username}". - - For the example above, if ``profile.username`` slot has value "admin", - it would return the following text: - "Your username is admin". - """ - try: - return KwargOnlyFormatter().format(template, **dict(self.slot_storage.__pydantic_extra__.items())) - except Exception as exc: - logger.exception("An exception occurred during template filling.", exc_info=exc) - return None From eb0a52800434dc57b58ff2237dc1a031068f47c2 Mon Sep 17 00:00:00 2001 From: ZergLev Date: Mon, 25 Nov 2024 19:35:57 +0300 Subject: [PATCH 04/38] changed RegexpGroupSlot to use regex only once, updated slot imports across the codebase (to the new file structure) --- chatsky/__rebuild_pydantic_models__.py | 3 +- chatsky/conditions/slots.py | 2 +- chatsky/core/context.py | 2 +- chatsky/processing/slots.py | 2 +- chatsky/slots/__init__.py | 2 +- chatsky/slots/slot_manager.py | 10 ++---- chatsky/slots/slots.py | 38 +++++++++++++-------- tests/slots/test_slot_functions.py | 4 ++- tests/slots/test_slot_manager.py | 4 +-- tests/slots/test_slot_partial_extraction.py | 2 +- tests/slots/test_slot_types.py | 5 ++- 11 files changed, 40 insertions(+), 34 deletions(-) diff --git a/chatsky/__rebuild_pydantic_models__.py b/chatsky/__rebuild_pydantic_models__.py index 2e89317bd..40c3c678c 100644 --- a/chatsky/__rebuild_pydantic_models__.py +++ b/chatsky/__rebuild_pydantic_models__.py @@ -4,7 +4,8 @@ from chatsky.core import Context, Message, Script from chatsky.core.script import Node from chatsky.core.pipeline import Pipeline -from chatsky.slots.slots import SlotManager, FunctionSlot +from chatsky.slots.slots import FunctionSlot +from chatsky.slots.slot_manager import SlotManager from chatsky.core.context import FrameworkData, ServiceState from chatsky.core.service import PipelineComponent diff --git a/chatsky/conditions/slots.py b/chatsky/conditions/slots.py index 8319c36af..c83a21f60 100644 --- a/chatsky/conditions/slots.py +++ b/chatsky/conditions/slots.py @@ -8,7 +8,7 @@ from typing import Literal, List, Any from chatsky.core import Context, BaseCondition -from chatsky.slots.base_slots import SlotName +from chatsky.slots.slot_manager import SlotName class SlotsExtracted(BaseCondition): diff --git a/chatsky/core/context.py b/chatsky/core/context.py index 59a9899a2..8fc1bf7ca 100644 --- a/chatsky/core/context.py +++ b/chatsky/core/context.py @@ -25,7 +25,7 @@ from pydantic import BaseModel, Field from chatsky.core.message import Message, MessageInitTypes -from chatsky.slots.slots import SlotManager +from chatsky.slots.slot_manager import SlotManager from chatsky.core.node_label import AbsoluteNodeLabel, AbsoluteNodeLabelInitTypes if TYPE_CHECKING: diff --git a/chatsky/processing/slots.py b/chatsky/processing/slots.py index 655f3d609..75b30a3fe 100644 --- a/chatsky/processing/slots.py +++ b/chatsky/processing/slots.py @@ -9,7 +9,7 @@ import logging from typing import List -from chatsky.slots.base_slots import SlotName +from chatsky.slots.slot_manager import SlotName from chatsky.core import Context, BaseProcessing from chatsky.responses.slots import FilledTemplate diff --git a/chatsky/slots/__init__.py b/chatsky/slots/__init__.py index 6c929b9af..6725feab3 100644 --- a/chatsky/slots/__init__.py +++ b/chatsky/slots/__init__.py @@ -1 +1 @@ -from chatsky.slots.slots import GroupSlot, ValueSlot, RegexpSlot, FunctionSlot +from chatsky.slots.slots import GroupSlot, ValueSlot, RegexpSlot, RegexpGroupSlot, FunctionSlot diff --git a/chatsky/slots/slot_manager.py b/chatsky/slots/slot_manager.py index 98b4e5b30..dc68790f0 100644 --- a/chatsky/slots/slot_manager.py +++ b/chatsky/slots/slot_manager.py @@ -11,15 +11,14 @@ import logging from functools import reduce -from string import Formatter -from pydantic import BaseModel, model_validator, Field, field_serializer, field_validator +from pydantic import BaseModel, Field from chatsky.slots.base_slots import ( ExtractedSlot, BaseSlot, ) -from chatsky.slots.slots import ExtractedGroupSlot, GroupSlot, ExtractedValueSlot +from chatsky.slots.slots import ExtractedGroupSlot, GroupSlot, ExtractedValueSlot, KwargOnlyFormatter if TYPE_CHECKING: from chatsky.core import Context @@ -76,11 +75,6 @@ def recursive_setattr(obj, slot_name: SlotName, value): setattr(obj, slot, value) -class KwargOnlyFormatter(Formatter): - def get_value(self, key, args, kwargs): - return super().get_value(str(key), args, kwargs) - - class SlotManager(BaseModel): """ Provides API for managing slots. diff --git a/chatsky/slots/slots.py b/chatsky/slots/slots.py index af4f3acd6..0e1be1c13 100644 --- a/chatsky/slots/slots.py +++ b/chatsky/slots/slots.py @@ -11,8 +11,9 @@ from typing import Callable, Any, Awaitable, TYPE_CHECKING, Union, Dict from typing_extensions import Annotated import logging +from string import Formatter -from pydantic import BaseModel, model_validator, Field, field_serializer, field_validator +from pydantic import model_validator, Field, field_serializer, field_validator from chatsky.utils.devel.async_helpers import wrap_sync_function_in_async from chatsky.utils.devel.json_serialization import pickle_serializer, pickle_validator @@ -22,7 +23,6 @@ ValueSlot, SlotNotExtracted, ) -from chatsky.slots.slot_manager import KwargOnlyFormatter if TYPE_CHECKING: from chatsky.core import Context, Message @@ -31,6 +31,11 @@ logger = logging.getLogger(__name__) +class KwargOnlyFormatter(Formatter): + def get_value(self, key, args, kwargs): + return super().get_value(str(key), args, kwargs) + + class ExtractedValueSlot(ExtractedSlot): """Value extracted from :py:class:`~.ValueSlot`.""" @@ -78,9 +83,7 @@ def __str__(self): class ExtractedGroupSlot(ExtractedSlot, extra="allow"): value_format: str = None - slots: Dict[ - str, Annotated[Union["ExtractedGroupSlot", "ExtractedValueSlot"], Field(union_mode="left_to_right")] - ] + slots: Dict[str, Annotated[Union["ExtractedGroupSlot", "ExtractedValueSlot"], Field(union_mode="left_to_right")]] @property def __slot_extracted__(self) -> bool: @@ -122,7 +125,7 @@ class GroupSlot(BaseSlot, frozen=True): """ value_format: str = None - slots: Dict[str, Annotated[Union["GroupSlot", "ValueSlot"], Field(union_mode="left_to_right")]] + slots: Dict[str, Annotated[Union["GroupSlot", "ValueSlot"], Field(union_mode="left_to_right")]] = {} def __init__(self, **kwargs): # supress unexpected argument warnings super().__init__(**kwargs) @@ -197,15 +200,22 @@ def __init__(self, **kwargs): # supress unexpected argument warnings super().__init__(**kwargs) async def get_value(self, ctx: Context) -> ExtractedGroupSlot: - child_values = await asyncio.gather( - *( - RegexpSlot(regexp=self.regexp, match_group_id=match_group).get_value(ctx) - for match_group in self.groups.values() + request_text = ctx.last_request.text + search = re.search(self.regexp, request_text) + if search: + return ExtractedGroupSlot( + slots={ + child_name: search.group(match_group) + for child_name, match_group in zip(self.groups.keys(), self.groups.values()) + } + ) + else: + return ExtractedGroupSlot( + slots={ + child_name: SlotNotExtracted(f"Failed to match pattern {self.regexp!r} in {request_text!r}.") + for child_name in self.groups.keys() + } ) - ) - return ExtractedGroupSlot( - slots={child_name: child_value for child_value, child_name in zip(child_values, self.groups.keys())} - ) def init_value(self) -> ExtractedGroupSlot: return ExtractedGroupSlot( diff --git a/tests/slots/test_slot_functions.py b/tests/slots/test_slot_functions.py index 902b1253c..1daacbd81 100644 --- a/tests/slots/test_slot_functions.py +++ b/tests/slots/test_slot_functions.py @@ -6,7 +6,9 @@ from chatsky import Context from chatsky.core import BaseResponse, Node from chatsky.core.message import MessageInitTypes, Message -from chatsky.slots.base_slots import ValueSlot, SlotNotExtracted, GroupSlot, SlotManager +from chatsky.slots.base_slots import ValueSlot, SlotNotExtracted +from chatsky.slots.slots import GroupSlot +from chatsky.slots.slot_manager import SlotManager from chatsky import conditions as cnd, responses as rsp, processing as proc from chatsky.processing.slots import logger as proc_logger from chatsky.slots.base_slots import logger as slot_logger diff --git a/tests/slots/test_slot_manager.py b/tests/slots/test_slot_manager.py index 9ff50fcd9..e6d7342c1 100644 --- a/tests/slots/test_slot_manager.py +++ b/tests/slots/test_slot_manager.py @@ -1,7 +1,6 @@ import pytest -from chatsky.slots.base_slots import ( - SlotManager, +from chatsky.slots.slots import ( RegexpSlot, GroupSlot, FunctionSlot, @@ -9,6 +8,7 @@ ExtractedValueSlot, SlotNotExtracted, ) +from chatsky.slots.slot_manager import SlotManager from chatsky.core import Message, Context diff --git a/tests/slots/test_slot_partial_extraction.py b/tests/slots/test_slot_partial_extraction.py index 88f5ee5fe..6c611b6f2 100644 --- a/tests/slots/test_slot_partial_extraction.py +++ b/tests/slots/test_slot_partial_extraction.py @@ -1,5 +1,5 @@ from chatsky.slots import RegexpSlot, GroupSlot -from chatsky.slots.base_slots import SlotManager +from chatsky.slots.slot_manager import SlotManager from chatsky.core import Message import pytest diff --git a/tests/slots/test_slot_types.py b/tests/slots/test_slot_types.py index 24f6c01d9..22bc7b003 100644 --- a/tests/slots/test_slot_types.py +++ b/tests/slots/test_slot_types.py @@ -2,12 +2,11 @@ from pydantic import ValidationError from chatsky.core import Message -from chatsky.slots.base_slots import ( +from chatsky.slots.base_slots import SlotNotExtracted, ExtractedValueSlot +from chatsky.slots.slots import ( RegexpSlot, GroupSlot, FunctionSlot, - SlotNotExtracted, - ExtractedValueSlot, ExtractedGroupSlot, ) From 036cf07cd2d77e5ee7434642da921e5efabcb9fe Mon Sep 17 00:00:00 2001 From: ZergLev Date: Mon, 25 Nov 2024 20:04:34 +0300 Subject: [PATCH 05/38] renamed 'success_only' to 'save_on_failure' --- chatsky/processing/slots.py | 8 +- chatsky/slots/slot_manager.py | 6 +- tests/slots/test_slot_functions.py | 2 +- tests/slots/test_slot_manager.py | 4 +- tests/slots/test_slot_partial_extraction.py | 8 +- tutorials/slots/2_partial_extraction.py | 84 ++++++++++----------- 6 files changed, 56 insertions(+), 56 deletions(-) diff --git a/chatsky/processing/slots.py b/chatsky/processing/slots.py index 75b30a3fe..98c0c3019 100644 --- a/chatsky/processing/slots.py +++ b/chatsky/processing/slots.py @@ -24,16 +24,16 @@ class Extract(BaseProcessing): slots: List[SlotName] """A list of slot names to extract.""" - success_only: bool = True + save_on_failure: bool = True """If set, only successfully extracted values will be stored in the slot storage.""" - def __init__(self, *slots: SlotName, success_only: bool = True): - super().__init__(slots=slots, success_only=success_only) + def __init__(self, *slots: SlotName, save_on_failure: bool = True): + super().__init__(slots=slots, save_on_failure=save_on_failure) async def call(self, ctx: Context): manager = ctx.framework_data.slot_manager results = await asyncio.gather( - *(manager.extract_slot(slot, ctx, self.success_only) for slot in self.slots), return_exceptions=True + *(manager.extract_slot(slot, ctx, self.save_on_failure) for slot in self.slots), return_exceptions=True ) for result in results: diff --git a/chatsky/slots/slot_manager.py b/chatsky/slots/slot_manager.py index dc68790f0..47c259f5b 100644 --- a/chatsky/slots/slot_manager.py +++ b/chatsky/slots/slot_manager.py @@ -115,7 +115,7 @@ def get_slot(self, slot_name: SlotName) -> BaseSlot: return slot raise KeyError(f"Could not find slot {slot_name!r}.") - async def extract_slot(self, slot_name: SlotName, ctx: Context, success_only: bool) -> None: + async def extract_slot(self, slot_name: SlotName, ctx: Context, save_on_failure: bool) -> None: """ Extract slot `slot_name` and store extracted value in `slot_storage`. @@ -123,12 +123,12 @@ async def extract_slot(self, slot_name: SlotName, ctx: Context, success_only: bo :param slot_name: Name of the slot to extract. :param ctx: Context. - :param success_only: Whether to store the value only if it is successfully extracted. + :param save_on_failure: Whether to store the value only if it is successfully extracted. """ slot = self.get_slot(slot_name) value = await slot.get_value(ctx) - if value.__slot_extracted__ or success_only is False: + if value.__slot_extracted__ or save_on_failure is False: recursive_setattr(self.slot_storage, slot_name, value) async def extract_all(self, ctx: Context): diff --git a/tests/slots/test_slot_functions.py b/tests/slots/test_slot_functions.py index 1daacbd81..09f65026a 100644 --- a/tests/slots/test_slot_functions.py +++ b/tests/slots/test_slot_functions.py @@ -67,7 +67,7 @@ async def test_basic_functions(context, manager, log_event_catcher): proc_logs = log_event_catcher(proc_logger, level=logging.ERROR) slot_logs = log_event_catcher(slot_logger, level=logging.ERROR) - await proc.Extract("0", "2", "err", success_only=False).wrapped_call(context) + await proc.Extract("0", "2", "err", save_on_failure=False).wrapped_call(context) assert manager.get_extracted_slot("0").value == 4 assert manager.is_slot_extracted("1") is False diff --git a/tests/slots/test_slot_manager.py b/tests/slots/test_slot_manager.py index e6d7342c1..9466600aa 100644 --- a/tests/slots/test_slot_manager.py +++ b/tests/slots/test_slot_manager.py @@ -174,7 +174,7 @@ def test_get_slot_by_name(empty_slot_manager): ], ) async def test_slot_extraction(slot_name, expected_slot_storage, empty_slot_manager, context_with_request): - await empty_slot_manager.extract_slot(slot_name, context_with_request, success_only=False) + await empty_slot_manager.extract_slot(slot_name, context_with_request, save_on_failure=False) assert empty_slot_manager.slot_storage == expected_slot_storage @@ -206,7 +206,7 @@ async def test_slot_extraction(slot_name, expected_slot_storage, empty_slot_mana ], ) async def test_successful_extraction(slot_name, expected_slot_storage, empty_slot_manager, context_with_request): - await empty_slot_manager.extract_slot(slot_name, context_with_request, success_only=True) + await empty_slot_manager.extract_slot(slot_name, context_with_request, save_on_failure=True) assert empty_slot_manager.slot_storage == expected_slot_storage diff --git a/tests/slots/test_slot_partial_extraction.py b/tests/slots/test_slot_partial_extraction.py index 6c611b6f2..2d8c22a62 100644 --- a/tests/slots/test_slot_partial_extraction.py +++ b/tests/slots/test_slot_partial_extraction.py @@ -65,20 +65,20 @@ def get_extracted_slots(manager: SlotManager): [("1 2 3", ["1", "2"]), ("1 3 5", ["1", "5"]), ("3 4 5 6", ["3", "4", "5", "6"])], ) async def test_partial_extraction(message, extracted, context_with_request, empty_slot_manager): - await empty_slot_manager.extract_slot("root_slot", context_with_request(message), success_only=False) + await empty_slot_manager.extract_slot("root_slot", context_with_request(message), save_on_failure=False) assert extracted == get_extracted_slots(empty_slot_manager) async def test_slot_storage_update(context_with_request, empty_slot_manager): - await empty_slot_manager.extract_slot("root_slot", context_with_request("1 3 5"), success_only=False) + await empty_slot_manager.extract_slot("root_slot", context_with_request("1 3 5"), save_on_failure=False) assert get_extracted_slots(empty_slot_manager) == ["1", "5"] - await empty_slot_manager.extract_slot("root_slot", context_with_request("2 4 6"), success_only=False) + await empty_slot_manager.extract_slot("root_slot", context_with_request("2 4 6"), save_on_failure=False) assert get_extracted_slots(empty_slot_manager) == ["1", "2", "5", "6"] - await empty_slot_manager.extract_slot("root_slot.nested_group", context_with_request("3 4"), success_only=False) + await empty_slot_manager.extract_slot("root_slot.nested_group", context_with_request("3 4"), save_on_failure=False) assert get_extracted_slots(empty_slot_manager) == ["1", "2", "3", "4", "5", "6"] diff --git a/tutorials/slots/2_partial_extraction.py b/tutorials/slots/2_partial_extraction.py index 87cc4bab4..5009ebcae 100644 --- a/tutorials/slots/2_partial_extraction.py +++ b/tutorials/slots/2_partial_extraction.py @@ -41,23 +41,23 @@ ## Success only extraction -The `Extract` function accepts `success_only` flag which makes it so +The `Extract` function accepts `save_on_failure` flag which makes it so that extracted value is not saved unless extraction is successful. This means that unsuccessfully trying to extract a slot will not overwrite its previously extracted value. -Note that `success_only` is `True` by default. +Note that `save_on_failure` is `True` by default. ## Partial group slot extraction A group slot marked with `allow_partial_extraction` only saves values of successfully extracted child slots. Extracting such group slot is equivalent to extracting every child slot -with the `success_only` flag. +with the `save_on_failure` flag. Partially extracted group slot is always considered successfully extracted -for the purposes of the `success_only` flag. +for the purposes of the `save_on_failure` flag. ## Code explanation @@ -68,16 +68,16 @@ Any slot in this group is saved if and only if that slot is successfully extracted. -Group `success_only_extraction` is extracted with the `success_only` +Group `save_on_failure_extraction` is extracted with the `save_on_failure` flag set to True. Any slot in this group is saved if and only if all of the slots in the group are successfully extracted within a single `Extract` call. -Group `success_only_false` is extracted with the `success_only` set to False. +Group `save_on_failure_false` is extracted with the `save_on_failure` set to False. Every slot in this group is saved (even if extraction was not successful). -Group `sub_slot_success_only_extraction` is extracted by passing all of its -child slots to the `Extract` method with the `success_only` flag set to True. +Group `sub_slot_save_on_failure_extraction` is extracted by passing all of its +child slots to the `Extract` method with the `save_on_failure` flag set to True. The behavior is equivalent to that of `partial_extraction`. """ @@ -97,13 +97,13 @@ **sub_slots, allow_partial_extraction=True, ), - "success_only_extraction": GroupSlot( + "save_on_failure_extraction": GroupSlot( **sub_slots, ), - "success_only_false": GroupSlot( + "save_on_failure_false": GroupSlot( **sub_slots, ), - "sub_slot_success_only_extraction": GroupSlot( + "sub_slot_save_on_failure_extraction": GroupSlot( **sub_slots, ), } @@ -126,30 +126,30 @@ PRE_RESPONSE: { "partial_extraction": proc.Extract("partial_extraction"), # partial extraction is always successful; - # success_only doesn't matter - "success_only_extraction": proc.Extract( - "success_only_extraction", success_only=True + # save_on_failure doesn't matter + "save_on_failure`_extraction": proc.Extract( + "save_on_failure_extraction", save_on_failure=True ), - # success_only is True by default - "success_only_false": proc.Extract( - "success_only_false", success_only=False + # save_on_failure is True by default + "save_on_failure_false": proc.Extract( + "save_on_failure_false", save_on_failure=False ), - "sub_slot_success_only_extraction": proc.Extract( - "sub_slot_success_only_extraction.email", - "sub_slot_success_only_extraction.date", - success_only=True, + "sub_slot_save_on_failure_extraction": proc.Extract( + "sub_slot_save_on_failure_extraction.email", + "sub_slot_save_on_failure_extraction.date", + save_on_failure=True, ), }, RESPONSE: rsp.FilledTemplate( "Extracted slots:\n" " Group with partial extraction:\n" " {partial_extraction}\n" - " Group with success_only:\n" - " {success_only_extraction}\n" - " Group without success_only:\n" - " {success_only_false}\n" - " Extracting sub-slots with success_only:\n" - " {sub_slot_success_only_extraction}" + " Group with save_on_failure:\n" + " {save_on_failure_extraction}\n" + " Group without save_on_failure:\n" + " {save_on_failure_false}\n" + " Extracting sub-slots with save_on_failure:\n" + " {sub_slot_save_on_failure_extraction}" ), }, }, @@ -162,11 +162,11 @@ "Extracted slots:\n" " Group with partial extraction:\n" " {'date': 'None', 'email': 'email@email.com'}\n" - " Group with success_only:\n" + " Group with save_on_failure:\n" " {'date': 'None', 'email': 'None'}\n" - " Group without success_only:\n" + " Group without save_on_failure:\n" " {'date': 'None', 'email': 'email@email.com'}\n" - " Extracting sub-slots with success_only:\n" + " Extracting sub-slots with save_on_failure:\n" " {'date': 'None', 'email': 'email@email.com'}", ), ( @@ -174,11 +174,11 @@ "Extracted slots:\n" " Group with partial extraction:\n" " {'date': '01.01.2024', 'email': 'email@email.com'}\n" - " Group with success_only:\n" + " Group with save_on_failure:\n" " {'date': 'None', 'email': 'None'}\n" - " Group without success_only:\n" + " Group without save_on_failure:\n" " {'date': '01.01.2024', 'email': 'None'}\n" - " Extracting sub-slots with success_only:\n" + " Extracting sub-slots with save_on_failure:\n" " {'date': '01.01.2024', 'email': 'email@email.com'}", ), ( @@ -186,11 +186,11 @@ "Extracted slots:\n" " Group with partial extraction:\n" " {'date': '02.01.2024', 'email': 'another_email@email.com'}\n" - " Group with success_only:\n" + " Group with save_on_failure:\n" " {'date': '02.01.2024', 'email': 'another_email@email.com'}\n" - " Group without success_only:\n" + " Group without save_on_failure:\n" " {'date': '02.01.2024', 'email': 'another_email@email.com'}\n" - " Extracting sub-slots with success_only:\n" + " Extracting sub-slots with save_on_failure:\n" " {'date': '02.01.2024', 'email': 'another_email@email.com'}", ), ( @@ -198,11 +198,11 @@ "Extracted slots:\n" " Group with partial extraction:\n" " {'date': '03.01.2024', 'email': 'another_email@email.com'}\n" - " Group with success_only:\n" + " Group with save_on_failure:\n" " {'date': '02.01.2024', 'email': 'another_email@email.com'}\n" - " Group without success_only:\n" + " Group without save_on_failure:\n" " {'date': '03.01.2024', 'email': 'None'}\n" - " Extracting sub-slots with success_only:\n" + " Extracting sub-slots with save_on_failure:\n" " {'date': '03.01.2024', 'email': 'another_email@email.com'}", ), ( @@ -210,11 +210,11 @@ "Extracted slots:\n" " Group with partial extraction:\n" " {'date': '03.01.2024', 'email': 'another_email@email.com'}\n" - " Group with success_only:\n" + " Group with save_on_failure:\n" " {'date': '02.01.2024', 'email': 'another_email@email.com'}\n" - " Group without success_only:\n" + " Group without save_on_failure:\n" " {'date': 'None', 'email': 'None'}\n" - " Extracting sub-slots with success_only:\n" + " Extracting sub-slots with save_on_failure:\n" " {'date': '03.01.2024', 'email': 'another_email@email.com'}", ), ] From 14c651c7c5e6c0da88a5019eae59d6828ca52539 Mon Sep 17 00:00:00 2001 From: ZergLev Date: Mon, 25 Nov 2024 20:07:56 +0300 Subject: [PATCH 06/38] lint --- tutorials/slots/2_partial_extraction.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tutorials/slots/2_partial_extraction.py b/tutorials/slots/2_partial_extraction.py index 5009ebcae..aa8e7875e 100644 --- a/tutorials/slots/2_partial_extraction.py +++ b/tutorials/slots/2_partial_extraction.py @@ -73,7 +73,8 @@ Any slot in this group is saved if and only if all of the slots in the group are successfully extracted within a single `Extract` call. -Group `save_on_failure_false` is extracted with the `save_on_failure` set to False. +Group `save_on_failure_false` is extracted with the `save_on_failure` +flag set to False. Every slot in this group is saved (even if extraction was not successful). Group `sub_slot_save_on_failure_extraction` is extracted by passing all of its From 306cb84364970b8a4548c725529e0b1cad51355d Mon Sep 17 00:00:00 2001 From: ZergLev Date: Wed, 27 Nov 2024 19:02:31 +0300 Subject: [PATCH 07/38] refactored RegexGroupSlot, changed description for it --- chatsky/slots/slots.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/chatsky/slots/slots.py b/chatsky/slots/slots.py index 0e1be1c13..674b6a1d9 100644 --- a/chatsky/slots/slots.py +++ b/chatsky/slots/slots.py @@ -83,7 +83,9 @@ def __str__(self): class ExtractedGroupSlot(ExtractedSlot, extra="allow"): value_format: str = None - slots: Dict[str, Annotated[Union["ExtractedGroupSlot", "ExtractedValueSlot"], Field(union_mode="left_to_right")]] + slots: Dict[ + str, Annotated[Union["ExtractedGroupSlot", "ExtractedValueSlot"], Field(union_mode="left_to_right")] + ] = Field(default_factory=dict) @property def __slot_extracted__(self) -> bool: @@ -125,7 +127,9 @@ class GroupSlot(BaseSlot, frozen=True): """ value_format: str = None - slots: Dict[str, Annotated[Union["GroupSlot", "ValueSlot"], Field(union_mode="left_to_right")]] = {} + slots: Dict[str, Annotated[Union["GroupSlot", "ValueSlot"], Field(union_mode="left_to_right")]] = Field( + default_factory=dict + ) def __init__(self, **kwargs): # supress unexpected argument warnings super().__init__(**kwargs) @@ -183,18 +187,18 @@ async def extract_value(self, ctx: Context) -> Union[str, SlotNotExtracted]: ) -# TODO: Change class and method descriptions. class RegexpGroupSlot(GroupSlot, frozen=True): """ - RegexpGroupSlot is semantically equal to a GroupSlot of RegexpSlots. - You can pass a compiled or a non-compiled pattern to the `regexp` argument. - If you want to extract a particular group, but not the full match, - change the `match_group_idx` parameter. + A slot type that applies a regex pattern once to extract values for + multiple child slots. Accepts a `regexp` pattern and a `groups` dictionary + mapping slot names to group indexes. Improves efficiency by performing a + single regex search for all specified groups, thus reducing the amount + of calls to your model. """ regexp: str groups: dict[str, int] - "Index of the group to match." + "A dictionary mapping slot names to match_group indexes." def __init__(self, **kwargs): # supress unexpected argument warnings super().__init__(**kwargs) @@ -205,7 +209,10 @@ async def get_value(self, ctx: Context) -> ExtractedGroupSlot: if search: return ExtractedGroupSlot( slots={ - child_name: search.group(match_group) + child_name: ExtractedValueSlot.model_construct( + is_slot_extracted=True, + extracted_value=search.group(match_group), + ) for child_name, match_group in zip(self.groups.keys(), self.groups.values()) } ) From 36a3d2b640ac68f6acebfdfbf7f05d96f8f4cd97 Mon Sep 17 00:00:00 2001 From: ZergLev Date: Wed, 27 Nov 2024 19:13:22 +0300 Subject: [PATCH 08/38] refactor: moved Extracted slots and GroupSlot into base_slots.py --- chatsky/slots/base_slots.py | 144 ++++++++++++++++++++++++++++++++- chatsky/slots/slot_manager.py | 5 +- chatsky/slots/slots.py | 147 +--------------------------------- 3 files changed, 149 insertions(+), 147 deletions(-) diff --git a/chatsky/slots/base_slots.py b/chatsky/slots/base_slots.py index c4bc3db0e..b4d62e78e 100644 --- a/chatsky/slots/base_slots.py +++ b/chatsky/slots/base_slots.py @@ -6,21 +6,30 @@ from __future__ import annotations +import asyncio from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Union +from typing import TYPE_CHECKING, Any, Union, Dict +from typing_extensions import Annotated import logging +from string import Formatter + +from pydantic import BaseModel, model_validator, Field, field_serializer, field_validator -from pydantic import BaseModel from chatsky.utils.devel.async_helpers import wrap_sync_function_in_async +from chatsky.utils.devel.json_serialization import pickle_serializer, pickle_validator if TYPE_CHECKING: from chatsky.core import Context - from chatsky.slots.slots import ExtractedValueSlot logger = logging.getLogger(__name__) +class KwargOnlyFormatter(Formatter): + def get_value(self, key, args, kwargs): + return super().get_value(str(key), args, kwargs) + + class SlotNotExtracted(Exception): """This exception can be returned or raised by slot extractor if slot extraction is unsuccessful.""" @@ -72,6 +81,135 @@ def init_value(self) -> ExtractedSlot: raise NotImplementedError +class ExtractedValueSlot(ExtractedSlot): + """Value extracted from :py:class:`~.ValueSlot`.""" + + is_slot_extracted: bool + extracted_value: Any + default_value: Any = None + + @field_serializer("extracted_value", "default_value", when_used="json") + def pickle_serialize_values(self, value): + """ + Cast values to string via pickle. + Allows storing arbitrary data in these fields when using context storages. + """ + if value is not None: + return pickle_serializer(value) + return value + + @field_validator("extracted_value", "default_value", mode="before") + @classmethod + def pickle_validate_values(cls, value): + """ + Restore values after being processed with + :py:meth:`pickle_serialize_values`. + """ + if value is not None: + return pickle_validator(value) + return value + + @property + def __slot_extracted__(self) -> bool: + return self.is_slot_extracted + + def __unset__(self): + self.is_slot_extracted = False + self.extracted_value = SlotNotExtracted("Slot manually unset.") + + @property + def value(self): + """Extracted value or the default value if the slot is not extracted.""" + return self.extracted_value if self.is_slot_extracted else self.default_value + + def __str__(self): + return str(self.value) + + +class ExtractedGroupSlot(ExtractedSlot, extra="allow"): + value_format: str = None + slots: Dict[ + str, Annotated[Union["ExtractedGroupSlot", "ExtractedValueSlot"], Field(union_mode="left_to_right")] + ] = Field(default_factory=dict) + + @property + def __slot_extracted__(self) -> bool: + return all([slot.__slot_extracted__ for slot in self.slots.values()]) + + def __unset__(self): + for child in self.slots.values(): + child.__unset__() + + # fill template here + def __str__(self): + if self.value_format is not None: + return KwargOnlyFormatter().format(self.value_format, **self.slots) + else: + return str({key: str(value) for key, value in self.slots.items()}) + + def update(self, old: "ExtractedGroupSlot"): + """ + Rebase this extracted groups slot on top of another one. + This is required to merge slot storage in-context + with a potentially different slot configuration passed to pipeline. + + :param old: An instance of :py:class:`~.ExtractedGroupSlot` stored in-context. + Extracted values will be transferred to this object. + """ + for slot in old.slots: + if slot in self.slots: + new_slot = self.slots[slot] + old_slot = old.slots[slot] + if isinstance(new_slot, ExtractedGroupSlot) and isinstance(old_slot, ExtractedGroupSlot): + new_slot.update(old_slot) + if isinstance(new_slot, ExtractedValueSlot) and isinstance(old_slot, ExtractedValueSlot): + self.slots[slot] = old_slot + + +class GroupSlot(BaseSlot, frozen=True): + """ + Base class for :py:class:`~.RootSlot` and :py:class:`~.GroupSlot`. + """ + + value_format: str = None + slots: Dict[str, Annotated[Union["GroupSlot", "ValueSlot"], Field(union_mode="left_to_right")]] = Field( + default_factory=dict + ) + + def __init__(self, **kwargs): # supress unexpected argument warnings + super().__init__(**kwargs) + + @model_validator(mode="after") + def __check_slot_names__(self): + """ + Slot names cannot contain dots. + """ + for field in self.slots.keys(): + if "." in field: + raise ValueError(f"Extra field name cannot contain dots: {field!r}") + return self + + def _flatten_group_slot(self, slot, parent_key=""): + items = {} + for key, value in slot.__pydantic_extra__.items(): + new_key = f"{parent_key}.{key}" if parent_key else key + if isinstance(value, GroupSlot): + items.update(self.__flatten_llm_group_slot(value, new_key)) + else: + items[new_key] = value + return items + + async def get_value(self, ctx: Context) -> ExtractedGroupSlot: + child_values = await asyncio.gather(*(child.get_value(ctx) for child in self.slots.values())) + return ExtractedGroupSlot( + value_format=self.value_format, + slots={child_name: child_value for child_value, child_name in zip(child_values, self.slots.keys())}, + ) + + def init_value(self) -> ExtractedGroupSlot: + return ExtractedGroupSlot(slots={child_name: child.init_value() for child_name, child in self.slots.items()}) + + class ValueSlot(BaseSlot, frozen=True): """ Value slot is a base class for all slots that are designed to extract concrete values. diff --git a/chatsky/slots/slot_manager.py b/chatsky/slots/slot_manager.py index 47c259f5b..e4399c478 100644 --- a/chatsky/slots/slot_manager.py +++ b/chatsky/slots/slot_manager.py @@ -17,8 +17,11 @@ from chatsky.slots.base_slots import ( ExtractedSlot, BaseSlot, + ExtractedValueSlot, + ExtractedGroupSlot, + GroupSlot, + KwargOnlyFormatter, ) -from chatsky.slots.slots import ExtractedGroupSlot, GroupSlot, ExtractedValueSlot, KwargOnlyFormatter if TYPE_CHECKING: from chatsky.core import Context diff --git a/chatsky/slots/slots.py b/chatsky/slots/slots.py index 674b6a1d9..077ad9265 100644 --- a/chatsky/slots/slots.py +++ b/chatsky/slots/slots.py @@ -6,22 +6,17 @@ from __future__ import annotations -import asyncio import re -from typing import Callable, Any, Awaitable, TYPE_CHECKING, Union, Dict -from typing_extensions import Annotated +from typing import Callable, Any, Awaitable, TYPE_CHECKING, Union import logging -from string import Formatter - -from pydantic import model_validator, Field, field_serializer, field_validator from chatsky.utils.devel.async_helpers import wrap_sync_function_in_async -from chatsky.utils.devel.json_serialization import pickle_serializer, pickle_validator from chatsky.slots.base_slots import ( - ExtractedSlot, - BaseSlot, ValueSlot, SlotNotExtracted, + GroupSlot, + ExtractedGroupSlot, + ExtractedValueSlot, ) if TYPE_CHECKING: @@ -31,140 +26,6 @@ logger = logging.getLogger(__name__) -class KwargOnlyFormatter(Formatter): - def get_value(self, key, args, kwargs): - return super().get_value(str(key), args, kwargs) - - -class ExtractedValueSlot(ExtractedSlot): - """Value extracted from :py:class:`~.ValueSlot`.""" - - is_slot_extracted: bool - extracted_value: Any - default_value: Any = None - - @field_serializer("extracted_value", "default_value", when_used="json") - def pickle_serialize_values(self, value): - """ - Cast values to string via pickle. - Allows storing arbitrary data in these fields when using context storages. - """ - if value is not None: - return pickle_serializer(value) - return value - - @field_validator("extracted_value", "default_value", mode="before") - @classmethod - def pickle_validate_values(cls, value): - """ - Restore values after being processed with - :py:meth:`pickle_serialize_values`. - """ - if value is not None: - return pickle_validator(value) - return value - - @property - def __slot_extracted__(self) -> bool: - return self.is_slot_extracted - - def __unset__(self): - self.is_slot_extracted = False - self.extracted_value = SlotNotExtracted("Slot manually unset.") - - @property - def value(self): - """Extracted value or the default value if the slot is not extracted.""" - return self.extracted_value if self.is_slot_extracted else self.default_value - - def __str__(self): - return str(self.value) - - -class ExtractedGroupSlot(ExtractedSlot, extra="allow"): - value_format: str = None - slots: Dict[ - str, Annotated[Union["ExtractedGroupSlot", "ExtractedValueSlot"], Field(union_mode="left_to_right")] - ] = Field(default_factory=dict) - - @property - def __slot_extracted__(self) -> bool: - return all([slot.__slot_extracted__ for slot in self.slots.values()]) - - def __unset__(self): - for child in self.slots.values(): - child.__unset__() - - # fill template here - def __str__(self): - if self.value_format is not None: - return KwargOnlyFormatter().format(self.value_format, **self.slots) - else: - return str({key: str(value) for key, value in self.slots.items()}) - - def update(self, old: "ExtractedGroupSlot"): - """ - Rebase this extracted groups slot on top of another one. - This is required to merge slot storage in-context - with a potentially different slot configuration passed to pipeline. - - :param old: An instance of :py:class:`~.ExtractedGroupSlot` stored in-context. - Extracted values will be transferred to this object. - """ - for slot in old.slots: - if slot in self.slots: - new_slot = self.slots[slot] - old_slot = old.slots[slot] - if isinstance(new_slot, ExtractedGroupSlot) and isinstance(old_slot, ExtractedGroupSlot): - new_slot.update(old_slot) - if isinstance(new_slot, ExtractedValueSlot) and isinstance(old_slot, ExtractedValueSlot): - self.slots[slot] = old_slot - - -class GroupSlot(BaseSlot, frozen=True): - """ - Base class for :py:class:`~.RootSlot` and :py:class:`~.GroupSlot`. - """ - - value_format: str = None - slots: Dict[str, Annotated[Union["GroupSlot", "ValueSlot"], Field(union_mode="left_to_right")]] = Field( - default_factory=dict - ) - - def __init__(self, **kwargs): # supress unexpected argument warnings - super().__init__(**kwargs) - - @model_validator(mode="after") - def __check_slot_names__(self): - """ - Slot names cannot contain dots. - """ - for field in self.slots.keys(): - if "." in field: - raise ValueError(f"Extra field name cannot contain dots: {field!r}") - return self - - def _flatten_group_slot(self, slot, parent_key=""): - items = {} - for key, value in slot.__pydantic_extra__.items(): - new_key = f"{parent_key}.{key}" if parent_key else key - if isinstance(value, GroupSlot): - items.update(self.__flatten_llm_group_slot(value, new_key)) - else: - items[new_key] = value - return items - - async def get_value(self, ctx: Context) -> ExtractedGroupSlot: - child_values = await asyncio.gather(*(child.get_value(ctx) for child in self.slots.values())) - return ExtractedGroupSlot( - value_format=self.value_format, - slots={child_name: child_value for child_value, child_name in zip(child_values, self.slots.keys())}, - ) - - def init_value(self) -> ExtractedGroupSlot: - return ExtractedGroupSlot(slots={child_name: child.init_value() for child_name, child in self.slots.items()}) - - class RegexpSlot(ValueSlot, frozen=True): """ RegexpSlot is a slot type that extracts its value using a regular expression. From c18d01a87b1b00ede71325f2cb479da42b85d8e7 Mon Sep 17 00:00:00 2001 From: ZergLev Date: Wed, 27 Nov 2024 19:24:21 +0300 Subject: [PATCH 09/38] refactor: updated slot imports across the codebase (according to the new slots' file structure) --- chatsky/slots/__init__.py | 3 ++- tests/slots/test_slot_functions.py | 3 +-- tests/slots/test_slot_manager.py | 5 ++--- tests/slots/test_slot_types.py | 9 ++------- 4 files changed, 7 insertions(+), 13 deletions(-) diff --git a/chatsky/slots/__init__.py b/chatsky/slots/__init__.py index 6725feab3..c350a5ea5 100644 --- a/chatsky/slots/__init__.py +++ b/chatsky/slots/__init__.py @@ -1 +1,2 @@ -from chatsky.slots.slots import GroupSlot, ValueSlot, RegexpSlot, RegexpGroupSlot, FunctionSlot +from chatsky.slots.slots import RegexpSlot, RegexpGroupSlot, FunctionSlot +from chatsky.slots.base_slots import GroupSlot, ValueSlot diff --git a/tests/slots/test_slot_functions.py b/tests/slots/test_slot_functions.py index 09f65026a..55e84f82a 100644 --- a/tests/slots/test_slot_functions.py +++ b/tests/slots/test_slot_functions.py @@ -6,8 +6,7 @@ from chatsky import Context from chatsky.core import BaseResponse, Node from chatsky.core.message import MessageInitTypes, Message -from chatsky.slots.base_slots import ValueSlot, SlotNotExtracted -from chatsky.slots.slots import GroupSlot +from chatsky.slots.base_slots import ValueSlot, SlotNotExtracted, GroupSlot from chatsky.slots.slot_manager import SlotManager from chatsky import conditions as cnd, responses as rsp, processing as proc from chatsky.processing.slots import logger as proc_logger diff --git a/tests/slots/test_slot_manager.py b/tests/slots/test_slot_manager.py index 9466600aa..2e65e123f 100644 --- a/tests/slots/test_slot_manager.py +++ b/tests/slots/test_slot_manager.py @@ -1,13 +1,12 @@ import pytest -from chatsky.slots.slots import ( - RegexpSlot, +from chatsky.slots.base_slots import ( GroupSlot, - FunctionSlot, ExtractedGroupSlot, ExtractedValueSlot, SlotNotExtracted, ) +from chatsky.slots.slots import RegexpSlot, FunctionSlot from chatsky.slots.slot_manager import SlotManager from chatsky.core import Message, Context diff --git a/tests/slots/test_slot_types.py b/tests/slots/test_slot_types.py index 22bc7b003..5f5470207 100644 --- a/tests/slots/test_slot_types.py +++ b/tests/slots/test_slot_types.py @@ -2,13 +2,8 @@ from pydantic import ValidationError from chatsky.core import Message -from chatsky.slots.base_slots import SlotNotExtracted, ExtractedValueSlot -from chatsky.slots.slots import ( - RegexpSlot, - GroupSlot, - FunctionSlot, - ExtractedGroupSlot, -) +from chatsky.slots.base_slots import SlotNotExtracted, ExtractedValueSlot, GroupSlot, ExtractedGroupSlot +from chatsky.slots.slots import RegexpSlot, FunctionSlot @pytest.mark.parametrize( From a49984bab8a059d44642acc9bb824e50c1681e87 Mon Sep 17 00:00:00 2001 From: ZergLev Date: Wed, 27 Nov 2024 19:56:45 +0300 Subject: [PATCH 10/38] drafted pipeline.yaml test files updates, test_script_parsing's TestImportPipeline tests all pass --- tests/core/script_parsing/pipeline.yaml | 22 +++++++++++-------- .../script_parsing/test_script_parsing.py | 8 +++++-- .../pipeline.yaml | 22 +++++++++++-------- 3 files changed, 32 insertions(+), 20 deletions(-) diff --git a/tests/core/script_parsing/pipeline.yaml b/tests/core/script_parsing/pipeline.yaml index 5b26e179e..dc798f5d4 100644 --- a/tests/core/script_parsing/pipeline.yaml +++ b/tests/core/script_parsing/pipeline.yaml @@ -17,12 +17,16 @@ fallback_label: - other_flow - other_node slots: - person: - likes: - chatsky.slots.RegexpSlot: - regexp: "I like (.+)" - match_group_idx: 1 - age: - chatsky.slots.RegexpSlot: - regexp: "I'm ([0-9]+) years old" - match_group_idx: 1 + chatsky.slots.GroupSlot: + slots: + person: + chatsky.slots.GroupSlot: + slots: + likes: + chatsky.slots.RegexpSlot: + regexp: "I like (.+)" + match_group_idx: 1 + age: + chatsky.slots.RegexpSlot: + regexp: "I'm ([0-9]+) years old" + match_group_idx: 1 diff --git a/tests/core/script_parsing/test_script_parsing.py b/tests/core/script_parsing/test_script_parsing.py index 9307f4f15..3098ab4a5 100644 --- a/tests/core/script_parsing/test_script_parsing.py +++ b/tests/core/script_parsing/test_script_parsing.py @@ -149,8 +149,12 @@ def test_normal_import(self): assert start_node.transitions[0].dst == chatsky.dst.Previous() assert start_node.transitions[0].cnd == chatsky.cnd.HasText("t") - assert pipeline.slots.person.likes == chatsky.slots.RegexpSlot(regexp="I like (.+)", match_group_idx=1) - assert pipeline.slots.person.age == chatsky.slots.RegexpSlot(regexp="I'm ([0-9]+) years old", match_group_idx=1) + assert pipeline.slots.slots["person"].slots["likes"] == chatsky.slots.RegexpSlot( + regexp="I like (.+)", match_group_idx=1 + ) + assert pipeline.slots.slots["person"].slots["age"] == chatsky.slots.RegexpSlot( + regexp="I'm ([0-9]+) years old", match_group_idx=1 + ) def test_import_json(self): pipeline = chatsky.Pipeline.from_file(current_dir / "pipeline.json", custom_dir=current_dir / "custom") diff --git a/utils/pipeline_yaml_import_example/pipeline.yaml b/utils/pipeline_yaml_import_example/pipeline.yaml index b7c638c45..ec692fa1e 100644 --- a/utils/pipeline_yaml_import_example/pipeline.yaml +++ b/utils/pipeline_yaml_import_example/pipeline.yaml @@ -96,15 +96,19 @@ fallback_label: - tech_flow - fallback_node slots: - person: - name: - chatsky.slots.RegexpSlot: - regexp: "My name is (.+)" - match_group_idx: 1 - age: - chatsky.slots.RegexpSlot: - regexp: "I'm ([0-9]+) years old" - match_group_idx: 1 + chatsky.slots.GroupSlot: + slots: + person: + chatsky.slots.GroupSlot: + slots: + name: + chatsky.slots.RegexpSlot: + regexp: "My name is (.+)" + match_group_idx: 1 + age: + chatsky.slots.RegexpSlot: + regexp: "I'm ([0-9]+) years old" + match_group_idx: 1 messenger_interface: chatsky.messengers.TelegramInterface: token: From c0ee05a5ca4963ceb439f9c871d95a4f330cc89c Mon Sep 17 00:00:00 2001 From: ZergLev Date: Mon, 2 Dec 2024 18:32:09 +0300 Subject: [PATCH 11/38] renamed slots.py -> standard_slots.py --- chatsky/__rebuild_pydantic_models__.py | 2 +- chatsky/slots/__init__.py | 2 +- chatsky/slots/{slots.py => standard_slots.py} | 0 tests/slots/test_slot_manager.py | 2 +- tests/slots/test_slot_types.py | 2 +- 5 files changed, 4 insertions(+), 4 deletions(-) rename chatsky/slots/{slots.py => standard_slots.py} (100%) diff --git a/chatsky/__rebuild_pydantic_models__.py b/chatsky/__rebuild_pydantic_models__.py index 40c3c678c..747f0d073 100644 --- a/chatsky/__rebuild_pydantic_models__.py +++ b/chatsky/__rebuild_pydantic_models__.py @@ -4,7 +4,7 @@ from chatsky.core import Context, Message, Script from chatsky.core.script import Node from chatsky.core.pipeline import Pipeline -from chatsky.slots.slots import FunctionSlot +from chatsky.slots.standard_slots import FunctionSlot from chatsky.slots.slot_manager import SlotManager from chatsky.core.context import FrameworkData, ServiceState from chatsky.core.service import PipelineComponent diff --git a/chatsky/slots/__init__.py b/chatsky/slots/__init__.py index c350a5ea5..bc1c33907 100644 --- a/chatsky/slots/__init__.py +++ b/chatsky/slots/__init__.py @@ -1,2 +1,2 @@ -from chatsky.slots.slots import RegexpSlot, RegexpGroupSlot, FunctionSlot +from chatsky.slots.standard_slots import RegexpSlot, RegexpGroupSlot, FunctionSlot from chatsky.slots.base_slots import GroupSlot, ValueSlot diff --git a/chatsky/slots/slots.py b/chatsky/slots/standard_slots.py similarity index 100% rename from chatsky/slots/slots.py rename to chatsky/slots/standard_slots.py diff --git a/tests/slots/test_slot_manager.py b/tests/slots/test_slot_manager.py index 2e65e123f..4694be3d2 100644 --- a/tests/slots/test_slot_manager.py +++ b/tests/slots/test_slot_manager.py @@ -6,7 +6,7 @@ ExtractedValueSlot, SlotNotExtracted, ) -from chatsky.slots.slots import RegexpSlot, FunctionSlot +from chatsky.slots.standard_slots import RegexpSlot, FunctionSlot from chatsky.slots.slot_manager import SlotManager from chatsky.core import Message, Context diff --git a/tests/slots/test_slot_types.py b/tests/slots/test_slot_types.py index 5f5470207..626840d53 100644 --- a/tests/slots/test_slot_types.py +++ b/tests/slots/test_slot_types.py @@ -3,7 +3,7 @@ from chatsky.core import Message from chatsky.slots.base_slots import SlotNotExtracted, ExtractedValueSlot, GroupSlot, ExtractedGroupSlot -from chatsky.slots.slots import RegexpSlot, FunctionSlot +from chatsky.slots.standard_slots import RegexpSlot, FunctionSlot @pytest.mark.parametrize( From ad6d90687858973e9e8c272450b0a7d896322e3a Mon Sep 17 00:00:00 2001 From: ZergLev Date: Mon, 2 Dec 2024 19:50:02 +0300 Subject: [PATCH 12/38] drafting dictionary validator for GorupSlot (unfinished) --- chatsky/slots/base_slots.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/chatsky/slots/base_slots.py b/chatsky/slots/base_slots.py index b4d62e78e..3548cae85 100644 --- a/chatsky/slots/base_slots.py +++ b/chatsky/slots/base_slots.py @@ -179,6 +179,29 @@ class GroupSlot(BaseSlot, frozen=True): def __init__(self, **kwargs): # supress unexpected argument warnings super().__init__(**kwargs) + @model_validator(mode="before") + @classmethod + def __validate_from_keywords__(cls, data: Any): + """ + Add support for initializing slots from keywords instead of a dictionary. + """ + if isinstance(data, dict): + if "slots" not in data: + return {"slots": data} + return data + return {"slots": dict(data)} + + """ + def inner_func(*args: Any, **kwargs: Any): + if len(kwargs) > 0: + return dict(kwargs) + + result = inner_func(data) + if result is not None: + return result + return data + """ + @model_validator(mode="after") def __check_slot_names__(self): """ From 4e0215e4f012b71cf5caac421d530ebfb94679e1 Mon Sep 17 00:00:00 2001 From: ZergLev Date: Mon, 2 Dec 2024 20:22:32 +0300 Subject: [PATCH 13/38] drafted model_validator for GroupSlot --- chatsky/slots/base_slots.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/chatsky/slots/base_slots.py b/chatsky/slots/base_slots.py index 3548cae85..36eb24044 100644 --- a/chatsky/slots/base_slots.py +++ b/chatsky/slots/base_slots.py @@ -181,15 +181,29 @@ def __init__(self, **kwargs): # supress unexpected argument warnings @model_validator(mode="before") @classmethod - def __validate_from_keywords__(cls, data: Any): + def __validate_group_slot__(cls, data: Any): """ - Add support for initializing slots from keywords instead of a dictionary. + Add support for initializing slots from keywords or a dictionary. + A combination is possible, in that case keywords will take priority. """ if isinstance(data, dict): + result = {"slots": dict()} + if data.get("value_format") is not None: + result.update({"value_format": data.get("value_format")}) + result.update(data.get("value_format", dict())) + result["slots"].update(data.get("slots", dict())) + for key, value in data.items(): + if key not in ["value_format", "slots"]: + result["slots"][key] = value + return result + return data + + """ if "slots" not in data: return {"slots": data} return data return {"slots": dict(data)} + """ """ def inner_func(*args: Any, **kwargs: Any): From fe2a382e44ecb74d445019491090021bb6b6be8f Mon Sep 17 00:00:00 2001 From: ZergLev Date: Mon, 2 Dec 2024 20:37:01 +0300 Subject: [PATCH 14/38] reverted pipeline.yml changes --- tests/core/script_parsing/pipeline.yaml | 22 ++++++++----------- .../pipeline.yaml | 22 ++++++++----------- 2 files changed, 18 insertions(+), 26 deletions(-) diff --git a/tests/core/script_parsing/pipeline.yaml b/tests/core/script_parsing/pipeline.yaml index dc798f5d4..5b26e179e 100644 --- a/tests/core/script_parsing/pipeline.yaml +++ b/tests/core/script_parsing/pipeline.yaml @@ -17,16 +17,12 @@ fallback_label: - other_flow - other_node slots: - chatsky.slots.GroupSlot: - slots: - person: - chatsky.slots.GroupSlot: - slots: - likes: - chatsky.slots.RegexpSlot: - regexp: "I like (.+)" - match_group_idx: 1 - age: - chatsky.slots.RegexpSlot: - regexp: "I'm ([0-9]+) years old" - match_group_idx: 1 + person: + likes: + chatsky.slots.RegexpSlot: + regexp: "I like (.+)" + match_group_idx: 1 + age: + chatsky.slots.RegexpSlot: + regexp: "I'm ([0-9]+) years old" + match_group_idx: 1 diff --git a/utils/pipeline_yaml_import_example/pipeline.yaml b/utils/pipeline_yaml_import_example/pipeline.yaml index ec692fa1e..b7c638c45 100644 --- a/utils/pipeline_yaml_import_example/pipeline.yaml +++ b/utils/pipeline_yaml_import_example/pipeline.yaml @@ -96,19 +96,15 @@ fallback_label: - tech_flow - fallback_node slots: - chatsky.slots.GroupSlot: - slots: - person: - chatsky.slots.GroupSlot: - slots: - name: - chatsky.slots.RegexpSlot: - regexp: "My name is (.+)" - match_group_idx: 1 - age: - chatsky.slots.RegexpSlot: - regexp: "I'm ([0-9]+) years old" - match_group_idx: 1 + person: + name: + chatsky.slots.RegexpSlot: + regexp: "My name is (.+)" + match_group_idx: 1 + age: + chatsky.slots.RegexpSlot: + regexp: "I'm ([0-9]+) years old" + match_group_idx: 1 messenger_interface: chatsky.messengers.TelegramInterface: token: From 442385036071f204d06dd834eca2ac07abb9bce2 Mon Sep 17 00:00:00 2001 From: ZergLev Date: Wed, 4 Dec 2024 18:05:31 +0300 Subject: [PATCH 15/38] reverted changes for GroupSlot (extra fields are back), renamed value_format into string_format --- chatsky/slots/base_slots.py | 99 ++++++++----------- chatsky/slots/slot_manager.py | 2 +- chatsky/slots/standard_slots.py | 6 +- .../script_parsing/test_script_parsing.py | 8 +- 4 files changed, 47 insertions(+), 68 deletions(-) diff --git a/chatsky/slots/base_slots.py b/chatsky/slots/base_slots.py index 36eb24044..031574712 100644 --- a/chatsky/slots/base_slots.py +++ b/chatsky/slots/base_slots.py @@ -127,25 +127,25 @@ def __str__(self): class ExtractedGroupSlot(ExtractedSlot, extra="allow"): - value_format: str = None - slots: Dict[ + string_format: str | None = None + __pydantic_extra__: Dict[ str, Annotated[Union["ExtractedGroupSlot", "ExtractedValueSlot"], Field(union_mode="left_to_right")] - ] = Field(default_factory=dict) + ] @property def __slot_extracted__(self) -> bool: - return all([slot.__slot_extracted__ for slot in self.slots.values()]) + return all([slot.__slot_extracted__ for slot in self.__pydantic_extra__.values()]) def __unset__(self): - for child in self.slots.values(): + for child in self.__pydantic_extra__.values(): child.__unset__() # fill template here def __str__(self): - if self.value_format is not None: - return KwargOnlyFormatter().format(self.value_format, **self.slots) + if self.string_format is not None: + return KwargOnlyFormatter().format(self.string_format, **self.__pydantic_extra__) else: - return str({key: str(value) for key, value in self.slots.items()}) + return str({key: str(value) for key, value in self.__pydantic_extra__.items()}) def update(self, old: "ExtractedGroupSlot"): """ @@ -156,14 +156,14 @@ def update(self, old: "ExtractedGroupSlot"): :param old: An instance of :py:class:`~.ExtractedGroupSlot` stored in-context. Extracted values will be transferred to this object. """ - for slot in old.slots: - if slot in self.slots: - new_slot = self.slots[slot] - old_slot = old.slots[slot] + for slot in old.__pydantic_extra__: + if slot in self.__pydantic_extra__: + new_slot = self.__pydantic_extra__[slot] + old_slot = old.__pydantic_extra__[slot] if isinstance(new_slot, ExtractedGroupSlot) and isinstance(old_slot, ExtractedGroupSlot): new_slot.update(old_slot) if isinstance(new_slot, ExtractedValueSlot) and isinstance(old_slot, ExtractedValueSlot): - self.slots[slot] = old_slot + self.__pydantic_extra__[slot] = old_slot class GroupSlot(BaseSlot, frozen=True): @@ -171,59 +171,38 @@ class GroupSlot(BaseSlot, frozen=True): Base class for :py:class:`~.RootSlot` and :py:class:`~.GroupSlot`. """ - value_format: str = None - slots: Dict[str, Annotated[Union["GroupSlot", "ValueSlot"], Field(union_mode="left_to_right")]] = Field( - default_factory=dict - ) + string_format: str = None + __pydantic_extra__: Dict[str, Annotated[Union["GroupSlot", "ValueSlot"], Field(union_mode="left_to_right")]] + allow_partial_extraction: bool = False + """If True, extraction returns only successfully extracted child slots.""" - def __init__(self, **kwargs): # supress unexpected argument warnings - super().__init__(**kwargs) + def __init__(self, allow_partial_extraction=False, **kwargs): + super().__init__(allow_partial_extraction=allow_partial_extraction, **kwargs) @model_validator(mode="before") @classmethod - def __validate_group_slot__(cls, data: Any): + def __check_reserved_slot_names__(cls, data: Any): """ - Add support for initializing slots from keywords or a dictionary. - A combination is possible, in that case keywords will take priority. + Check that reserved names are used correctly and not taken by the user's slots. """ if isinstance(data, dict): - result = {"slots": dict()} - if data.get("value_format") is not None: - result.update({"value_format": data.get("value_format")}) - result.update(data.get("value_format", dict())) - result["slots"].update(data.get("slots", dict())) - for key, value in data.items(): - if key not in ["value_format", "slots"]: - result["slots"][key] = value - return result + reserved_names = ["string_format", "allow_partial_extraction"] + for name in reserved_names: + if data.get(name) is not None: + if isinstance(data.get(name), BaseSlot): + raise TypeError(f"Slots cannot be named '{name}', it's a reserved name.") return data - """ - if "slots" not in data: - return {"slots": data} - return data - return {"slots": dict(data)} - """ - - """ - def inner_func(*args: Any, **kwargs: Any): - if len(kwargs) > 0: - return dict(kwargs) - - result = inner_func(data) - if result is not None: - return result - return data - """ - @model_validator(mode="after") - def __check_slot_names__(self): + def __check_extra_field_names__(self): """ - Slot names cannot contain dots. + Extra field names cannot be dunder names or contain dots. """ - for field in self.slots.keys(): + for field in self.__pydantic_extra__.keys(): if "." in field: raise ValueError(f"Extra field name cannot contain dots: {field!r}") + if field.startswith("__") and field.endswith("__"): + raise ValueError(f"Extra field names cannot be dunder: {field!r}") return self def _flatten_group_slot(self, slot, parent_key=""): @@ -237,14 +216,18 @@ def _flatten_group_slot(self, slot, parent_key=""): return items async def get_value(self, ctx: Context) -> ExtractedGroupSlot: - child_values = await asyncio.gather(*(child.get_value(ctx) for child in self.slots.values())) - return ExtractedGroupSlot( - value_format=self.value_format, - slots={child_name: child_value for child_value, child_name in zip(child_values, self.slots.keys())}, - ) + child_values = await asyncio.gather(*(child.get_value(ctx) for child in self.__pydantic_extra__.values())) + extracted_values = {} + for child_value, child_name in zip(child_values, self.__pydantic_extra__.keys()): + if child_value.__slot_extracted__ or not self.allow_partial_extraction: + extracted_values[child_name] = child_value + + return ExtractedGroupSlot(string_format=self.string_format, **extracted_values) def init_value(self) -> ExtractedGroupSlot: - return ExtractedGroupSlot(slots={child_name: child.init_value() for child_name, child in self.slots.items()}) + return ExtractedGroupSlot( + **{child_name: child.init_value() for child_name, child in self.__pydantic_extra__.items()} + ) class ValueSlot(BaseSlot, frozen=True): diff --git a/chatsky/slots/slot_manager.py b/chatsky/slots/slot_manager.py index e4399c478..3dc023468 100644 --- a/chatsky/slots/slot_manager.py +++ b/chatsky/slots/slot_manager.py @@ -193,7 +193,7 @@ def fill_template(self, template: str) -> Optional[str]: "Your username is admin". """ try: - return KwargOnlyFormatter().format(template, **dict(self.slot_storage.slots.items())) + return KwargOnlyFormatter().format(template, **dict(self.slot_storage.__pydantic_extra__.items())) except Exception as exc: logger.exception("An exception occurred during template filling.", exc_info=exc) return None diff --git a/chatsky/slots/standard_slots.py b/chatsky/slots/standard_slots.py index 077ad9265..8f43dd15c 100644 --- a/chatsky/slots/standard_slots.py +++ b/chatsky/slots/standard_slots.py @@ -69,7 +69,7 @@ async def get_value(self, ctx: Context) -> ExtractedGroupSlot: search = re.search(self.regexp, request_text) if search: return ExtractedGroupSlot( - slots={ + **{ child_name: ExtractedValueSlot.model_construct( is_slot_extracted=True, extracted_value=search.group(match_group), @@ -79,7 +79,7 @@ async def get_value(self, ctx: Context) -> ExtractedGroupSlot: ) else: return ExtractedGroupSlot( - slots={ + **{ child_name: SlotNotExtracted(f"Failed to match pattern {self.regexp!r} in {request_text!r}.") for child_name in self.groups.keys() } @@ -87,7 +87,7 @@ async def get_value(self, ctx: Context) -> ExtractedGroupSlot: def init_value(self) -> ExtractedGroupSlot: return ExtractedGroupSlot( - slots={ + **{ child_name: RegexpSlot(regexp=self.regexp, match_group_id=match_group).init_value() for child_name, match_group in self.groups.items() } diff --git a/tests/core/script_parsing/test_script_parsing.py b/tests/core/script_parsing/test_script_parsing.py index 3098ab4a5..9307f4f15 100644 --- a/tests/core/script_parsing/test_script_parsing.py +++ b/tests/core/script_parsing/test_script_parsing.py @@ -149,12 +149,8 @@ def test_normal_import(self): assert start_node.transitions[0].dst == chatsky.dst.Previous() assert start_node.transitions[0].cnd == chatsky.cnd.HasText("t") - assert pipeline.slots.slots["person"].slots["likes"] == chatsky.slots.RegexpSlot( - regexp="I like (.+)", match_group_idx=1 - ) - assert pipeline.slots.slots["person"].slots["age"] == chatsky.slots.RegexpSlot( - regexp="I'm ([0-9]+) years old", match_group_idx=1 - ) + assert pipeline.slots.person.likes == chatsky.slots.RegexpSlot(regexp="I like (.+)", match_group_idx=1) + assert pipeline.slots.person.age == chatsky.slots.RegexpSlot(regexp="I'm ([0-9]+) years old", match_group_idx=1) def test_import_json(self): pipeline = chatsky.Pipeline.from_file(current_dir / "pipeline.json", custom_dir=current_dir / "custom") From 2794f8f246fb026b4a233b39306e479e14282855 Mon Sep 17 00:00:00 2001 From: ZergLev Date: Wed, 4 Dec 2024 18:55:57 +0300 Subject: [PATCH 16/38] made string_format Optional[] --- chatsky/slots/base_slots.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/chatsky/slots/base_slots.py b/chatsky/slots/base_slots.py index 031574712..87cd82826 100644 --- a/chatsky/slots/base_slots.py +++ b/chatsky/slots/base_slots.py @@ -8,7 +8,7 @@ import asyncio from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Union, Dict +from typing import TYPE_CHECKING, Any, Union, Dict, Optional from typing_extensions import Annotated import logging from string import Formatter @@ -127,7 +127,7 @@ def __str__(self): class ExtractedGroupSlot(ExtractedSlot, extra="allow"): - string_format: str | None = None + string_format: Optional[str] = None __pydantic_extra__: Dict[ str, Annotated[Union["ExtractedGroupSlot", "ExtractedValueSlot"], Field(union_mode="left_to_right")] ] @@ -171,7 +171,7 @@ class GroupSlot(BaseSlot, frozen=True): Base class for :py:class:`~.RootSlot` and :py:class:`~.GroupSlot`. """ - string_format: str = None + string_format: Optional[str] = None __pydantic_extra__: Dict[str, Annotated[Union["GroupSlot", "ValueSlot"], Field(union_mode="left_to_right")]] allow_partial_extraction: bool = False """If True, extraction returns only successfully extracted child slots.""" From a33a4f710187e619f21162bf928d97cd73d342c8 Mon Sep 17 00:00:00 2001 From: ZergLev Date: Wed, 4 Dec 2024 19:14:05 +0300 Subject: [PATCH 17/38] moved ValueSlot before GroupSlot --- chatsky/slots/base_slots.py | 100 ++++++++++++++++++------------------ 1 file changed, 50 insertions(+), 50 deletions(-) diff --git a/chatsky/slots/base_slots.py b/chatsky/slots/base_slots.py index 87cd82826..326ce66fe 100644 --- a/chatsky/slots/base_slots.py +++ b/chatsky/slots/base_slots.py @@ -127,10 +127,10 @@ def __str__(self): class ExtractedGroupSlot(ExtractedSlot, extra="allow"): - string_format: Optional[str] = None __pydantic_extra__: Dict[ str, Annotated[Union["ExtractedGroupSlot", "ExtractedValueSlot"], Field(union_mode="left_to_right")] ] + string_format: Optional[str] = None @property def __slot_extracted__(self) -> bool: @@ -166,13 +166,60 @@ def update(self, old: "ExtractedGroupSlot"): self.__pydantic_extra__[slot] = old_slot -class GroupSlot(BaseSlot, frozen=True): +class ValueSlot(BaseSlot, frozen=True): + """ + Value slot is a base class for all slots that are designed to extract concrete values. + Subclass it, if you want to declare your own slot type. + """ + + default_value: Any = None + + @abstractmethod + async def extract_value(self, ctx: Context) -> Union[Any, SlotNotExtracted]: + """ + Return value extracted from context. + + Return :py:exc:`~.SlotNotExtracted` to mark extraction as unsuccessful. + + Raising exceptions is also allowed and will result in an unsuccessful extraction as well. + """ + raise NotImplementedError + + async def get_value(self, ctx: Context) -> ExtractedValueSlot: + """Wrapper for :py:meth:`~.ValueSlot.extract_value` to handle exceptions.""" + extracted_value = SlotNotExtracted("Caught an exit exception.") + is_slot_extracted = False + + try: + extracted_value = await wrap_sync_function_in_async(self.extract_value, ctx) + is_slot_extracted = not isinstance(extracted_value, SlotNotExtracted) + except Exception as error: + logger.exception(f"Exception occurred during {self.__class__.__name__!r} extraction.", exc_info=error) + extracted_value = error + finally: + if not is_slot_extracted: + logger.debug(f"Slot {self.__class__.__name__!r} was not extracted: {extracted_value}") + return ExtractedValueSlot.model_construct( + is_slot_extracted=is_slot_extracted, + extracted_value=extracted_value, + default_value=self.default_value, + ) + + def init_value(self) -> ExtractedValueSlot: + return ExtractedValueSlot.model_construct( + is_slot_extracted=False, + extracted_value=SlotNotExtracted("Initial slot extraction."), + default_value=self.default_value, + ) + + +class GroupSlot(BaseSlot, extra="allow", frozen=True): """ Base class for :py:class:`~.RootSlot` and :py:class:`~.GroupSlot`. """ - string_format: Optional[str] = None __pydantic_extra__: Dict[str, Annotated[Union["GroupSlot", "ValueSlot"], Field(union_mode="left_to_right")]] + string_format: Optional[str] = None allow_partial_extraction: bool = False """If True, extraction returns only successfully extracted child slots.""" @@ -228,50 +275,3 @@ def init_value(self) -> ExtractedGroupSlot: return ExtractedGroupSlot( **{child_name: child.init_value() for child_name, child in self.__pydantic_extra__.items()} ) - - -class ValueSlot(BaseSlot, frozen=True): - """ - Value slot is a base class for all slots that are designed to extract concrete values. - Subclass it, if you want to declare your own slot type. - """ - - default_value: Any = None - - @abstractmethod - async def extract_value(self, ctx: Context) -> Union[Any, SlotNotExtracted]: - """ - Return value extracted from context. - - Return :py:exc:`~.SlotNotExtracted` to mark extraction as unsuccessful. - - Raising exceptions is also allowed and will result in an unsuccessful extraction as well. - """ - raise NotImplementedError - - async def get_value(self, ctx: Context) -> ExtractedValueSlot: - """Wrapper for :py:meth:`~.ValueSlot.extract_value` to handle exceptions.""" - extracted_value = SlotNotExtracted("Caught an exit exception.") - is_slot_extracted = False - - try: - extracted_value = await wrap_sync_function_in_async(self.extract_value, ctx) - is_slot_extracted = not isinstance(extracted_value, SlotNotExtracted) - except Exception as error: - logger.exception(f"Exception occurred during {self.__class__.__name__!r} extraction.", exc_info=error) - extracted_value = error - finally: - if not is_slot_extracted: - logger.debug(f"Slot {self.__class__.__name__!r} was not extracted: {extracted_value}") - return ExtractedValueSlot.model_construct( - is_slot_extracted=is_slot_extracted, - extracted_value=extracted_value, - default_value=self.default_value, - ) - - def init_value(self) -> ExtractedValueSlot: - return ExtractedValueSlot.model_construct( - is_slot_extracted=False, - extracted_value=SlotNotExtracted("Initial slot extraction."), - default_value=self.default_value, - ) From ce8f606a01057a478255b01d4716dc6afddc7360 Mon Sep 17 00:00:00 2001 From: ZergLev Date: Wed, 4 Dec 2024 19:45:17 +0300 Subject: [PATCH 18/38] changed RegexGroupSlot base class to BaseSlot --- chatsky/slots/standard_slots.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/chatsky/slots/standard_slots.py b/chatsky/slots/standard_slots.py index 8f43dd15c..b16f45833 100644 --- a/chatsky/slots/standard_slots.py +++ b/chatsky/slots/standard_slots.py @@ -12,11 +12,11 @@ from chatsky.utils.devel.async_helpers import wrap_sync_function_in_async from chatsky.slots.base_slots import ( - ValueSlot, SlotNotExtracted, - GroupSlot, - ExtractedGroupSlot, ExtractedValueSlot, + ExtractedGroupSlot, + ValueSlot, + BaseSlot, ) if TYPE_CHECKING: @@ -48,7 +48,7 @@ async def extract_value(self, ctx: Context) -> Union[str, SlotNotExtracted]: ) -class RegexpGroupSlot(GroupSlot, frozen=True): +class RegexpGroupSlot(BaseSlot, frozen=True): """ A slot type that applies a regex pattern once to extract values for multiple child slots. Accepts a `regexp` pattern and a `groups` dictionary From d9aa6b25b0e3788a6339af279629389de730976e Mon Sep 17 00:00:00 2001 From: ZergLev Date: Wed, 4 Dec 2024 19:58:20 +0300 Subject: [PATCH 19/38] fixed tests, they pass now --- chatsky/slots/slot_manager.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/chatsky/slots/slot_manager.py b/chatsky/slots/slot_manager.py index 3dc023468..a2b000f51 100644 --- a/chatsky/slots/slot_manager.py +++ b/chatsky/slots/slot_manager.py @@ -70,12 +70,17 @@ def two_arg_getattr(__o, name): def recursive_setattr(obj, slot_name: SlotName, value): - parent_slot, _, slot = slot_name.rpartition(".") + parent_slot, sep, slot = slot_name.rpartition(".") - if parent_slot: - setattr(recursive_getattr(obj, parent_slot), slot, value) + if sep == ".": + parent_obj = recursive_getattr(obj, parent_slot) else: - setattr(obj, slot, value) + parent_obj = obj + + if isinstance(value, ExtractedGroupSlot): + getattr(parent_obj, slot).update(value) + else: + setattr(parent_obj, slot, value) class SlotManager(BaseModel): From b361ea006be83069887c62b4817835132baef1b4 Mon Sep 17 00:00:00 2001 From: ZergLev Date: Wed, 4 Dec 2024 21:51:38 +0300 Subject: [PATCH 20/38] added default_values to RegexpGroupSlot, fixed an issue with get_value, finished the RegexpGroupSlot test, fixed a weird Pydantic inheritance issue --- chatsky/slots/base_slots.py | 11 ++--- chatsky/slots/standard_slots.py | 31 +++++++++++--- tests/slots/test_slot_types.py | 71 ++++++++++++++++++++++++++++++++- 3 files changed, 101 insertions(+), 12 deletions(-) diff --git a/chatsky/slots/base_slots.py b/chatsky/slots/base_slots.py index 326ce66fe..b9955fc67 100644 --- a/chatsky/slots/base_slots.py +++ b/chatsky/slots/base_slots.py @@ -245,11 +245,12 @@ def __check_extra_field_names__(self): """ Extra field names cannot be dunder names or contain dots. """ - for field in self.__pydantic_extra__.keys(): - if "." in field: - raise ValueError(f"Extra field name cannot contain dots: {field!r}") - if field.startswith("__") and field.endswith("__"): - raise ValueError(f"Extra field names cannot be dunder: {field!r}") + if self.__pydantic_extra__ is not None: + for field in self.__pydantic_extra__.keys(): + if "." in field: + raise ValueError(f"Extra field name cannot contain dots: {field!r}") + if field.startswith("__") and field.endswith("__"): + raise ValueError(f"Extra field names cannot be dunder: {field!r}") return self def _flatten_group_slot(self, slot, parent_key=""): diff --git a/chatsky/slots/standard_slots.py b/chatsky/slots/standard_slots.py index b16f45833..d2d7d9a5b 100644 --- a/chatsky/slots/standard_slots.py +++ b/chatsky/slots/standard_slots.py @@ -7,9 +7,12 @@ from __future__ import annotations import re -from typing import Callable, Any, Awaitable, TYPE_CHECKING, Union +from typing import Callable, Any, Awaitable, TYPE_CHECKING, Union, Dict, Optional import logging +from pydantic import Field +from typing_extensions import Annotated + from chatsky.utils.devel.async_helpers import wrap_sync_function_in_async from chatsky.slots.base_slots import ( SlotNotExtracted, @@ -17,6 +20,7 @@ ExtractedGroupSlot, ValueSlot, BaseSlot, + GroupSlot, ) if TYPE_CHECKING: @@ -48,7 +52,7 @@ async def extract_value(self, ctx: Context) -> Union[str, SlotNotExtracted]: ) -class RegexpGroupSlot(BaseSlot, frozen=True): +class RegexpGroupSlot(GroupSlot, extra="forbid", frozen=True): """ A slot type that applies a regex pattern once to extract values for multiple child slots. Accepts a `regexp` pattern and a `groups` dictionary @@ -57,9 +61,15 @@ class RegexpGroupSlot(BaseSlot, frozen=True): of calls to your model. """ + # Parent fields repeated for Pydantic issues + __pydantic_extra__: Dict[str, Annotated[Union["GroupSlot", "ValueSlot"], Field(union_mode="left_to_right")]] + string_format: Optional[str] = None + allow_partial_extraction: bool = False + regexp: str groups: dict[str, int] "A dictionary mapping slot names to match_group indexes." + default_values: dict[str, Any] = Field(default_factory=dict) def __init__(self, **kwargs): # supress unexpected argument warnings super().__init__(**kwargs) @@ -69,28 +79,37 @@ async def get_value(self, ctx: Context) -> ExtractedGroupSlot: search = re.search(self.regexp, request_text) if search: return ExtractedGroupSlot( + string_format=self.string_format, **{ child_name: ExtractedValueSlot.model_construct( is_slot_extracted=True, extracted_value=search.group(match_group), ) for child_name, match_group in zip(self.groups.keys(), self.groups.values()) - } + }, ) else: return ExtractedGroupSlot( + string_format=self.string_format, **{ - child_name: SlotNotExtracted(f"Failed to match pattern {self.regexp!r} in {request_text!r}.") + child_name: ExtractedValueSlot.model_construct( + is_slot_extracted=False, + extracted_value=SlotNotExtracted( + f"Failed to match pattern {self.regexp!r} in {request_text!r}." + ), + default_value=self.default_values.get(child_name, None), + ) for child_name in self.groups.keys() - } + }, ) def init_value(self) -> ExtractedGroupSlot: return ExtractedGroupSlot( + string_format=self.string_format, **{ child_name: RegexpSlot(regexp=self.regexp, match_group_id=match_group).init_value() for child_name, match_group in self.groups.items() - } + }, ) diff --git a/tests/slots/test_slot_types.py b/tests/slots/test_slot_types.py index 626840d53..85d1285ee 100644 --- a/tests/slots/test_slot_types.py +++ b/tests/slots/test_slot_types.py @@ -3,7 +3,7 @@ from chatsky.core import Message from chatsky.slots.base_slots import SlotNotExtracted, ExtractedValueSlot, GroupSlot, ExtractedGroupSlot -from chatsky.slots.standard_slots import RegexpSlot, FunctionSlot +from chatsky.slots.standard_slots import RegexpSlot, FunctionSlot, RegexpGroupSlot @pytest.mark.parametrize( @@ -125,6 +125,75 @@ async def test_group_slot_extraction(user_request, slot, expected, is_extracted, assert result.__slot_extracted__ == is_extracted +@pytest.mark.parametrize( + ("user_request", "slot", "expected", "is_extracted"), + [ + ( + Message(text="I am Bot. I have a colleague, his name is Carl."), + # Message(text="I am Bot. My email is bot@bot"), + RegexpGroupSlot( + regexp=r"am (.+?)\..*name is (.+?)\.", + groups={"name_1": 1, "name_2": 2}, + ), + ExtractedGroupSlot( + name_1=ExtractedValueSlot.model_construct( + is_slot_extracted=True, extracted_value="Bot", default_value=None + ), + name_2=ExtractedValueSlot.model_construct( + is_slot_extracted=True, extracted_value="Carl", default_value=None + ), + ), + True, + ), + ( + Message(text="I am Bot. I won't tell you my email"), + RegexpGroupSlot( + regexp=r"(?<=am ).+?(?=\.).*[a-zA-Z\.]+@[a-zA-Z\.]+", + groups={"name": 1, "email": 2}, + ), + ExtractedGroupSlot( + name=ExtractedValueSlot.model_construct( + is_slot_extracted=False, + extracted_value=SlotNotExtracted( + "Failed to match pattern {regexp!r} in {request_text!r}.".format( + regexp=r"(?<=am ).+?(?=\.).*[a-zA-Z\.]+@[a-zA-Z\.]+", + request_text="I am Bot. I won't tell you my email", + ) + ), + default_value=None, + ), + email=ExtractedValueSlot.model_construct( + is_slot_extracted=False, + extracted_value=SlotNotExtracted( + "Failed to match pattern {regexp!r} in {request_text!r}.".format( + regexp=r"(?<=am ).+?(?=\.).*[a-zA-Z\.]+@[a-zA-Z\.]+", + request_text="I am Bot. I won't tell you my email", + ) + ), + default_value=None, + ), + ), + False, + ), + ], +) +async def test_regex_group_slot_extraction(user_request, slot, expected, is_extracted, context): + context.add_request(user_request) + result = await slot.get_value(context) + assert result == expected + assert result.__slot_extracted__ == is_extracted + + +""" +# TODO: this test +async def test_group_slot_string_format(user_request, regexp, expected, context): + context.add_request(user_request) + slot = RegexpSlot(regexp=regexp) + result = await slot.get_value(context) + assert result == expected +""" + + @pytest.mark.parametrize("forbidden_name", ["__dunder__", "contains.dot"]) def test_group_subslot_name_validation(forbidden_name): with pytest.raises(ValidationError): From 7e4580ae67e3a672f8c662e84e09850971225ce2 Mon Sep 17 00:00:00 2001 From: ZergLev Date: Wed, 11 Dec 2024 07:46:40 +0300 Subject: [PATCH 21/38] minor changes, lint, porting to laptop --- chatsky/slots/base_slots.py | 2 +- tests/slots/test_slot_types.py | 49 +++++++++++++++++++++++++++++++--- 2 files changed, 46 insertions(+), 5 deletions(-) diff --git a/chatsky/slots/base_slots.py b/chatsky/slots/base_slots.py index b9955fc67..9c2558d86 100644 --- a/chatsky/slots/base_slots.py +++ b/chatsky/slots/base_slots.py @@ -258,7 +258,7 @@ def _flatten_group_slot(self, slot, parent_key=""): for key, value in slot.__pydantic_extra__.items(): new_key = f"{parent_key}.{key}" if parent_key else key if isinstance(value, GroupSlot): - items.update(self.__flatten_llm_group_slot(value, new_key)) + items.update(self._flatten_group_slot(value, new_key)) else: items[new_key] = value return items diff --git a/tests/slots/test_slot_types.py b/tests/slots/test_slot_types.py index 85d1285ee..dff3d0724 100644 --- a/tests/slots/test_slot_types.py +++ b/tests/slots/test_slot_types.py @@ -184,14 +184,55 @@ async def test_regex_group_slot_extraction(user_request, slot, expected, is_extr assert result.__slot_extracted__ == is_extracted -""" # TODO: this test -async def test_group_slot_string_format(user_request, regexp, expected, context): +@pytest.mark.parametrize( + ("user_request", "slot", "expected", "is_extracted"), + [ + ( + Message(text="I am Bot. My email is bot@bot"), + GroupSlot( + name=RegexpSlot(regexp=r"(?<=am ).+?(?=\.)"), + email=RegexpSlot(regexp=r"[a-zA-Z\.]+@[a-zA-Z\.]+"), + ), + ExtractedGroupSlot( + name=ExtractedValueSlot.model_construct( + is_slot_extracted=True, extracted_value="Bot", default_value=None + ), + email=ExtractedValueSlot.model_construct( + is_slot_extracted=True, extracted_value="bot@bot", default_value=None + ), + ), + True, + ), + ( + Message(text="I am Bot. I won't tell you my email"), + GroupSlot( + name=RegexpSlot(regexp=r"(?<=am ).+?(?=\.)"), + email=RegexpSlot(regexp=r"[a-zA-Z\.]+@[a-zA-Z\.]+"), + ), + ExtractedGroupSlot( + name=ExtractedValueSlot.model_construct( + is_slot_extracted=True, extracted_value="Bot", default_value=None + ), + email=ExtractedValueSlot.model_construct( + is_slot_extracted=False, + extracted_value=SlotNotExtracted( + "Failed to match pattern {regexp!r} in {request_text!r}.".format( + regexp=r"[a-zA-Z\.]+@[a-zA-Z\.]+", request_text="I am Bot. I won't tell you my email" + ) + ), + default_value=None, + ), + ), + False, + ), + ], +) +async def test_group_slot_string_format(user_request, slot, expected, is_extracted, context): context.add_request(user_request) - slot = RegexpSlot(regexp=regexp) result = await slot.get_value(context) assert result == expected -""" + assert result.__slot_extracted__ == is_extracted @pytest.mark.parametrize("forbidden_name", ["__dunder__", "contains.dot"]) From c53ffd16f38b983f3a8a674ab80caa2bd5abde92 Mon Sep 17 00:00:00 2001 From: ZergLev Date: Wed, 11 Dec 2024 13:22:32 +0300 Subject: [PATCH 22/38] added the string_format test, added exception handling for string representation in ExtractedGroupSlot (SlotManager already does that, so I replicated that for coherence) --- chatsky/slots/base_slots.py | 6 +++++- tests/slots/test_slot_types.py | 34 ++++++++-------------------------- 2 files changed, 13 insertions(+), 27 deletions(-) diff --git a/chatsky/slots/base_slots.py b/chatsky/slots/base_slots.py index 9c2558d86..3763ec791 100644 --- a/chatsky/slots/base_slots.py +++ b/chatsky/slots/base_slots.py @@ -143,7 +143,11 @@ def __unset__(self): # fill template here def __str__(self): if self.string_format is not None: - return KwargOnlyFormatter().format(self.string_format, **self.__pydantic_extra__) + try: + return KwargOnlyFormatter().format(self.string_format, **self.__pydantic_extra__) + except Exception as exc: + logger.exception("An exception occurred during template filling.", exc_info=exc) + return None else: return str({key: str(value) for key, value in self.__pydantic_extra__.items()}) diff --git a/tests/slots/test_slot_types.py b/tests/slots/test_slot_types.py index dff3d0724..6089ad8da 100644 --- a/tests/slots/test_slot_types.py +++ b/tests/slots/test_slot_types.py @@ -130,7 +130,6 @@ async def test_group_slot_extraction(user_request, slot, expected, is_extracted, [ ( Message(text="I am Bot. I have a colleague, his name is Carl."), - # Message(text="I am Bot. My email is bot@bot"), RegexpGroupSlot( regexp=r"am (.+?)\..*name is (.+?)\.", groups={"name_1": 1, "name_2": 2}, @@ -184,54 +183,37 @@ async def test_regex_group_slot_extraction(user_request, slot, expected, is_extr assert result.__slot_extracted__ == is_extracted +# string_format="Your name is {name}. Your email is {email}.", # TODO: this test @pytest.mark.parametrize( - ("user_request", "slot", "expected", "is_extracted"), + ("user_request", "slot", "expected_str_format", "is_extracted"), [ ( Message(text="I am Bot. My email is bot@bot"), GroupSlot( + string_format="Your name is {name}. Your email is {email}.", name=RegexpSlot(regexp=r"(?<=am ).+?(?=\.)"), email=RegexpSlot(regexp=r"[a-zA-Z\.]+@[a-zA-Z\.]+"), ), - ExtractedGroupSlot( - name=ExtractedValueSlot.model_construct( - is_slot_extracted=True, extracted_value="Bot", default_value=None - ), - email=ExtractedValueSlot.model_construct( - is_slot_extracted=True, extracted_value="bot@bot", default_value=None - ), - ), + "Your name is Bot. Your email is bot@bot.", True, ), ( Message(text="I am Bot. I won't tell you my email"), GroupSlot( + string_format="Your name is {name}. Your email is {email}.", name=RegexpSlot(regexp=r"(?<=am ).+?(?=\.)"), email=RegexpSlot(regexp=r"[a-zA-Z\.]+@[a-zA-Z\.]+"), ), - ExtractedGroupSlot( - name=ExtractedValueSlot.model_construct( - is_slot_extracted=True, extracted_value="Bot", default_value=None - ), - email=ExtractedValueSlot.model_construct( - is_slot_extracted=False, - extracted_value=SlotNotExtracted( - "Failed to match pattern {regexp!r} in {request_text!r}.".format( - regexp=r"[a-zA-Z\.]+@[a-zA-Z\.]+", request_text="I am Bot. I won't tell you my email" - ) - ), - default_value=None, - ), - ), + None, False, ), ], ) -async def test_group_slot_string_format(user_request, slot, expected, is_extracted, context): +async def test_group_slot_string_format(user_request, slot, expected_str_format, is_extracted, context): context.add_request(user_request) result = await slot.get_value(context) - assert result == expected + assert str(result) == expected_str_format assert result.__slot_extracted__ == is_extracted From b254bd83a83300d637beef06e1d9ecd54a336ad0 Mon Sep 17 00:00:00 2001 From: ZergLev Date: Wed, 11 Dec 2024 13:33:39 +0300 Subject: [PATCH 23/38] started tutorial for RegexpGroupSlot and string_format feature of GroupSlot --- .../2_regexgroupslot_and_string_format.py | 211 ++++++++++++++++++ 1 file changed, 211 insertions(+) create mode 100644 tutorials/slots/2_regexgroupslot_and_string_format.py diff --git a/tutorials/slots/2_regexgroupslot_and_string_format.py b/tutorials/slots/2_regexgroupslot_and_string_format.py new file mode 100644 index 000000000..2c8305b34 --- /dev/null +++ b/tutorials/slots/2_regexgroupslot_and_string_format.py @@ -0,0 +1,211 @@ +# %% [markdown] +""" +# 2. `RegexpGroupSlot` and `string_format` + +The following tutorial shows basic usage of `RegexpGroupSlot` +and `string format` feature of GroupSlot. +""" + +# %pip install chatsky + +# %% +from chatsky import ( + RESPONSE, + TRANSITIONS, + PRE_TRANSITION, + PRE_RESPONSE, + GLOBAL, + LOCAL, + Pipeline, + Transition as Tr, + conditions as cnd, + processing as proc, + responses as rsp, +) + +from chatsky.slots import RegexpSlot + +from chatsky.utils.testing import ( + check_happy_path, + is_interactive_mode, +) + +# %% [markdown] +""" +The slots fall into the following category groups: + +- Value slots can be used to extract slot values from user utterances. +- Group slots can be used to split value slots into groups + with an arbitrary level of nesting. + +You can build the slot tree by passing the child slot instances as extra fields +of the parent slot. In the following cell, we define two slot groups: + + Group 1: person.username, person.email + Group 2: friend.first_name, friend.last_name + +Currently there are two types of value slots: + +- %mddoclink(api,slots.slots,RegexpSlot): + Extracts slot values via regexp. +- %mddoclink(api,slots.slots,FunctionSlot): + Extracts slot values with the help of a user-defined function. +""" + +# %% +SLOTS = { + "person": { + "username": RegexpSlot( + regexp=r"username is ([a-zA-Z]+)", + match_group_idx=1, + ), + "email": RegexpSlot( + regexp=r"email is ([a-z@\.A-Z]+)", + match_group_idx=1, + ), + }, + "friend": { + "first_name": RegexpSlot(regexp=r"^[A-Z][a-z]+?(?= )"), + "last_name": RegexpSlot(regexp=r"(?<= )[A-Z][a-z]+"), + }, +} + +# %% [markdown] +""" +The slots module provides several functions for managing slots in-script: + +- %mddoclink(api,conditions.slots,SlotsExtracted): + Condition for checking if specified slots are extracted. +- %mddoclink(api,processing.slots,Extract): + A processing function that extracts specified slots. +- %mddoclink(api,processing.slots,Unset): + A processing function that marks specified slots as not extracted, + effectively resetting their state. +- %mddoclink(api,processing.slots,UnsetAll): + A processing function that marks all slots as not extracted. +- %mddoclink(api,processing.slots,FillTemplate): + A processing function that fills the `response` + Message text with extracted slot values. +- %mddoclink(api,responses.slots,FilledTemplate): + A response function that takes a Message with a + format-string text and returns Message + with its text string filled with extracted slot values. + +The usage of all the above functions is shown in the following script: +""" + +# %% +script = { + GLOBAL: { + TRANSITIONS: [ + Tr(dst=("username_flow", "ask"), cnd=cnd.Regexp(r"^[sS]tart")) + ] + }, + "username_flow": { + LOCAL: { + PRE_TRANSITION: {"get_slot": proc.Extract("person.username")}, + TRANSITIONS: [ + Tr( + dst=("email_flow", "ask"), + cnd=cnd.SlotsExtracted("person.username"), + priority=1.2, + ), + Tr(dst=("username_flow", "repeat_question"), priority=0.8), + ], + }, + "ask": { + RESPONSE: "Write your username (my username is ...):", + }, + "repeat_question": { + RESPONSE: "Please, type your username again (my username is ...):", + }, + }, + "email_flow": { + LOCAL: { + PRE_TRANSITION: {"get_slot": proc.Extract("person.email")}, + TRANSITIONS: [ + Tr( + dst=("friend_flow", "ask"), + cnd=cnd.SlotsExtracted("person.username", "person.email"), + priority=1.2, + ), + Tr(dst=("email_flow", "repeat_question"), priority=0.8), + ], + }, + "ask": { + RESPONSE: "Write your email (my email is ...):", + }, + "repeat_question": { + RESPONSE: "Please, write your email again (my email is ...):", + }, + }, + "friend_flow": { + LOCAL: { + PRE_TRANSITION: {"get_slots": proc.Extract("friend")}, + TRANSITIONS: [ + Tr( + dst=("root", "utter"), + cnd=cnd.SlotsExtracted( + "friend.first_name", "friend.last_name", mode="any" + ), + priority=1.2, + ), + Tr(dst=("friend_flow", "repeat_question"), priority=0.8), + ], + }, + "ask": {RESPONSE: "Please, name me one of your friends: (John Doe)"}, + "repeat_question": { + RESPONSE: "Please, name me one of your friends again: (John Doe)" + }, + }, + "root": { + "start": { + TRANSITIONS: [Tr(dst=("username_flow", "ask"))], + }, + "fallback": { + RESPONSE: "Finishing query", + TRANSITIONS: [Tr(dst=("username_flow", "ask"))], + }, + "utter": { + RESPONSE: rsp.FilledTemplate( + "Your friend is {friend.first_name} {friend.last_name}" + ), + TRANSITIONS: [Tr(dst=("root", "utter_alternative"))], + }, + "utter_alternative": { + RESPONSE: "Your username is {person.username}. " + "Your email is {person.email}.", + PRE_RESPONSE: {"fill": proc.FillTemplate()}, + }, + }, +} + +# %% +HAPPY_PATH = [ + ("hi", "Write your username (my username is ...):"), + ("my username is groot", "Write your email (my email is ...):"), + ( + "my email is groot@gmail.com", + "Please, name me one of your friends: (John Doe)", + ), + ("Bob Page", "Your friend is Bob Page"), + ("ok", "Your username is groot. Your email is groot@gmail.com."), + ("ok", "Finishing query"), +] + +# %% +pipeline = Pipeline( + script=script, + start_label=("root", "start"), + fallback_label=("root", "fallback"), + slots=SLOTS, +) + +if __name__ == "__main__": + check_happy_path( + pipeline, HAPPY_PATH, printout=True + ) # This is a function for automatic tutorial running + # (testing) with HAPPY_PATH + + if is_interactive_mode(): + pipeline.run() From 89eee15fb737cf74f57e0b86cb955246a69b4b03 Mon Sep 17 00:00:00 2001 From: ZergLev Date: Wed, 11 Dec 2024 16:44:08 +0300 Subject: [PATCH 24/38] fixed test for string_format, works now --- tests/slots/test_slot_types.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/slots/test_slot_types.py b/tests/slots/test_slot_types.py index 6089ad8da..2f8ffde9c 100644 --- a/tests/slots/test_slot_types.py +++ b/tests/slots/test_slot_types.py @@ -183,8 +183,6 @@ async def test_regex_group_slot_extraction(user_request, slot, expected, is_extr assert result.__slot_extracted__ == is_extracted -# string_format="Your name is {name}. Your email is {email}.", -# TODO: this test @pytest.mark.parametrize( ("user_request", "slot", "expected_str_format", "is_extracted"), [ @@ -205,7 +203,7 @@ async def test_regex_group_slot_extraction(user_request, slot, expected, is_extr name=RegexpSlot(regexp=r"(?<=am ).+?(?=\.)"), email=RegexpSlot(regexp=r"[a-zA-Z\.]+@[a-zA-Z\.]+"), ), - None, + "Your name is Bot. Your email is None.", False, ), ], From 453446619c1807580d59918efbb34ca1a3fb9d82 Mon Sep 17 00:00:00 2001 From: ZergLev Date: Wed, 11 Dec 2024 16:59:11 +0300 Subject: [PATCH 25/38] added SlotValueEquals in test_slot_functions.py, this particular change needs to be reviewed --- chatsky/conditions/__init__.py | 2 +- tests/slots/test_slot_functions.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/chatsky/conditions/__init__.py b/chatsky/conditions/__init__.py index 0d02477dd..bffcecdb5 100644 --- a/chatsky/conditions/__init__.py +++ b/chatsky/conditions/__init__.py @@ -9,5 +9,5 @@ Not, HasCallbackQuery, ) -from chatsky.conditions.slots import SlotsExtracted +from chatsky.conditions.slots import SlotsExtracted, SlotValueEquals from chatsky.conditions.service import ServiceFinished diff --git a/tests/slots/test_slot_functions.py b/tests/slots/test_slot_functions.py index 55e84f82a..6127475d2 100644 --- a/tests/slots/test_slot_functions.py +++ b/tests/slots/test_slot_functions.py @@ -69,6 +69,8 @@ async def test_basic_functions(context, manager, log_event_catcher): await proc.Extract("0", "2", "err", save_on_failure=False).wrapped_call(context) assert manager.get_extracted_slot("0").value == 4 + assert await cnd.SlotValueEquals("0", "5") is False + assert await cnd.SlotValueEquals("0", "4") is True assert manager.is_slot_extracted("1") is False assert isinstance(manager.get_extracted_slot("err").extracted_value, RuntimeError) From 729d4493e75521fb21da96c3f2d6720349fbb7c0 Mon Sep 17 00:00:00 2001 From: ZergLev Date: Thu, 12 Dec 2024 15:54:58 +0300 Subject: [PATCH 26/38] small review done in meeting (reverted exception handling in GroupSlot, test#3 tweaked(SlotEqualsValue) ), added SlotEqualsValue condition to existing tutorial --- chatsky/slots/base_slots.py | 12 +++++++----- tests/slots/test_slot_functions.py | 6 ++++-- tutorials/slots/1_basic_example.py | 2 ++ 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/chatsky/slots/base_slots.py b/chatsky/slots/base_slots.py index 3763ec791..5f1d05438 100644 --- a/chatsky/slots/base_slots.py +++ b/chatsky/slots/base_slots.py @@ -143,11 +143,7 @@ def __unset__(self): # fill template here def __str__(self): if self.string_format is not None: - try: - return KwargOnlyFormatter().format(self.string_format, **self.__pydantic_extra__) - except Exception as exc: - logger.exception("An exception occurred during template filling.", exc_info=exc) - return None + return KwargOnlyFormatter().format(self.string_format, **self.__pydantic_extra__) else: return str({key: str(value) for key, value in self.__pydantic_extra__.items()}) @@ -257,7 +253,13 @@ def __check_extra_field_names__(self): raise ValueError(f"Extra field names cannot be dunder: {field!r}") return self +# TODO: rewrite docs def _flatten_group_slot(self, slot, parent_key=""): + """ + Flattens GroupSlot from nested into a single dictionary. + + Helper method for reimplementing GroupSlots. + """ items = {} for key, value in slot.__pydantic_extra__.items(): new_key = f"{parent_key}.{key}" if parent_key else key diff --git a/tests/slots/test_slot_functions.py b/tests/slots/test_slot_functions.py index 6127475d2..df4d99ad1 100644 --- a/tests/slots/test_slot_functions.py +++ b/tests/slots/test_slot_functions.py @@ -60,6 +60,8 @@ async def test_basic_functions(context, manager, log_event_catcher): await proc.Extract("0", "2", "err").wrapped_call(context) assert manager.get_extracted_slot("0").value == 4 + assert await cnd.SlotValueEquals("0", 5) is False + assert await cnd.SlotValueEquals("0", 4) is True assert manager.is_slot_extracted("1") is False assert isinstance(manager.get_extracted_slot("err").extracted_value, SlotNotExtracted) @@ -69,8 +71,8 @@ async def test_basic_functions(context, manager, log_event_catcher): await proc.Extract("0", "2", "err", save_on_failure=False).wrapped_call(context) assert manager.get_extracted_slot("0").value == 4 - assert await cnd.SlotValueEquals("0", "5") is False - assert await cnd.SlotValueEquals("0", "4") is True + assert await cnd.SlotValueEquals("0", 5) is False + assert await cnd.SlotValueEquals("0", 4) is True assert manager.is_slot_extracted("1") is False assert isinstance(manager.get_extracted_slot("err").extracted_value, RuntimeError) diff --git a/tutorials/slots/1_basic_example.py b/tutorials/slots/1_basic_example.py index dcddfbade..56a1c19da 100644 --- a/tutorials/slots/1_basic_example.py +++ b/tutorials/slots/1_basic_example.py @@ -76,6 +76,8 @@ - %mddoclink(api,conditions.slots,SlotsExtracted): Condition for checking if specified slots are extracted. +- %mddoclink(api,conditions.slots,SlotValueEquals): + Condition for checking if the specified slots' value is equal to a given value. - %mddoclink(api,processing.slots,Extract): A processing function that extracts specified slots. - %mddoclink(api,processing.slots,Unset): From 01731df22e3f91f1b78193aa8a062eca7df89e60 Mon Sep 17 00:00:00 2001 From: ZergLev Date: Fri, 13 Dec 2024 19:34:25 +0300 Subject: [PATCH 27/38] added docs for _flatten_group_slot --- chatsky/slots/base_slots.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/chatsky/slots/base_slots.py b/chatsky/slots/base_slots.py index 5f1d05438..776df64b0 100644 --- a/chatsky/slots/base_slots.py +++ b/chatsky/slots/base_slots.py @@ -253,12 +253,11 @@ def __check_extra_field_names__(self): raise ValueError(f"Extra field names cannot be dunder: {field!r}") return self -# TODO: rewrite docs - def _flatten_group_slot(self, slot, parent_key=""): + def _flatten_group_slot(self, slot, parent_key="") -> dict: """ - Flattens GroupSlot from nested into a single dictionary. + Unpacks a GroupSlot's nested `Slots` into a single dictionary. - Helper method for reimplementing GroupSlots. + Intended as a helper method for making GroupSlot child classes. """ items = {} for key, value in slot.__pydantic_extra__.items(): From 14cc4fd769774109ebf14d212391646da5a0119c Mon Sep 17 00:00:00 2001 From: ZergLev Date: Fri, 13 Dec 2024 20:01:20 +0300 Subject: [PATCH 28/38] lint + SlotEqualsValue condition fixed (tests pass now) --- chatsky/conditions/slots.py | 2 +- chatsky/slots/standard_slots.py | 1 - tests/slots/test_slot_functions.py | 8 ++++---- tutorials/slots/1_basic_example.py | 3 ++- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/chatsky/conditions/slots.py b/chatsky/conditions/slots.py index c83a21f60..98f881d42 100644 --- a/chatsky/conditions/slots.py +++ b/chatsky/conditions/slots.py @@ -59,4 +59,4 @@ def __init__(self, slot_name: SlotName, value: Any): async def call(self, ctx: Context) -> bool: manager = ctx.framework_data.slot_manager - return manager.get_extracted_slot(self.slot_name).self.value == self.value + return manager.get_extracted_slot(self.slot_name).value == self.value diff --git a/chatsky/slots/standard_slots.py b/chatsky/slots/standard_slots.py index d2d7d9a5b..fc1c921b8 100644 --- a/chatsky/slots/standard_slots.py +++ b/chatsky/slots/standard_slots.py @@ -19,7 +19,6 @@ ExtractedValueSlot, ExtractedGroupSlot, ValueSlot, - BaseSlot, GroupSlot, ) diff --git a/tests/slots/test_slot_functions.py b/tests/slots/test_slot_functions.py index df4d99ad1..5961869bf 100644 --- a/tests/slots/test_slot_functions.py +++ b/tests/slots/test_slot_functions.py @@ -60,8 +60,8 @@ async def test_basic_functions(context, manager, log_event_catcher): await proc.Extract("0", "2", "err").wrapped_call(context) assert manager.get_extracted_slot("0").value == 4 - assert await cnd.SlotValueEquals("0", 5) is False - assert await cnd.SlotValueEquals("0", 4) is True + assert await cnd.SlotValueEquals("0", 5).wrapped_call(context) is False + assert await cnd.SlotValueEquals("0", 4).wrapped_call(context) is True assert manager.is_slot_extracted("1") is False assert isinstance(manager.get_extracted_slot("err").extracted_value, SlotNotExtracted) @@ -71,8 +71,8 @@ async def test_basic_functions(context, manager, log_event_catcher): await proc.Extract("0", "2", "err", save_on_failure=False).wrapped_call(context) assert manager.get_extracted_slot("0").value == 4 - assert await cnd.SlotValueEquals("0", 5) is False - assert await cnd.SlotValueEquals("0", 4) is True + assert await cnd.SlotValueEquals("0", 5).wrapped_call(context) is False + assert await cnd.SlotValueEquals("0", 4).wrapped_call(context) is True assert manager.is_slot_extracted("1") is False assert isinstance(manager.get_extracted_slot("err").extracted_value, RuntimeError) diff --git a/tutorials/slots/1_basic_example.py b/tutorials/slots/1_basic_example.py index 56a1c19da..c5973904c 100644 --- a/tutorials/slots/1_basic_example.py +++ b/tutorials/slots/1_basic_example.py @@ -77,7 +77,8 @@ - %mddoclink(api,conditions.slots,SlotsExtracted): Condition for checking if specified slots are extracted. - %mddoclink(api,conditions.slots,SlotValueEquals): - Condition for checking if the specified slots' value is equal to a given value. + Condition for checking if the specified slots' value + is equal to a given value. - %mddoclink(api,processing.slots,Extract): A processing function that extracts specified slots. - %mddoclink(api,processing.slots,Unset): From 9890179612187f49ecc3538c49edf29fde4f594e Mon Sep 17 00:00:00 2001 From: ZergLev Date: Fri, 13 Dec 2024 20:12:14 +0300 Subject: [PATCH 29/38] minor fix for doc building (title underscores added) --- chatsky/slots/base_slots.py | 2 +- chatsky/slots/slot_manager.py | 2 +- chatsky/slots/standard_slots.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/chatsky/slots/base_slots.py b/chatsky/slots/base_slots.py index 776df64b0..448f9cf63 100644 --- a/chatsky/slots/base_slots.py +++ b/chatsky/slots/base_slots.py @@ -1,6 +1,6 @@ """ Base Slots ------ +---------- This module defines base classes for slots. """ diff --git a/chatsky/slots/slot_manager.py b/chatsky/slots/slot_manager.py index a2b000f51..633c9ed00 100644 --- a/chatsky/slots/slot_manager.py +++ b/chatsky/slots/slot_manager.py @@ -1,6 +1,6 @@ """ Slot Manager ------ +------------ This module defines the SlotManager class, facilitating slot management for Pipeline. """ diff --git a/chatsky/slots/standard_slots.py b/chatsky/slots/standard_slots.py index fc1c921b8..364940b1b 100644 --- a/chatsky/slots/standard_slots.py +++ b/chatsky/slots/standard_slots.py @@ -1,6 +1,6 @@ """ -Slots ------ +Standard Slots +-------------- This module defines some concrete implementations of slots. """ From fc5bc58d2b3285cda9ee37ff62f8068467ba4239 Mon Sep 17 00:00:00 2001 From: ZergLev Date: Fri, 13 Dec 2024 20:31:47 +0300 Subject: [PATCH 30/38] Updating API reference --- tutorials/slots/1_basic_example.py | 4 ++-- tutorials/slots/2_regexgroupslot_and_string_format.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tutorials/slots/1_basic_example.py b/tutorials/slots/1_basic_example.py index c5973904c..3573f6a5c 100644 --- a/tutorials/slots/1_basic_example.py +++ b/tutorials/slots/1_basic_example.py @@ -46,9 +46,9 @@ Currently there are two types of value slots: -- %mddoclink(api,slots.slots,RegexpSlot): +- %mddoclink(api,slots.standard_slots,RegexpSlot): Extracts slot values via regexp. -- %mddoclink(api,slots.slots,FunctionSlot): +- %mddoclink(api,slots.standard_slots,FunctionSlot): Extracts slot values with the help of a user-defined function. """ diff --git a/tutorials/slots/2_regexgroupslot_and_string_format.py b/tutorials/slots/2_regexgroupslot_and_string_format.py index 2c8305b34..e919fc106 100644 --- a/tutorials/slots/2_regexgroupslot_and_string_format.py +++ b/tutorials/slots/2_regexgroupslot_and_string_format.py @@ -46,9 +46,9 @@ Currently there are two types of value slots: -- %mddoclink(api,slots.slots,RegexpSlot): +- %mddoclink(api,slots.standard_slots,RegexpSlot): Extracts slot values via regexp. -- %mddoclink(api,slots.slots,FunctionSlot): +- %mddoclink(api,slots.standard_slots,FunctionSlot): Extracts slot values with the help of a user-defined function. """ From 4d9f790448f2a9da7ff6d8eed381e455d7caf693 Mon Sep 17 00:00:00 2001 From: ZergLev Date: Fri, 13 Dec 2024 21:12:30 +0300 Subject: [PATCH 31/38] started the RegexpGroupSlot and string_format tutorial --- chatsky/slots/standard_slots.py | 1 + .../2_regexgroupslot_and_string_format.py | 28 +++++++++++++++++++ ..._extraction.py => 3_partial_extraction.py} | 2 +- 3 files changed, 30 insertions(+), 1 deletion(-) rename tutorials/slots/{2_partial_extraction.py => 3_partial_extraction.py} (99%) diff --git a/chatsky/slots/standard_slots.py b/chatsky/slots/standard_slots.py index 364940b1b..f361c9817 100644 --- a/chatsky/slots/standard_slots.py +++ b/chatsky/slots/standard_slots.py @@ -69,6 +69,7 @@ class RegexpGroupSlot(GroupSlot, extra="forbid", frozen=True): groups: dict[str, int] "A dictionary mapping slot names to match_group indexes." default_values: dict[str, Any] = Field(default_factory=dict) + # TODO: write docstring, could copy from tutorial def __init__(self, **kwargs): # supress unexpected argument warnings super().__init__(**kwargs) diff --git a/tutorials/slots/2_regexgroupslot_and_string_format.py b/tutorials/slots/2_regexgroupslot_and_string_format.py index e919fc106..91826fd66 100644 --- a/tutorials/slots/2_regexgroupslot_and_string_format.py +++ b/tutorials/slots/2_regexgroupslot_and_string_format.py @@ -32,6 +32,34 @@ # %% [markdown] """ +## RegexpGroupSlot extraction + +The `RegexpGroupSlot` class reuses one regex.search() call to save on +execution time in specific cases like LLM, where the amount of get_value() +calls is important. + +## RegexpGroupSlot arguments + +* `regexp` - the regular expression to match with the `ctx.last_request.text`. +* `groups` - a dictionary mapping slot names to group indexes, where numbers + mean the index of the capture group that was found with `re.search()` +* `default_values` - a list of functions that run after the service. + You can read more about the handlers in this [tutorial] + +This means higher efficiency than a GroupSlot of RegexpSlots, +because only a single regex search is performed. It is saved and reused +for all specified groups, thus reducing the amount of calls to your model, +for example an LLM model. + +The `RegexpGroupSlot` class is derived from `GroupSlot` class, inheriting +its `string_format()` feature, which will be explained later in this tutorial. +Though, `partial extraction` is turned off for this class. + +This means that unsuccessfully trying to extract a slot will not overwrite +its previously extracted value. + +Note that `save_on_failure` is `True` by default. + The slots fall into the following category groups: - Value slots can be used to extract slot values from user utterances. diff --git a/tutorials/slots/2_partial_extraction.py b/tutorials/slots/3_partial_extraction.py similarity index 99% rename from tutorials/slots/2_partial_extraction.py rename to tutorials/slots/3_partial_extraction.py index aa8e7875e..7f4d7c4fa 100644 --- a/tutorials/slots/2_partial_extraction.py +++ b/tutorials/slots/3_partial_extraction.py @@ -1,6 +1,6 @@ # %% [markdown] """ -# 2. Partial slot extraction +# 3. Partial slot extraction This tutorial shows advanced options for slot extraction allowing to extract only some of the slots. From 41ef51799628d0f148db2daaea3d69ab3bf331ba Mon Sep 17 00:00:00 2001 From: ZergLev Date: Mon, 16 Dec 2024 18:43:05 +0300 Subject: [PATCH 32/38] model_validator added for RegexpGroupSlot (checking that groups() dictionary doesn't request a capture group that's out of bounds) --- chatsky/slots/standard_slots.py | 40 ++++++++++++++++--- .../2_regexgroupslot_and_string_format.py | 1 - 2 files changed, 34 insertions(+), 7 deletions(-) diff --git a/chatsky/slots/standard_slots.py b/chatsky/slots/standard_slots.py index f361c9817..fa9e3f272 100644 --- a/chatsky/slots/standard_slots.py +++ b/chatsky/slots/standard_slots.py @@ -7,10 +7,11 @@ from __future__ import annotations import re +from re import Pattern from typing import Callable, Any, Awaitable, TYPE_CHECKING, Union, Dict, Optional import logging -from pydantic import Field +from pydantic import Field, model_validator from typing_extensions import Annotated from chatsky.utils.devel.async_helpers import wrap_sync_function_in_async @@ -38,6 +39,7 @@ class RegexpSlot(ValueSlot, frozen=True): """ regexp: str + "The regexp to search for in ctx.last_request.text" match_group_idx: int = 0 "Index of the group to match." @@ -64,15 +66,41 @@ class RegexpGroupSlot(GroupSlot, extra="forbid", frozen=True): __pydantic_extra__: Dict[str, Annotated[Union["GroupSlot", "ValueSlot"], Field(union_mode="left_to_right")]] string_format: Optional[str] = None allow_partial_extraction: bool = False + "Unlike in `GroupSlot` this field has no effect in this class. If a slot doesn't" - regexp: str + regexp: Pattern + "The regexp to search for in ctx.last_request.text" groups: dict[str, int] "A dictionary mapping slot names to match_group indexes." default_values: dict[str, Any] = Field(default_factory=dict) - # TODO: write docstring, could copy from tutorial - - def __init__(self, **kwargs): # supress unexpected argument warnings - super().__init__(**kwargs) + "A dictionary with default values for each slot name in case a slot's extraction fails." + + @model_validator(mode="after") + def validate_groups(self): + for elem in self.groups.values(): + if elem > self.regexp.groups: + raise ValueError("Requested group number is too high, there aren't that many capture groups!") + if elem < 0: + raise ValueError("Requested capture group number cannot be negative.") + return self + + def __init__(self, regexp: Union[str, Pattern], groups: dict[str, int], default_values: dict[str, Any] = None, + string_format: str = None, flags: int = 0): + init_dict = { + "regexp": re.compile(regexp, flags), + "groups": groups, + "default_values": default_values, + "string_format": string_format, + } + empty_fields = set() + for k, v in init_dict.items(): + if k not in self.model_fields: + raise NotImplementedError("Init method contains a field not in model fields.") + if v is None: + empty_fields.add(k) + for field in empty_fields: + del init_dict[field] + super().__init__(**init_dict) async def get_value(self, ctx: Context) -> ExtractedGroupSlot: request_text = ctx.last_request.text diff --git a/tutorials/slots/2_regexgroupslot_and_string_format.py b/tutorials/slots/2_regexgroupslot_and_string_format.py index 91826fd66..23b49a5d8 100644 --- a/tutorials/slots/2_regexgroupslot_and_string_format.py +++ b/tutorials/slots/2_regexgroupslot_and_string_format.py @@ -53,7 +53,6 @@ The `RegexpGroupSlot` class is derived from `GroupSlot` class, inheriting its `string_format()` feature, which will be explained later in this tutorial. -Though, `partial extraction` is turned off for this class. This means that unsuccessfully trying to extract a slot will not overwrite its previously extracted value. From 575ad22e510e806fb2e251ead9625b186164ca7d Mon Sep 17 00:00:00 2001 From: ZergLev Date: Mon, 16 Dec 2024 19:08:43 +0300 Subject: [PATCH 33/38] fixed weird test bug, will discuss during call --- tests/slots/test_slot_types.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/slots/test_slot_types.py b/tests/slots/test_slot_types.py index 2f8ffde9c..dbbddb951 100644 --- a/tests/slots/test_slot_types.py +++ b/tests/slots/test_slot_types.py @@ -1,3 +1,5 @@ +import re + import pytest from pydantic import ValidationError @@ -147,7 +149,7 @@ async def test_group_slot_extraction(user_request, slot, expected, is_extracted, ( Message(text="I am Bot. I won't tell you my email"), RegexpGroupSlot( - regexp=r"(?<=am ).+?(?=\.).*[a-zA-Z\.]+@[a-zA-Z\.]+", + regexp=r"am (.+?)\..*email is (.+?)\.", groups={"name": 1, "email": 2}, ), ExtractedGroupSlot( @@ -155,7 +157,7 @@ async def test_group_slot_extraction(user_request, slot, expected, is_extracted, is_slot_extracted=False, extracted_value=SlotNotExtracted( "Failed to match pattern {regexp!r} in {request_text!r}.".format( - regexp=r"(?<=am ).+?(?=\.).*[a-zA-Z\.]+@[a-zA-Z\.]+", + regexp=re.compile(r"am (.+?)\..*email is (.+?)\."), request_text="I am Bot. I won't tell you my email", ) ), @@ -165,7 +167,7 @@ async def test_group_slot_extraction(user_request, slot, expected, is_extracted, is_slot_extracted=False, extracted_value=SlotNotExtracted( "Failed to match pattern {regexp!r} in {request_text!r}.".format( - regexp=r"(?<=am ).+?(?=\.).*[a-zA-Z\.]+@[a-zA-Z\.]+", + regexp=re.compile(r"am (.+?)\..*email is (.+?)\."), request_text="I am Bot. I won't tell you my email", ) ), From efdb7f1b4e3320d13fe8191d5d70782d677c5d30 Mon Sep 17 00:00:00 2001 From: ZergLev Date: Thu, 9 Jan 2025 14:55:54 +0300 Subject: [PATCH 34/38] lint + pydantic fields validation changed at the work call, tutorial is being worked on --- chatsky/slots/base_slots.py | 1 + chatsky/slots/standard_slots.py | 24 ++-- .../2_regexgroupslot_and_string_format.py | 119 +++++++----------- 3 files changed, 63 insertions(+), 81 deletions(-) diff --git a/chatsky/slots/base_slots.py b/chatsky/slots/base_slots.py index 448f9cf63..66e121bb2 100644 --- a/chatsky/slots/base_slots.py +++ b/chatsky/slots/base_slots.py @@ -220,6 +220,7 @@ class GroupSlot(BaseSlot, extra="allow", frozen=True): __pydantic_extra__: Dict[str, Annotated[Union["GroupSlot", "ValueSlot"], Field(union_mode="left_to_right")]] string_format: Optional[str] = None + """Makes the str() representation formatted with slot values being the keywords.""" allow_partial_extraction: bool = False """If True, extraction returns only successfully extracted child slots.""" diff --git a/chatsky/slots/standard_slots.py b/chatsky/slots/standard_slots.py index fa9e3f272..995118376 100644 --- a/chatsky/slots/standard_slots.py +++ b/chatsky/slots/standard_slots.py @@ -55,25 +55,23 @@ async def extract_value(self, ctx: Context) -> Union[str, SlotNotExtracted]: class RegexpGroupSlot(GroupSlot, extra="forbid", frozen=True): """ - A slot type that applies a regex pattern once to extract values for - multiple child slots. Accepts a `regexp` pattern and a `groups` dictionary - mapping slot names to group indexes. Improves efficiency by performing a - single regex search for all specified groups, thus reducing the amount - of calls to your model. + A slot type that reuses one regex.search() call for several slots + to save on execution time in specific cases like LLM, where the amount of + get_value() calls is important. """ # Parent fields repeated for Pydantic issues __pydantic_extra__: Dict[str, Annotated[Union["GroupSlot", "ValueSlot"], Field(union_mode="left_to_right")]] string_format: Optional[str] = None allow_partial_extraction: bool = False - "Unlike in `GroupSlot` this field has no effect in this class. If a slot doesn't" + "Unlike in `GroupSlot` this field has no effect in this class." regexp: Pattern "The regexp to search for in ctx.last_request.text" groups: dict[str, int] - "A dictionary mapping slot names to match_group indexes." + "A dictionary mapping slot names to capture group indexes (like `match_group` in RegexpSlot)." default_values: dict[str, Any] = Field(default_factory=dict) - "A dictionary with default values for each slot name in case a slot's extraction fails." + "A dictionary with default values for each slot name in case the regexp search fails." @model_validator(mode="after") def validate_groups(self): @@ -84,8 +82,14 @@ def validate_groups(self): raise ValueError("Requested capture group number cannot be negative.") return self - def __init__(self, regexp: Union[str, Pattern], groups: dict[str, int], default_values: dict[str, Any] = None, - string_format: str = None, flags: int = 0): + def __init__( + self, + regexp: Union[str, Pattern], + groups: dict[str, int], + default_values: dict[str, Any] = None, + string_format: str = None, + flags: int = 0, + ): init_dict = { "regexp": re.compile(regexp, flags), "groups": groups, diff --git a/tutorials/slots/2_regexgroupslot_and_string_format.py b/tutorials/slots/2_regexgroupslot_and_string_format.py index 23b49a5d8..70fca4caf 100644 --- a/tutorials/slots/2_regexgroupslot_and_string_format.py +++ b/tutorials/slots/2_regexgroupslot_and_string_format.py @@ -23,7 +23,7 @@ responses as rsp, ) -from chatsky.slots import RegexpSlot +from chatsky.slots import RegexpSlot, RegexpGroupSlot, GroupSlot from chatsky.utils.testing import ( check_happy_path, @@ -34,100 +34,77 @@ """ ## RegexpGroupSlot extraction -The `RegexpGroupSlot` class reuses one regex.search() call to save on -execution time in specific cases like LLM, where the amount of get_value() -calls is important. +The `RegexpGroupSlot` is a slot type that reuses one regex.search() call for +several slots to save on execution time in specific cases like LLM, where the +amount of get_value() calls is important. ## RegexpGroupSlot arguments * `regexp` - the regular expression to match with the `ctx.last_request.text`. * `groups` - a dictionary mapping slot names to group indexes, where numbers mean the index of the capture group that was found with `re.search()` -* `default_values` - a list of functions that run after the service. - You can read more about the handlers in this [tutorial] - -This means higher efficiency than a GroupSlot of RegexpSlots, -because only a single regex search is performed. It is saved and reused -for all specified groups, thus reducing the amount of calls to your model, -for example an LLM model. + (like `match_group` in RegexpSlot). +* `default_values` - a dictionary with default values for each slot name in + case the regexp search fails. The `RegexpGroupSlot` class is derived from `GroupSlot` class, inheriting -its `string_format()` feature, which will be explained later in this tutorial. - -This means that unsuccessfully trying to extract a slot will not overwrite -its previously extracted value. - -Note that `save_on_failure` is `True` by default. +its `string_format()` feature. -The slots fall into the following category groups: +## `string_format` usage -- Value slots can be used to extract slot values from user utterances. -- Group slots can be used to split value slots into groups - with an arbitrary level of nesting. +You can set `string_format` to change the `__str__` representation of the +`ExtractedValueSlot`. `string_format` can be set to a string, which will +be formatted with Python's `str.format()`, using extracted slots names and +their values as keyword arguments. -You can build the slot tree by passing the child slot instances as extra fields -of the parent slot. In the following cell, we define two slot groups: +The reason this exists at all is so you don't have to specify each and every +child slot name and can now represent a GroupSlot in a pre-determined way. - Group 1: person.username, person.email - Group 2: friend.first_name, friend.last_name - -Currently there are two types of value slots: - -- %mddoclink(api,slots.standard_slots,RegexpSlot): - Extracts slot values via regexp. -- %mddoclink(api,slots.standard_slots,FunctionSlot): - Extracts slot values with the help of a user-defined function. +Here are some examples of `RegexpGroupSlot` and `string_format `use: """ # %% -SLOTS = { - "person": { - "username": RegexpSlot( - regexp=r"username is ([a-zA-Z]+)", - match_group_idx=1, - ), - "email": RegexpSlot( - regexp=r"email is ([a-z@\.A-Z]+)", - match_group_idx=1, - ), - }, - "friend": { - "first_name": RegexpSlot(regexp=r"^[A-Z][a-z]+?(?= )"), - "last_name": RegexpSlot(regexp=r"(?<= )[A-Z][a-z]+"), - }, +two_numbers_regexp_group_slot = RegexpGroupSlot( + string_format="Second number is {second_number}," + "first_number is {first_number}.", + regexp=r"first number is (\d+)\D*second number is (\d+)", + groups={"first_number": 1, "second_number": 2}, +) + +# date -> RegexpGroupSlot. +sub_slots_for_group_slot = { + "date": RegexpSlot( + regexp=r"(0?[1-9]|(?:1|2)[0-9]|3[0-1])[\.\/]" + r"(0?[1-9]|1[0-2])[\.\/](\d{4}|\d{2})", + ), + "email": RegexpSlot( + regexp=r"[\w\.-]+@[\w\.-]+\.\w{2,4}", + ), } +string_format_group_slot = GroupSlot( + string_format="Date is {date}, email is {email}", **sub_slots_for_group_slot +) -# %% [markdown] -""" -The slots module provides several functions for managing slots in-script: - -- %mddoclink(api,conditions.slots,SlotsExtracted): - Condition for checking if specified slots are extracted. -- %mddoclink(api,processing.slots,Extract): - A processing function that extracts specified slots. -- %mddoclink(api,processing.slots,Unset): - A processing function that marks specified slots as not extracted, - effectively resetting their state. -- %mddoclink(api,processing.slots,UnsetAll): - A processing function that marks all slots as not extracted. -- %mddoclink(api,processing.slots,FillTemplate): - A processing function that fills the `response` - Message text with extracted slot values. -- %mddoclink(api,responses.slots,FilledTemplate): - A response function that takes a Message with a - format-string text and returns Message - with its text string filled with extracted slot values. - -The usage of all the above functions is shown in the following script: -""" +SLOTS = { + "two_numbers_slot": two_numbers_regexp_group_slot, + "string_format_group_slot": string_format_group_slot, +} -# %% script = { GLOBAL: { TRANSITIONS: [ Tr(dst=("username_flow", "ask"), cnd=cnd.Regexp(r"^[sS]tart")) ] }, + "two_numbers_flow": { + "ask": { + RESPONSE: "Write two numbers: ", + PRE_TRANSITION: {"get_slot": proc.Extract("two_numbers_slot")}, + }, + "answer_node": { + PRE_RESPONSE: {"get_slot": proc.Extract("two_numbers_slot")} + }, + }, "username_flow": { LOCAL: { PRE_TRANSITION: {"get_slot": proc.Extract("person.username")}, From 0f70b0fa52d09e8ab66f58d0db07940e287cb253 Mon Sep 17 00:00:00 2001 From: ZergLev Date: Fri, 10 Jan 2025 12:45:11 +0300 Subject: [PATCH 35/38] tutorial code written, testing now --- .../2_regexgroupslot_and_string_format.py | 135 ++++++------------ 1 file changed, 41 insertions(+), 94 deletions(-) diff --git a/tutorials/slots/2_regexgroupslot_and_string_format.py b/tutorials/slots/2_regexgroupslot_and_string_format.py index 70fca4caf..3346fd39e 100644 --- a/tutorials/slots/2_regexgroupslot_and_string_format.py +++ b/tutorials/slots/2_regexgroupslot_and_string_format.py @@ -9,18 +9,19 @@ # %pip install chatsky # %% +import re + from chatsky import ( RESPONSE, TRANSITIONS, PRE_TRANSITION, - PRE_RESPONSE, GLOBAL, - LOCAL, Pipeline, Transition as Tr, conditions as cnd, processing as proc, responses as rsp, + destinations as dst, ) from chatsky.slots import RegexpSlot, RegexpGroupSlot, GroupSlot @@ -60,141 +61,87 @@ The reason this exists at all is so you don't have to specify each and every child slot name and can now represent a GroupSlot in a pre-determined way. -Here are some examples of `RegexpGroupSlot` and `string_format `use: +Below are some examples of `RegexpGroupSlot` and `string_format `use: """ # %% -two_numbers_regexp_group_slot = RegexpGroupSlot( - string_format="Second number is {second_number}," - "first_number is {first_number}.", - regexp=r"first number is (\d+)\D*second number is (\d+)", - groups={"first_number": 1, "second_number": 2}, -) - -# date -> RegexpGroupSlot. sub_slots_for_group_slot = { - "date": RegexpSlot( + "date": RegexpGroupSlot( + string_format="{day}/{month}/{year}", regexp=r"(0?[1-9]|(?:1|2)[0-9]|3[0-1])[\.\/]" r"(0?[1-9]|1[0-2])[\.\/](\d{4}|\d{2})", + groups={"day": 1, "month": 2, "year": 3}, ), "email": RegexpSlot( regexp=r"[\w\.-]+@[\w\.-]+\.\w{2,4}", ), } -string_format_group_slot = GroupSlot( - string_format="Date is {date}, email is {email}", **sub_slots_for_group_slot +date_and_email = GroupSlot( + string_format="Your date of birth is {date}, email is {email}", + **sub_slots_for_group_slot ) SLOTS = { - "two_numbers_slot": two_numbers_regexp_group_slot, - "string_format_group_slot": string_format_group_slot, + "date_and_email": date_and_email, } script = { GLOBAL: { TRANSITIONS: [ - Tr(dst=("username_flow", "ask"), cnd=cnd.Regexp(r"^[sS]tart")) + Tr( + dst=("date_and_email_flow", "ask_email"), + cnd=cnd.Regexp(r"^[sS]tart"), + ), ] }, - "two_numbers_flow": { - "ask": { - RESPONSE: "Write two numbers: ", - PRE_TRANSITION: {"get_slot": proc.Extract("two_numbers_slot")}, - }, - "answer_node": { - PRE_RESPONSE: {"get_slot": proc.Extract("two_numbers_slot")} + "date_and_email_flow": { + "start": { + TRANSITIONS: [Tr(dst=("date_and_email_flow", "ask_date"))], }, - }, - "username_flow": { - LOCAL: { - PRE_TRANSITION: {"get_slot": proc.Extract("person.username")}, + "fallback": { + RESPONSE: "Finishing query", TRANSITIONS: [ + Tr(dst=("date_and_email_flow", "ask_email")), Tr( - dst=("email_flow", "ask"), - cnd=cnd.SlotsExtracted("person.username"), - priority=1.2, + dst=dst.Backward(), + cnd=cnd.Regexp(r"back", flags=re.IGNORECASE), ), - Tr(dst=("username_flow", "repeat_question"), priority=0.8), ], }, - "ask": { - RESPONSE: "Write your username (my username is ...):", - }, - "repeat_question": { - RESPONSE: "Please, type your username again (my username is ...):", - }, - }, - "email_flow": { - LOCAL: { - PRE_TRANSITION: {"get_slot": proc.Extract("person.email")}, + "ask_email": { + RESPONSE: "Write your email (my email is ...):", + PRE_TRANSITION: {"get_slot": proc.Extract("date_and_email.email")}, TRANSITIONS: [ Tr( - dst=("friend_flow", "ask"), - cnd=cnd.SlotsExtracted("person.username", "person.email"), - priority=1.2, - ), - Tr(dst=("email_flow", "repeat_question"), priority=0.8), + dst="ask_email", + cnd=cnd.SlotsExtracted("date_and_email.email"), + ) ], }, - "ask": { - RESPONSE: "Write your email (my email is ...):", - }, - "repeat_question": { - RESPONSE: "Please, write your email again (my email is ...):", - }, - }, - "friend_flow": { - LOCAL: { - PRE_TRANSITION: {"get_slots": proc.Extract("friend")}, + "ask_date": { + RESPONSE: "Write your date of birth:", + PRE_TRANSITION: {"get_slot": proc.Extract("date_and_email.date")}, TRANSITIONS: [ Tr( - dst=("root", "utter"), - cnd=cnd.SlotsExtracted( - "friend.first_name", "friend.last_name", mode="any" - ), - priority=1.2, - ), - Tr(dst=("friend_flow", "repeat_question"), priority=0.8), + dst="answer_node", + cnd=cnd.SlotsExtracted("date_and_email.date"), + ) ], }, - "ask": {RESPONSE: "Please, name me one of your friends: (John Doe)"}, - "repeat_question": { - RESPONSE: "Please, name me one of your friends again: (John Doe)" - }, - }, - "root": { - "start": { - TRANSITIONS: [Tr(dst=("username_flow", "ask"))], - }, - "fallback": { - RESPONSE: "Finishing query", - TRANSITIONS: [Tr(dst=("username_flow", "ask"))], - }, - "utter": { - RESPONSE: rsp.FilledTemplate( - "Your friend is {friend.first_name} {friend.last_name}" - ), - TRANSITIONS: [Tr(dst=("root", "utter_alternative"))], - }, - "utter_alternative": { - RESPONSE: "Your username is {person.username}. " - "Your email is {person.email}.", - PRE_RESPONSE: {"fill": proc.FillTemplate()}, - }, + "answer_node": {RESPONSE: rsp.FilledTemplate("{date_and_email}")}, }, } # %% HAPPY_PATH = [ - ("hi", "Write your username (my username is ...):"), - ("my username is groot", "Write your email (my email is ...):"), + ("hi", "Write your email (my email is ...):"), + ("my email is groot@gmail.com", "Write your date of birth:"), ( - "my email is groot@gmail.com", - "Please, name me one of your friends: (John Doe)", + "my date of birth is 06/10/1984", + "Your date of birth is 06/10/1984, email is groot@gmail.com", ), - ("Bob Page", "Your friend is Bob Page"), - ("ok", "Your username is groot. Your email is groot@gmail.com."), ("ok", "Finishing query"), + ("start", "Write your email (my email is ...):"), ] # %% From 550a437e19b7f11c22bce3496a85b4f07ad5f532 Mon Sep 17 00:00:00 2001 From: ZergLev Date: Fri, 10 Jan 2025 13:53:53 +0300 Subject: [PATCH 36/38] tutorial fixed and slightly simplified, made a minor change to RegexpSlot validation --- chatsky/slots/standard_slots.py | 2 +- .../2_regexgroupslot_and_string_format.py | 38 +++++++++---------- 2 files changed, 18 insertions(+), 22 deletions(-) diff --git a/chatsky/slots/standard_slots.py b/chatsky/slots/standard_slots.py index 995118376..15d3cd768 100644 --- a/chatsky/slots/standard_slots.py +++ b/chatsky/slots/standard_slots.py @@ -38,7 +38,7 @@ class RegexpSlot(ValueSlot, frozen=True): change the `match_group_idx` parameter. """ - regexp: str + regexp: str | Pattern "The regexp to search for in ctx.last_request.text" match_group_idx: int = 0 "Index of the group to match." diff --git a/tutorials/slots/2_regexgroupslot_and_string_format.py b/tutorials/slots/2_regexgroupslot_and_string_format.py index 3346fd39e..984bf62e2 100644 --- a/tutorials/slots/2_regexgroupslot_and_string_format.py +++ b/tutorials/slots/2_regexgroupslot_and_string_format.py @@ -24,7 +24,7 @@ destinations as dst, ) -from chatsky.slots import RegexpSlot, RegexpGroupSlot, GroupSlot +from chatsky.slots import RegexpSlot, RegexpGroupSlot from chatsky.utils.testing import ( check_happy_path, @@ -65,25 +65,17 @@ """ # %% -sub_slots_for_group_slot = { +SLOTS = { "date": RegexpGroupSlot( - string_format="{day}/{month}/{year}", regexp=r"(0?[1-9]|(?:1|2)[0-9]|3[0-1])[\.\/]" r"(0?[1-9]|1[0-2])[\.\/](\d{4}|\d{2})", groups={"day": 1, "month": 2, "year": 3}, + string_format="{day}/{month}/{year}", ), "email": RegexpSlot( regexp=r"[\w\.-]+@[\w\.-]+\.\w{2,4}", ), } -date_and_email = GroupSlot( - string_format="Your date of birth is {date}, email is {email}", - **sub_slots_for_group_slot -) - -SLOTS = { - "date_and_email": date_and_email, -} script = { GLOBAL: { @@ -96,39 +88,43 @@ }, "date_and_email_flow": { "start": { - TRANSITIONS: [Tr(dst=("date_and_email_flow", "ask_date"))], + TRANSITIONS: [Tr(dst=("date_and_email_flow", "ask_email"))], }, "fallback": { RESPONSE: "Finishing query", TRANSITIONS: [ - Tr(dst=("date_and_email_flow", "ask_email")), Tr( dst=dst.Backward(), cnd=cnd.Regexp(r"back", flags=re.IGNORECASE), ), + Tr(dst=("date_and_email_flow", "ask_email"), priority=0.8), ], }, "ask_email": { RESPONSE: "Write your email (my email is ...):", - PRE_TRANSITION: {"get_slot": proc.Extract("date_and_email.email")}, + PRE_TRANSITION: {"get_slot": proc.Extract("email")}, TRANSITIONS: [ Tr( - dst="ask_email", - cnd=cnd.SlotsExtracted("date_and_email.email"), + dst="ask_date", + cnd=cnd.SlotsExtracted("email"), ) ], }, "ask_date": { RESPONSE: "Write your date of birth:", - PRE_TRANSITION: {"get_slot": proc.Extract("date_and_email.date")}, + PRE_TRANSITION: {"get_slot": proc.Extract("date")}, TRANSITIONS: [ Tr( dst="answer_node", - cnd=cnd.SlotsExtracted("date_and_email.date"), + cnd=cnd.SlotsExtracted("date"), ) ], }, - "answer_node": {RESPONSE: rsp.FilledTemplate("{date_and_email}")}, + "answer_node": { + RESPONSE: rsp.FilledTemplate( + "Your date of birth is {date}, email is {email}" + ) + }, }, } @@ -147,8 +143,8 @@ # %% pipeline = Pipeline( script=script, - start_label=("root", "start"), - fallback_label=("root", "fallback"), + start_label=("date_and_email_flow", "start"), + fallback_label=("date_and_email_flow", "fallback"), slots=SLOTS, ) From d38207d5788f5b744066e3184c4134b9aa93a92c Mon Sep 17 00:00:00 2001 From: ZergLev Date: Fri, 10 Jan 2025 21:14:46 +0300 Subject: [PATCH 37/38] string_format fixed --- chatsky/slots/base_slots.py | 1 + 1 file changed, 1 insertion(+) diff --git a/chatsky/slots/base_slots.py b/chatsky/slots/base_slots.py index 66e121bb2..c33116dae 100644 --- a/chatsky/slots/base_slots.py +++ b/chatsky/slots/base_slots.py @@ -280,5 +280,6 @@ async def get_value(self, ctx: Context) -> ExtractedGroupSlot: def init_value(self) -> ExtractedGroupSlot: return ExtractedGroupSlot( + string_format=self.string_format, **{child_name: child.init_value() for child_name, child in self.__pydantic_extra__.items()} ) From 870481dfc4f372fa456c6960039df6b5b3f18721 Mon Sep 17 00:00:00 2001 From: ZergLev Date: Fri, 10 Jan 2025 21:26:43 +0300 Subject: [PATCH 38/38] lint --- chatsky/slots/base_slots.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chatsky/slots/base_slots.py b/chatsky/slots/base_slots.py index c33116dae..e1893e6b5 100644 --- a/chatsky/slots/base_slots.py +++ b/chatsky/slots/base_slots.py @@ -281,5 +281,5 @@ async def get_value(self, ctx: Context) -> ExtractedGroupSlot: def init_value(self) -> ExtractedGroupSlot: return ExtractedGroupSlot( string_format=self.string_format, - **{child_name: child.init_value() for child_name, child in self.__pydantic_extra__.items()} + **{child_name: child.init_value() for child_name, child in self.__pydantic_extra__.items()}, )