Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use new marvin endrun functionality #2

Merged
merged 1 commit into from
Apr 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 78 additions & 33 deletions src/control_flow/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
import marvin
import marvin.utilities.tools
import prefect
from marvin.beta.assistants import Thread
from marvin.beta.assistants.assistants import Assistant
from marvin.beta.assistants.handlers import PrintHandler
from marvin.beta.assistants.runs import Run
from marvin.tools.assistants import AssistantTool, CancelRun
from marvin.tools.assistants import AssistantTool, EndRun
from marvin.types import FunctionTool
from marvin.utilities.asyncio import ExposeSyncMethodsMixin, expose_sync_method
from marvin.utilities.jinja import Environment
Expand All @@ -30,18 +31,19 @@
T = TypeVar("T")
logger = logging.getLogger(__name__)

TEMP_THREADS = {}

TOOL_CALL_CODE_INTERPRETER_TEMPLATE = inspect.cleandoc(
"""
# Tool call: code interpreter
## Tool call: code interpreter

## Code
### Code

```python
{code}
```

## Result
### Result

```json
{result}
Expand All @@ -51,9 +53,9 @@

TOOL_CALL_FUNCTION_ARGS_TEMPLATE = inspect.cleandoc(
"""
# Tool call: {name}
## Tool call: {name}

## Arguments
### Arguments

```json
{args}
Expand All @@ -62,7 +64,7 @@
)
TOOL_CALL_FUNCTION_RESULT_TEMPLATE = inspect.cleandoc(
"""
# Tool call: {name}
## Tool call: {name}

**Description:** {description}

Expand All @@ -72,7 +74,7 @@
{args}
```

## Result
### Result

```json
{result}
Expand All @@ -90,10 +92,7 @@

## Instructions

In addition to completing your tasks, these are your current instructions. You
must follow them at all times, even when using a tool to talk to a user. Note
that instructions can change at any time and the thread history may reflect
different instructions than these:
Follow these instructions at all times:

{% if assistant.instructions -%}
- {{ assistant.instructions }}
Expand Down Expand Up @@ -123,8 +122,8 @@
especially when working with a human user.


{% for task in agent.tasks %}
### Task {{ task.id }}
{% for task_id, task in agent.numbered_tasks() %}
### Task {{ task_id }}
- Status: {{ task.status.value }}
- Objective: {{ task.objective }}
{% if task.instructions %}
Expand Down Expand Up @@ -166,6 +165,7 @@
system works to them. They can only see messages you send them via tool, not the
rest of the thread. When dealing with humans, you may not always get a clear or
correct response. You may need to ask multiple times or rephrase your questions.
You should also interpret human responses broadly and not be too literal.
{% else %}
You can not communicate with a human user at this time.
{% endif %}
Expand Down Expand Up @@ -273,14 +273,14 @@ def talk_to_human(message: str, get_response: bool = True) -> str:

def end_run():
"""Use this tool to end the run."""
raise CancelRun()
return EndRun()


class Agent(BaseModel, Generic[T], ExposeSyncMethodsMixin):
tasks: list[AITask] = []
flow: AIFlow = Field(None, validate_default=True)
assistant: Assistant = Field(None, validate_default=True)
tools: list[Union[AssistantTool, Callable]] = []
tools: list[Union[AssistantTool, Assistant, Callable]] = []
context: dict = Field(None, validate_default=True)
user_access: bool = Field(
None,
Expand Down Expand Up @@ -321,17 +321,14 @@ def _default_assistant(cls, v):
v = Assistant()
return v

@field_validator("user_access", mode="before")
def _default_user_access(cls, v):
@field_validator("user_access", "system_access", mode="before")
def _default_access(cls, v):
if v is None:
v = False
return v

@field_validator("system_access", mode="before")
def _default_system_access(cls, v):
if v is None:
v = False
return v
def numbered_tasks(self) -> list[tuple[int, AITask]]:
return [(i + 1, task) for i, task in enumerate(self.tasks)]

def _get_instructions(self, context: dict = None):
instructions = Environment.render(
Expand All @@ -351,16 +348,31 @@ def _get_tools(self) -> list[AssistantTool]:
if not self.tasks:
tools.append(end_run)

for task in self.tasks:
tools.extend([task._create_complete_tool(), task._create_fail_tool()])
# if there is only one task, and the agent can't send a response to the
# system, then we can quit as soon as it is marked finished
if not self.system_access and len(self.tasks) == 1:
end_run = True
else:
end_run = False

for i, task in self.numbered_tasks():
tools.extend(
[
task._create_complete_tool(task_id=i, end_run=end_run),
task._create_fail_tool(task_id=i, end_run=end_run),
]
)

if self.user_access:
tools.append(talk_to_human)

final_tools = []
for tool in tools:
if not isinstance(tool, AssistantTool):
if isinstance(tool, marvin.beta.assistants.Assistant):
tool = self.model_copy(update={"assistant": tool}).as_tool()
elif not isinstance(tool, AssistantTool):
tool = marvin.utilities.tools.tool_from_function(tool)

if isinstance(tool, FunctionTool):

async def modified_fn(
Expand Down Expand Up @@ -405,17 +417,22 @@ def _get_openai_run_task(self):
"""

@prefect_task(name="Execute OpenAI assistant run")
async def execute_openai_run(context: dict = None, run_kwargs: dict = None):
async def execute_openai_run(
context: dict = None, run_kwargs: dict = None
) -> Run:
run_kwargs = run_kwargs or {}
if "model" not in run_kwargs:
run_kwargs["model"] = self.assistant.model or settings.assistant_model
model = run_kwargs.pop(
"model", self.assistant.model or settings.assistant_model
)
thread = run_kwargs.pop("thread", self.flow.thread)

run = Run(
assistant=self.assistant,
thread=self.flow.thread,
thread=thread,
instructions=self._get_instructions(context=context),
tools=self._get_tools(),
event_handler_class=AgentHandler,
model=model,
**run_kwargs,
)
await run.run_async()
Expand Down Expand Up @@ -452,6 +469,7 @@ async def execute_openai_run(context: dict = None, run_kwargs: dict = None):
key="steps",
description="All steps taken during the run.",
)
return run

return execute_openai_run

Expand All @@ -469,13 +487,43 @@ async def run_async(self, context: dict = None, **run_kwargs) -> list[AITask]:
any(t.status == TaskStatus.PENDING for t in self.tasks)
and counter < settings.max_agent_iterations
):
breakpoint()
openai_run(context=context, run_kwargs=run_kwargs)
counter += 1

result = [t.result for t in self.tasks if t.status == TaskStatus.COMPLETED]

return result

def as_tool(self):
thread = TEMP_THREADS.setdefault(self.assistant.model_dump_json(), Thread())

def _run(message: str, context: dict = None) -> list[str]:
task = self._get_openai_run_task()
run: Run = task(context=context, run_kwargs=dict(thread=thread))
return [m.model_dump_json() for m in run.messages]

return marvin.utilities.tools.tool_from_function(
_run,
name=f"call_ai_{self.assistant.name}",
description=inspect.cleandoc("""
Use this tool to talk to a sub-AI that can operate independently of
you. The sub-AI may have a different skillset or be able to access
different tools than you. The sub-AI will run one iteration and
respond to you. You may continue to invoke it multiple times in sequence, as
needed.

Note: you can only talk to one sub-AI at a time. Do not call in parallel or you will get an error about thread conflicts.

## Sub-AI Details

- Name: {name}
- Instructions: {instructions}
""").format(
name=self.assistant.name, instructions=self.assistant.instructions
),
)


def ai_task(
fn=None, *, objective: str = None, user_access: bool = None, **agent_kwargs: dict
Expand Down Expand Up @@ -541,12 +589,9 @@ def run_ai(

# load flow
flow = ctx.get("flow", None)
if flow is None:
flow = AIFlow()

# create task
ai_task = AITask[cast](objective=task, context=context)
flow.add_task(ai_task)

# run agent
agent = Agent(tasks=[ai_task], flow=flow, user_access=user_access, **agent_kwargs)
Expand Down
34 changes: 10 additions & 24 deletions src/control_flow/flow.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import functools
from typing import Callable, List, Optional, Union
from typing import Callable, Optional, Union

from marvin.beta.assistants import Assistant, Thread
from marvin.beta.assistants.assistants import AssistantTool
Expand All @@ -10,13 +10,10 @@

from control_flow.context import ctx

from .task import AITask

logger = get_logger(__name__)


class AIFlow(BaseModel):
tasks: List[AITask] = []
thread: Thread = Field(None, validate_default=True)
assistant: Optional[Assistant] = Field(None, validate_default=True)
tools: list[Union[AssistantTool, Callable]] = Field(None, validate_default=True)
Expand Down Expand Up @@ -47,24 +44,6 @@ def _default_tools(cls, v):
v = []
return v

def add_task(self, task: AITask):
if task.id is None:
task.id = len(self.tasks) + 1
elif task.id in {t.id for t in self.tasks}:
raise ValueError(f"Task with id {task.id} already exists.")
self.tasks.append(task)

def get_task_by_id(self, task_id: int) -> Optional[AITask]:
for task in self.tasks:
if task.id == task_id:
return task
return None

def update_task(self, task_id: int, status: str, result: str = None):
task = self.get_task_by_id(task_id)
if task:
task.update(status=status, result=result)

def add_message(self, message: str):
prefect_task(self.thread.add)(message)

Expand Down Expand Up @@ -101,7 +80,12 @@ def wrapper(
):
p_fn = prefect_flow(fn)
flow_assistant = _assistant or assistant
flow_thread = _thread or thread or flow_assistant.default_thread
flow_thread = (
_thread
or thread
or (flow_assistant.default_thread if flow_assistant else None)
or Thread()
)
flow_instructions = _instructions or instructions
flow_tools = _tools or tools
flow_obj = AIFlow(
Expand All @@ -111,7 +95,9 @@ def wrapper(
instructions=flow_instructions,
)

logger.info(f'Executing AI flow "{fn.__name__}" on thread "{flow_thread.id}"')
logger.info(
f'Executing AI flow "{fn.__name__}" on thread "{flow_obj.thread.id}"'
)

with ctx(flow=flow_obj):
return p_fn(*args, **kwargs)
Expand Down
4 changes: 3 additions & 1 deletion src/control_flow/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@

class ControlFlowSettings(BaseSettings):
model_config: SettingsConfigDict = SettingsConfigDict(
env_prefix="CONTROLFLOW_",
env_file=(
""
if os.getenv("CONTROL_FLOW_TEST_MODE")
if os.getenv("CONTROLFLOW_TEST_MODE")
else ("~/.control_flow/.env", ".env")
),
extra="allow",
Expand All @@ -19,6 +20,7 @@ class ControlFlowSettings(BaseSettings):
class Settings(ControlFlowSettings):
assistant_model: str = "gpt-4-1106-preview"
max_agent_iterations: int = 10
use_prefect: bool = True


settings = Settings()
Loading
Loading