diff --git a/src/control_flow/agent.py b/src/control_flow/agent.py index 9c1aedd1..e8c9195e 100644 --- a/src/control_flow/agent.py +++ b/src/control_flow/agent.py @@ -7,10 +7,11 @@ import marvin import marvin.utilities.tools import prefect +from marvin.beta.assistants import Thread from marvin.beta.assistants.assistants import Assistant from marvin.beta.assistants.handlers import PrintHandler from marvin.beta.assistants.runs import Run -from marvin.tools.assistants import AssistantTool, CancelRun +from marvin.tools.assistants import AssistantTool, EndRun from marvin.types import FunctionTool from marvin.utilities.asyncio import ExposeSyncMethodsMixin, expose_sync_method from marvin.utilities.jinja import Environment @@ -30,18 +31,19 @@ T = TypeVar("T") logger = logging.getLogger(__name__) +TEMP_THREADS = {} TOOL_CALL_CODE_INTERPRETER_TEMPLATE = inspect.cleandoc( """ - # Tool call: code interpreter + ## Tool call: code interpreter - ## Code + ### Code ```python {code} ``` - ## Result + ### Result ```json {result} @@ -51,9 +53,9 @@ TOOL_CALL_FUNCTION_ARGS_TEMPLATE = inspect.cleandoc( """ - # Tool call: {name} + ## Tool call: {name} - ## Arguments + ### Arguments ```json {args} @@ -62,7 +64,7 @@ ) TOOL_CALL_FUNCTION_RESULT_TEMPLATE = inspect.cleandoc( """ - # Tool call: {name} + ## Tool call: {name} **Description:** {description} @@ -72,7 +74,7 @@ {args} ``` - ## Result + ### Result ```json {result} @@ -90,10 +92,7 @@ ## Instructions -In addition to completing your tasks, these are your current instructions. You -must follow them at all times, even when using a tool to talk to a user. Note -that instructions can change at any time and the thread history may reflect -different instructions than these: +Follow these instructions at all times: {% if assistant.instructions -%} - {{ assistant.instructions }} @@ -123,8 +122,8 @@ especially when working with a human user. -{% for task in agent.tasks %} -### Task {{ task.id }} +{% for task_id, task in agent.numbered_tasks() %} +### Task {{ task_id }} - Status: {{ task.status.value }} - Objective: {{ task.objective }} {% if task.instructions %} @@ -166,6 +165,7 @@ system works to them. They can only see messages you send them via tool, not the rest of the thread. When dealing with humans, you may not always get a clear or correct response. You may need to ask multiple times or rephrase your questions. +You should also interpret human responses broadly and not be too literal. {% else %} You can not communicate with a human user at this time. {% endif %} @@ -273,14 +273,14 @@ def talk_to_human(message: str, get_response: bool = True) -> str: def end_run(): """Use this tool to end the run.""" - raise CancelRun() + return EndRun() class Agent(BaseModel, Generic[T], ExposeSyncMethodsMixin): tasks: list[AITask] = [] flow: AIFlow = Field(None, validate_default=True) assistant: Assistant = Field(None, validate_default=True) - tools: list[Union[AssistantTool, Callable]] = [] + tools: list[Union[AssistantTool, Assistant, Callable]] = [] context: dict = Field(None, validate_default=True) user_access: bool = Field( None, @@ -321,17 +321,14 @@ def _default_assistant(cls, v): v = Assistant() return v - @field_validator("user_access", mode="before") - def _default_user_access(cls, v): + @field_validator("user_access", "system_access", mode="before") + def _default_access(cls, v): if v is None: v = False return v - @field_validator("system_access", mode="before") - def _default_system_access(cls, v): - if v is None: - v = False - return v + def numbered_tasks(self) -> list[tuple[int, AITask]]: + return [(i + 1, task) for i, task in enumerate(self.tasks)] def _get_instructions(self, context: dict = None): instructions = Environment.render( @@ -351,16 +348,31 @@ def _get_tools(self) -> list[AssistantTool]: if not self.tasks: tools.append(end_run) - for task in self.tasks: - tools.extend([task._create_complete_tool(), task._create_fail_tool()]) + # if there is only one task, and the agent can't send a response to the + # system, then we can quit as soon as it is marked finished + if not self.system_access and len(self.tasks) == 1: + end_run = True + else: + end_run = False + + for i, task in self.numbered_tasks(): + tools.extend( + [ + task._create_complete_tool(task_id=i, end_run=end_run), + task._create_fail_tool(task_id=i, end_run=end_run), + ] + ) if self.user_access: tools.append(talk_to_human) final_tools = [] for tool in tools: - if not isinstance(tool, AssistantTool): + if isinstance(tool, marvin.beta.assistants.Assistant): + tool = self.model_copy(update={"assistant": tool}).as_tool() + elif not isinstance(tool, AssistantTool): tool = marvin.utilities.tools.tool_from_function(tool) + if isinstance(tool, FunctionTool): async def modified_fn( @@ -405,17 +417,22 @@ def _get_openai_run_task(self): """ @prefect_task(name="Execute OpenAI assistant run") - async def execute_openai_run(context: dict = None, run_kwargs: dict = None): + async def execute_openai_run( + context: dict = None, run_kwargs: dict = None + ) -> Run: run_kwargs = run_kwargs or {} - if "model" not in run_kwargs: - run_kwargs["model"] = self.assistant.model or settings.assistant_model + model = run_kwargs.pop( + "model", self.assistant.model or settings.assistant_model + ) + thread = run_kwargs.pop("thread", self.flow.thread) run = Run( assistant=self.assistant, - thread=self.flow.thread, + thread=thread, instructions=self._get_instructions(context=context), tools=self._get_tools(), event_handler_class=AgentHandler, + model=model, **run_kwargs, ) await run.run_async() @@ -452,6 +469,7 @@ async def execute_openai_run(context: dict = None, run_kwargs: dict = None): key="steps", description="All steps taken during the run.", ) + return run return execute_openai_run @@ -469,6 +487,7 @@ async def run_async(self, context: dict = None, **run_kwargs) -> list[AITask]: any(t.status == TaskStatus.PENDING for t in self.tasks) and counter < settings.max_agent_iterations ): + breakpoint() openai_run(context=context, run_kwargs=run_kwargs) counter += 1 @@ -476,6 +495,35 @@ async def run_async(self, context: dict = None, **run_kwargs) -> list[AITask]: return result + def as_tool(self): + thread = TEMP_THREADS.setdefault(self.assistant.model_dump_json(), Thread()) + + def _run(message: str, context: dict = None) -> list[str]: + task = self._get_openai_run_task() + run: Run = task(context=context, run_kwargs=dict(thread=thread)) + return [m.model_dump_json() for m in run.messages] + + return marvin.utilities.tools.tool_from_function( + _run, + name=f"call_ai_{self.assistant.name}", + description=inspect.cleandoc(""" + Use this tool to talk to a sub-AI that can operate independently of + you. The sub-AI may have a different skillset or be able to access + different tools than you. The sub-AI will run one iteration and + respond to you. You may continue to invoke it multiple times in sequence, as + needed. + + Note: you can only talk to one sub-AI at a time. Do not call in parallel or you will get an error about thread conflicts. + + ## Sub-AI Details + + - Name: {name} + - Instructions: {instructions} + """).format( + name=self.assistant.name, instructions=self.assistant.instructions + ), + ) + def ai_task( fn=None, *, objective: str = None, user_access: bool = None, **agent_kwargs: dict @@ -541,12 +589,9 @@ def run_ai( # load flow flow = ctx.get("flow", None) - if flow is None: - flow = AIFlow() # create task ai_task = AITask[cast](objective=task, context=context) - flow.add_task(ai_task) # run agent agent = Agent(tasks=[ai_task], flow=flow, user_access=user_access, **agent_kwargs) diff --git a/src/control_flow/flow.py b/src/control_flow/flow.py index eae5c931..4eb99565 100644 --- a/src/control_flow/flow.py +++ b/src/control_flow/flow.py @@ -1,5 +1,5 @@ import functools -from typing import Callable, List, Optional, Union +from typing import Callable, Optional, Union from marvin.beta.assistants import Assistant, Thread from marvin.beta.assistants.assistants import AssistantTool @@ -10,13 +10,10 @@ from control_flow.context import ctx -from .task import AITask - logger = get_logger(__name__) class AIFlow(BaseModel): - tasks: List[AITask] = [] thread: Thread = Field(None, validate_default=True) assistant: Optional[Assistant] = Field(None, validate_default=True) tools: list[Union[AssistantTool, Callable]] = Field(None, validate_default=True) @@ -47,24 +44,6 @@ def _default_tools(cls, v): v = [] return v - def add_task(self, task: AITask): - if task.id is None: - task.id = len(self.tasks) + 1 - elif task.id in {t.id for t in self.tasks}: - raise ValueError(f"Task with id {task.id} already exists.") - self.tasks.append(task) - - def get_task_by_id(self, task_id: int) -> Optional[AITask]: - for task in self.tasks: - if task.id == task_id: - return task - return None - - def update_task(self, task_id: int, status: str, result: str = None): - task = self.get_task_by_id(task_id) - if task: - task.update(status=status, result=result) - def add_message(self, message: str): prefect_task(self.thread.add)(message) @@ -101,7 +80,12 @@ def wrapper( ): p_fn = prefect_flow(fn) flow_assistant = _assistant or assistant - flow_thread = _thread or thread or flow_assistant.default_thread + flow_thread = ( + _thread + or thread + or (flow_assistant.default_thread if flow_assistant else None) + or Thread() + ) flow_instructions = _instructions or instructions flow_tools = _tools or tools flow_obj = AIFlow( @@ -111,7 +95,9 @@ def wrapper( instructions=flow_instructions, ) - logger.info(f'Executing AI flow "{fn.__name__}" on thread "{flow_thread.id}"') + logger.info( + f'Executing AI flow "{fn.__name__}" on thread "{flow_obj.thread.id}"' + ) with ctx(flow=flow_obj): return p_fn(*args, **kwargs) diff --git a/src/control_flow/settings.py b/src/control_flow/settings.py index 8bf87444..9e46653b 100644 --- a/src/control_flow/settings.py +++ b/src/control_flow/settings.py @@ -5,9 +5,10 @@ class ControlFlowSettings(BaseSettings): model_config: SettingsConfigDict = SettingsConfigDict( + env_prefix="CONTROLFLOW_", env_file=( "" - if os.getenv("CONTROL_FLOW_TEST_MODE") + if os.getenv("CONTROLFLOW_TEST_MODE") else ("~/.control_flow/.env", ".env") ), extra="allow", @@ -19,6 +20,7 @@ class ControlFlowSettings(BaseSettings): class Settings(ControlFlowSettings): assistant_model: str = "gpt-4-1106-preview" max_agent_iterations: int = 10 + use_prefect: bool = True settings = Settings() diff --git a/src/control_flow/task.py b/src/control_flow/task.py index 719cdbc3..438aa546 100644 --- a/src/control_flow/task.py +++ b/src/control_flow/task.py @@ -3,12 +3,11 @@ import marvin import marvin.utilities.tools +from marvin.beta.assistants.runs import EndRun from marvin.utilities.logging import get_logger from marvin.utilities.tools import FunctionTool from pydantic import BaseModel, Field, field_validator -from control_flow.context import ctx - T = TypeVar("T") logger = get_logger(__name__) @@ -28,7 +27,6 @@ class AITask(BaseModel, Generic[T]): iterate until all tasks are completed. """ - id: int = Field(None, validate_default=True) objective: str instructions: Optional[str] = None context: dict = Field(None, validate_default=True) @@ -39,21 +37,15 @@ class AITask(BaseModel, Generic[T]): # internal model_config: dict = dict(validate_assignment=True, extra="forbid") - @field_validator("id", mode="before") - def _default_id(cls, v): - if v is None: - flow = ctx.get("flow") - if flow is not None: - v = len(flow.tasks) + 1 - return v - @field_validator("context", mode="before") def _default_context(cls, v): if v is None: v = {} return v - def _create_complete_tool(self) -> FunctionTool: + def _create_complete_tool( + self, task_id: int, end_run: bool = False + ) -> FunctionTool: """ Create an agent-compatible tool for completing this task. """ @@ -65,26 +57,30 @@ def _create_complete_tool(self) -> FunctionTool: def complete(result: result_type): self.result = result self.status = TaskStatus.COMPLETED + if end_run: + return EndRun() tool = marvin.utilities.tools.tool_from_function( complete, - name=f"complete_task_{self.id}", - description=f"Mark task {self.id} completed", + name=f"complete_task_{task_id}", + description=f"Mark task {task_id} completed", ) else: def complete(): self.status = TaskStatus.COMPLETED + if end_run: + return EndRun() tool = marvin.utilities.tools.tool_from_function( complete, - name=f"complete_task_{self.id}", - description=f"Mark task {self.id} completed", + name=f"complete_task_{task_id}", + description=f"Mark task {task_id} completed", ) return tool - def _create_fail_tool(self) -> FunctionTool: + def _create_fail_tool(self, task_id: int, end_run: bool = False) -> FunctionTool: """ Create an agent-compatible tool for failing this task. """ @@ -92,11 +88,13 @@ def _create_fail_tool(self) -> FunctionTool: def fail(message: Optional[str] = None): self.error = message self.status = TaskStatus.FAILED + if end_run: + return EndRun() tool = marvin.utilities.tools.tool_from_function( fail, - name=f"fail_task_{self.id}", - description=f"Mark task {self.id} failed", + name=f"fail_task_{task_id}", + description=f"Mark task {task_id} failed", ) return tool