diff --git a/src/control_flow/agent.py b/src/control_flow/agent.py index 97a84d30..f8d5c1a8 100644 --- a/src/control_flow/agent.py +++ b/src/control_flow/agent.py @@ -35,6 +35,7 @@ T = TypeVar("T") logger = logging.getLogger(__name__) +NOT_PROVIDED = object() TEMP_THREADS = {} @@ -116,7 +117,9 @@ {% endfor %} {% else %} -You have no explicit tasks to complete. Follow your instructions as best as you can. +You have no explicit tasks to complete. Follow your instructions as best as you +can. If it is not possible to comply with the instructions in any way, use the +`end_run` tool to manually stop the run. {% endif %} ## Communication @@ -321,15 +324,15 @@ def _get_tools(self) -> list[AssistantTool]: # if there is only one task, and the agent can't send a response to the # system, then we can quit as soon as it is marked finished if not self.system_access and len(self.tasks) == 1: - end_run = True + early_end_run = True else: - end_run = False + early_end_run = False for i, task in self.numbered_tasks(): tools.extend( [ - task._create_complete_tool(task_id=i, end_run=end_run), - task._create_fail_tool(task_id=i, end_run=end_run), + task._create_complete_tool(task_id=i, end_run=early_end_run), + task._create_fail_tool(task_id=i, end_run=early_end_run), ] ) @@ -386,13 +389,14 @@ def _get_openai_run_task(self): This needs to be regenerated each time in case the instructions change. """ - @prefect_task(task_run_name="Run OpenAI assistant") + @prefect_task(task_run_name=f"Run OpenAI assistant ({self.assistant.name})") async def execute_openai_run( context: dict = None, run_kwargs: dict = None ) -> Run: run_kwargs = run_kwargs or {} model = run_kwargs.pop( - "model", self.assistant.model or settings.assistant_model + "model", + self.assistant.model or self.flow.model or settings.assistant_model, ) thread = run_kwargs.pop("thread", self.flow.thread) @@ -408,12 +412,14 @@ async def execute_openai_run( await run.run_async() create_json_artifact( key="messages", - data=run.messages, + # dump explicilty because of odd OAI serialization issue + data=[m.model_dump() for m in run.messages], description="All messages sent and received during the run.", ) create_json_artifact( key="actions", - data=run.steps, + # dump explicilty because of odd OAI serialization issue + data=[s.model_dump() for s in run.steps], description="All actions taken by the assistant during the run.", ) return run @@ -426,8 +432,8 @@ async def run_async(self, context: dict = None, **run_kwargs) -> list[AITask]: openai_run(context=context, run_kwargs=run_kwargs) - # if this is not an interactive run, continue to run the AI - # until all tasks are no longer pending + # if this AI can't post messages to the system, then continue to invoke + # it until all tasks are finished if not self.system_access: counter = 0 while ( @@ -513,16 +519,19 @@ def _name_from_objective(): """Helper function for naming task runs""" from prefect.runtime import task_run - objective = task_run.parameters["task"] - if len(objective) > 50: - return f"Task: {objective[:50]}..." + objective = task_run.parameters.get("task") + + if not objective: + objective = "Follow general instructions" + if len(objective) > 75: + return f"Task: {objective[:75]}..." return f"Task: {objective}" @prefect_task(task_run_name=_name_from_objective) def run_agent( - task: str, - cast: T = str, + task: str = None, + cast: T = NOT_PROVIDED, context: dict = None, user_access: bool = None, model: str = None, @@ -533,18 +542,27 @@ def run_agent( response will be cast to the given result type. """ + if cast is NOT_PROVIDED: + if not task: + cast = None + else: + cast = str + # load flow flow = ctx.get("flow", None) - # create task - ai_task = AITask[cast](objective=task, context=context) + # create tasks + if task: + ai_tasks = [AITask[cast](objective=task, context=context)] + else: + ai_tasks = [] # run agent - agent = Agent(tasks=[ai_task], flow=flow, user_access=user_access, **agent_kwargs) + agent = Agent(tasks=ai_tasks, flow=flow, user_access=user_access, **agent_kwargs) agent.run(model=model) - # return - if ai_task.status == TaskStatus.COMPLETED: - return ai_task.result - elif ai_task.status == TaskStatus.FAILED: - raise ValueError(ai_task.error) + if ai_tasks: + if ai_tasks[0].status == TaskStatus.COMPLETED: + return ai_tasks[0].result + elif ai_tasks[0].status == TaskStatus.FAILED: + raise ValueError(ai_tasks[0].error) diff --git a/src/control_flow/flow.py b/src/control_flow/flow.py index 1f6bbb9a..24dfcb81 100644 --- a/src/control_flow/flow.py +++ b/src/control_flow/flow.py @@ -4,6 +4,7 @@ from marvin.beta.assistants import Assistant, Thread from marvin.beta.assistants.assistants import AssistantTool from marvin.utilities.logging import get_logger +from openai.types.beta.threads import Message from prefect import flow as prefect_flow from prefect import task as prefect_task from pydantic import BaseModel, Field, field_validator @@ -19,6 +20,7 @@ class AIFlow(BaseModel): assistant: Optional[Assistant] = Field(None, validate_default=True) tools: list[Union[AssistantTool, Callable]] = Field(None, validate_default=True) instructions: Optional[str] = None + model: Optional[str] = None model_config: dict = dict(validate_assignment=True, extra="forbid") @@ -56,6 +58,7 @@ def ai_flow( thread: Thread = None, tools: list[Union[AssistantTool, Callable]] = None, instructions: str = None, + model: str = None, ): """ Prepare a function to be executed as a Control Flow flow. @@ -68,6 +71,7 @@ def ai_flow( thread=thread, tools=tools, instructions=instructions, + model=model, ) @functools.wraps(fn) @@ -77,6 +81,7 @@ def wrapper( _thread: Thread = None, _tools: list[Union[AssistantTool, Callable]] = None, _instructions: str = None, + _model: str = None, **kwargs, ): p_fn = prefect_flow(fn) @@ -89,11 +94,13 @@ def wrapper( ) flow_instructions = _instructions or instructions flow_tools = _tools or tools + flow_model = _model or model flow_obj = AIFlow( thread=flow_thread, assistant=flow_assistant, tools=flow_tools, instructions=flow_instructions, + model=flow_model, ) logger.info( @@ -104,3 +111,15 @@ def wrapper( return p_fn(*args, **kwargs) return wrapper + + +def get_messages(limit: int = None) -> list[Message]: + """ + Loads messages from the flow's thread. + + Will error if no flow is found in the context. + """ + flow: Optional[AIFlow] = ctx.get("flow") + if not flow: + raise ValueError("No flow found in context") + return flow.thread.get_messages(limit=limit)