diff --git a/chatsky/__rebuild_pydantic_models__.py b/chatsky/__rebuild_pydantic_models__.py index 1da4126a9..747f0d073 100644 --- a/chatsky/__rebuild_pydantic_models__.py +++ b/chatsky/__rebuild_pydantic_models__.py @@ -1,10 +1,11 @@ # flake8: noqa: F401 from chatsky.core.service.types import ExtraHandlerRuntimeInfo, ComponentExecutionState -from chatsky.core import Context, Script +from chatsky.core import Context, Message, Script from chatsky.core.script import Node from chatsky.core.pipeline import Pipeline -from chatsky.slots.slots import SlotManager +from chatsky.slots.standard_slots import FunctionSlot +from chatsky.slots.slot_manager import SlotManager from chatsky.core.context import FrameworkData, ServiceState from chatsky.core.service import PipelineComponent @@ -15,3 +16,4 @@ ExtraHandlerRuntimeInfo.model_rebuild() FrameworkData.model_rebuild() ServiceState.model_rebuild() +FunctionSlot.model_rebuild() diff --git a/chatsky/conditions/__init__.py b/chatsky/conditions/__init__.py index 0d02477dd..bffcecdb5 100644 --- a/chatsky/conditions/__init__.py +++ b/chatsky/conditions/__init__.py @@ -9,5 +9,5 @@ Not, HasCallbackQuery, ) -from chatsky.conditions.slots import SlotsExtracted +from chatsky.conditions.slots import SlotsExtracted, SlotValueEquals from chatsky.conditions.service import ServiceFinished diff --git a/chatsky/conditions/slots.py b/chatsky/conditions/slots.py index eaddd3140..98f881d42 100644 --- a/chatsky/conditions/slots.py +++ b/chatsky/conditions/slots.py @@ -5,10 +5,10 @@ """ from __future__ import annotations -from typing import Literal, List +from typing import Literal, List, Any from chatsky.core import Context, BaseCondition -from chatsky.slots.slots import SlotName +from chatsky.slots.slot_manager import SlotName class SlotsExtracted(BaseCondition): @@ -36,3 +36,27 @@ async def call(self, ctx: Context) -> bool: return all(manager.is_slot_extracted(slot) for slot in self.slots) elif self.mode == "any": return any(manager.is_slot_extracted(slot) for slot in self.slots) + + +class SlotValueEquals(BaseCondition): + """ + Check if :py:attr:`.slot_name`'s extracted value is equal to a given value. + + :raises KeyError: If the slot with the specified name does not exist. + """ + + slot_name: SlotName + """ + Name of the slot that needs to be checked. + """ + value: Any + """ + The value which the slot's extracted value is supposed to be checked against. + """ + + def __init__(self, slot_name: SlotName, value: Any): + super().__init__(slot_name=slot_name, value=value) + + async def call(self, ctx: Context) -> bool: + manager = ctx.framework_data.slot_manager + return manager.get_extracted_slot(self.slot_name).value == self.value diff --git a/chatsky/core/context.py b/chatsky/core/context.py index 59a9899a2..8fc1bf7ca 100644 --- a/chatsky/core/context.py +++ b/chatsky/core/context.py @@ -25,7 +25,7 @@ from pydantic import BaseModel, Field from chatsky.core.message import Message, MessageInitTypes -from chatsky.slots.slots import SlotManager +from chatsky.slots.slot_manager import SlotManager from chatsky.core.node_label import AbsoluteNodeLabel, AbsoluteNodeLabelInitTypes if TYPE_CHECKING: diff --git a/chatsky/core/pipeline.py b/chatsky/core/pipeline.py index 20caa74ea..c97dafab1 100644 --- a/chatsky/core/pipeline.py +++ b/chatsky/core/pipeline.py @@ -21,7 +21,7 @@ from chatsky.messengers.console import CLIMessengerInterface from chatsky.messengers.common import MessengerInterface -from chatsky.slots.slots import GroupSlot +from chatsky.slots import GroupSlot from chatsky.core.service.group import ServiceGroup, ServiceGroupInitTypes from chatsky.core.service.extra import ComponentExtraHandlerInitTypes, BeforeHandler, AfterHandler from .service import Service diff --git a/chatsky/processing/slots.py b/chatsky/processing/slots.py index 898195e94..98c0c3019 100644 --- a/chatsky/processing/slots.py +++ b/chatsky/processing/slots.py @@ -9,7 +9,7 @@ import logging from typing import List -from chatsky.slots.slots import SlotName +from chatsky.slots.slot_manager import SlotName from chatsky.core import Context, BaseProcessing from chatsky.responses.slots import FilledTemplate @@ -24,16 +24,16 @@ class Extract(BaseProcessing): slots: List[SlotName] """A list of slot names to extract.""" - success_only: bool = True + save_on_failure: bool = True """If set, only successfully extracted values will be stored in the slot storage.""" - def __init__(self, *slots: SlotName, success_only: bool = True): - super().__init__(slots=slots, success_only=success_only) + def __init__(self, *slots: SlotName, save_on_failure: bool = True): + super().__init__(slots=slots, save_on_failure=save_on_failure) async def call(self, ctx: Context): manager = ctx.framework_data.slot_manager results = await asyncio.gather( - *(manager.extract_slot(slot, ctx, self.success_only) for slot in self.slots), return_exceptions=True + *(manager.extract_slot(slot, ctx, self.save_on_failure) for slot in self.slots), return_exceptions=True ) for result in results: diff --git a/chatsky/slots/__init__.py b/chatsky/slots/__init__.py index 6c929b9af..bc1c33907 100644 --- a/chatsky/slots/__init__.py +++ b/chatsky/slots/__init__.py @@ -1 +1,2 @@ -from chatsky.slots.slots import GroupSlot, ValueSlot, RegexpSlot, FunctionSlot +from chatsky.slots.standard_slots import RegexpSlot, RegexpGroupSlot, FunctionSlot +from chatsky.slots.base_slots import GroupSlot, ValueSlot diff --git a/chatsky/slots/base_slots.py b/chatsky/slots/base_slots.py new file mode 100644 index 000000000..e1893e6b5 --- /dev/null +++ b/chatsky/slots/base_slots.py @@ -0,0 +1,285 @@ +""" +Base Slots +---------- +This module defines base classes for slots. +""" + +from __future__ import annotations + +import asyncio +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Union, Dict, Optional +from typing_extensions import Annotated +import logging +from string import Formatter + +from pydantic import BaseModel, model_validator, Field, field_serializer, field_validator + +from chatsky.utils.devel.async_helpers import wrap_sync_function_in_async +from chatsky.utils.devel.json_serialization import pickle_serializer, pickle_validator + +if TYPE_CHECKING: + from chatsky.core import Context + + +logger = logging.getLogger(__name__) + + +class KwargOnlyFormatter(Formatter): + def get_value(self, key, args, kwargs): + return super().get_value(str(key), args, kwargs) + + +class SlotNotExtracted(Exception): + """This exception can be returned or raised by slot extractor if slot extraction is unsuccessful.""" + + pass + + +class ExtractedSlot(BaseModel, ABC): + """ + Represents value of an extracted slot. + + Instances of this class are managed by framework and + are stored in :py:attr:`~chatsky.core.context.FrameworkData.slot_manager`. + They can be accessed via the ``ctx.framework_data.slot_manager.get_extracted_slot`` method. + """ + + @property + @abstractmethod + def __slot_extracted__(self) -> bool: + """Whether the slot is extracted.""" + raise NotImplementedError + + def __unset__(self): + """Mark slot as not extracted and clear extracted data (except for default value).""" + raise NotImplementedError + + @abstractmethod + def __str__(self): + """String representation is used to fill templates.""" + raise NotImplementedError + + +class BaseSlot(BaseModel, frozen=True): + """ + BaseSlot is a base class for all slots. + """ + + @abstractmethod + async def get_value(self, ctx: Context) -> ExtractedSlot: + """ + Extract slot value from :py:class:`~.Context` and return an instance of :py:class:`~.ExtractedSlot`. + """ + raise NotImplementedError + + @abstractmethod + def init_value(self) -> ExtractedSlot: + """ + Provide an initial value to fill slot storage with. + """ + raise NotImplementedError + + +class ExtractedValueSlot(ExtractedSlot): + """Value extracted from :py:class:`~.ValueSlot`.""" + + is_slot_extracted: bool + extracted_value: Any + default_value: Any = None + + @field_serializer("extracted_value", "default_value", when_used="json") + def pickle_serialize_values(self, value): + """ + Cast values to string via pickle. + Allows storing arbitrary data in these fields when using context storages. + """ + if value is not None: + return pickle_serializer(value) + return value + + @field_validator("extracted_value", "default_value", mode="before") + @classmethod + def pickle_validate_values(cls, value): + """ + Restore values after being processed with + :py:meth:`pickle_serialize_values`. + """ + if value is not None: + return pickle_validator(value) + return value + + @property + def __slot_extracted__(self) -> bool: + return self.is_slot_extracted + + def __unset__(self): + self.is_slot_extracted = False + self.extracted_value = SlotNotExtracted("Slot manually unset.") + + @property + def value(self): + """Extracted value or the default value if the slot is not extracted.""" + return self.extracted_value if self.is_slot_extracted else self.default_value + + def __str__(self): + return str(self.value) + + +class ExtractedGroupSlot(ExtractedSlot, extra="allow"): + __pydantic_extra__: Dict[ + str, Annotated[Union["ExtractedGroupSlot", "ExtractedValueSlot"], Field(union_mode="left_to_right")] + ] + string_format: Optional[str] = None + + @property + def __slot_extracted__(self) -> bool: + return all([slot.__slot_extracted__ for slot in self.__pydantic_extra__.values()]) + + def __unset__(self): + for child in self.__pydantic_extra__.values(): + child.__unset__() + + # fill template here + def __str__(self): + if self.string_format is not None: + return KwargOnlyFormatter().format(self.string_format, **self.__pydantic_extra__) + else: + return str({key: str(value) for key, value in self.__pydantic_extra__.items()}) + + def update(self, old: "ExtractedGroupSlot"): + """ + Rebase this extracted groups slot on top of another one. + This is required to merge slot storage in-context + with a potentially different slot configuration passed to pipeline. + + :param old: An instance of :py:class:`~.ExtractedGroupSlot` stored in-context. + Extracted values will be transferred to this object. + """ + for slot in old.__pydantic_extra__: + if slot in self.__pydantic_extra__: + new_slot = self.__pydantic_extra__[slot] + old_slot = old.__pydantic_extra__[slot] + if isinstance(new_slot, ExtractedGroupSlot) and isinstance(old_slot, ExtractedGroupSlot): + new_slot.update(old_slot) + if isinstance(new_slot, ExtractedValueSlot) and isinstance(old_slot, ExtractedValueSlot): + self.__pydantic_extra__[slot] = old_slot + + +class ValueSlot(BaseSlot, frozen=True): + """ + Value slot is a base class for all slots that are designed to extract concrete values. + Subclass it, if you want to declare your own slot type. + """ + + default_value: Any = None + + @abstractmethod + async def extract_value(self, ctx: Context) -> Union[Any, SlotNotExtracted]: + """ + Return value extracted from context. + + Return :py:exc:`~.SlotNotExtracted` to mark extraction as unsuccessful. + + Raising exceptions is also allowed and will result in an unsuccessful extraction as well. + """ + raise NotImplementedError + + async def get_value(self, ctx: Context) -> ExtractedValueSlot: + """Wrapper for :py:meth:`~.ValueSlot.extract_value` to handle exceptions.""" + extracted_value = SlotNotExtracted("Caught an exit exception.") + is_slot_extracted = False + + try: + extracted_value = await wrap_sync_function_in_async(self.extract_value, ctx) + is_slot_extracted = not isinstance(extracted_value, SlotNotExtracted) + except Exception as error: + logger.exception(f"Exception occurred during {self.__class__.__name__!r} extraction.", exc_info=error) + extracted_value = error + finally: + if not is_slot_extracted: + logger.debug(f"Slot {self.__class__.__name__!r} was not extracted: {extracted_value}") + return ExtractedValueSlot.model_construct( + is_slot_extracted=is_slot_extracted, + extracted_value=extracted_value, + default_value=self.default_value, + ) + + def init_value(self) -> ExtractedValueSlot: + return ExtractedValueSlot.model_construct( + is_slot_extracted=False, + extracted_value=SlotNotExtracted("Initial slot extraction."), + default_value=self.default_value, + ) + + +class GroupSlot(BaseSlot, extra="allow", frozen=True): + """ + Base class for :py:class:`~.RootSlot` and :py:class:`~.GroupSlot`. + """ + + __pydantic_extra__: Dict[str, Annotated[Union["GroupSlot", "ValueSlot"], Field(union_mode="left_to_right")]] + string_format: Optional[str] = None + """Makes the str() representation formatted with slot values being the keywords.""" + allow_partial_extraction: bool = False + """If True, extraction returns only successfully extracted child slots.""" + + def __init__(self, allow_partial_extraction=False, **kwargs): + super().__init__(allow_partial_extraction=allow_partial_extraction, **kwargs) + + @model_validator(mode="before") + @classmethod + def __check_reserved_slot_names__(cls, data: Any): + """ + Check that reserved names are used correctly and not taken by the user's slots. + """ + if isinstance(data, dict): + reserved_names = ["string_format", "allow_partial_extraction"] + for name in reserved_names: + if data.get(name) is not None: + if isinstance(data.get(name), BaseSlot): + raise TypeError(f"Slots cannot be named '{name}', it's a reserved name.") + return data + + @model_validator(mode="after") + def __check_extra_field_names__(self): + """ + Extra field names cannot be dunder names or contain dots. + """ + if self.__pydantic_extra__ is not None: + for field in self.__pydantic_extra__.keys(): + if "." in field: + raise ValueError(f"Extra field name cannot contain dots: {field!r}") + if field.startswith("__") and field.endswith("__"): + raise ValueError(f"Extra field names cannot be dunder: {field!r}") + return self + + def _flatten_group_slot(self, slot, parent_key="") -> dict: + """ + Unpacks a GroupSlot's nested `Slots` into a single dictionary. + + Intended as a helper method for making GroupSlot child classes. + """ + items = {} + for key, value in slot.__pydantic_extra__.items(): + new_key = f"{parent_key}.{key}" if parent_key else key + if isinstance(value, GroupSlot): + items.update(self._flatten_group_slot(value, new_key)) + else: + items[new_key] = value + return items + + async def get_value(self, ctx: Context) -> ExtractedGroupSlot: + child_values = await asyncio.gather(*(child.get_value(ctx) for child in self.__pydantic_extra__.values())) + extracted_values = {} + for child_value, child_name in zip(child_values, self.__pydantic_extra__.keys()): + if child_value.__slot_extracted__ or not self.allow_partial_extraction: + extracted_values[child_name] = child_value + + return ExtractedGroupSlot(string_format=self.string_format, **extracted_values) + + def init_value(self) -> ExtractedGroupSlot: + return ExtractedGroupSlot( + string_format=self.string_format, + **{child_name: child.init_value() for child_name, child in self.__pydantic_extra__.items()}, + ) diff --git a/chatsky/slots/slot_manager.py b/chatsky/slots/slot_manager.py new file mode 100644 index 000000000..633c9ed00 --- /dev/null +++ b/chatsky/slots/slot_manager.py @@ -0,0 +1,204 @@ +""" +Slot Manager +------------ +This module defines the SlotManager class, facilitating slot management for Pipeline. +""" + +from __future__ import annotations + +from typing_extensions import TypeAlias +from typing import Union, TYPE_CHECKING, Optional +import logging + +from functools import reduce + +from pydantic import BaseModel, Field + +from chatsky.slots.base_slots import ( + ExtractedSlot, + BaseSlot, + ExtractedValueSlot, + ExtractedGroupSlot, + GroupSlot, + KwargOnlyFormatter, +) + +if TYPE_CHECKING: + from chatsky.core import Context + +logger = logging.getLogger(__name__) + +SlotName: TypeAlias = str +""" +A string to identify slots. + +Top-level slots are identified by their key in a :py:class:`~.GroupSlot`. + +E.g. + +.. code:: python + + GroupSlot( + user=RegexpSlot(), + password=FunctionSlot, + ) + +Has two slots with names "user" and "password". + +For nested group slots use dots to separate names: + +.. code:: python + + GroupSlot( + user=GroupSlot( + name=FunctionSlot, + password=FunctionSlot, + ) + ) + +Has two slots with names "user.name" and "user.password". +""" + + +def recursive_getattr(obj, slot_name: SlotName): + def two_arg_getattr(__o, name): + # pydantic handles exception when accessing a non-existing extra-field on its own + # return None by default to avoid that + return getattr(__o, name, None) + + return reduce(two_arg_getattr, [obj, *slot_name.split(".")]) + + +def recursive_setattr(obj, slot_name: SlotName, value): + parent_slot, sep, slot = slot_name.rpartition(".") + + if sep == ".": + parent_obj = recursive_getattr(obj, parent_slot) + else: + parent_obj = obj + + if isinstance(value, ExtractedGroupSlot): + getattr(parent_obj, slot).update(value) + else: + setattr(parent_obj, slot, value) + + +class SlotManager(BaseModel): + """ + Provides API for managing slots. + + An instance of this class can be accessed via ``ctx.framework_data.slot_manager``. + """ + + slot_storage: ExtractedGroupSlot = Field(default_factory=ExtractedGroupSlot) + """Slot storage. Stored inside ctx.framework_data.""" + root_slot: GroupSlot = Field(default_factory=GroupSlot, exclude=True) + """Slot configuration passed during pipeline initialization.""" + + def set_root_slot(self, root_slot: GroupSlot): + """ + Set root_slot configuration from pipeline. + Update extracted slots with the new configuration: + + New slots are added with their :py:meth:`~.BaseSlot.init_value`. + Old extracted slot values are preserved only if their configuration did not change. + That is if they are still present in the config and if their fundamental type did not change + (i.e. `GroupSlot` did not turn into a `ValueSlot` or vice versa). + + This method is called by pipeline and is not supposed to be used otherwise. + """ + self.root_slot = root_slot + new_slot_storage = root_slot.init_value() + new_slot_storage.update(self.slot_storage) + self.slot_storage = new_slot_storage + + def get_slot(self, slot_name: SlotName) -> BaseSlot: + """ + Get slot configuration from the slot name. + + :raises KeyError: If the slot with the specified name does not exist. + """ + slot = recursive_getattr(self.root_slot, slot_name) + if isinstance(slot, BaseSlot): + return slot + raise KeyError(f"Could not find slot {slot_name!r}.") + + async def extract_slot(self, slot_name: SlotName, ctx: Context, save_on_failure: bool) -> None: + """ + Extract slot `slot_name` and store extracted value in `slot_storage`. + + :raises KeyError: If the slot with the specified name does not exist. + + :param slot_name: Name of the slot to extract. + :param ctx: Context. + :param save_on_failure: Whether to store the value only if it is successfully extracted. + """ + slot = self.get_slot(slot_name) + value = await slot.get_value(ctx) + + if value.__slot_extracted__ or save_on_failure is False: + recursive_setattr(self.slot_storage, slot_name, value) + + async def extract_all(self, ctx: Context): + """ + Extract all slots from slot configuration `root_slot` and set `slot_storage` to the extracted value. + """ + self.slot_storage = await self.root_slot.get_value(ctx) + + def get_extracted_slot(self, slot_name: SlotName) -> Union[ExtractedValueSlot, ExtractedGroupSlot]: + """ + Retrieve extracted value from `slot_storage`. + + :raises KeyError: If the slot with the specified name does not exist. + """ + slot = recursive_getattr(self.slot_storage, slot_name) + if isinstance(slot, ExtractedSlot): + return slot + raise KeyError(f"Could not find slot {slot_name!r}.") + + def is_slot_extracted(self, slot_name: str) -> bool: + """ + Return if the specified slot is extracted. + + :raises KeyError: If the slot with the specified name does not exist. + """ + return self.get_extracted_slot(slot_name).__slot_extracted__ + + def all_slots_extracted(self) -> bool: + """ + Return if all slots are extracted. + """ + return self.slot_storage.__slot_extracted__ + + def unset_slot(self, slot_name: SlotName) -> None: + """ + Mark specified slot as not extracted and clear extracted value. + + :raises KeyError: If the slot with the specified name does not exist. + """ + self.get_extracted_slot(slot_name).__unset__() + + def unset_all_slots(self) -> None: + """ + Mark all slots as not extracted and clear all extracted values. + """ + self.slot_storage.__unset__() + + def fill_template(self, template: str) -> Optional[str]: + """ + Fill `template` string with extracted slot values and return a formatted string + or None if an exception has occurred while trying to fill template. + + `template` should be a format-string: + + E.g. "Your username is {profile.username}". + + For the example above, if ``profile.username`` slot has value "admin", + it would return the following text: + "Your username is admin". + """ + try: + return KwargOnlyFormatter().format(template, **dict(self.slot_storage.__pydantic_extra__.items())) + except Exception as exc: + logger.exception("An exception occurred during template filling.", exc_info=exc) + return None diff --git a/chatsky/slots/slots.py b/chatsky/slots/slots.py deleted file mode 100644 index 3cadd9205..000000000 --- a/chatsky/slots/slots.py +++ /dev/null @@ -1,461 +0,0 @@ -""" -Slots ------ -This module defines base classes for slots and some concrete implementations of them. -""" - -from __future__ import annotations - -import asyncio -import re -from abc import ABC, abstractmethod -from typing import Callable, Any, Awaitable, TYPE_CHECKING, Union, Optional, Dict -from typing_extensions import TypeAlias, Annotated -import logging -from functools import reduce -from string import Formatter - -from pydantic import BaseModel, model_validator, Field, field_serializer, field_validator - -from chatsky.utils.devel.async_helpers import wrap_sync_function_in_async -from chatsky.utils.devel.json_serialization import pickle_serializer, pickle_validator - -if TYPE_CHECKING: - from chatsky.core import Context, Message - - -logger = logging.getLogger(__name__) - - -SlotName: TypeAlias = str -""" -A string to identify slots. - -Top-level slots are identified by their key in a :py:class:`~.GroupSlot`. - -E.g. - -.. code:: python - - GroupSlot( - user=RegexpSlot(), - password=FunctionSlot, - ) - -Has two slots with names "user" and "password". - -For nested group slots use dots to separate names: - -.. code:: python - - GroupSlot( - user=GroupSlot( - name=FunctionSlot, - password=FunctionSlot, - ) - ) - -Has two slots with names "user.name" and "user.password". -""" - - -def recursive_getattr(obj, slot_name: SlotName): - def two_arg_getattr(__o, name): - # pydantic handles exception when accessing a non-existing extra-field on its own - # return None by default to avoid that - return getattr(__o, name, None) - - return reduce(two_arg_getattr, [obj, *slot_name.split(".")]) - - -def recursive_setattr(obj, slot_name: SlotName, value): - parent_slot, sep, slot = slot_name.rpartition(".") - - if sep == ".": - parent_obj = recursive_getattr(obj, parent_slot) - else: - parent_obj = obj - - if isinstance(value, ExtractedGroupSlot): - getattr(parent_obj, slot).update(value) - else: - setattr(parent_obj, slot, value) - - -class SlotNotExtracted(Exception): - """This exception can be returned or raised by slot extractor if slot extraction is unsuccessful.""" - - pass - - -class ExtractedSlot(BaseModel, ABC): - """ - Represents value of an extracted slot. - - Instances of this class are managed by framework and - are stored in :py:attr:`~chatsky.core.context.FrameworkData.slot_manager`. - They can be accessed via the ``ctx.framework_data.slot_manager.get_extracted_slot`` method. - """ - - @property - @abstractmethod - def __slot_extracted__(self) -> bool: - """Whether the slot is extracted.""" - raise NotImplementedError - - def __unset__(self): - """Mark slot as not extracted and clear extracted data (except for default value).""" - raise NotImplementedError - - @abstractmethod - def __str__(self): - """String representation is used to fill templates.""" - raise NotImplementedError - - -class ExtractedValueSlot(ExtractedSlot): - """Value extracted from :py:class:`~.ValueSlot`.""" - - is_slot_extracted: bool - extracted_value: Any - default_value: Any = None - - @field_serializer("extracted_value", "default_value", when_used="json") - def pickle_serialize_values(self, value): - """ - Cast values to string via pickle. - Allows storing arbitrary data in these fields when using context storages. - """ - if value is not None: - return pickle_serializer(value) - return value - - @field_validator("extracted_value", "default_value", mode="before") - @classmethod - def pickle_validate_values(cls, value): - """ - Restore values after being processed with - :py:meth:`pickle_serialize_values`. - """ - if value is not None: - return pickle_validator(value) - return value - - @property - def __slot_extracted__(self) -> bool: - return self.is_slot_extracted - - def __unset__(self): - self.is_slot_extracted = False - self.extracted_value = SlotNotExtracted("Slot manually unset.") - - @property - def value(self): - """Extracted value or the default value if the slot is not extracted.""" - return self.extracted_value if self.is_slot_extracted else self.default_value - - def __str__(self): - return str(self.value) - - -class ExtractedGroupSlot(ExtractedSlot, extra="allow"): - __pydantic_extra__: Dict[ - str, Annotated[Union["ExtractedGroupSlot", "ExtractedValueSlot"], Field(union_mode="left_to_right")] - ] - - @property - def __slot_extracted__(self) -> bool: - return all([slot.__slot_extracted__ for slot in self.__pydantic_extra__.values()]) - - def __unset__(self): - for child in self.__pydantic_extra__.values(): - child.__unset__() - - def __str__(self): - return str({key: str(value) for key, value in self.__pydantic_extra__.items()}) - - def update(self, old: "ExtractedGroupSlot"): - """ - Rebase this extracted groups slot on top of another one. - This is required to merge slot storage in-context - with a potentially different slot configuration passed to pipeline. - - :param old: An instance of :py:class:`~.ExtractedGroupSlot` stored in-context. - Extracted values will be transferred to this object. - """ - for slot in old.__pydantic_extra__: - if slot in self.__pydantic_extra__: - new_slot = self.__pydantic_extra__[slot] - old_slot = old.__pydantic_extra__[slot] - if isinstance(new_slot, ExtractedGroupSlot) and isinstance(old_slot, ExtractedGroupSlot): - new_slot.update(old_slot) - if isinstance(new_slot, ExtractedValueSlot) and isinstance(old_slot, ExtractedValueSlot): - self.__pydantic_extra__[slot] = old_slot - - -class BaseSlot(BaseModel, frozen=True): - """ - BaseSlot is a base class for all slots. - """ - - @abstractmethod - async def get_value(self, ctx: Context) -> ExtractedSlot: - """ - Extract slot value from :py:class:`~.Context` and return an instance of :py:class:`~.ExtractedSlot`. - """ - raise NotImplementedError - - @abstractmethod - def init_value(self) -> ExtractedSlot: - """ - Provide an initial value to fill slot storage with. - """ - raise NotImplementedError - - -class ValueSlot(BaseSlot, frozen=True): - """ - Value slot is a base class for all slots that are designed to extract concrete values. - Subclass it, if you want to declare your own slot type. - """ - - default_value: Any = None - - @abstractmethod - async def extract_value(self, ctx: Context) -> Union[Any, SlotNotExtracted]: - """ - Return value extracted from context. - - Return :py:exc:`~.SlotNotExtracted` to mark extraction as unsuccessful. - - Raising exceptions is also allowed and will result in an unsuccessful extraction as well. - """ - raise NotImplementedError - - async def get_value(self, ctx: Context) -> ExtractedValueSlot: - """Wrapper for :py:meth:`~.ValueSlot.extract_value` to handle exceptions.""" - extracted_value = SlotNotExtracted("Caught an exit exception.") - is_slot_extracted = False - - try: - extracted_value = await wrap_sync_function_in_async(self.extract_value, ctx) - is_slot_extracted = not isinstance(extracted_value, SlotNotExtracted) - except Exception as error: - logger.exception(f"Exception occurred during {self.__class__.__name__!r} extraction.", exc_info=error) - extracted_value = error - finally: - if not is_slot_extracted: - logger.debug(f"Slot {self.__class__.__name__!r} was not extracted: {extracted_value}") - return ExtractedValueSlot.model_construct( - is_slot_extracted=is_slot_extracted, - extracted_value=extracted_value, - default_value=self.default_value, - ) - - def init_value(self) -> ExtractedValueSlot: - return ExtractedValueSlot.model_construct( - is_slot_extracted=False, - extracted_value=SlotNotExtracted("Initial slot extraction."), - default_value=self.default_value, - ) - - -class GroupSlot(BaseSlot, extra="allow", frozen=True): - """ - Base class for :py:class:`~.RootSlot` and :py:class:`~.GroupSlot`. - """ - - __pydantic_extra__: Dict[str, Annotated[Union["GroupSlot", "ValueSlot"], Field(union_mode="left_to_right")]] - allow_partial_extraction: bool = False - """If True, extraction returns only successfully extracted child slots.""" - - def __init__(self, allow_partial_extraction=False, **kwargs): - super().__init__(allow_partial_extraction=allow_partial_extraction, **kwargs) - - @model_validator(mode="after") - def __check_extra_field_names__(self): - """ - Extra field names cannot be dunder names or contain dots. - """ - for field in self.__pydantic_extra__.keys(): - if "." in field: - raise ValueError(f"Extra field name cannot contain dots: {field!r}") - if field.startswith("__") and field.endswith("__"): - raise ValueError(f"Extra field names cannot be dunder: {field!r}") - return self - - async def get_value(self, ctx: Context) -> ExtractedGroupSlot: - child_values = await asyncio.gather(*(child.get_value(ctx) for child in self.__pydantic_extra__.values())) - extracted_values = {} - for child_value, child_name in zip(child_values, self.__pydantic_extra__.keys()): - if child_value.__slot_extracted__ or not self.allow_partial_extraction: - extracted_values[child_name] = child_value - - return ExtractedGroupSlot(**extracted_values) - - def init_value(self) -> ExtractedGroupSlot: - return ExtractedGroupSlot( - **{child_name: child.init_value() for child_name, child in self.__pydantic_extra__.items()} - ) - - -class RegexpSlot(ValueSlot, frozen=True): - """ - RegexpSlot is a slot type that extracts its value using a regular expression. - You can pass a compiled or a non-compiled pattern to the `regexp` argument. - If you want to extract a particular group, but not the full match, - change the `match_group_idx` parameter. - """ - - regexp: str - match_group_idx: int = 0 - "Index of the group to match." - - async def extract_value(self, ctx: Context) -> Union[str, SlotNotExtracted]: - request_text = ctx.last_request.text - search = re.search(self.regexp, request_text) - return ( - search.group(self.match_group_idx) - if search - else SlotNotExtracted(f"Failed to match pattern {self.regexp!r} in {request_text!r}.") - ) - - -class FunctionSlot(ValueSlot, frozen=True): - """ - A simpler version of :py:class:`~.ValueSlot`. - - Uses a user-defined `func` to extract slot value from the :py:attr:`~.Context.last_request` Message. - """ - - func: Callable[[Message], Union[Awaitable[Union[Any, SlotNotExtracted]], Any, SlotNotExtracted]] - - async def extract_value(self, ctx: Context) -> Union[Any, SlotNotExtracted]: - return await wrap_sync_function_in_async(self.func, ctx.last_request) - - -class SlotManager(BaseModel): - """ - Provides API for managing slots. - - An instance of this class can be accessed via ``ctx.framework_data.slot_manager``. - """ - - slot_storage: ExtractedGroupSlot = Field(default_factory=ExtractedGroupSlot) - """Slot storage. Stored inside ctx.framework_data.""" - root_slot: GroupSlot = Field(default_factory=GroupSlot, exclude=True) - """Slot configuration passed during pipeline initialization.""" - - def set_root_slot(self, root_slot: GroupSlot): - """ - Set root_slot configuration from pipeline. - Update extracted slots with the new configuration: - - New slots are added with their :py:meth:`~.BaseSlot.init_value`. - Old extracted slot values are preserved only if their configuration did not change. - That is if they are still present in the config and if their fundamental type did not change - (i.e. `GroupSlot` did not turn into a `ValueSlot` or vice versa). - - This method is called by pipeline and is not supposed to be used otherwise. - """ - self.root_slot = root_slot - new_slot_storage = root_slot.init_value() - new_slot_storage.update(self.slot_storage) - self.slot_storage = new_slot_storage - - def get_slot(self, slot_name: SlotName) -> BaseSlot: - """ - Get slot configuration from the slot name. - - :raises KeyError: If the slot with the specified name does not exist. - """ - slot = recursive_getattr(self.root_slot, slot_name) - if isinstance(slot, BaseSlot): - return slot - raise KeyError(f"Could not find slot {slot_name!r}.") - - async def extract_slot(self, slot_name: SlotName, ctx: Context, success_only: bool) -> None: - """ - Extract slot `slot_name` and store extracted value in `slot_storage`. - - Extracted group slots update slot storage instead of overwriting it. - - :raises KeyError: If the slot with the specified name does not exist. - - :param slot_name: Name of the slot to extract. - :param ctx: Context. - :param success_only: Whether to store the value only if it is successfully extracted. - """ - slot = self.get_slot(slot_name) - value = await slot.get_value(ctx) - - if value.__slot_extracted__ or success_only is False: - recursive_setattr(self.slot_storage, slot_name, value) - - async def extract_all(self, ctx: Context): - """ - Extract all slots from slot configuration `root_slot` and set `slot_storage` to the extracted value. - """ - self.slot_storage = await self.root_slot.get_value(ctx) - - def get_extracted_slot(self, slot_name: SlotName) -> Union[ExtractedValueSlot, ExtractedGroupSlot]: - """ - Retrieve extracted value from `slot_storage`. - - :raises KeyError: If the slot with the specified name does not exist. - """ - slot = recursive_getattr(self.slot_storage, slot_name) - if isinstance(slot, ExtractedSlot): - return slot - raise KeyError(f"Could not find slot {slot_name!r}.") - - def is_slot_extracted(self, slot_name: str) -> bool: - """ - Return if the specified slot is extracted. - - :raises KeyError: If the slot with the specified name does not exist. - """ - return self.get_extracted_slot(slot_name).__slot_extracted__ - - def all_slots_extracted(self) -> bool: - """ - Return if all slots are extracted. - """ - return self.slot_storage.__slot_extracted__ - - def unset_slot(self, slot_name: SlotName) -> None: - """ - Mark specified slot as not extracted and clear extracted value. - - :raises KeyError: If the slot with the specified name does not exist. - """ - self.get_extracted_slot(slot_name).__unset__() - - def unset_all_slots(self) -> None: - """ - Mark all slots as not extracted and clear all extracted values. - """ - self.slot_storage.__unset__() - - class KwargOnlyFormatter(Formatter): - def get_value(self, key, args, kwargs): - return super().get_value(str(key), args, kwargs) - - def fill_template(self, template: str) -> Optional[str]: - """ - Fill `template` string with extracted slot values and return a formatted string - or None if an exception has occurred while trying to fill template. - - `template` should be a format-string: - - E.g. "Your username is {profile.username}". - - For the example above, if ``profile.username`` slot has value "admin", - it would return the following text: - "Your username is admin". - """ - try: - return self.KwargOnlyFormatter().format(template, **dict(self.slot_storage.__pydantic_extra__.items())) - except Exception as exc: - logger.exception("An exception occurred during template filling.", exc_info=exc) - return None diff --git a/chatsky/slots/standard_slots.py b/chatsky/slots/standard_slots.py new file mode 100644 index 000000000..15d3cd768 --- /dev/null +++ b/chatsky/slots/standard_slots.py @@ -0,0 +1,158 @@ +""" +Standard Slots +-------------- +This module defines some concrete implementations of slots. +""" + +from __future__ import annotations + +import re +from re import Pattern +from typing import Callable, Any, Awaitable, TYPE_CHECKING, Union, Dict, Optional +import logging + +from pydantic import Field, model_validator +from typing_extensions import Annotated + +from chatsky.utils.devel.async_helpers import wrap_sync_function_in_async +from chatsky.slots.base_slots import ( + SlotNotExtracted, + ExtractedValueSlot, + ExtractedGroupSlot, + ValueSlot, + GroupSlot, +) + +if TYPE_CHECKING: + from chatsky.core import Context, Message + + +logger = logging.getLogger(__name__) + + +class RegexpSlot(ValueSlot, frozen=True): + """ + RegexpSlot is a slot type that extracts its value using a regular expression. + You can pass a compiled or a non-compiled pattern to the `regexp` argument. + If you want to extract a particular group, but not the full match, + change the `match_group_idx` parameter. + """ + + regexp: str | Pattern + "The regexp to search for in ctx.last_request.text" + match_group_idx: int = 0 + "Index of the group to match." + + async def extract_value(self, ctx: Context) -> Union[str, SlotNotExtracted]: + request_text = ctx.last_request.text + search = re.search(self.regexp, request_text) + return ( + search.group(self.match_group_idx) + if search + else SlotNotExtracted(f"Failed to match pattern {self.regexp!r} in {request_text!r}.") + ) + + +class RegexpGroupSlot(GroupSlot, extra="forbid", frozen=True): + """ + A slot type that reuses one regex.search() call for several slots + to save on execution time in specific cases like LLM, where the amount of + get_value() calls is important. + """ + + # Parent fields repeated for Pydantic issues + __pydantic_extra__: Dict[str, Annotated[Union["GroupSlot", "ValueSlot"], Field(union_mode="left_to_right")]] + string_format: Optional[str] = None + allow_partial_extraction: bool = False + "Unlike in `GroupSlot` this field has no effect in this class." + + regexp: Pattern + "The regexp to search for in ctx.last_request.text" + groups: dict[str, int] + "A dictionary mapping slot names to capture group indexes (like `match_group` in RegexpSlot)." + default_values: dict[str, Any] = Field(default_factory=dict) + "A dictionary with default values for each slot name in case the regexp search fails." + + @model_validator(mode="after") + def validate_groups(self): + for elem in self.groups.values(): + if elem > self.regexp.groups: + raise ValueError("Requested group number is too high, there aren't that many capture groups!") + if elem < 0: + raise ValueError("Requested capture group number cannot be negative.") + return self + + def __init__( + self, + regexp: Union[str, Pattern], + groups: dict[str, int], + default_values: dict[str, Any] = None, + string_format: str = None, + flags: int = 0, + ): + init_dict = { + "regexp": re.compile(regexp, flags), + "groups": groups, + "default_values": default_values, + "string_format": string_format, + } + empty_fields = set() + for k, v in init_dict.items(): + if k not in self.model_fields: + raise NotImplementedError("Init method contains a field not in model fields.") + if v is None: + empty_fields.add(k) + for field in empty_fields: + del init_dict[field] + super().__init__(**init_dict) + + async def get_value(self, ctx: Context) -> ExtractedGroupSlot: + request_text = ctx.last_request.text + search = re.search(self.regexp, request_text) + if search: + return ExtractedGroupSlot( + string_format=self.string_format, + **{ + child_name: ExtractedValueSlot.model_construct( + is_slot_extracted=True, + extracted_value=search.group(match_group), + ) + for child_name, match_group in zip(self.groups.keys(), self.groups.values()) + }, + ) + else: + return ExtractedGroupSlot( + string_format=self.string_format, + **{ + child_name: ExtractedValueSlot.model_construct( + is_slot_extracted=False, + extracted_value=SlotNotExtracted( + f"Failed to match pattern {self.regexp!r} in {request_text!r}." + ), + default_value=self.default_values.get(child_name, None), + ) + for child_name in self.groups.keys() + }, + ) + + def init_value(self) -> ExtractedGroupSlot: + return ExtractedGroupSlot( + string_format=self.string_format, + **{ + child_name: RegexpSlot(regexp=self.regexp, match_group_id=match_group).init_value() + for child_name, match_group in self.groups.items() + }, + ) + + +class FunctionSlot(ValueSlot, frozen=True): + """ + A simpler version of :py:class:`~.ValueSlot`. + + Uses a user-defined `func` to extract slot value from the :py:attr:`~.Context.last_request` Message. + """ + + func: Callable[[Message], Union[Awaitable[Union[Any, SlotNotExtracted]], Any, SlotNotExtracted]] + + async def extract_value(self, ctx: Context) -> Union[Any, SlotNotExtracted]: + return await wrap_sync_function_in_async(self.func, ctx.last_request) diff --git a/tests/slots/conftest.py b/tests/slots/conftest.py index 9cdcc4eec..c19d34519 100644 --- a/tests/slots/conftest.py +++ b/tests/slots/conftest.py @@ -1,7 +1,7 @@ import pytest from chatsky.core import Message, TRANSITIONS, RESPONSE, Context, Pipeline, Transition as Tr -from chatsky.slots.slots import SlotNotExtracted +from chatsky.slots.base_slots import SlotNotExtracted @pytest.fixture(scope="function", autouse=True) diff --git a/tests/slots/test_slot_functions.py b/tests/slots/test_slot_functions.py index 83c6b0171..5961869bf 100644 --- a/tests/slots/test_slot_functions.py +++ b/tests/slots/test_slot_functions.py @@ -6,10 +6,11 @@ from chatsky import Context from chatsky.core import BaseResponse, Node from chatsky.core.message import MessageInitTypes, Message -from chatsky.slots.slots import ValueSlot, SlotNotExtracted, GroupSlot, SlotManager +from chatsky.slots.base_slots import ValueSlot, SlotNotExtracted, GroupSlot +from chatsky.slots.slot_manager import SlotManager from chatsky import conditions as cnd, responses as rsp, processing as proc from chatsky.processing.slots import logger as proc_logger -from chatsky.slots.slots import logger as slot_logger +from chatsky.slots.base_slots import logger as slot_logger from chatsky.responses.slots import logger as rsp_logger @@ -59,15 +60,19 @@ async def test_basic_functions(context, manager, log_event_catcher): await proc.Extract("0", "2", "err").wrapped_call(context) assert manager.get_extracted_slot("0").value == 4 + assert await cnd.SlotValueEquals("0", 5).wrapped_call(context) is False + assert await cnd.SlotValueEquals("0", 4).wrapped_call(context) is True assert manager.is_slot_extracted("1") is False assert isinstance(manager.get_extracted_slot("err").extracted_value, SlotNotExtracted) proc_logs = log_event_catcher(proc_logger, level=logging.ERROR) slot_logs = log_event_catcher(slot_logger, level=logging.ERROR) - await proc.Extract("0", "2", "err", success_only=False).wrapped_call(context) + await proc.Extract("0", "2", "err", save_on_failure=False).wrapped_call(context) assert manager.get_extracted_slot("0").value == 4 + assert await cnd.SlotValueEquals("0", 5).wrapped_call(context) is False + assert await cnd.SlotValueEquals("0", 4).wrapped_call(context) is True assert manager.is_slot_extracted("1") is False assert isinstance(manager.get_extracted_slot("err").extracted_value, RuntimeError) diff --git a/tests/slots/test_slot_manager.py b/tests/slots/test_slot_manager.py index 33885b360..4694be3d2 100644 --- a/tests/slots/test_slot_manager.py +++ b/tests/slots/test_slot_manager.py @@ -1,14 +1,13 @@ import pytest -from chatsky.slots.slots import ( - SlotManager, - RegexpSlot, +from chatsky.slots.base_slots import ( GroupSlot, - FunctionSlot, ExtractedGroupSlot, ExtractedValueSlot, SlotNotExtracted, ) +from chatsky.slots.standard_slots import RegexpSlot, FunctionSlot +from chatsky.slots.slot_manager import SlotManager from chatsky.core import Message, Context @@ -174,7 +173,7 @@ def test_get_slot_by_name(empty_slot_manager): ], ) async def test_slot_extraction(slot_name, expected_slot_storage, empty_slot_manager, context_with_request): - await empty_slot_manager.extract_slot(slot_name, context_with_request, success_only=False) + await empty_slot_manager.extract_slot(slot_name, context_with_request, save_on_failure=False) assert empty_slot_manager.slot_storage == expected_slot_storage @@ -206,7 +205,7 @@ async def test_slot_extraction(slot_name, expected_slot_storage, empty_slot_mana ], ) async def test_successful_extraction(slot_name, expected_slot_storage, empty_slot_manager, context_with_request): - await empty_slot_manager.extract_slot(slot_name, context_with_request, success_only=True) + await empty_slot_manager.extract_slot(slot_name, context_with_request, save_on_failure=True) assert empty_slot_manager.slot_storage == expected_slot_storage diff --git a/tests/slots/test_slot_partial_extraction.py b/tests/slots/test_slot_partial_extraction.py index 234287c17..2d8c22a62 100644 --- a/tests/slots/test_slot_partial_extraction.py +++ b/tests/slots/test_slot_partial_extraction.py @@ -1,5 +1,5 @@ from chatsky.slots import RegexpSlot, GroupSlot -from chatsky.slots.slots import SlotManager +from chatsky.slots.slot_manager import SlotManager from chatsky.core import Message import pytest @@ -65,20 +65,20 @@ def get_extracted_slots(manager: SlotManager): [("1 2 3", ["1", "2"]), ("1 3 5", ["1", "5"]), ("3 4 5 6", ["3", "4", "5", "6"])], ) async def test_partial_extraction(message, extracted, context_with_request, empty_slot_manager): - await empty_slot_manager.extract_slot("root_slot", context_with_request(message), success_only=False) + await empty_slot_manager.extract_slot("root_slot", context_with_request(message), save_on_failure=False) assert extracted == get_extracted_slots(empty_slot_manager) async def test_slot_storage_update(context_with_request, empty_slot_manager): - await empty_slot_manager.extract_slot("root_slot", context_with_request("1 3 5"), success_only=False) + await empty_slot_manager.extract_slot("root_slot", context_with_request("1 3 5"), save_on_failure=False) assert get_extracted_slots(empty_slot_manager) == ["1", "5"] - await empty_slot_manager.extract_slot("root_slot", context_with_request("2 4 6"), success_only=False) + await empty_slot_manager.extract_slot("root_slot", context_with_request("2 4 6"), save_on_failure=False) assert get_extracted_slots(empty_slot_manager) == ["1", "2", "5", "6"] - await empty_slot_manager.extract_slot("root_slot.nested_group", context_with_request("3 4"), success_only=False) + await empty_slot_manager.extract_slot("root_slot.nested_group", context_with_request("3 4"), save_on_failure=False) assert get_extracted_slots(empty_slot_manager) == ["1", "2", "3", "4", "5", "6"] diff --git a/tests/slots/test_slot_types.py b/tests/slots/test_slot_types.py index a21cbd896..dbbddb951 100644 --- a/tests/slots/test_slot_types.py +++ b/tests/slots/test_slot_types.py @@ -1,15 +1,11 @@ +import re + import pytest from pydantic import ValidationError from chatsky.core import Message -from chatsky.slots.slots import ( - RegexpSlot, - GroupSlot, - FunctionSlot, - SlotNotExtracted, - ExtractedValueSlot, - ExtractedGroupSlot, -) +from chatsky.slots.base_slots import SlotNotExtracted, ExtractedValueSlot, GroupSlot, ExtractedGroupSlot +from chatsky.slots.standard_slots import RegexpSlot, FunctionSlot, RegexpGroupSlot @pytest.mark.parametrize( @@ -131,6 +127,96 @@ async def test_group_slot_extraction(user_request, slot, expected, is_extracted, assert result.__slot_extracted__ == is_extracted +@pytest.mark.parametrize( + ("user_request", "slot", "expected", "is_extracted"), + [ + ( + Message(text="I am Bot. I have a colleague, his name is Carl."), + RegexpGroupSlot( + regexp=r"am (.+?)\..*name is (.+?)\.", + groups={"name_1": 1, "name_2": 2}, + ), + ExtractedGroupSlot( + name_1=ExtractedValueSlot.model_construct( + is_slot_extracted=True, extracted_value="Bot", default_value=None + ), + name_2=ExtractedValueSlot.model_construct( + is_slot_extracted=True, extracted_value="Carl", default_value=None + ), + ), + True, + ), + ( + Message(text="I am Bot. I won't tell you my email"), + RegexpGroupSlot( + regexp=r"am (.+?)\..*email is (.+?)\.", + groups={"name": 1, "email": 2}, + ), + ExtractedGroupSlot( + name=ExtractedValueSlot.model_construct( + is_slot_extracted=False, + extracted_value=SlotNotExtracted( + "Failed to match pattern {regexp!r} in {request_text!r}.".format( + regexp=re.compile(r"am (.+?)\..*email is (.+?)\."), + request_text="I am Bot. I won't tell you my email", + ) + ), + default_value=None, + ), + email=ExtractedValueSlot.model_construct( + is_slot_extracted=False, + extracted_value=SlotNotExtracted( + "Failed to match pattern {regexp!r} in {request_text!r}.".format( + regexp=re.compile(r"am (.+?)\..*email is (.+?)\."), + request_text="I am Bot. I won't tell you my email", + ) + ), + default_value=None, + ), + ), + False, + ), + ], +) +async def test_regex_group_slot_extraction(user_request, slot, expected, is_extracted, context): + context.add_request(user_request) + result = await slot.get_value(context) + assert result == expected + assert result.__slot_extracted__ == is_extracted + + +@pytest.mark.parametrize( + ("user_request", "slot", "expected_str_format", "is_extracted"), + [ + ( + Message(text="I am Bot. My email is bot@bot"), + GroupSlot( + string_format="Your name is {name}. Your email is {email}.", + name=RegexpSlot(regexp=r"(?<=am ).+?(?=\.)"), + email=RegexpSlot(regexp=r"[a-zA-Z\.]+@[a-zA-Z\.]+"), + ), + "Your name is Bot. Your email is bot@bot.", + True, + ), + ( + Message(text="I am Bot. I won't tell you my email"), + GroupSlot( + string_format="Your name is {name}. Your email is {email}.", + name=RegexpSlot(regexp=r"(?<=am ).+?(?=\.)"), + email=RegexpSlot(regexp=r"[a-zA-Z\.]+@[a-zA-Z\.]+"), + ), + "Your name is Bot. Your email is None.", + False, + ), + ], +) +async def test_group_slot_string_format(user_request, slot, expected_str_format, is_extracted, context): + context.add_request(user_request) + result = await slot.get_value(context) + assert str(result) == expected_str_format + assert result.__slot_extracted__ == is_extracted + + @pytest.mark.parametrize("forbidden_name", ["__dunder__", "contains.dot"]) def test_group_subslot_name_validation(forbidden_name): with pytest.raises(ValidationError): diff --git a/tutorials/slots/1_basic_example.py b/tutorials/slots/1_basic_example.py index dcddfbade..3573f6a5c 100644 --- a/tutorials/slots/1_basic_example.py +++ b/tutorials/slots/1_basic_example.py @@ -46,9 +46,9 @@ Currently there are two types of value slots: -- %mddoclink(api,slots.slots,RegexpSlot): +- %mddoclink(api,slots.standard_slots,RegexpSlot): Extracts slot values via regexp. -- %mddoclink(api,slots.slots,FunctionSlot): +- %mddoclink(api,slots.standard_slots,FunctionSlot): Extracts slot values with the help of a user-defined function. """ @@ -76,6 +76,9 @@ - %mddoclink(api,conditions.slots,SlotsExtracted): Condition for checking if specified slots are extracted. +- %mddoclink(api,conditions.slots,SlotValueEquals): + Condition for checking if the specified slots' value + is equal to a given value. - %mddoclink(api,processing.slots,Extract): A processing function that extracts specified slots. - %mddoclink(api,processing.slots,Unset): diff --git a/tutorials/slots/2_regexgroupslot_and_string_format.py b/tutorials/slots/2_regexgroupslot_and_string_format.py new file mode 100644 index 000000000..984bf62e2 --- /dev/null +++ b/tutorials/slots/2_regexgroupslot_and_string_format.py @@ -0,0 +1,158 @@ +# %% [markdown] +""" +# 2. `RegexpGroupSlot` and `string_format` + +The following tutorial shows basic usage of `RegexpGroupSlot` +and `string format` feature of GroupSlot. +""" + +# %pip install chatsky + +# %% +import re + +from chatsky import ( + RESPONSE, + TRANSITIONS, + PRE_TRANSITION, + GLOBAL, + Pipeline, + Transition as Tr, + conditions as cnd, + processing as proc, + responses as rsp, + destinations as dst, +) + +from chatsky.slots import RegexpSlot, RegexpGroupSlot + +from chatsky.utils.testing import ( + check_happy_path, + is_interactive_mode, +) + +# %% [markdown] +""" +## RegexpGroupSlot extraction + +The `RegexpGroupSlot` is a slot type that reuses one regex.search() call for +several slots to save on execution time in specific cases like LLM, where the +amount of get_value() calls is important. + +## RegexpGroupSlot arguments + +* `regexp` - the regular expression to match with the `ctx.last_request.text`. +* `groups` - a dictionary mapping slot names to group indexes, where numbers + mean the index of the capture group that was found with `re.search()` + (like `match_group` in RegexpSlot). +* `default_values` - a dictionary with default values for each slot name in + case the regexp search fails. + +The `RegexpGroupSlot` class is derived from `GroupSlot` class, inheriting +its `string_format()` feature. + +## `string_format` usage + +You can set `string_format` to change the `__str__` representation of the +`ExtractedValueSlot`. `string_format` can be set to a string, which will +be formatted with Python's `str.format()`, using extracted slots names and +their values as keyword arguments. + +The reason this exists at all is so you don't have to specify each and every +child slot name and can now represent a GroupSlot in a pre-determined way. + +Below are some examples of `RegexpGroupSlot` and `string_format `use: +""" + +# %% +SLOTS = { + "date": RegexpGroupSlot( + regexp=r"(0?[1-9]|(?:1|2)[0-9]|3[0-1])[\.\/]" + r"(0?[1-9]|1[0-2])[\.\/](\d{4}|\d{2})", + groups={"day": 1, "month": 2, "year": 3}, + string_format="{day}/{month}/{year}", + ), + "email": RegexpSlot( + regexp=r"[\w\.-]+@[\w\.-]+\.\w{2,4}", + ), +} + +script = { + GLOBAL: { + TRANSITIONS: [ + Tr( + dst=("date_and_email_flow", "ask_email"), + cnd=cnd.Regexp(r"^[sS]tart"), + ), + ] + }, + "date_and_email_flow": { + "start": { + TRANSITIONS: [Tr(dst=("date_and_email_flow", "ask_email"))], + }, + "fallback": { + RESPONSE: "Finishing query", + TRANSITIONS: [ + Tr( + dst=dst.Backward(), + cnd=cnd.Regexp(r"back", flags=re.IGNORECASE), + ), + Tr(dst=("date_and_email_flow", "ask_email"), priority=0.8), + ], + }, + "ask_email": { + RESPONSE: "Write your email (my email is ...):", + PRE_TRANSITION: {"get_slot": proc.Extract("email")}, + TRANSITIONS: [ + Tr( + dst="ask_date", + cnd=cnd.SlotsExtracted("email"), + ) + ], + }, + "ask_date": { + RESPONSE: "Write your date of birth:", + PRE_TRANSITION: {"get_slot": proc.Extract("date")}, + TRANSITIONS: [ + Tr( + dst="answer_node", + cnd=cnd.SlotsExtracted("date"), + ) + ], + }, + "answer_node": { + RESPONSE: rsp.FilledTemplate( + "Your date of birth is {date}, email is {email}" + ) + }, + }, +} + +# %% +HAPPY_PATH = [ + ("hi", "Write your email (my email is ...):"), + ("my email is groot@gmail.com", "Write your date of birth:"), + ( + "my date of birth is 06/10/1984", + "Your date of birth is 06/10/1984, email is groot@gmail.com", + ), + ("ok", "Finishing query"), + ("start", "Write your email (my email is ...):"), +] + +# %% +pipeline = Pipeline( + script=script, + start_label=("date_and_email_flow", "start"), + fallback_label=("date_and_email_flow", "fallback"), + slots=SLOTS, +) + +if __name__ == "__main__": + check_happy_path( + pipeline, HAPPY_PATH, printout=True + ) # This is a function for automatic tutorial running + # (testing) with HAPPY_PATH + + if is_interactive_mode(): + pipeline.run() diff --git a/tutorials/slots/2_partial_extraction.py b/tutorials/slots/3_partial_extraction.py similarity index 70% rename from tutorials/slots/2_partial_extraction.py rename to tutorials/slots/3_partial_extraction.py index 87cc4bab4..7f4d7c4fa 100644 --- a/tutorials/slots/2_partial_extraction.py +++ b/tutorials/slots/3_partial_extraction.py @@ -1,6 +1,6 @@ # %% [markdown] """ -# 2. Partial slot extraction +# 3. Partial slot extraction This tutorial shows advanced options for slot extraction allowing to extract only some of the slots. @@ -41,23 +41,23 @@ ## Success only extraction -The `Extract` function accepts `success_only` flag which makes it so +The `Extract` function accepts `save_on_failure` flag which makes it so that extracted value is not saved unless extraction is successful. This means that unsuccessfully trying to extract a slot will not overwrite its previously extracted value. -Note that `success_only` is `True` by default. +Note that `save_on_failure` is `True` by default. ## Partial group slot extraction A group slot marked with `allow_partial_extraction` only saves values of successfully extracted child slots. Extracting such group slot is equivalent to extracting every child slot -with the `success_only` flag. +with the `save_on_failure` flag. Partially extracted group slot is always considered successfully extracted -for the purposes of the `success_only` flag. +for the purposes of the `save_on_failure` flag. ## Code explanation @@ -68,16 +68,17 @@ Any slot in this group is saved if and only if that slot is successfully extracted. -Group `success_only_extraction` is extracted with the `success_only` +Group `save_on_failure_extraction` is extracted with the `save_on_failure` flag set to True. Any slot in this group is saved if and only if all of the slots in the group are successfully extracted within a single `Extract` call. -Group `success_only_false` is extracted with the `success_only` set to False. +Group `save_on_failure_false` is extracted with the `save_on_failure` +flag set to False. Every slot in this group is saved (even if extraction was not successful). -Group `sub_slot_success_only_extraction` is extracted by passing all of its -child slots to the `Extract` method with the `success_only` flag set to True. +Group `sub_slot_save_on_failure_extraction` is extracted by passing all of its +child slots to the `Extract` method with the `save_on_failure` flag set to True. The behavior is equivalent to that of `partial_extraction`. """ @@ -97,13 +98,13 @@ **sub_slots, allow_partial_extraction=True, ), - "success_only_extraction": GroupSlot( + "save_on_failure_extraction": GroupSlot( **sub_slots, ), - "success_only_false": GroupSlot( + "save_on_failure_false": GroupSlot( **sub_slots, ), - "sub_slot_success_only_extraction": GroupSlot( + "sub_slot_save_on_failure_extraction": GroupSlot( **sub_slots, ), } @@ -126,30 +127,30 @@ PRE_RESPONSE: { "partial_extraction": proc.Extract("partial_extraction"), # partial extraction is always successful; - # success_only doesn't matter - "success_only_extraction": proc.Extract( - "success_only_extraction", success_only=True + # save_on_failure doesn't matter + "save_on_failure`_extraction": proc.Extract( + "save_on_failure_extraction", save_on_failure=True ), - # success_only is True by default - "success_only_false": proc.Extract( - "success_only_false", success_only=False + # save_on_failure is True by default + "save_on_failure_false": proc.Extract( + "save_on_failure_false", save_on_failure=False ), - "sub_slot_success_only_extraction": proc.Extract( - "sub_slot_success_only_extraction.email", - "sub_slot_success_only_extraction.date", - success_only=True, + "sub_slot_save_on_failure_extraction": proc.Extract( + "sub_slot_save_on_failure_extraction.email", + "sub_slot_save_on_failure_extraction.date", + save_on_failure=True, ), }, RESPONSE: rsp.FilledTemplate( "Extracted slots:\n" " Group with partial extraction:\n" " {partial_extraction}\n" - " Group with success_only:\n" - " {success_only_extraction}\n" - " Group without success_only:\n" - " {success_only_false}\n" - " Extracting sub-slots with success_only:\n" - " {sub_slot_success_only_extraction}" + " Group with save_on_failure:\n" + " {save_on_failure_extraction}\n" + " Group without save_on_failure:\n" + " {save_on_failure_false}\n" + " Extracting sub-slots with save_on_failure:\n" + " {sub_slot_save_on_failure_extraction}" ), }, }, @@ -162,11 +163,11 @@ "Extracted slots:\n" " Group with partial extraction:\n" " {'date': 'None', 'email': 'email@email.com'}\n" - " Group with success_only:\n" + " Group with save_on_failure:\n" " {'date': 'None', 'email': 'None'}\n" - " Group without success_only:\n" + " Group without save_on_failure:\n" " {'date': 'None', 'email': 'email@email.com'}\n" - " Extracting sub-slots with success_only:\n" + " Extracting sub-slots with save_on_failure:\n" " {'date': 'None', 'email': 'email@email.com'}", ), ( @@ -174,11 +175,11 @@ "Extracted slots:\n" " Group with partial extraction:\n" " {'date': '01.01.2024', 'email': 'email@email.com'}\n" - " Group with success_only:\n" + " Group with save_on_failure:\n" " {'date': 'None', 'email': 'None'}\n" - " Group without success_only:\n" + " Group without save_on_failure:\n" " {'date': '01.01.2024', 'email': 'None'}\n" - " Extracting sub-slots with success_only:\n" + " Extracting sub-slots with save_on_failure:\n" " {'date': '01.01.2024', 'email': 'email@email.com'}", ), ( @@ -186,11 +187,11 @@ "Extracted slots:\n" " Group with partial extraction:\n" " {'date': '02.01.2024', 'email': 'another_email@email.com'}\n" - " Group with success_only:\n" + " Group with save_on_failure:\n" " {'date': '02.01.2024', 'email': 'another_email@email.com'}\n" - " Group without success_only:\n" + " Group without save_on_failure:\n" " {'date': '02.01.2024', 'email': 'another_email@email.com'}\n" - " Extracting sub-slots with success_only:\n" + " Extracting sub-slots with save_on_failure:\n" " {'date': '02.01.2024', 'email': 'another_email@email.com'}", ), ( @@ -198,11 +199,11 @@ "Extracted slots:\n" " Group with partial extraction:\n" " {'date': '03.01.2024', 'email': 'another_email@email.com'}\n" - " Group with success_only:\n" + " Group with save_on_failure:\n" " {'date': '02.01.2024', 'email': 'another_email@email.com'}\n" - " Group without success_only:\n" + " Group without save_on_failure:\n" " {'date': '03.01.2024', 'email': 'None'}\n" - " Extracting sub-slots with success_only:\n" + " Extracting sub-slots with save_on_failure:\n" " {'date': '03.01.2024', 'email': 'another_email@email.com'}", ), ( @@ -210,11 +211,11 @@ "Extracted slots:\n" " Group with partial extraction:\n" " {'date': '03.01.2024', 'email': 'another_email@email.com'}\n" - " Group with success_only:\n" + " Group with save_on_failure:\n" " {'date': '02.01.2024', 'email': 'another_email@email.com'}\n" - " Group without success_only:\n" + " Group without save_on_failure:\n" " {'date': 'None', 'email': 'None'}\n" - " Extracting sub-slots with success_only:\n" + " Extracting sub-slots with save_on_failure:\n" " {'date': '03.01.2024', 'email': 'another_email@email.com'}", ), ]