Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion examples/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from pydantic import BaseModel

from .base import SupervisedTool, Supervisor
from .base import SupervisedTool, Supervisor, AsyncSupervisor


class AlternatingSupervisor(Supervisor[Any]):
Expand All @@ -23,6 +23,22 @@ def supervise(self, invocation: Any) -> tuple[bool, str]:
return True, ""


class AsyncAlternatingSupervisor(AsyncSupervisor[Any]):
"""
An async supervisor that alternates between allowing and denying tool invocations.
"""

def __init__(self) -> None:
self.counter = 0

async def supervise_async(self, invocation: Any) -> tuple[bool, str]:
if self.counter % 2 == 0:
self.counter += 1
return False, "Blocked by AsyncAlternatingSupervisor"
self.counter += 1
return True, ""


class ExampleTool(SupervisedTool):
"""
An example tool that can be supervised.
Expand All @@ -34,3 +50,6 @@ class ExampleTool(SupervisedTool):

def _run_unsupervised(self, *args: Any, **kwargs: Any) -> Any:
return "Hello, world!"

async def _arun_unsupervised(self, *args: Any, **kwargs: Any) -> Any:
return "Hello, async world!"
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ dev = [
"mypy==1.17.0",
"pytest==8.4.1",
"pytest-mock==3.14.1",
"pytest-asyncio==0.24.0",
"isort==5.12.0"
]

Expand All @@ -52,3 +53,6 @@ line-length = 120

[tool.mypy]
strict = true

[tool.pytest.ini_options]
asyncio_mode = "auto"
72 changes: 70 additions & 2 deletions src/talos/hypervisor/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,34 @@ class Supervisor(BaseModel, ABC):
"""

@abstractmethod
def approve(self, action: str, args: dict) -> tuple[bool, str | None]:
def approve(self, action: str, args: dict[str, Any]) -> tuple[bool, str | None]:
"""
Approves or denies an action.
"""
pass


class AsyncSupervisor(BaseModel, ABC):
"""
An abstract base class for async supervisors.
"""

@abstractmethod
async def approve_async(self, action: str, args: dict[str, Any]) -> tuple[bool, str | None]:
"""
Approves or denies an action asynchronously.
"""
pass


class RuleBasedSupervisor(Supervisor):
"""
A supervisor that uses a set of rules to approve or deny actions.
"""

rules: list[Rule]

def approve(self, action: str, args: dict) -> tuple[bool, str | None]:
def approve(self, action: str, args: dict[str, Any]) -> tuple[bool, str | None]:
"""
Approves or denies an action based on the rules.
"""
Expand All @@ -55,3 +68,58 @@ def approve(self, action: str, args: dict) -> tuple[bool, str | None]:
if not approved:
return False, error_message
return True, None


class AsyncRuleBasedSupervisor(AsyncSupervisor):
"""
An async supervisor that uses a set of rules to approve or deny actions.
"""

rules: list[Rule]

async def approve_async(self, action: str, args: dict[str, Any]) -> tuple[bool, str | None]:
"""
Approves or denies an action based on the rules asynchronously.
"""
for rule in self.rules:
if rule.tool_name == action:
for arg_name, validation_fn in rule.validations.items():
if arg_name in args:
approved, error_message = validation_fn(args[arg_name])
if not approved:
return False, error_message
return True, None


class SimpleSupervisor(Supervisor):
"""
A simple supervisor that approves every other tool call.
"""

counter: int = 0

def approve(self, action: str, args: dict[str, Any]) -> tuple[bool, str | None]:
"""
Approves or denies an action.
"""
self.counter += 1
if self.counter % 2 == 0:
return True, None
return False, "Denied by SimpleSupervisor"


class AsyncSimpleSupervisor(AsyncSupervisor):
"""
A simple async supervisor that approves every other tool call.
"""

counter: int = 0

async def approve_async(self, action: str, args: dict[str, Any]) -> tuple[bool, str | None]:
"""
Approves or denies an action asynchronously.
"""
self.counter += 1
if self.counter % 2 == 0:
return True, None
return False, "Denied by AsyncSimpleSupervisor"
31 changes: 29 additions & 2 deletions src/talos/tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,28 @@ def supervise(self, invocation: T) -> tuple[bool, str]:
raise NotImplementedError


class AsyncSupervisor(ABC, Generic[T]):
"""
An async supervisor can be used to analyze a tool invocation and determine if it is
malicious or not. This is the async version of the Supervisor interface.
"""

