diff --git a/src/control_flow/core/agent.py b/src/control_flow/core/agent.py index c13e823d..dc7a6860 100644 --- a/src/control_flow/core/agent.py +++ b/src/control_flow/core/agent.py @@ -1,60 +1,24 @@ -import inspect -import json import logging from enum import Enum from typing import Callable -from marvin.types import FunctionTool from marvin.utilities.tools import tool_from_function -from prefect import task as prefect_task from pydantic import Field from control_flow.utilities.prefect import ( - create_markdown_artifact, + wrap_prefect_tool, ) from control_flow.utilities.types import Assistant, AssistantTool, ControlFlowModel +from control_flow.utilities.user_access import talk_to_human logger = logging.getLogger(__name__) -TOOL_CALL_FUNCTION_RESULT_TEMPLATE = inspect.cleandoc( - """ - ## Tool call: {name} - - **Description:** {description} - - ## Arguments - - ```json - {args} - ``` - - ### Result - - ```json - {result} - ``` - """ -) - class AgentStatus(Enum): INCOMPLETE = "incomplete" COMPLETE = "complete" -def talk_to_human(message: str, get_response: bool = True) -> str: - """ - Send a message to the human user and optionally wait for a response. - If `get_response` is True, the function will return the user's response, - otherwise it will return a simple confirmation. - """ - print(message) - if get_response: - response = input("> ") - return response - return "Message sent to user" - - class Agent(Assistant, ControlFlowModel): user_access: bool = Field( False, @@ -65,61 +29,9 @@ class Agent(Assistant, ControlFlowModel): description="If True, the agent will communicate with the controller via messages.", ) - def get_tools(self, user_access: bool = None) -> list[AssistantTool | Callable]: - if user_access is None: - user_access = self.user_access + def get_tools(self) -> list[AssistantTool | Callable]: tools = super().get_tools() - if user_access: + if self.user_access: tools.append(tool_from_function(talk_to_human)) - wrapped_tools = [] - for tool in tools: - wrapped_tools.append(self._wrap_prefect_tool(tool)) - return tools - - def _wrap_prefect_tool(self, tool: AssistantTool | Callable) -> AssistantTool: - if not isinstance(tool, AssistantTool): - tool = tool_from_function(tool) - - if isinstance(tool, FunctionTool): - # for functions, we modify the function to become a Prefect task and - # publish an artifact that contains details about the function call - - async def modified_fn( - *args, - # provide default args to avoid a late-binding issue - original_fn: Callable = tool.function._python_fn, - tool: FunctionTool = tool, - **kwargs, - ): - # call fn - result = original_fn(*args, **kwargs) - - # prepare artifact - passed_args = ( - inspect.signature(original_fn).bind(*args, **kwargs).arguments - ) - try: - passed_args = json.dumps(passed_args, indent=2) - except Exception: - pass - create_markdown_artifact( - markdown=TOOL_CALL_FUNCTION_RESULT_TEMPLATE.format( - name=tool.function.name, - description=tool.function.description or "(none provided)", - args=passed_args, - result=result, - ), - key="result", - ) - - # return result - return result - - # replace the function with the modified version - tool.function._python_fn = prefect_task( - modified_fn, - task_run_name=f"Tool call: {tool.function.name}", - ) - - return tool + return [wrap_prefect_tool(tool) for tool in tools] diff --git a/src/control_flow/core/controller/controller.py b/src/control_flow/core/controller/controller.py index cee12cea..5c1b3a9a 100644 --- a/src/control_flow/core/controller/controller.py +++ b/src/control_flow/core/controller/controller.py @@ -22,8 +22,9 @@ from control_flow.utilities.prefect import ( create_json_artifact, create_python_artifact, + wrap_prefect_tool, ) -from control_flow.utilities.types import Thread +from control_flow.utilities.types import FunctionTool, Thread logger = logging.getLogger(__name__) @@ -43,11 +44,6 @@ class Controller(BaseModel, ExposeSyncMethodsMixin): # termination_strategy: TerminationStrategy context: dict = {} instructions: str = None - user_access: bool | None = Field( - None, - description="If True or False, overrides the user_access of the " - "agents. If None, the user_access setting of each agents is used.", - ) model_config: dict = dict(extra="forbid") @field_validator("agents", mode="before") @@ -118,17 +114,27 @@ async def run_agent(self, agent: Agent, thread: Thread = None) -> Run: instructions = instructions_template.render() - tools = self.flow.tools + agent.get_tools(user_access=self.user_access) + tools = self.flow.tools + agent.get_tools() for task in self.tasks: task_id = self.flow.get_task_id(task) tools = tools + task.get_tools(task_id=task_id) + # filter tools because duplicate names are not allowed + final_tools = [] + final_tool_names = set() + for tool in tools: + if isinstance(tool, FunctionTool): + if tool.function.name in final_tool_names: + continue + final_tool_names.add(tool.function.name) + final_tools.append(wrap_prefect_tool(tool)) + run = Run( assistant=agent, thread=thread or self.flow.thread, instructions=instructions, - tools=tools, + tools=final_tools, event_handler_class=AgentHandler, ) diff --git a/src/control_flow/core/controller/instruction_template.py b/src/control_flow/core/controller/instruction_template.py index c597167e..5e8f260b 100644 --- a/src/control_flow/core/controller/instruction_template.py +++ b/src/control_flow/core/controller/instruction_template.py @@ -75,7 +75,8 @@ class CommunicationTemplate(Template): make up answers (or put empty answers) for the others. Ask again and only fail the task if you truly can not make progress. {% else %} - You can not interact with a human at this time. + You can not interact with a human at this time. If your task requires human + contact and no agent has user access, you should fail the task. {% endif %} """ @@ -95,10 +96,14 @@ class CollaborationTemplate(Template): ### Agents {% for agent in other_agents %} - {{loop.index}}. "{{agent.name}}": {{agent.description}} + + #### "{{agent.name}}" + Can talk to humans: {{agent.user_access}} + Description: {% if agent.description %}{{agent.description}}{% endif %} + {% endfor %} {% if not other_agents %} - (There are no other agents currently participating in this workflow) + (No other agents are currently participating in this workflow) {% endif %} """ other_agents: list[Agent] @@ -108,7 +113,7 @@ class InstructionsTemplate(Template): template: str = """ ## Instructions - {% if flow_instructions %} + {% if flow_instructions -%} ### Workflow instructions These instructions apply to the entire workflow: @@ -116,7 +121,7 @@ class InstructionsTemplate(Template): {{ flow_instructions }} {% endif %} - {% if controller_instructions %} + {% if controller_instructions -%} ### Controller instructions These instructions apply to these tasks: @@ -124,7 +129,7 @@ class InstructionsTemplate(Template): {{ controller_instructions }} {% endif %} - {% if agent_instructions %} + {% if agent_instructions -%} ### Agent instructions These instructions apply only to you: @@ -132,7 +137,7 @@ class InstructionsTemplate(Template): {{ agent_instructions }} {% endif %} - {% if additional_instructions %} + {% if additional_instructions -%} ### Additional instructions These instructions were additionally provided for this part of the workflow: @@ -176,6 +181,7 @@ class TasksTemplate(Template): #### Task {{ controller.flow.get_task_id(task) }} - Status: {{ task.status.value }} - Objective: {{ task.objective }} + - User access: {{ task.user_access }} {% if task.instructions %} - Instructions: {{ task.instructions }} {% endif %} @@ -191,6 +197,7 @@ class TasksTemplate(Template): {% endif %} {% endfor %} + {% if controller.flow.completed_tasks(reverse=True, limit=20) %} ### Completed tasks The following tasks were recently completed: @@ -208,6 +215,7 @@ class TasksTemplate(Template): {% endif %} {% endfor %} + {% endif %} """ controller: Controller diff --git a/src/control_flow/core/task.py b/src/control_flow/core/task.py index 09672333..f8882639 100644 --- a/src/control_flow/core/task.py +++ b/src/control_flow/core/task.py @@ -8,7 +8,9 @@ from pydantic import Field from control_flow.utilities.logging import get_logger +from control_flow.utilities.prefect import wrap_prefect_tool from control_flow.utilities.types import AssistantTool, ControlFlowModel +from control_flow.utilities.user_access import talk_to_human T = TypeVar("T") logger = get_logger(__name__) @@ -30,6 +32,7 @@ class Task(ControlFlowModel, Generic[T]): tools: list[AssistantTool | Callable] = [] created_at: datetime.datetime = Field(default_factory=datetime.datetime.now) completed_at: datetime.datetime | None = None + user_access: bool = False def __hash__(self): return id(self) @@ -64,10 +67,13 @@ def _create_fail_tool(self, task_id: int) -> FunctionTool: return tool def get_tools(self, task_id: int) -> list[AssistantTool | Callable]: - return [ + tools = self.tools + [ self._create_complete_tool(task_id), self._create_fail_tool(task_id), - ] + self.tools + ] + if self.user_access: + tools.append(marvin.utilities.tools.tool_from_function(talk_to_human)) + return [wrap_prefect_tool(t) for t in tools] def complete(self, result: T): self.result = result diff --git a/src/control_flow/dx.py b/src/control_flow/dx.py index fbf41b84..5b576768 100644 --- a/src/control_flow/dx.py +++ b/src/control_flow/dx.py @@ -68,7 +68,12 @@ def wrapper( def ai_task( - fn=None, *, objective: str = None, user_access: bool = None, **agent_kwargs: dict + fn=None, + *, + objective: str = None, + agents: list[Agent] = None, + tools: list[AssistantTool | Callable] = None, + user_access: bool = None, ): """ Use a Python function to create an AI task. When the function is called, an @@ -77,7 +82,11 @@ def ai_task( if fn is None: return functools.partial( - ai_task, objective=objective, user_access=user_access, **agent_kwargs + ai_task, + objective=objective, + agents=agents, + tools=tools, + user_access=user_access, ) sig = inspect.signature(fn) @@ -89,18 +98,18 @@ def ai_task( objective = fn.__name__ @functools.wraps(fn) - def wrapper(*args, **kwargs): + def wrapper(*args, _agents: list[Agent] = None, **kwargs): # first process callargs bound = sig.bind(*args, **kwargs) bound.apply_defaults() - # return run_ai.with_options(name=f"Task: {fn.__name__}")( - return run_ai( + return run_ai.with_options(name=f"Task: {fn.__name__}")( tasks=objective, + agents=_agents or agents, cast=fn.__annotations__.get("return"), context=bound.arguments, + tools=tools, user_access=user_access, - **agent_kwargs, ) return wrapper @@ -125,6 +134,7 @@ def run_ai( agents: list[Agent] = None, cast: T = NOT_PROVIDED, context: dict = None, + tools: list[AssistantTool | Callable] = None, user_access: bool = False, ) -> T | list[T]: """ @@ -155,20 +165,26 @@ def run_ai( # create tasks if tasks: - ai_tasks = [Task[cast](objective=t, context=context or {}) for t in tasks] + ai_tasks = [ + Task[cast]( + objective=t, + context=context or {}, + user_access=user_access, + tools=tools or [], + ) + for t in tasks + ] else: ai_tasks = [] # create agent if agents is None: - agents = [Agent()] + agents = [Agent(user_access=user_access)] # create Controller from control_flow.core.controller.controller import Controller - controller = Controller( - tasks=ai_tasks, agents=agents, flow=flow, user_access=user_access - ) + controller = Controller(tasks=ai_tasks, agents=agents, flow=flow) controller.run() if ai_tasks: diff --git a/src/control_flow/utilities/marvin.py b/src/control_flow/utilities/marvin.py index 709f8686..5fad7e80 100644 --- a/src/control_flow/utilities/marvin.py +++ b/src/control_flow/utilities/marvin.py @@ -8,7 +8,9 @@ from openai.types.chat import ChatCompletion from prefect import task as prefect_task -from control_flow.utilities.prefect import create_json_artifact +from control_flow.utilities.prefect import ( + create_json_artifact, +) original_classify_async = marvin.classify_async original_cast_async = marvin.cast_async diff --git a/src/control_flow/utilities/prefect.py b/src/control_flow/utilities/prefect.py index 49f5d73e..cf0025f4 100644 --- a/src/control_flow/utilities/prefect.py +++ b/src/control_flow/utilities/prefect.py @@ -1,12 +1,20 @@ -from typing import Any +import inspect +import json +from typing import Any, Callable from uuid import UUID +import prefect +from marvin.types import FunctionTool from marvin.utilities.asyncio import run_sync +from marvin.utilities.tools import tool_from_function from prefect import get_client as get_prefect_client +from prefect import task as prefect_task from prefect.artifacts import ArtifactRequest from prefect.context import FlowRunContext, TaskRunContext from pydantic import TypeAdapter +from control_flow.utilities.types import AssistantTool + def create_markdown_artifact( key: str, @@ -82,3 +90,76 @@ def create_python_artifact( task_run_id=task_run_id, flow_run_id=flow_run_id, ) + + +TOOL_CALL_FUNCTION_RESULT_TEMPLATE = inspect.cleandoc( + """ + ## Tool call: {name} + + **Description:** {description} + + ## Arguments + + ```json + {args} + ``` + + ### Result + + ```json + {result} + ``` + """ +) + + +def wrap_prefect_tool(tool: AssistantTool | Callable) -> AssistantTool: + """ + Wraps a Marvin tool in a prefect task + """ + if not isinstance(tool, AssistantTool): + tool = tool_from_function(tool) + + if isinstance(tool, FunctionTool): + # for functions, we modify the function to become a Prefect task and + # publish an artifact that contains details about the function call + + if isinstance(tool.function._python_fn, prefect.tasks.Task): + return tool + + async def modified_fn( + *args, + # provide default args to avoid a late-binding issue + original_fn: Callable = tool.function._python_fn, + tool: FunctionTool = tool, + **kwargs, + ): + # call fn + result = original_fn(*args, **kwargs) + + # prepare artifact + passed_args = inspect.signature(original_fn).bind(*args, **kwargs).arguments + try: + passed_args = json.dumps(passed_args, indent=2) + except Exception: + pass + create_markdown_artifact( + markdown=TOOL_CALL_FUNCTION_RESULT_TEMPLATE.format( + name=tool.function.name, + description=tool.function.description or "(none provided)", + args=passed_args, + result=result, + ), + key="result", + ) + + # return result + return result + + # replace the function with the modified version + tool.function._python_fn = prefect_task( + modified_fn, + task_run_name=f"Tool call: {tool.function.name}", + ) + + return tool diff --git a/src/control_flow/utilities/types.py b/src/control_flow/utilities/types.py index 5b4c45d2..ab398c82 100644 --- a/src/control_flow/utilities/types.py +++ b/src/control_flow/utilities/types.py @@ -1,5 +1,6 @@ from marvin.beta.assistants import Assistant, Thread from marvin.beta.assistants.assistants import AssistantTool +from marvin.types import FunctionTool from marvin.utilities.asyncio import ExposeSyncMethodsMixin from pydantic import BaseModel diff --git a/src/control_flow/utilities/user_access.py b/src/control_flow/utilities/user_access.py new file mode 100644 index 00000000..2012ae2c --- /dev/null +++ b/src/control_flow/utilities/user_access.py @@ -0,0 +1,11 @@ +def talk_to_human(message: str, get_response: bool = True) -> str: + """ + Send a message to the human user and optionally wait for a response. + If `get_response` is True, the function will return the user's response, + otherwise it will return a simple confirmation. + """ + print(message) + if get_response: + response = input("> ") + return response + return "Message sent to user." diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..d4bb2090 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1 @@ +from .fixtures import * diff --git a/tests/fixtures/__init__.py b/tests/fixtures/__init__.py new file mode 100644 index 00000000..7aa102be --- /dev/null +++ b/tests/fixtures/__init__.py @@ -0,0 +1 @@ +from .mocks import * diff --git a/tests/fixtures/mocks.py b/tests/fixtures/mocks.py new file mode 100644 index 00000000..ff85345a --- /dev/null +++ b/tests/fixtures/mocks.py @@ -0,0 +1,20 @@ +from unittest.mock import Mock, patch + +import pytest +from control_flow.utilities.user_access import talk_to_human + + +@pytest.fixture(autouse=True) +def mock_talk_to_human(): + """Return an empty default handler instead of a print handler to avoid + printing assistant output during tests""" + + def mock_talk_to_human(message: str, get_response: bool) -> str: + print(dict(message=message, get_response=get_response)) + return "Message sent to user." + + mock_talk_to_human.__doc__ = talk_to_human.__doc__ + with patch( + "control_flow.utilities.user_access.mock_talk_to_human", new=talk_to_human + ): + yield diff --git a/tests/flows/test_sign_guestbook.py b/tests/flows/test_sign_guestbook.py index 66b7b2e8..b3ac7479 100644 --- a/tests/flows/test_sign_guestbook.py +++ b/tests/flows/test_sign_guestbook.py @@ -1,11 +1,10 @@ -from control_flow import Assistant, run_ai -from control_flow.core.flow import ai_flow +from control_flow import Agent, ai_flow, run_ai # define assistants -a = Assistant(name="a") -b = Assistant(name="b") -c = Assistant(name="c") +a = Agent(name="a") +b = Agent(name="b") +c = Agent(name="c") # define tools @@ -34,7 +33,7 @@ def guestbook_flow(): sign their names for the task to be complete. You can read the sign to see if that has happened yet. You can not sign for another assistant. """, - assistants=[a, b, c], + agents=[a, b, c], tools=[sign, view_guestbook], ) diff --git a/tests/flows/test_user_access.py b/tests/flows/test_user_access.py new file mode 100644 index 00000000..107de649 --- /dev/null +++ b/tests/flows/test_user_access.py @@ -0,0 +1,54 @@ +import pytest +from control_flow import Agent, ai_flow, run_ai + +# define assistants + +user_agent = Agent(name="user-agent", user_access=True) +non_user_agent = Agent(name="non-user-agent", user_access=False) + + +def test_no_user_access_fails(): + @ai_flow + def user_access_flow(): + run_ai( + "This task requires human user access. Inform the user that today is a good day.", + agents=[non_user_agent], + ) + + with pytest.raises(ValueError): + user_access_flow() + + +def test_user_access_agent_succeeds(): + @ai_flow + def user_access_flow(): + run_ai( + "This task requires human user access. Inform the user that today is a good day.", + agents=[user_agent], + ) + + assert user_access_flow() + + +def test_user_access_task_succeeds(): + @ai_flow + def user_access_flow(): + run_ai( + "This task requires human user access. Inform the user that today is a good day.", + agents=[non_user_agent], + user_access=True, + ) + + assert user_access_flow() + + +def test_user_access_agent_and_task_succeeds(): + @ai_flow + def user_access_flow(): + run_ai( + "This task requires human user access. Inform the user that today is a good day.", + agents=[user_agent], + user_access=True, + ) + + assert user_access_flow()