Skip to content

Commit

Permalink
allow passing model to flow
Browse files Browse the repository at this point in the history
  • Loading branch information
jlowin committed Apr 7, 2024
1 parent 2045c92 commit 737fffc
Showing 1 changed file with 19 additions and 0 deletions.
19 changes: 19 additions & 0 deletions src/control_flow/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from marvin.beta.assistants import Assistant, Thread
from marvin.beta.assistants.assistants import AssistantTool
from marvin.utilities.logging import get_logger
from openai.types.beta.threads import Message
from prefect import flow as prefect_flow
from prefect import task as prefect_task
from pydantic import BaseModel, Field, field_validator
Expand All @@ -19,6 +20,7 @@ class AIFlow(BaseModel):
assistant: Optional[Assistant] = Field(None, validate_default=True)
tools: list[Union[AssistantTool, Callable]] = Field(None, validate_default=True)
instructions: Optional[str] = None
model: Optional[str] = None

model_config: dict = dict(validate_assignment=True, extra="forbid")

Expand Down Expand Up @@ -56,6 +58,7 @@ def ai_flow(
thread: Thread = None,
tools: list[Union[AssistantTool, Callable]] = None,
instructions: str = None,
model: str = None,
):
"""
Prepare a function to be executed as a Control Flow flow.
Expand All @@ -68,6 +71,7 @@ def ai_flow(
thread=thread,
tools=tools,
instructions=instructions,
model=model,
)

@functools.wraps(fn)
Expand All @@ -77,6 +81,7 @@ def wrapper(
_thread: Thread = None,
_tools: list[Union[AssistantTool, Callable]] = None,
_instructions: str = None,
_model: str = None,
**kwargs,
):
p_fn = prefect_flow(fn)
Expand All @@ -89,11 +94,13 @@ def wrapper(
)
flow_instructions = _instructions or instructions
flow_tools = _tools or tools
flow_model = _model or model
flow_obj = AIFlow(
thread=flow_thread,
assistant=flow_assistant,
tools=flow_tools,
instructions=flow_instructions,
model=flow_model,
)

logger.info(
Expand All @@ -104,3 +111,15 @@ def wrapper(
return p_fn(*args, **kwargs)

return wrapper


def get_messages(limit: int = None) -> list[Message]:
"""
Loads messages from the flow's thread.
Will error if no flow is found in the context.
"""
flow: Optional[AIFlow] = ctx.get("flow")
if not flow:
raise ValueError("No flow found in context")
return flow.thread.get_messages(limit=limit)

0 comments on commit 737fffc

Please sign in to comment.