From 0806b452cd57cdaebb85245cb9c8c8857c35be31 Mon Sep 17 00:00:00 2001
From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com>
Date: Sat, 6 Apr 2024 14:29:45 -0400
Subject: [PATCH] Use new marvin endrun functionality

---
 src/control_flow/agent.py    | 111 ++++++++++++++++++++++++-----------
 src/control_flow/flow.py     |  34 ++++-------
 src/control_flow/settings.py |   4 +-
 src/control_flow/task.py     |  36 ++++++------
 4 files changed, 108 insertions(+), 77 deletions(-)

diff --git a/src/control_flow/agent.py b/src/control_flow/agent.py
index 9c1aedd1..e8c9195e 100644
--- a/src/control_flow/agent.py
+++ b/src/control_flow/agent.py
@@ -7,10 +7,11 @@
 import marvin
 import marvin.utilities.tools
 import prefect
+from marvin.beta.assistants import Thread
 from marvin.beta.assistants.assistants import Assistant
 from marvin.beta.assistants.handlers import PrintHandler
 from marvin.beta.assistants.runs import Run
-from marvin.tools.assistants import AssistantTool, CancelRun
+from marvin.tools.assistants import AssistantTool, EndRun
 from marvin.types import FunctionTool
 from marvin.utilities.asyncio import ExposeSyncMethodsMixin, expose_sync_method
 from marvin.utilities.jinja import Environment
@@ -30,18 +31,19 @@
 T = TypeVar("T")
 logger = logging.getLogger(__name__)
 
