diff --git a/poetry.lock b/poetry.lock index 611d643..1b869e2 100644 --- a/poetry.lock +++ b/poetry.lock @@ -4805,6 +4805,72 @@ typing-extensions = "*" [package.extras] dev = ["black", "build", "flake8", "flake8-black", "isort", "jupyter-console", "mkdocs", "mkdocs-include-markdown-plugin", "mkdocstrings[python]", "pytest", "pytest-asyncio", "pytest-trio", "sphinx", "toml", "tox", "trio", "trio", "trio-typing", "twine", "twisted", "validate-pyproject[all]"] +[[package]] +name = "pygame" +version = "2.5.2" +description = "Python Game Development" +optional = false +python-versions = ">=3.6" +files = [ + {file = "pygame-2.5.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a0769eb628c818761755eb0a0ca8216b95270ea8cbcbc82227e39ac9644643da"}, + {file = "pygame-2.5.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ed9a3d98adafa0805ccbaaff5d2996a2b5795381285d8437a4a5d248dbd12b4a"}, + {file = "pygame-2.5.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f30d1618672a55e8c6669281ba264464b3ab563158e40d89e8c8b3faa0febebd"}, + {file = "pygame-2.5.2-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:39690e9be9baf58b7359d1f3b2336e1fd6f92fedbbce42987be5df27f8d30718"}, + {file = "pygame-2.5.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:03879ec299c9f4ba23901b2649a96b2143f0a5d787f0b6c39469989e2320caf1"}, + {file = "pygame-2.5.2-cp310-cp310-win32.whl", hash = "sha256:74e1d6284100e294f445832e6f6343be4fe4748decc4f8a51131ae197dae8584"}, + {file = "pygame-2.5.2-cp310-cp310-win_amd64.whl", hash = "sha256:485239c7d32265fd35b76ae8f64f34b0637ae11e69d76de15710c4b9edcc7c8d"}, + {file = "pygame-2.5.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:34646ca20e163dc6f6cf8170f1e12a2e41726780112594ac061fa448cf7ccd75"}, + {file = "pygame-2.5.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3b8a6e351665ed26ea791f0e1fd649d3f483e8681892caef9d471f488f9ea5ee"}, + {file = "pygame-2.5.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dc346965847aef00013fa2364f41a64f068cd096dcc7778fc306ca3735f0eedf"}, + {file = "pygame-2.5.2-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:35632035fd81261f2d797fa810ea8c46111bd78ceb6089d52b61ed7dc3c5d05f"}, + {file = "pygame-2.5.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0e24d05184e4195fe5ebcdce8b18ecb086f00182b9ae460a86682d312ce8d31f"}, + {file = "pygame-2.5.2-cp311-cp311-win32.whl", hash = "sha256:f02c1c7505af18d426d355ac9872bd5c916b27f7b0fe224749930662bea47a50"}, + {file = "pygame-2.5.2-cp311-cp311-win_amd64.whl", hash = "sha256:6d58c8cf937815d3b7cdc0fa9590c5129cb2c9658b72d00e8a4568dea2ff1d42"}, + {file = "pygame-2.5.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:1a2a43802bb5e89ce2b3b775744e78db4f9a201bf8d059b946c61722840ceea8"}, + {file = "pygame-2.5.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:1c289f2613c44fe70a1e40769de4a49c5ab5a29b9376f1692bb1a15c9c1c9bfa"}, + {file = "pygame-2.5.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:074aa6c6e110c925f7f27f00c7733c6303407edc61d738882985091d1eb2ef17"}, + {file = "pygame-2.5.2-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fe0228501ec616779a0b9c4299e837877783e18df294dd690b9ab0eed3d8aaab"}, + {file = "pygame-2.5.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:31648d38ecdc2335ffc0e38fb18a84b3339730521505dac68514f83a1092e3f4"}, + {file = "pygame-2.5.2-cp312-cp312-win32.whl", hash = "sha256:224c308856334bc792f696e9278e50d099a87c116f7fc314cd6aa3ff99d21592"}, + {file = "pygame-2.5.2-cp312-cp312-win_amd64.whl", hash = "sha256:dd2d2650faf54f9a0f5bd0db8409f79609319725f8f08af6507a0609deadcad4"}, + {file = "pygame-2.5.2-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:9b30bc1220c457169571aac998e54b013aaeb732d2fd8744966cb1cfab1f61d1"}, + {file = "pygame-2.5.2-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:78fcd7643358b886a44127ff7dec9041c056c212b3a98977674f83f99e9b12d3"}, + {file = "pygame-2.5.2-cp36-cp36m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:35cf093a51cb294ede56c29d4acf41538c00f297fcf78a9b186fb7d23c0577b6"}, + {file = "pygame-2.5.2-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6fe323acbf53a0195c8c98b1b941eba7ac24e3e2b28ae48e8cda566f15fc4945"}, + {file = "pygame-2.5.2-cp36-cp36m-win32.whl", hash = "sha256:5697528266b4716d9cdd44a5a1d210f4d86ef801d0f64ca5da5d0816704009d9"}, + {file = "pygame-2.5.2-cp36-cp36m-win_amd64.whl", hash = "sha256:edda1f7cff4806a4fa39e0e8ccd75f38d1d340fa5fc52d8582ade87aca247d92"}, + {file = "pygame-2.5.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:9bd738fd4ecc224769d0b4a719f96900a86578e26e0105193658a32966df2aae"}, + {file = "pygame-2.5.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:30a8d7cf12363b4140bf2f93b5eec4028376ca1d0fe4b550588f836279485308"}, + {file = "pygame-2.5.2-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bc12e4dea3e88ea8a553de6d56a37b704dbe2aed95105889f6afeb4b96e62097"}, + {file = "pygame-2.5.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2b34c73cb328024f8db3cb6487a37e54000148988275d8d6e5adf99d9323c937"}, + {file = "pygame-2.5.2-cp37-cp37m-win32.whl", hash = "sha256:7d0a2794649defa57ef50b096a99f7113d3d0c2e32d1426cafa7d618eadce4c7"}, + {file = "pygame-2.5.2-cp37-cp37m-win_amd64.whl", hash = "sha256:41f8779f52e0f6e6e6ccb8f0b5536e432bf386ee29c721a1c22cada7767b0cef"}, + {file = "pygame-2.5.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:677e37bc0ea7afd89dde5a88ced4458aa8656159c70a576eea68b5622ee1997b"}, + {file = "pygame-2.5.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:47a8415d2bd60e6909823b5643a1d4ef5cc29417d817f2a214b255f6fa3a1e4c"}, + {file = "pygame-2.5.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4ff21201df6278b8ca2e948fb148ffe88f5481fd03760f381dd61e45954c7dff"}, + {file = "pygame-2.5.2-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d29a84b2e02814b9ba925357fd2e1df78efe5e1aa64dc3051eaed95d2b96eafd"}, + {file = "pygame-2.5.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d78485c4d21133d6b2fbb504cd544ca655e50b6eb551d2995b3aa6035928adda"}, + {file = "pygame-2.5.2-cp38-cp38-win32.whl", hash = "sha256:d851247239548aa357c4a6840fb67adc2d570ce7cb56988d036a723d26b48bff"}, + {file = "pygame-2.5.2-cp38-cp38-win_amd64.whl", hash = "sha256:88d1cdacc2d3471eceab98bf0c93c14d3a8461f93e58e3d926f20d4de3a75554"}, + {file = "pygame-2.5.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:4f1559e7efe4efb9dc19d2d811d702f325d9605f9f6f9ececa39ee6890c798f5"}, + {file = "pygame-2.5.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:cf2191b756ceb0e8458a761d0c665b0c70b538570449e0d39b75a5ba94ac5cf0"}, + {file = "pygame-2.5.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6cf2257447ce7f2d6de37e5fb019d2bbe32ed05a5721ace8bc78c2d9beaf3aee"}, + {file = "pygame-2.5.2-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d75cbbfaba2b81434d62631d0b08b85fab16cf4a36e40b80298d3868927e1299"}, + {file = "pygame-2.5.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:daca456d5b9f52e088e06a127dec182b3638a775684fb2260f25d664351cf1ae"}, + {file = "pygame-2.5.2-cp39-cp39-win32.whl", hash = "sha256:3b3e619e33d11c297d7a57a82db40681f9c2c3ae1d5bf06003520b4fe30c435d"}, + {file = "pygame-2.5.2-cp39-cp39-win_amd64.whl", hash = "sha256:1822d534bb7fe756804647b6da2c9ea5d7a62d8796b2e15d172d3be085de28c6"}, + {file = "pygame-2.5.2-pp36-pypy36_pp73-win32.whl", hash = "sha256:e708fc8f709a0fe1d1876489345f2e443d47f3976d33455e2e1e937f972f8677"}, + {file = "pygame-2.5.2-pp37-pypy37_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c13edebc43c240fb0532969e914f0ccefff5ae7e50b0b788d08ad2c15ef793e4"}, + {file = "pygame-2.5.2-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:263b4a7cbfc9fe2055abc21b0251cc17dea6dff750f0e1c598919ff350cdbffe"}, + {file = "pygame-2.5.2-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:e58e2b0c791041e4bccafa5bd7650623ba1592b8fe62ae0a276b7d0ecb314b6c"}, + {file = "pygame-2.5.2-pp38-pypy38_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a0bd67426c02ffe6c9827fc4bcbda9442fbc451d29b17c83a3c088c56fef2c90"}, + {file = "pygame-2.5.2-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9dcff6cbba1584cf7732ce1dbdd044406cd4f6e296d13bcb7fba963fb4aeefc9"}, + {file = "pygame-2.5.2-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:ce4b6c0bfe44d00bb0998a6517bd0cf9455f642f30f91bc671ad41c05bf6f6ae"}, + {file = "pygame-2.5.2-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:68c4e8e60b725ffc7a6c6ecd9bb5fcc5ed2d6e0e2a2c4a29a8454856ef16ad63"}, + {file = "pygame-2.5.2-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1f3849f97372a3381c66955f99a0d58485ccd513c3d00c030b869094ce6997a6"}, + {file = "pygame-2.5.2.tar.gz", hash = "sha256:c1b89eb5d539e7ac5cf75513125fb5f2f0a2d918b1fd6e981f23bf0ac1b1c24a"}, +] + [[package]] name = "pygments" version = "2.17.2" @@ -7301,4 +7367,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.10,<4.0" -content-hash = "bf741a7361c0684fc8911d8ae4115e1aa5aa1e149d9abcb827a83bbcaabb8e8e" +content-hash = "9d80ddefc39388a05bb2804c2df6d57c09834376f7357e50e2deedf9c2d059ae" diff --git a/pyproject.toml b/pyproject.toml index 44b6e82..bff020e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] # https://python-poetry.org/docs/pyproject/ name = "dspygen" -version = "2024.4.17" +version = "2024.4.26" description = "A Ruby on Rails style framework for the DSPy (Demonstrate, Search, Predict) project for Language Models like GPT, BERT, and LLama." authors = ["Sean Chatman "] readme = "README.md" @@ -59,6 +59,7 @@ chromadb = "^0.4.24" anyio = "^4.3.0" docutils = "0.21" transitions = "^0.9.0" +pygame = "^2.5.2" [tool.poetry.group.test.dependencies] # https://python-poetry.org/docs/master/managing-dependencies/ coverage = { extras = ["toml"], version = ">=7.2.5" } diff --git a/src/dspygen/modules/gen_pydantic_instance_module.py b/src/dspygen/modules/gen_pydantic_instance_module.py new file mode 100644 index 0000000..8651e5a --- /dev/null +++ b/src/dspygen/modules/gen_pydantic_instance_module.py @@ -0,0 +1,151 @@ +import ast +import dspy +import inspect +import logging +from typing import Optional, TypeVar, cast + +from pydantic import BaseModel, ValidationError + +from dspy import Assert, ChainOfThought, InputField, OutputField, Signature + +logger = logging.getLogger(__name__) +logger.setLevel(logging.ERROR) + + +def eval_dict_str(dict_str: str) -> dict: + """Safely convert str to dict""" + return ast.literal_eval(dict_str) + + +class PromptToPydanticInstanceSignature(Signature): + """Synthesize the prompt into the kwargs to fit the model. + Do not duplicate the field descriptions + """ + + root_pydantic_model_class_name = InputField( + desc="The class name of the pydantic model to receive the kwargs" + ) + pydantic_model_definitions = InputField( + desc="Pydantic model class definitions as a string" + ) + prompt = InputField( + desc="The prompt to be synthesized into data. Do not duplicate descriptions" + ) + root_model_kwargs_dict = OutputField( + prefix="kwargs_dict: dict = ", + desc="Generate a Python dictionary as a string with minimized whitespace that only contains json valid values.", + ) + + +class PromptToPydanticInstanceErrorSignature(Signature): + """Synthesize the prompt into the kwargs fit the model""" + + error = InputField(desc="Error message to fix the kwargs") + + root_pydantic_model_class_name = InputField( + desc="The class name of the pydantic model to receive the kwargs" + ) + pydantic_model_definitions = InputField( + desc="Pydantic model class definitions as a string" + ) + prompt = InputField(desc="The prompt to be synthesized into data") + root_model_kwargs_dict = OutputField( + prefix="kwargs_dict = ", + desc="Generate a Python dictionary as a string with minimized whitespace that only contains json valid values.", + ) + + +T = TypeVar("T", bound=BaseModel) + + +class GenPydanticInstance(dspy.Module): + """A module for generating and validating Pydantic model instances based on prompts. + + Usage: + To use this module, instantiate the GenPydanticInstance class with the desired + root Pydantic model and optional child models. Then, call the `forward` method + with a prompt to generate Pydantic model instances based on the provided prompt. + """ + + def __init__( + self, + root_model: type[T], + child_models: Optional[list[type[BaseModel]]] = None, + generate_sig=PromptToPydanticInstanceSignature, + correct_generate_sig=PromptToPydanticInstanceErrorSignature, + ): + super().__init__() + + self.models = [root_model] # Always include root_model in models list + + if child_models: + self.models.extend(child_models) + + self.output_key = "root_model_kwargs_dict" + self.root_model = root_model + + # Concatenate source code of models for use in generation/correction logic + self.model_sources = "\n".join( + [inspect.getsource(model) for model in self.models] + ) + + # Initialize DSPy ChainOfThought modules for generation and correction + self.generate = ChainOfThought(generate_sig) + self.correct_generate = ChainOfThought(correct_generate_sig) + self.validation_error = None + + def validate_root_model(self, output: str) -> bool: + """Validates whether the generated output conforms to the root Pydantic model.""" + try: + model_inst = self.root_model.model_validate(eval_dict_str(output)) + return isinstance(model_inst, self.root_model) + except (ValidationError, ValueError, TypeError, SyntaxError) as error: + self.validation_error = error + logger.debug(f"Validation error: {error}") + return False + + def validate_output(self, output) -> T: + """Validates the generated output and returns an instance of the root Pydantic model if successful.""" + Assert( + self.validate_root_model(output), + f"""You need to create a kwargs dict for {self.root_model.__name__}\n + Validation error:\n{self.validation_error}""", + ) + + return self.root_model.model_validate(eval_dict_str(output)) + + def forward(self, prompt) -> T: + """Takes a prompt as input and generates a Python dictionary that represents an instance of the + root Pydantic model. It also handles error correction and validation. + """ + output = self.generate( + prompt=prompt, + root_pydantic_model_class_name=self.root_model.__name__, + pydantic_model_definitions=self.model_sources, + ) + + output = output[self.output_key] + + try: + return self.validate_output(output) + except (AssertionError, ValueError, TypeError) as error: + logger.error(f"Error {error!s}\nOutput:\n{output}") + + # Correction attempt + corrected_output = self.correct_generate( + prompt=prompt, + root_pydantic_model_class_name=self.root_model.__name__, + pydantic_model_definitions=self.model_sources, + error=f"str(error){self.validation_error}", + )[self.output_key] + + return self.validate_output(corrected_output) + + def __call__(self, prompt): + return self.forward(prompt=prompt) + + +def gen_pydantic_instance_call(prompt, root_model, child_models=None): + model_module = GenPydanticInstance(root_model=root_model, child_models=child_models) + model_inst = model_module(prompt) + return model_inst diff --git a/src/dspygen/utils/fsm_mixin.py b/src/dspygen/utils/fsm_mixin.py new file mode 100644 index 0000000..c6456c4 --- /dev/null +++ b/src/dspygen/utils/fsm_mixin.py @@ -0,0 +1,88 @@ +import inspect + +from transitions import Machine +import functools +from enum import Enum, auto + +from transitions.core import State + + +from transitions import Machine +import functools +from enum import Enum, auto + +import inspect +import functools +from transitions import Machine + +import inspect +import functools +from transitions import Machine + + +# A decorator that adds transition metadata to methods +def trigger(source, dest, conditions=None, unless=None, before=None, after=None, prepare=None): + def decorator(func): + if not hasattr(func, '_transitions'): + func._transitions = [] + func._transitions.append({ + 'trigger': func.__name__, + 'source': source, + 'dest': dest, + 'conditions': conditions or [], + 'unless': unless or [], + 'before': before, + 'after': after, + 'prepare': prepare + }) + + @functools.wraps(func) + def wrapper(self, *args, **kwargs): + # Execute any 'prepare' callbacks + if prepare: + [getattr(self, p)() for p in (prepare if isinstance(prepare, list) else [prepare])] + + # Check 'unless' conditions to prevent transition + if unless and any([getattr(self, u)() for u in (unless if isinstance(unless, list) else [unless])]): + return func(self, *args, **kwargs) + + # Check 'conditions' to allow transition + if conditions is None or all( + [getattr(self, c)() for c in (conditions if isinstance(conditions, list) else [conditions])]): + if before: + [getattr(self, b)() for b in (before if isinstance(before, list) else [before])] + + # Correctly trigger the transition through the state machine + event_trigger = getattr(self, 'trigger') + event_trigger(func.__name__) + + result = func(self, *args, **kwargs) # Execute the actual function logic + + if after: + [getattr(self, a)() for a in (after if isinstance(after, list) else [after])] + return result + + return func(self, *args, **kwargs) # Conditions not met, no transition + + return wrapper + + return decorator + + +class FSMMixin: + def setup_fsm(self, state_enum, initial): + self.states = [State(state.name) for state in state_enum] + self.machine = Machine(model=self, states=self.states, initial=initial, auto_transitions=False) + self.initialize_transitions() + + def initialize_transitions(self): + for name, method in inspect.getmembers(self, predicate=inspect.ismethod): + if hasattr(method, '_transitions'): + for trans in method._transitions: + self.machine.add_transition(**trans) + + +def state_transition_possibilities(fsm): + current_state = fsm.state + transitions = fsm.machine.get_transitions() + return [transition.dest for transition in transitions if transition.source == current_state] diff --git a/tests/utils/test_fsm_mixin.py b/tests/utils/test_fsm_mixin.py new file mode 100644 index 0000000..8a9763d --- /dev/null +++ b/tests/utils/test_fsm_mixin.py @@ -0,0 +1,78 @@ +from enum import Enum, auto + +from transitions import MachineError + +from dspygen.utils.fsm_mixin import FSMMixin, trigger, state_transition_possibilities + + +class LightState(Enum): + """ Enum for the states of a traffic light. """ + GREEN = auto() + YELLOW = auto() + RED = auto() + + +class TrafficLight(FSMMixin): + def __init__(self): + super().setup_fsm(LightState, LightState.GREEN) + + @trigger(source=LightState.GREEN, dest=LightState.YELLOW, before='log_transition') + def slow_down(self): + print("Light turned yellow!") + + @trigger(source=LightState.YELLOW, dest=LightState.RED, after='celebrate_red') + def stop(self): + print("Light turned red!") + + @trigger(source=LightState.RED, dest=LightState.GREEN) + def go(self): + print("Light turned green!") + + def log_transition(self): + print(f"Transition from {self.state} initiated.") + + def celebrate_red(self): + print("Red light celebration!") + + +def test_fsm(): + """Main function""" + tl = TrafficLight() + tl.slow_down() # Transition from GREEN to YELLOW + tl.stop() # Transition from YELLOW to RED + + assert tl.state == LightState.RED + + +def test_possibilities(): + """Main function""" + tl = TrafficLight() + tl.slow_down() # Transition from GREEN to YELLOW + + poss = state_transition_possibilities(tl) + + assert poss == ['RED'] + + +def test_fsm(): + """Main function""" + tl = TrafficLight() + try: + tl.slow_down() # Transition from GREEN to YELLOW + print("Successfully transitioned from GREEN to YELLOW.") + except MachineError as e: + print(f"Failed to transition from GREEN to YELLOW: {e}") + + try: + tl.stop() # Transition from YELLOW to RED + print("Successfully transitioned from YELLOW to RED.") + except MachineError as e: + print(f"Failed to transition from YELLOW to RED: {e}") + + try: + tl.slow_down() # Attempt transition from RED to YELLOW (not defined, should raise error) + print("Successfully transitioned from RED to YELLOW.") + except MachineError as e: + print(f"Failed to transition from RED to YELLOW: {e}") + + print("Final state:", tl.state) diff --git a/tests/utils/test_fsm_superhero.py b/tests/utils/test_fsm_superhero.py new file mode 100644 index 0000000..7578c47 --- /dev/null +++ b/tests/utils/test_fsm_superhero.py @@ -0,0 +1,76 @@ +from enum import Enum, auto +import random +from dspygen.utils.fsm_mixin import FSMMixin, trigger + + +class SuperheroState(Enum): + ASLEEP = auto() + HANGING_OUT = auto() + HUNGRY = auto() + SWEATY = auto() + SAVING_THE_WORLD = auto() + + +class NarcolepticSuperhero(FSMMixin): + def __init__(self, name): + self.name = name + self.kittens_rescued = 0 + super().setup_fsm(SuperheroState, SuperheroState.ASLEEP) + + @trigger(source=SuperheroState.ASLEEP, dest=SuperheroState.HANGING_OUT) + def wake_up(self): + print(f"{self.name} woke up and is ready to hang out.") + + @trigger(source=SuperheroState.HANGING_OUT, dest=SuperheroState.HUNGRY) + def work_out(self): + print(f"{self.name} is now hungry after working out.") + + @trigger(source=SuperheroState.HUNGRY, dest=SuperheroState.HANGING_OUT) + def eat(self): + print(f"{self.name} is hanging out after eating.") + + @trigger(source='*', dest=SuperheroState.SAVING_THE_WORLD, before='change_into_super_secret_costume') + def distress_call(self): + print(f"{self.name} is off to save the world.") + + @trigger(source=SuperheroState.SAVING_THE_WORLD, dest=SuperheroState.SWEATY, after='update_journal') + def complete_mission(self): + print(f"{self.name} has completed the mission and is now sweaty.") + + @trigger(source=SuperheroState.SWEATY, dest=SuperheroState.ASLEEP, conditions=['is_exhausted']) + def clean_up_exhausted(self): + print(f"{self.name} is too exhausted and going back to sleep.") + + @trigger(source=SuperheroState.SWEATY, dest=SuperheroState.HANGING_OUT) + def clean_up(self): + print(f"{self.name} cleaned up and is hanging out again.") + + @trigger(source='*', dest=SuperheroState.ASLEEP) + def nap(self): + print(f"{self.name} has taken a nap.") + + def update_journal(self): + self.kittens_rescued += 1 + print(f"Updated journal: {self.kittens_rescued} kittens rescued.") + + def is_exhausted(self): + return random.random() < 0.5 + + def change_into_super_secret_costume(self): + print(f"{self.name} is changing into their super-secret costume.") + + +def test_narcoleptic_superhero(): + hero = NarcolepticSuperhero("SleepyMan") + hero.wake_up() + hero.work_out() + hero.eat() + hero.distress_call() + hero.complete_mission() + if hero.is_exhausted(): + hero.clean_up_exhausted() + else: + hero.clean_up() + hero.nap() + + assert hero.state == SuperheroState.ASLEEP.name diff --git a/tests/utils/test_fsm_teacher.py b/tests/utils/test_fsm_teacher.py new file mode 100644 index 0000000..df90c12 --- /dev/null +++ b/tests/utils/test_fsm_teacher.py @@ -0,0 +1,50 @@ +import pytest +from transitions import MachineError + +from enum import Enum, auto +from dspygen.utils.fsm_mixin import FSMMixin, trigger + + +class InteractionState(Enum): + """ Enum for states of the chatbot interaction. """ + ASKING_QUESTION = auto() + AWAITING_ANSWER = auto() + EVALUATING_ANSWER = auto() + PROVIDING_FEEDBACK = auto() + + +class TeacherChatbot(FSMMixin): + def __init__(self): + super().setup_fsm(InteractionState, InteractionState.ASKING_QUESTION) + + @trigger(source=InteractionState.ASKING_QUESTION, dest=InteractionState.AWAITING_ANSWER) + def question_asked(self): + print("Question asked: What is the capital of France?") + + @trigger(source=InteractionState.AWAITING_ANSWER, dest=InteractionState.EVALUATING_ANSWER) + def answer_received(self, answer): + self.answer = answer + print(f"Answer received: {answer}") + + @trigger(source=InteractionState.EVALUATING_ANSWER, dest=InteractionState.PROVIDING_FEEDBACK) + def answer_evaluated(self, is_correct): + self.is_correct = is_correct + feedback = "Correct!" if is_correct else "That's not right, try again!" + print(feedback) + + @trigger(source=InteractionState.PROVIDING_FEEDBACK, dest=InteractionState.ASKING_QUESTION) + def feedback_given(self): + print("Ready for the next question!") + + +def test_chatbot(): + """ Test function to simulate the chatbot interaction. """ + chatbot = TeacherChatbot() + chatbot.question_asked() # Transition to AWAITING_ANSWER + chatbot.answer_received("Paris") # Transition to EVALUATING_ANSWER + chatbot.answer_evaluated(True) # Transition to PROVIDING_FEEDBACK + chatbot.feedback_given() # Transition back to ASKING_QUESTION + + assert chatbot.state == InteractionState.ASKING_QUESTION.name + + print("Final state:", chatbot.state)