From 737fffcd7aacb7ca2a65a81fc28a08c809625dd6 Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Sun, 7 Apr 2024 16:16:38 -0400 Subject: [PATCH] allow passing model to flow --- src/control_flow/flow.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/control_flow/flow.py b/src/control_flow/flow.py index 1f6bbb9a..24dfcb81 100644 --- a/src/control_flow/flow.py +++ b/src/control_flow/flow.py @@ -4,6 +4,7 @@ from marvin.beta.assistants import Assistant, Thread from marvin.beta.assistants.assistants import AssistantTool from marvin.utilities.logging import get_logger +from openai.types.beta.threads import Message from prefect import flow as prefect_flow from prefect import task as prefect_task from pydantic import BaseModel, Field, field_validator @@ -19,6 +20,7 @@ class AIFlow(BaseModel): assistant: Optional[Assistant] = Field(None, validate_default=True) tools: list[Union[AssistantTool, Callable]] = Field(None, validate_default=True) instructions: Optional[str] = None + model: Optional[str] = None model_config: dict = dict(validate_assignment=True, extra="forbid") @@ -56,6 +58,7 @@ def ai_flow( thread: Thread = None, tools: list[Union[AssistantTool, Callable]] = None, instructions: str = None, + model: str = None, ): """ Prepare a function to be executed as a Control Flow flow. @@ -68,6 +71,7 @@ def ai_flow( thread=thread, tools=tools, instructions=instructions, + model=model, ) @functools.wraps(fn) @@ -77,6 +81,7 @@ def wrapper( _thread: Thread = None, _tools: list[Union[AssistantTool, Callable]] = None, _instructions: str = None, + _model: str = None, **kwargs, ): p_fn = prefect_flow(fn) @@ -89,11 +94,13 @@ def wrapper( ) flow_instructions = _instructions or instructions flow_tools = _tools or tools + flow_model = _model or model flow_obj = AIFlow( thread=flow_thread, assistant=flow_assistant, tools=flow_tools, instructions=flow_instructions, + model=flow_model, ) logger.info( @@ -104,3 +111,15 @@ def wrapper( return p_fn(*args, **kwargs) return wrapper + + +def get_messages(limit: int = None) -> list[Message]: + """ + Loads messages from the flow's thread. + + Will error if no flow is found in the context. + """ + flow: Optional[AIFlow] = ctx.get("flow") + if not flow: + raise ValueError("No flow found in context") + return flow.thread.get_messages(limit=limit)