+TEMP_THREADS = {}
 
 TOOL_CALL_CODE_INTERPRETER_TEMPLATE = inspect.cleandoc(
     """
-    # Tool call: code interpreter
+    ## Tool call: code interpreter
         
-    ## Code
+    ### Code
     
     ```python
     {code}
     ```
     
-    ## Result
+    ### Result
     
     ```json
     {result}
@@ -51,9 +53,9 @@
 
 TOOL_CALL_FUNCTION_ARGS_TEMPLATE = inspect.cleandoc(
     """
-    # Tool call: {name}
+    ## Tool call: {name}
     
-    ## Arguments
+    ### Arguments
     
     ```json
     {args}
@@ -62,7 +64,7 @@
 )
 TOOL_CALL_FUNCTION_RESULT_TEMPLATE = inspect.cleandoc(
     """
-    # Tool call: {name}
+    ## Tool call: {name}
     
     **Description:** {description}
     
@@ -72,7 +74,7 @@
     {args}
     ```
     
-    ## Result
+    ### Result
     
     ```json
     {result}
@@ -90,10 +92,7 @@
 
 ## Instructions
 
-In addition to completing your tasks, these are your current instructions. You
-must follow them at all times, even when using a tool to talk to a user. Note
-that instructions can change at any time and the thread history may reflect
-different instructions than these:
+Follow these instructions at all times:
 
 {% if assistant.instructions -%}
 - {{ assistant.instructions }}
@@ -123,8 +122,8 @@
 especially when working with a human user.
 
 
-{% for task in agent.tasks %}
-### Task {{ task.id }}
+{% for task_id, task in agent.numbered_tasks() %}
+### Task {{ task_id }}
 - Status: {{ task.status.value }}
 - Objective: {{ task.objective }}
 {% if task.instructions %}
@@ -166,6 +165,7 @@
 system works to them. They can only see messages you send them via tool, not the
 rest of the thread. When dealing with humans, you may not always get a clear or
 correct response. You may need to ask multiple times or rephrase your questions.
+You should also interpret human responses broadly and not be too literal.
 {% else %}
 You can not communicate with a human user at this time.
 {% endif %}
@@ -273,14 +273,14 @@ def talk_to_human(message: str, get_response: bool = True) -> str:
 
 def end_run():
     """Use this tool to end the run."""
-    raise CancelRun()
+    return EndRun()
 
 
 class Agent(BaseModel, Generic[T], ExposeSyncMethodsMixin):
     tasks: list[AITask] = []
     flow: AIFlow = Field(None, validate_default=True)
     assistant: Assistant = Field(None, validate_default=True)
-    tools: list[Union[AssistantTool, Callable]] = []
+    tools: list[Union[AssistantTool, Assistant, Callable]] = []
     context: dict = Field(None, validate_default=True)
     user_access: bool = Field(
         None,
@@ -321,17 +321,14 @@ def _default_assistant(cls, v):
                 v = Assistant()
         return v
 
-    @field_validator("user_access", mode="before")
-    def _default_user_access(cls, v):
+    @field_validator("user_access", "system_access", mode="before")
+    def _default_access(cls, v):
         if v is None:
             v = False
         return v
 
-    @field_validator("system_access", mode="before")
-    def _default_system_access(cls, v):
-        if v is None:
-            v = False
-        return v
+    def numbered_tasks(self) -> list[tuple[int, AITask]]:
+        return [(i + 1, task) for i, task in enumerate(self.tasks)]
 
     def _get_instructions(self, context: dict = None):
         instructions = Environment.render(
@@ -351,16 +348,31 @@ def _get_tools(self) -> list[AssistantTool]:
         if not self.tasks:
             tools.append(end_run)
 
-        for task in self.tasks:
-            tools.extend([task._create_complete_tool(), task._create_fail_tool()])
+        # if there is only one task, and the agent can't send a response to the
+        # system, then we can quit as soon as it is marked finished
+        if not self.system_access and len(self.tasks) == 1:
+            end_run = True
+        else:
+            end_run = False
+
+        for i, task in self.numbered_tasks():
+            tools.extend(
+                [
+                    task._create_complete_tool(task_id=i, end_run=end_run),
+                    task._create_fail_tool(task_id=i, end_run=end_run),
+                ]
+            )
 
         if self.user_access:
             tools.append(talk_to_human)
 
         final_tools = []
         for tool in tools:
-            if not isinstance(tool, AssistantTool):
+            if isinstance(tool, marvin.beta.assistants.Assistant):
+                tool = self.model_copy(update={"assistant": tool}).as_tool()
+            elif not isinstance(tool, AssistantTool):
                 tool = marvin.utilities.tools.tool_from_function(tool)
+
             if isinstance(tool, FunctionTool):
 
                 async def modified_fn(
@@ -405,17 +417,22 @@ def _get_openai_run_task(self):
         """
 
         @prefect_task(name="Execute OpenAI assistant run")
-        async def execute_openai_run(context: dict = None, run_kwargs: dict = None):
+        async def execute_openai_run(
+            context: dict = None, run_kwargs: dict = None
+        ) -> Run:
             run_kwargs = run_kwargs or {}
-            if "model" not in run_kwargs:
-                run_kwargs["model"] = self.assistant.model or settings.assistant_model
+            model = run_kwargs.pop(
+                "model", self.assistant.model or settings.assistant_model
+            )
+            thread = run_kwargs.pop("thread", self.flow.thread)
 
             run = Run(
                 assistant=self.assistant,
-                thread=self.flow.thread,
+                thread=thread,
                 instructions=self._get_instructions(context=context),
                 tools=self._get_tools(),
                 event_handler_class=AgentHandler,
+                model=model,
                 **run_kwargs,
             )
             await run.run_async()
@@ -452,6 +469,7 @@ async def execute_openai_run(context: dict = None, run_kwargs: dict = None):
                 key="steps",
                 description="All steps taken during the run.",
             )
+            return run
 
         return execute_openai_run
 
@@ -469,6 +487,7 @@ async def run_async(self, context: dict = None, **run_kwargs) -> list[AITask]:
                 any(t.status == TaskStatus.PENDING for t in self.tasks)
                 and counter < settings.max_agent_iterations
             ):
+                breakpoint()
                 openai_run(context=context, run_kwargs=run_kwargs)
                 counter += 1
 
@@ -476,6 +495,35 @@ async def run_async(self, context: dict = None, **run_kwargs) -> list[AITask]:
 
         return result
 
+    def as_tool(self):
+        thread = TEMP_THREADS.setdefault(self.assistant.model_dump_json(), Thread())
+
+        def _run(message: str, context: dict = None) -> list[str]:
+            task = self._get_openai_run_task()
+            run: Run = task(context=context, run_kwargs=dict(thread=thread))
+            return [m.model_dump_json() for m in run.messages]
+
+        return marvin.utilities.tools.tool_from_function(
+            _run,
+            name=f"call_ai_{self.assistant.name}",
+            description=inspect.cleandoc("""
+            Use this tool to talk to a sub-AI that can operate independently of
+            you. The sub-AI may have a different skillset or be able to access
+            different tools than you. The sub-AI will run one iteration and
+            respond to you. You may continue to invoke it multiple times in sequence, as
+            needed. 
+            
+            Note: you can only talk to one sub-AI at a time. Do not call in parallel or you will get an error about thread conflicts.
+            
+            ## Sub-AI Details
+            
+            - Name: {name}
+            - Instructions: {instructions}
+            """).format(
+                name=self.assistant.name, instructions=self.assistant.instructions
+            ),
+        )
+
 
 def ai_task(
     fn=None, *, objective: str = None, user_access: bool = None, **agent_kwargs: dict
@@ -541,12 +589,9 @@ def run_ai(
 
     # load flow
     flow = ctx.get("flow", None)
-    if flow is None:
-        flow = AIFlow()
 
     # create task
     ai_task = AITask[cast](objective=task, context=context)
-    flow.add_task(ai_task)
 
     # run agent
     agent = Agent(tasks=[ai_task], flow=flow, user_access=user_access, **agent_kwargs)
diff --git a/src/control_flow/flow.py b/src/control_flow/flow.py
index eae5c931..4eb99565 100644
--- a/src/control_flow/flow.py
+++ b/src/control_flow/flow.py
@@ -1,5 +1,5 @@
 import functools
-from typing import Callable, List, Optional, Union
+from typing import Callable, Optional, Union
 
 from marvin.beta.assistants import Assistant, Thread
 from marvin.beta.assistants.assistants import AssistantTool
@@ -10,13 +10,10 @@
 
 from control_flow.context import ctx
 
-from .task import AITask
-
 logger = get_logger(__name__)
 
 
 class AIFlow(BaseModel):
-    tasks: List[AITask] = []
     thread: Thread = Field(None, validate_default=True)
     assistant: Optional[Assistant] = Field(None, validate_default=True)
     tools: list[Union[AssistantTool, Callable]] = Field(None, validate_default=True)
@@ -47,24 +44,6 @@ def _default_tools(cls, v):
             v = []
         return v
 
-    def add_task(self, task: AITask):
-        if task.id is None:
-            task.id = len(self.tasks) + 1
-        elif task.id in {t.id for t in self.tasks}:
-            raise ValueError(f"Task with id {task.id} already exists.")
-        self.tasks.append(task)
-
-    def get_task_by_id(self, task_id: int) -> Optional[AITask]:
-        for task in self.tasks:
-            if task.id == task_id:
-                return task
-        return None
-
-    def update_task(self, task_id: int, status: str, result: str = None):
-        task = self.get_task_by_id(task_id)
-        if task:
-            task.update(status=status, result=result)
-
     def add_message(self, message: str):
         prefect_task(self.thread.add)(message)
 
@@ -101,7 +80,12 @@ def wrapper(
     ):
         p_fn = prefect_flow(fn)
         flow_assistant = _assistant or assistant
-        flow_thread = _thread or thread or flow_assistant.default_thread
+        flow_thread = (
+            _thread
+            or thread
+            or (flow_assistant.default_thread if flow_assistant else None)
+            or Thread()
+        )
         flow_instructions = _instructions or instructions
         flow_tools = _tools or tools
         flow_obj = AIFlow(
@@ -111,7 +95,9 @@ def wrapper(
             instructions=flow_instructions,
         )
 
-        logger.info(f'Executing AI flow "{fn.__name__}" on thread "{flow_thread.id}"')
+        logger.info(
+            f'Executing AI flow "{fn.__name__}" on thread "{flow_obj.thread.id}"'
+        )
 
         with ctx(flow=flow_obj):
             return p_fn(*args, **kwargs)
diff --git a/src/control_flow/settings.py b/src/control_flow/settings.py
index 8bf87444..9e46653b 100644
--- a/src/control_flow/settings.py
+++ b/src/control_flow/settings.py
@@ -5,9 +5,10 @@
 
 class ControlFlowSettings(BaseSettings):
     model_config: SettingsConfigDict = SettingsConfigDict(
+        env_prefix="CONTROLFLOW_",
         env_file=(
             ""
-            if os.getenv("CONTROL_FLOW_TEST_MODE")
+            if os.getenv("CONTROLFLOW_TEST_MODE")
             else ("~/.control_flow/.env", ".env")
         ),
         extra="allow",
@@ -19,6 +20,7 @@ class ControlFlowSettings(BaseSettings):
 class Settings(ControlFlowSettings):
     assistant_model: str = "gpt-4-1106-preview"
     max_agent_iterations: int = 10
+    use_prefect: bool = True
 
 
 settings = Settings()
diff --git a/src/control_flow/task.py b/src/control_flow/task.py
index 719cdbc3..438aa546 100644
--- a/src/control_flow/task.py
+++ b/src/control_flow/task.py
@@ -3,12 +3,11 @@
 
 import marvin
 import marvin.utilities.tools
+from marvin.beta.assistants.runs import EndRun
 from marvin.utilities.logging import get_logger
 from marvin.utilities.tools import FunctionTool
 from pydantic import BaseModel, Field, field_validator
 
-from control_flow.context import ctx
-
 T = TypeVar("T")
 logger = get_logger(__name__)
 
@@ -28,7 +27,6 @@ class AITask(BaseModel, Generic[T]):
     iterate until all tasks are completed.
     """
 
-    id: int = Field(None, validate_default=True)
     objective: str
     instructions: Optional[str] = None
     context: dict = Field(None, validate_default=True)
@@ -39,21 +37,15 @@ class AITask(BaseModel, Generic[T]):
     # internal
     model_config: dict = dict(validate_assignment=True, extra="forbid")
 
-    @field_validator("id", mode="before")
-    def _default_id(cls, v):
-        if v is None:
-            flow = ctx.get("flow")
-            if flow is not None:
-                v = len(flow.tasks) + 1
-        return v
-
     @field_validator("context", mode="before")
     def _default_context(cls, v):
         if v is None:
             v = {}
         return v
 
-    def _create_complete_tool(self) -> FunctionTool:
+    def _create_complete_tool(
+        self, task_id: int, end_run: bool = False
+    ) -> FunctionTool:
         """
         Create an agent-compatible tool for completing this task.
         """
@@ -65,26 +57,30 @@ def _create_complete_tool(self) -> FunctionTool:
             def complete(result: result_type):
                 self.result = result
                 self.status = TaskStatus.COMPLETED
+                if end_run:
+                    return EndRun()
 
             tool = marvin.utilities.tools.tool_from_function(
                 complete,
-                name=f"complete_task_{self.id}",
-                description=f"Mark task {self.id} completed",
+                name=f"complete_task_{task_id}",
+                description=f"Mark task {task_id} completed",
             )
         else:
 
             def complete():
                 self.status = TaskStatus.COMPLETED
+                if end_run:
+                    return EndRun()
 
             tool = marvin.utilities.tools.tool_from_function(
                 complete,
-                name=f"complete_task_{self.id}",
-                description=f"Mark task {self.id} completed",
+                name=f"complete_task_{task_id}",
+                description=f"Mark task {task_id} completed",
             )
 
         return tool
 
-    def _create_fail_tool(self) -> FunctionTool:
+    def _create_fail_tool(self, task_id: int, end_run: bool = False) -> FunctionTool:
         """
         Create an agent-compatible tool for failing this task.
         """
@@ -92,11 +88,13 @@ def _create_fail_tool(self) -> FunctionTool:
         def fail(message: Optional[str] = None):
             self.error = message
             self.status = TaskStatus.FAILED
+            if end_run:
+                return EndRun()
 
         tool = marvin.utilities.tools.tool_from_function(
             fail,
-            name=f"fail_task_{self.id}",
-            description=f"Mark task {self.id} failed",
+            name=f"fail_task_{task_id}",
+            description=f"Mark task {task_id} failed",
         )
         return tool