@abstractmethod
async def supervise_async(self, invocation: T) -> tuple[bool, str]:
"""
Analyze the tool invocation and determine if it is malicious or not.

Args:
invocation: The tool invocation to analyze.

Returns:
A tuple of a boolean and a string. If the invocation is malicious,
the boolean is False and the string is an error message. Otherwise,
the boolean is True and the string is empty.
"""
raise NotImplementedError


class SupervisedTool(BaseTool):
"""
A tool that has an optional supervisor. When a tool call is submitted, it
Expand All @@ -41,6 +63,7 @@ class SupervisedTool(BaseTool):
"""

supervisor: Supervisor[Any] | None = Field(default=None)
async_supervisor: AsyncSupervisor[Any] | None = Field(default=None)

def _run(self, *args: Any, **kwargs: Any) -> Any:
if self.supervisor:
Expand All @@ -50,8 +73,12 @@ def _run(self, *args: Any, **kwargs: Any) -> Any:
return self._run_unsupervised(*args, **kwargs)

async def _arun(self, *args: Any, **kwargs: Any) -> Any:
if self.supervisor:
# TODO: Add support for async supervisors.
if self.async_supervisor:
ok, message = await self.async_supervisor.supervise_async({"args": args, "kwargs": kwargs})
if not ok:
return message
elif self.supervisor:
# Fallback to sync supervisor for backward compatibility
ok, message = self.supervisor.supervise({"args": args, "kwargs": kwargs})
if not ok:
return message
Expand Down
29 changes: 28 additions & 1 deletion src/talos/tools/supervised_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from langchain_core.tools import BaseTool

from talos.hypervisor.supervisor import Supervisor
from talos.hypervisor.supervisor import Supervisor, AsyncSupervisor


class SupervisedTool(BaseTool):
Expand All @@ -14,6 +14,7 @@ class SupervisedTool(BaseTool):

tool: BaseTool
supervisor: Supervisor | None = None
async_supervisor: AsyncSupervisor | None = None
messages: list

def set_supervisor(self, supervisor: Supervisor | None):
Expand All @@ -22,6 +23,12 @@ def set_supervisor(self, supervisor: Supervisor | None):
"""
self.supervisor = supervisor

def set_async_supervisor(self, async_supervisor: AsyncSupervisor | None):
"""
Sets the async supervisor for the tool.
"""
self.async_supervisor = async_supervisor

def _run(self, *args: Any, **kwargs: Any) -> Any:
"""
Runs the tool.
Expand All @@ -34,3 +41,23 @@ def _run(self, *args: Any, **kwargs: Any) -> Any:
else:
return error_message or f"Tool call to '{self.name}' denied by supervisor."
return self.tool.run(tool_input, **kwargs)

async def _arun(self, *args: Any, **kwargs: Any) -> Any:
"""
Runs the tool asynchronously.
"""
tool_input = args[0] if args else kwargs
if self.async_supervisor:
approved, error_message = await self.async_supervisor.approve_async(self.name, tool_input)
if approved:
return await self.tool.arun(tool_input, **kwargs)
else:
return error_message or f"Tool call to '{self.name}' denied by async supervisor."
elif self.supervisor:
# Fallback to sync supervisor for backward compatibility
approved, error_message = self.supervisor.approve(self.name, tool_input)
if approved:
return await self.tool.arun(tool_input, **kwargs)
else:
return error_message or f"Tool call to '{self.name}' denied by supervisor."
return await self.tool.arun(tool_input, **kwargs)
25 changes: 24 additions & 1 deletion tests/simple_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import TYPE_CHECKING, Any

from talos.hypervisor.supervisor import Supervisor
from talos.hypervisor.supervisor import Supervisor, AsyncSupervisor

if TYPE_CHECKING:
from talos.core.agent import Agent
Expand All @@ -29,3 +29,26 @@ def approve(self, action: str, args: dict[str, Any]) -> tuple[bool, str | None]:
if self.counter % 2 == 0:
return True, None
return False, "Denied by SimpleSupervisor"


class AsyncSimpleSupervisor(AsyncSupervisor):
"""
A simple async supervisor that approves every other tool call.
"""

counter: int = 0

def set_agent(self, agent: "Agent"):
"""
Sets the agent to be supervised.
"""
pass

async def approve_async(self, action: str, args: dict[str, Any]) -> tuple[bool, str | None]:
"""
Approves or denies an action asynchronously.
"""
self.counter += 1
if self.counter % 2 == 0:
return True, None
return False, "Denied by AsyncSimpleSupervisor"
Loading