Skip to content

Commit

Permalink
FSM Mixin with tests
Browse files Browse the repository at this point in the history
  • Loading branch information
seanchatmangpt committed Apr 26, 2024
1 parent 9da2529 commit 085ca43
Show file tree
Hide file tree
Showing 7 changed files with 512 additions and 2 deletions.
68 changes: 67 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"

[tool.poetry] # https://python-poetry.org/docs/pyproject/
name = "dspygen"
version = "2024.4.17"
version = "2024.4.26"
description = "A Ruby on Rails style framework for the DSPy (Demonstrate, Search, Predict) project for Language Models like GPT, BERT, and LLama."
authors = ["Sean Chatman <info@chatmangpt.com>"]
readme = "README.md"
Expand Down Expand Up @@ -59,6 +59,7 @@ chromadb = "^0.4.24"
anyio = "^4.3.0"
docutils = "0.21"
transitions = "^0.9.0"
pygame = "^2.5.2"

[tool.poetry.group.test.dependencies] # https://python-poetry.org/docs/master/managing-dependencies/
coverage = { extras = ["toml"], version = ">=7.2.5" }
Expand Down
151 changes: 151 additions & 0 deletions src/dspygen/modules/gen_pydantic_instance_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import ast
import dspy
import inspect
import logging
from typing import Optional, TypeVar, cast

from pydantic import BaseModel, ValidationError

from dspy import Assert, ChainOfThought, InputField, OutputField, Signature

logger = logging.getLogger(__name__)
logger.setLevel(logging.ERROR)


def eval_dict_str(dict_str: str) -> dict:
"""Safely convert str to dict"""
return ast.literal_eval(dict_str)


class PromptToPydanticInstanceSignature(Signature):
"""Synthesize the prompt into the kwargs to fit the model.
Do not duplicate the field descriptions
"""

root_pydantic_model_class_name = InputField(
desc="The class name of the pydantic model to receive the kwargs"
)
pydantic_model_definitions = InputField(
desc="Pydantic model class definitions as a string"
)
prompt = InputField(
desc="The prompt to be synthesized into data. Do not duplicate descriptions"
)
root_model_kwargs_dict = OutputField(
prefix="kwargs_dict: dict = ",
desc="Generate a Python dictionary as a string with minimized whitespace that only contains json valid values.",
)


class PromptToPydanticInstanceErrorSignature(Signature):
"""Synthesize the prompt into the kwargs fit the model"""

error = InputField(desc="Error message to fix the kwargs")

root_pydantic_model_class_name = InputField(
desc="The class name of the pydantic model to receive the kwargs"
)
pydantic_model_definitions = InputField(
desc="Pydantic model class definitions as a string"
)
prompt = InputField(desc="The prompt to be synthesized into data")
root_model_kwargs_dict = OutputField(
prefix="kwargs_dict = ",
desc="Generate a Python dictionary as a string with minimized whitespace that only contains json valid values.",
)


T = TypeVar("T", bound=BaseModel)


class GenPydanticInstance(dspy.Module):
"""A module for generating and validating Pydantic model instances based on prompts.
Usage:
To use this module, instantiate the GenPydanticInstance class with the desired
root Pydantic model and optional child models. Then, call the `forward` method
with a prompt to generate Pydantic model instances based on the provided prompt.
"""

def __init__(
self,
root_model: type[T],
child_models: Optional[list[type[BaseModel]]] = None,
generate_sig=PromptToPydanticInstanceSignature,
correct_generate_sig=PromptToPydanticInstanceErrorSignature,
):
super().__init__()

self.models = [root_model] # Always include root_model in models list

if child_models:
self.models.extend(child_models)

self.output_key = "root_model_kwargs_dict"
self.root_model = root_model

# Concatenate source code of models for use in generation/correction logic
self.model_sources = "\n".join(
[inspect.getsource(model) for model in self.models]
)

# Initialize DSPy ChainOfThought modules for generation and correction
self.generate = ChainOfThought(generate_sig)
self.correct_generate = ChainOfThought(correct_generate_sig)
self.validation_error = None

