diff --git a/examples/example.py b/examples/example.py index e9497e35..a80ca67f 100644 --- a/examples/example.py +++ b/examples/example.py @@ -4,7 +4,7 @@ from pydantic import BaseModel -from .base import SupervisedTool, Supervisor +from .base import SupervisedTool, Supervisor, AsyncSupervisor class AlternatingSupervisor(Supervisor[Any]): @@ -23,6 +23,22 @@ def supervise(self, invocation: Any) -> tuple[bool, str]: return True, "" +class AsyncAlternatingSupervisor(AsyncSupervisor[Any]): + """ + An async supervisor that alternates between allowing and denying tool invocations. + """ + + def __init__(self) -> None: + self.counter = 0 + + async def supervise_async(self, invocation: Any) -> tuple[bool, str]: + if self.counter % 2 == 0: + self.counter += 1 + return False, "Blocked by AsyncAlternatingSupervisor" + self.counter += 1 + return True, "" + + class ExampleTool(SupervisedTool): """ An example tool that can be supervised. @@ -34,3 +50,6 @@ class ExampleTool(SupervisedTool): def _run_unsupervised(self, *args: Any, **kwargs: Any) -> Any: return "Hello, world!" + + async def _arun_unsupervised(self, *args: Any, **kwargs: Any) -> Any: + return "Hello, async world!" diff --git a/pyproject.toml b/pyproject.toml index 518cc4c3..f41eca76 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,7 @@ dev = [ "mypy==1.17.0", "pytest==8.4.1", "pytest-mock==3.14.1", + "pytest-asyncio==0.24.0", "isort==5.12.0" ] @@ -52,3 +53,6 @@ line-length = 120 [tool.mypy] strict = true + +[tool.pytest.ini_options] +asyncio_mode = "auto" diff --git a/src/talos/hypervisor/supervisor.py b/src/talos/hypervisor/supervisor.py index 2ec42232..9d7b7cc1 100644 --- a/src/talos/hypervisor/supervisor.py +++ b/src/talos/hypervisor/supervisor.py @@ -29,13 +29,26 @@ class Supervisor(BaseModel, ABC): """ @abstractmethod - def approve(self, action: str, args: dict) -> tuple[bool, str | None]: + def approve(self, action: str, args: dict[str, Any]) -> tuple[bool, str | None]: """ Approves or denies an action. """ pass +class AsyncSupervisor(BaseModel, ABC): + """ + An abstract base class for async supervisors. + """ + + @abstractmethod + async def approve_async(self, action: str, args: dict[str, Any]) -> tuple[bool, str | None]: + """ + Approves or denies an action asynchronously. + """ + pass + + class RuleBasedSupervisor(Supervisor): """ A supervisor that uses a set of rules to approve or deny actions. @@ -43,7 +56,7 @@ class RuleBasedSupervisor(Supervisor): rules: list[Rule] - def approve(self, action: str, args: dict) -> tuple[bool, str | None]: + def approve(self, action: str, args: dict[str, Any]) -> tuple[bool, str | None]: """ Approves or denies an action based on the rules. """ @@ -55,3 +68,58 @@ def approve(self, action: str, args: dict) -> tuple[bool, str | None]: if not approved: return False, error_message return True, None + + +class AsyncRuleBasedSupervisor(AsyncSupervisor): + """ + An async supervisor that uses a set of rules to approve or deny actions. + """ + + rules: list[Rule] + + async def approve_async(self, action: str, args: dict[str, Any]) -> tuple[bool, str | None]: + """ + Approves or denies an action based on the rules asynchronously. + """ + for rule in self.rules: + if rule.tool_name == action: + for arg_name, validation_fn in rule.validations.items(): + if arg_name in args: + approved, error_message = validation_fn(args[arg_name]) + if not approved: + return False, error_message + return True, None + + +class SimpleSupervisor(Supervisor): + """ + A simple supervisor that approves every other tool call. + """ + + counter: int = 0 + + def approve(self, action: str, args: dict[str, Any]) -> tuple[bool, str | None]: + """ + Approves or denies an action. + """ + self.counter += 1 + if self.counter % 2 == 0: + return True, None + return False, "Denied by SimpleSupervisor" + + +class AsyncSimpleSupervisor(AsyncSupervisor): + """ + A simple async supervisor that approves every other tool call. + """ + + counter: int = 0 + + async def approve_async(self, action: str, args: dict[str, Any]) -> tuple[bool, str | None]: + """ + Approves or denies an action asynchronously. + """ + self.counter += 1 + if self.counter % 2 == 0: + return True, None + return False, "Denied by AsyncSimpleSupervisor" diff --git a/src/talos/tools/base.py b/src/talos/tools/base.py index d005130d..18d1b54c 100644 --- a/src/talos/tools/base.py +++ b/src/talos/tools/base.py @@ -32,6 +32,28 @@ def supervise(self, invocation: T) -> tuple[bool, str]: raise NotImplementedError +class AsyncSupervisor(ABC, Generic[T]): + """ + An async supervisor can be used to analyze a tool invocation and determine if it is + malicious or not. This is the async version of the Supervisor interface. + """ + + @abstractmethod + async def supervise_async(self, invocation: T) -> tuple[bool, str]: + """ + Analyze the tool invocation and determine if it is malicious or not. + + Args: + invocation: The tool invocation to analyze. + + Returns: + A tuple of a boolean and a string. If the invocation is malicious, + the boolean is False and the string is an error message. Otherwise, + the boolean is True and the string is empty. + """ + raise NotImplementedError + + class SupervisedTool(BaseTool): """ A tool that has an optional supervisor. When a tool call is submitted, it @@ -41,6 +63,7 @@ class SupervisedTool(BaseTool): """ supervisor: Supervisor[Any] | None = Field(default=None) + async_supervisor: AsyncSupervisor[Any] | None = Field(default=None) def _run(self, *args: Any, **kwargs: Any) -> Any: if self.supervisor: @@ -50,8 +73,12 @@ def _run(self, *args: Any, **kwargs: Any) -> Any: return self._run_unsupervised(*args, **kwargs) async def _arun(self, *args: Any, **kwargs: Any) -> Any: - if self.supervisor: - # TODO: Add support for async supervisors. + if self.async_supervisor: + ok, message = await self.async_supervisor.supervise_async({"args": args, "kwargs": kwargs}) + if not ok: + return message + elif self.supervisor: + # Fallback to sync supervisor for backward compatibility ok, message = self.supervisor.supervise({"args": args, "kwargs": kwargs}) if not ok: return message diff --git a/src/talos/tools/supervised_tool.py b/src/talos/tools/supervised_tool.py index bf7d66f6..16765fdc 100644 --- a/src/talos/tools/supervised_tool.py +++ b/src/talos/tools/supervised_tool.py @@ -4,7 +4,7 @@ from langchain_core.tools import BaseTool -from talos.hypervisor.supervisor import Supervisor +from talos.hypervisor.supervisor import Supervisor, AsyncSupervisor class SupervisedTool(BaseTool): @@ -14,6 +14,7 @@ class SupervisedTool(BaseTool): tool: BaseTool supervisor: Supervisor | None = None + async_supervisor: AsyncSupervisor | None = None messages: list def set_supervisor(self, supervisor: Supervisor | None): @@ -22,6 +23,12 @@ def set_supervisor(self, supervisor: Supervisor | None): """ self.supervisor = supervisor + def set_async_supervisor(self, async_supervisor: AsyncSupervisor | None): + """ + Sets the async supervisor for the tool. + """ + self.async_supervisor = async_supervisor + def _run(self, *args: Any, **kwargs: Any) -> Any: """ Runs the tool. @@ -34,3 +41,23 @@ def _run(self, *args: Any, **kwargs: Any) -> Any: else: return error_message or f"Tool call to '{self.name}' denied by supervisor." return self.tool.run(tool_input, **kwargs) + + async def _arun(self, *args: Any, **kwargs: Any) -> Any: + """ + Runs the tool asynchronously. + """ + tool_input = args[0] if args else kwargs + if self.async_supervisor: + approved, error_message = await self.async_supervisor.approve_async(self.name, tool_input) + if approved: + return await self.tool.arun(tool_input, **kwargs) + else: + return error_message or f"Tool call to '{self.name}' denied by async supervisor." + elif self.supervisor: + # Fallback to sync supervisor for backward compatibility + approved, error_message = self.supervisor.approve(self.name, tool_input) + if approved: + return await self.tool.arun(tool_input, **kwargs) + else: + return error_message or f"Tool call to '{self.name}' denied by supervisor." + return await self.tool.arun(tool_input, **kwargs) diff --git a/tests/simple_supervisor.py b/tests/simple_supervisor.py index 7ce53091..e11d9c76 100644 --- a/tests/simple_supervisor.py +++ b/tests/simple_supervisor.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Any -from talos.hypervisor.supervisor import Supervisor +from talos.hypervisor.supervisor import Supervisor, AsyncSupervisor if TYPE_CHECKING: from talos.core.agent import Agent @@ -29,3 +29,26 @@ def approve(self, action: str, args: dict[str, Any]) -> tuple[bool, str | None]: if self.counter % 2 == 0: return True, None return False, "Denied by SimpleSupervisor" + + +class AsyncSimpleSupervisor(AsyncSupervisor): + """ + A simple async supervisor that approves every other tool call. + """ + + counter: int = 0 + + def set_agent(self, agent: "Agent"): + """ + Sets the agent to be supervised. + """ + pass + + async def approve_async(self, action: str, args: dict[str, Any]) -> tuple[bool, str | None]: + """ + Approves or denies an action asynchronously. + """ + self.counter += 1 + if self.counter % 2 == 0: + return True, None + return False, "Denied by AsyncSimpleSupervisor" diff --git a/tests/test_async_supervisors.py b/tests/test_async_supervisors.py new file mode 100644 index 00000000..ecdcdbe1 --- /dev/null +++ b/tests/test_async_supervisors.py @@ -0,0 +1,121 @@ +from __future__ import annotations + +import asyncio +from typing import Any + +from pydantic import BaseModel + +from talos.tools.base import SupervisedTool, Supervisor, AsyncSupervisor + + +class SimpleSupervisor(Supervisor[Any]): + """ + A simple supervisor that approves every other tool call. + """ + + counter: int = 0 + + def supervise(self, invocation: Any) -> tuple[bool, str]: + self.counter += 1 + if self.counter % 2 == 0: + return True, "" + return False, "Denied by SimpleSupervisor" + + +class AsyncSimpleSupervisor(AsyncSupervisor[Any]): + """ + A simple async supervisor that approves every other tool call. + """ + + counter: int = 0 + + async def supervise_async(self, invocation: Any) -> tuple[bool, str]: + self.counter += 1 + if self.counter % 2 == 0: + return True, "" + return False, "Denied by AsyncSimpleSupervisor" + + +class DummyTool(SupervisedTool): + """ + A dummy tool for testing. + """ + + name: str = "dummy_tool" + description: str = "A dummy tool for testing" + args_schema: type[BaseModel] = BaseModel + + def _run_unsupervised(self, *args: Any, **kwargs: Any) -> Any: + return "dummy result" + + async def _arun_unsupervised(self, *args: Any, **kwargs: Any) -> Any: + return "async dummy result" + + +def test_supervised_tool_supervised() -> None: + supervisor = SimpleSupervisor() + supervised_tool = DummyTool() + supervised_tool.supervisor = supervisor + + # First call should be denied + result = supervised_tool._run({"x": 1}) + assert result == "Denied by SimpleSupervisor" + + # Second call should be approved + result = supervised_tool._run({"x": 1}) + assert result == "dummy result" + + +def test_supervised_tool_unsupervised() -> None: + supervised_tool = DummyTool() + supervised_tool.supervisor = None + + result = supervised_tool._run({"x": 1}) + assert result == "dummy result" + + +async def test_async_supervised_tool_supervised() -> None: + async_supervisor = AsyncSimpleSupervisor() + supervised_tool = DummyTool() + supervised_tool.async_supervisor = async_supervisor + + # First call should be denied + result = await supervised_tool._arun({"x": 1}) + assert result == "Denied by AsyncSimpleSupervisor" + + # Second call should be approved + result = await supervised_tool._arun({"x": 1}) + assert result == "async dummy result" + + +async def test_async_supervised_tool_unsupervised() -> None: + supervised_tool = DummyTool() + supervised_tool.async_supervisor = None + + result = await supervised_tool._arun({"x": 1}) + assert result == "async dummy result" + + +async def test_async_supervised_tool_sync_fallback() -> None: + """Test that async tools fall back to sync supervisor when no async supervisor is set.""" + supervisor = SimpleSupervisor() + supervised_tool = DummyTool() + supervised_tool.supervisor = supervisor + supervised_tool.async_supervisor = None + + # This should use the sync supervisor + result = await supervised_tool._arun({"x": 1}) + assert result == "Denied by SimpleSupervisor" + + +def test_async_supervisor_direct() -> None: + """Test the async supervisor directly.""" + async_supervisor = AsyncSimpleSupervisor() + + # Test the first call (should be denied) + result = asyncio.run(async_supervisor.supervise_async({"test": "data"})) + assert result == (False, "Denied by AsyncSimpleSupervisor") + + # Test the second call (should be approved) + result = asyncio.run(async_supervisor.supervise_async({"test": "data"})) + assert result == (True, "") \ No newline at end of file diff --git a/tests/test_supervised_tool.py b/tests/test_supervised_tool.py index 65d95962..af5f6e48 100644 --- a/tests/test_supervised_tool.py +++ b/tests/test_supervised_tool.py @@ -1,9 +1,10 @@ from __future__ import annotations +import asyncio from langchain_core.tools import tool from talos.tools.supervised_tool import SupervisedTool -from tests.simple_supervisor import SimpleSupervisor +from tests.simple_supervisor import SimpleSupervisor, AsyncSimpleSupervisor @tool @@ -12,6 +13,11 @@ def dummy_tool(x: int) -> int: return x * 2 +async def dummy_async_tool(x: int) -> int: + """A dummy async tool.""" + return x * 2 + + def test_supervised_tool_unsupervised() -> None: supervised_tool = SupervisedTool( tool=dummy_tool, @@ -38,3 +44,74 @@ def test_supervised_tool_supervised() -> None: assert supervised_tool.run({"x": 1}) == 2 assert supervised_tool.run({"x": 1}) == "Denied by SimpleSupervisor" assert supervised_tool.run({"x": 1}) == 2 + + +async def test_async_supervised_tool_unsupervised() -> None: + supervised_tool = SupervisedTool( + tool=dummy_tool, + supervisor=None, + async_supervisor=None, + messages=[], + name=dummy_tool.name, + description=dummy_tool.description, + args_schema=dummy_tool.args_schema, + ) + result = await supervised_tool._arun({"x": 1}) + assert result == 2 + + +async def test_async_supervised_tool_with_async_supervisor() -> None: + async_supervisor = AsyncSimpleSupervisor() + supervised_tool = SupervisedTool( + tool=dummy_tool, + supervisor=None, + async_supervisor=async_supervisor, + messages=[], + name=dummy_tool.name, + description=dummy_tool.description, + args_schema=dummy_tool.args_schema, + ) + + # First call should be denied + result = await supervised_tool._arun({"x": 1}) + assert result == "Denied by AsyncSimpleSupervisor" + + # Second call should be approved + result = await supervised_tool._arun({"x": 1}) + assert result == 2 + + +async def test_async_supervised_tool_sync_fallback() -> None: + """Test that async tools fall back to sync supervisor when no async supervisor is set.""" + supervisor = SimpleSupervisor() + supervised_tool = SupervisedTool( + tool=dummy_tool, + supervisor=supervisor, + async_supervisor=None, + messages=[], + name=dummy_tool.name, + description=dummy_tool.description, + args_schema=dummy_tool.args_schema, + ) + + # This should use the sync supervisor + result = await supervised_tool._arun({"x": 1}) + assert result == "Denied by SimpleSupervisor" + + +def test_supervised_tool_set_async_supervisor() -> None: + """Test setting async supervisor on supervised tool.""" + supervised_tool = SupervisedTool( + tool=dummy_tool, + supervisor=None, + messages=[], + name=dummy_tool.name, + description=dummy_tool.description, + args_schema=dummy_tool.args_schema, + ) + + async_supervisor = AsyncSimpleSupervisor() + supervised_tool.set_async_supervisor(async_supervisor) + + assert supervised_tool.async_supervisor is not None + assert supervised_tool.async_supervisor == async_supervisor diff --git a/tests/test_supervisor.py b/tests/test_supervisor.py index 27ec767d..481a2814 100644 --- a/tests/test_supervisor.py +++ b/tests/test_supervisor.py @@ -1,8 +1,9 @@ from __future__ import annotations +import asyncio from langchain_core.tools import tool -from talos.hypervisor.supervisor import Rule, RuleBasedSupervisor +from talos.hypervisor.supervisor import Rule, RuleBasedSupervisor, AsyncRuleBasedSupervisor from talos.tools.supervised_tool import SupervisedTool @@ -38,3 +39,67 @@ def dummy_tool(x: int) -> int: # Test that the supervisor denies an invalid action. result = supervised_tool.run({"x": -1}) assert result == "x must be greater than 0" + + +async def test_async_rule_based_supervisor(): + """ + Tests that the async rule-based supervisor correctly approves or denies actions. + """ + + @tool + def dummy_tool(x: int) -> int: + """A dummy tool.""" + return x * 2 + + rules = [ + Rule( + tool_name="dummy_tool", + validations={"x": lambda x: (x > 0, "x must be greater than 0") if x <= 0 else (True, None)}, + ) + ] + async_supervisor = AsyncRuleBasedSupervisor(rules=rules) + supervised_tool = SupervisedTool( + tool=dummy_tool, + supervisor=None, + async_supervisor=async_supervisor, + messages=[], + name=dummy_tool.name, + description=dummy_tool.description, + args_schema=dummy_tool.args_schema, + ) + + # Test that the async supervisor approves a valid action. + result = await supervised_tool._arun({"x": 1}) + assert result == 2 + + # Test that the async supervisor denies an invalid action. + result = await supervised_tool._arun({"x": -1}) + assert result == "x must be greater than 0" + + +async def test_async_rule_based_supervisor_direct(): + """ + Tests the async rule-based supervisor directly without the tool wrapper. + """ + rules = [ + Rule( + tool_name="test_tool", + validations={"x": lambda x: (x > 0, "x must be greater than 0") if x <= 0 else (True, None)}, + ) + ] + async_supervisor = AsyncRuleBasedSupervisor(rules=rules) + + # Test approval + approved, message = await async_supervisor.approve_async("test_tool", {"x": 1}) + assert approved is True + assert message is None + + # Test denial + approved, message = await async_supervisor.approve_async("test_tool", {"x": -1}) + assert approved is False + assert message == "x must be greater than 0" + + # Test unknown tool + approved, message = await async_supervisor.approve_async("unknown_tool", {"x": 1}) + assert approved is True + assert message is None