Skip to content

Commit

Permalink
Merge pull request #14 from jlowin/user-access
Browse files Browse the repository at this point in the history
Improve user access
  • Loading branch information
jlowin authored Apr 17, 2024
2 parents d1f72f0 + 01685d6 commit 8a061ac
Show file tree
Hide file tree
Showing 14 changed files with 247 additions and 129 deletions.
98 changes: 5 additions & 93 deletions src/control_flow/core/agent.py
Original file line number Diff line number Diff line change
@@ -1,60 +1,24 @@
import inspect
import json
import logging
from enum import Enum
from typing import Callable

from marvin.types import FunctionTool
from marvin.utilities.tools import tool_from_function
from prefect import task as prefect_task
from pydantic import Field

from control_flow.utilities.prefect import (
create_markdown_artifact,
wrap_prefect_tool,
)
from control_flow.utilities.types import Assistant, AssistantTool, ControlFlowModel
from control_flow.utilities.user_access import talk_to_human

logger = logging.getLogger(__name__)

TOOL_CALL_FUNCTION_RESULT_TEMPLATE = inspect.cleandoc(
"""
## Tool call: {name}
**Description:** {description}
## Arguments
```json
{args}
```
### Result
```json
{result}
```
"""
)


class AgentStatus(Enum):
INCOMPLETE = "incomplete"
COMPLETE = "complete"


def talk_to_human(message: str, get_response: bool = True) -> str:
"""
Send a message to the human user and optionally wait for a response.
If `get_response` is True, the function will return the user's response,
otherwise it will return a simple confirmation.
"""
print(message)
if get_response:
response = input("> ")
return response
return "Message sent to user"


class Agent(Assistant, ControlFlowModel):
user_access: bool = Field(
False,
Expand All @@ -65,61 +29,9 @@ class Agent(Assistant, ControlFlowModel):
description="If True, the agent will communicate with the controller via messages.",
)

def get_tools(self, user_access: bool = None) -> list[AssistantTool | Callable]:
if user_access is None:
user_access = self.user_access
def get_tools(self) -> list[AssistantTool | Callable]:
tools = super().get_tools()
if user_access:
if self.user_access:
tools.append(tool_from_function(talk_to_human))

wrapped_tools = []
for tool in tools:
wrapped_tools.append(self._wrap_prefect_tool(tool))
return tools

def _wrap_prefect_tool(self, tool: AssistantTool | Callable) -> AssistantTool:
if not isinstance(tool, AssistantTool):
tool = tool_from_function(tool)

if isinstance(tool, FunctionTool):
# for functions, we modify the function to become a Prefect task and
# publish an artifact that contains details about the function call

async def modified_fn(
*args,
# provide default args to avoid a late-binding issue
original_fn: Callable = tool.function._python_fn,
tool: FunctionTool = tool,
**kwargs,
):
# call fn
result = original_fn(*args, **kwargs)

# prepare artifact
passed_args = (
inspect.signature(original_fn).bind(*args, **kwargs).arguments
)
try:
passed_args = json.dumps(passed_args, indent=2)
except Exception:
pass
create_markdown_artifact(
markdown=TOOL_CALL_FUNCTION_RESULT_TEMPLATE.format(
name=tool.function.name,
description=tool.function.description or "(none provided)",
args=passed_args,
result=result,
),
key="result",
)

# return result
return result

# replace the function with the modified version
tool.function._python_fn = prefect_task(
modified_fn,
task_run_name=f"Tool call: {tool.function.name}",
)

return tool
return [wrap_prefect_tool(tool) for tool in tools]
22 changes: 14 additions & 8 deletions src/control_flow/core/controller/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@
from control_flow.utilities.prefect import (
create_json_artifact,
create_python_artifact,
wrap_prefect_tool,
)
from control_flow.utilities.types import Thread
from control_flow.utilities.types import FunctionTool, Thread

logger = logging.getLogger(__name__)

Expand All @@ -43,11 +44,6 @@ class Controller(BaseModel, ExposeSyncMethodsMixin):
# termination_strategy: TerminationStrategy
context: dict = {}
instructions: str = None
user_access: bool | None = Field(
None,
description="If True or False, overrides the user_access of the "
"agents. If None, the user_access setting of each agents is used.",
)
model_config: dict = dict(extra="forbid")

@field_validator("agents", mode="before")
Expand Down Expand Up @@ -118,17 +114,27 @@ async def run_agent(self, agent: Agent, thread: Thread = None) -> Run:

instructions = instructions_template.render()

tools = self.flow.tools + agent.get_tools(user_access=self.user_access)
tools = self.flow.tools + agent.get_tools()

for task in self.tasks:
task_id = self.flow.get_task_id(task)
tools = tools + task.get_tools(task_id=task_id)

# filter tools because duplicate names are not allowed
final_tools = []
final_tool_names = set()
for tool in tools:
if isinstance(tool, FunctionTool):
if tool.function.name in final_tool_names:
continue
final_tool_names.add(tool.function.name)
final_tools.append(wrap_prefect_tool(tool))

run = Run(
assistant=agent,
thread=thread or self.flow.thread,
instructions=instructions,
tools=tools,
tools=final_tools,
event_handler_class=AgentHandler,
)