def validate_root_model(self, output: str) -> bool:
"""Validates whether the generated output conforms to the root Pydantic model."""
try:
model_inst = self.root_model.model_validate(eval_dict_str(output))
return isinstance(model_inst, self.root_model)
except (ValidationError, ValueError, TypeError, SyntaxError) as error:
self.validation_error = error
logger.debug(f"Validation error: {error}")
return False

def validate_output(self, output) -> T:
"""Validates the generated output and returns an instance of the root Pydantic model if successful."""
Assert(
self.validate_root_model(output),
f"""You need to create a kwargs dict for {self.root_model.__name__}\n
Validation error:\n{self.validation_error}""",
)

return self.root_model.model_validate(eval_dict_str(output))

def forward(self, prompt) -> T:
"""Takes a prompt as input and generates a Python dictionary that represents an instance of the
root Pydantic model. It also handles error correction and validation.
"""
output = self.generate(
prompt=prompt,
root_pydantic_model_class_name=self.root_model.__name__,
pydantic_model_definitions=self.model_sources,
)

output = output[self.output_key]

try:
return self.validate_output(output)
except (AssertionError, ValueError, TypeError) as error:
logger.error(f"Error {error!s}\nOutput:\n{output}")

# Correction attempt
corrected_output = self.correct_generate(
prompt=prompt,
root_pydantic_model_class_name=self.root_model.__name__,
pydantic_model_definitions=self.model_sources,
error=f"str(error){self.validation_error}",
)[self.output_key]

return self.validate_output(corrected_output)

def __call__(self, prompt):
return self.forward(prompt=prompt)


def gen_pydantic_instance_call(prompt, root_model, child_models=None):
model_module = GenPydanticInstance(root_model=root_model, child_models=child_models)
model_inst = model_module(prompt)
return model_inst
88 changes: 88 additions & 0 deletions src/dspygen/utils/fsm_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import inspect

from transitions import Machine
import functools
from enum import Enum, auto

from transitions.core import State


from transitions import Machine
import functools
from enum import Enum, auto

import inspect
import functools
from transitions import Machine

import inspect
import functools
from transitions import Machine


# A decorator that adds transition metadata to methods
def trigger(source, dest, conditions=None, unless=None, before=None, after=None, prepare=None):
def decorator(func):
if not hasattr(func, '_transitions'):
func._transitions = []
func._transitions.append({
'trigger': func.__name__,
'source': source,
'dest': dest,
'conditions': conditions or [],
'unless': unless or [],
'before': before,
'after': after,
'prepare': prepare
})

@functools.wraps(func)
def wrapper(self, *args, **kwargs):
# Execute any 'prepare' callbacks
if prepare:
[getattr(self, p)() for p in (prepare if isinstance(prepare, list) else [prepare])]

# Check 'unless' conditions to prevent transition
if unless and any([getattr(self, u)() for u in (unless if isinstance(unless, list) else [unless])]):
return func(self, *args, **kwargs)

# Check 'conditions' to allow transition
if conditions is None or all(
[getattr(self, c)() for c in (conditions if isinstance(conditions, list) else [conditions])]):
if before:
[getattr(self, b)() for b in (before if isinstance(before, list) else [before])]

# Correctly trigger the transition through the state machine
event_trigger = getattr(self, 'trigger')
event_trigger(func.__name__)

result = func(self, *args, **kwargs) # Execute the actual function logic

if after:
[getattr(self, a)() for a in (after if isinstance(after, list) else [after])]
return result

return func(self, *args, **kwargs) # Conditions not met, no transition

return wrapper

return decorator


class FSMMixin:
def setup_fsm(self, state_enum, initial):
self.states = [State(state.name) for state in state_enum]
self.machine = Machine(model=self, states=self.states, initial=initial, auto_transitions=False)
self.initialize_transitions()

def initialize_transitions(self):
for name, method in inspect.getmembers(self, predicate=inspect.ismethod):
if hasattr(method, '_transitions'):
for trans in method._transitions:
self.machine.add_transition(**trans)


def state_transition_possibilities(fsm):
current_state = fsm.state
transitions = fsm.machine.get_transitions()
return [transition.dest for transition in transitions if transition.source == current_state]
Loading

0 comments on commit 085ca43

Please sign in to comment.