Expand Down
22 changes: 15 additions & 7 deletions src/control_flow/core/controller/instruction_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ class CommunicationTemplate(Template):
make up answers (or put empty answers) for the others. Ask again and only
fail the task if you truly can not make progress.
{% else %}
You can not interact with a human at this time.
You can not interact with a human at this time. If your task requires human
contact and no agent has user access, you should fail the task.
{% endif %}
"""
Expand All @@ -95,10 +96,14 @@ class CollaborationTemplate(Template):
### Agents
{% for agent in other_agents %}
{{loop.index}}. "{{agent.name}}": {{agent.description}}
#### "{{agent.name}}"
Can talk to humans: {{agent.user_access}}
Description: {% if agent.description %}{{agent.description}}{% endif %}
{% endfor %}
{% if not other_agents %}
(There are no other agents currently participating in this workflow)
(No other agents are currently participating in this workflow)
{% endif %}
"""
other_agents: list[Agent]
Expand All @@ -108,31 +113,31 @@ class InstructionsTemplate(Template):
template: str = """
## Instructions
{% if flow_instructions %}
{% if flow_instructions -%}
### Workflow instructions
These instructions apply to the entire workflow:
{{ flow_instructions }}
{% endif %}
{% if controller_instructions %}
{% if controller_instructions -%}
### Controller instructions
These instructions apply to these tasks:
{{ controller_instructions }}
{% endif %}
{% if agent_instructions %}
{% if agent_instructions -%}
### Agent instructions
These instructions apply only to you:
{{ agent_instructions }}
{% endif %}
{% if additional_instructions %}
{% if additional_instructions -%}
### Additional instructions
These instructions were additionally provided for this part of the workflow:
Expand Down Expand Up @@ -176,6 +181,7 @@ class TasksTemplate(Template):
#### Task {{ controller.flow.get_task_id(task) }}
- Status: {{ task.status.value }}
- Objective: {{ task.objective }}
- User access: {{ task.user_access }}
{% if task.instructions %}
- Instructions: {{ task.instructions }}
{% endif %}
Expand All @@ -191,6 +197,7 @@ class TasksTemplate(Template):
{% endif %}
{% endfor %}
{% if controller.flow.completed_tasks(reverse=True, limit=20) %}
### Completed tasks
The following tasks were recently completed:
Expand All @@ -208,6 +215,7 @@ class TasksTemplate(Template):
{% endif %}
{% endfor %}
{% endif %}
"""
controller: Controller

Expand Down
10 changes: 8 additions & 2 deletions src/control_flow/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
from pydantic import Field

from control_flow.utilities.logging import get_logger
from control_flow.utilities.prefect import wrap_prefect_tool
from control_flow.utilities.types import AssistantTool, ControlFlowModel
from control_flow.utilities.user_access import talk_to_human

T = TypeVar("T")
logger = get_logger(__name__)
Expand All @@ -30,6 +32,7 @@ class Task(ControlFlowModel, Generic[T]):
tools: list[AssistantTool | Callable] = []
created_at: datetime.datetime = Field(default_factory=datetime.datetime.now)
completed_at: datetime.datetime | None = None
user_access: bool = False

def __hash__(self):
return id(self)
Expand Down Expand Up @@ -64,10 +67,13 @@ def _create_fail_tool(self, task_id: int) -> FunctionTool:
return tool

def get_tools(self, task_id: int) -> list[AssistantTool | Callable]:
return [
tools = self.tools + [
self._create_complete_tool(task_id),
self._create_fail_tool(task_id),
] + self.tools
]
if self.user_access:
tools.append(marvin.utilities.tools.tool_from_function(talk_to_human))
return [wrap_prefect_tool(t) for t in tools]

def complete(self, result: T):
self.result = result
Expand Down
38 changes: 27 additions & 11 deletions src/control_flow/dx.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,12 @@ def wrapper(


def ai_task(
fn=None, *, objective: str = None, user_access: bool = None, **agent_kwargs: dict
fn=None,
*,
objective: str = None,
agents: list[Agent] = None,
tools: list[AssistantTool | Callable] = None,
user_access: bool = None,
):
"""
Use a Python function to create an AI task. When the function is called, an
Expand All @@ -77,7 +82,11 @@ def ai_task(

if fn is None:
return functools.partial(
ai_task, objective=objective, user_access=user_access, **agent_kwargs
ai_task,
objective=objective,
agents=agents,
tools=tools,
user_access=user_access,
)

sig = inspect.signature(fn)
Expand All @@ -89,18 +98,18 @@ def ai_task(
objective = fn.__name__

@functools.wraps(fn)
def wrapper(*args, **kwargs):
def wrapper(*args, _agents: list[Agent] = None, **kwargs):
# first process callargs
bound = sig.bind(*args, **kwargs)
bound.apply_defaults()

# return run_ai.with_options(name=f"Task: {fn.__name__}")(
return run_ai(
return run_ai.with_options(name=f"Task: {fn.__name__}")(
tasks=objective,
agents=_agents or agents,
cast=fn.__annotations__.get("return"),
context=bound.arguments,
tools=tools,
user_access=user_access,
**agent_kwargs,
)

return wrapper
Expand All @@ -125,6 +134,7 @@ def run_ai(
agents: list[Agent] = None,
cast: T = NOT_PROVIDED,
context: dict = None,
tools: list[AssistantTool | Callable] = None,
user_access: bool = False,
) -> T | list[T]:
"""
Expand Down Expand Up @@ -155,20 +165,26 @@ def run_ai(

# create tasks
if tasks:
ai_tasks = [Task[cast](objective=t, context=context or {}) for t in tasks]
ai_tasks = [
Task[cast](
objective=t,
context=context or {},
user_access=user_access,
tools=tools or [],
)
for t in tasks
]
else:
ai_tasks = []

# create agent
if agents is None:
agents = [Agent()]
agents = [Agent(user_access=user_access)]

# create Controller
from control_flow.core.controller.controller import Controller

controller = Controller(
tasks=ai_tasks, agents=agents, flow=flow, user_access=user_access
)
controller = Controller(tasks=ai_tasks, agents=agents, flow=flow)
controller.run()

if ai_tasks:
Expand Down
Loading

0 comments on commit 8a061ac

Please sign in to comment.