diff --git a/examples/multiagent/README.md b/examples/multiagent/README.md new file mode 100644 index 0000000..c63fa73 --- /dev/null +++ b/examples/multiagent/README.md @@ -0,0 +1,51 @@ +# Multi-Agent Examples for GEM + +This directory contains multi-agent environment examples using GEM's MultiAgentEnv framework. + +## TAU-BENCH Retail Integration + +The `tau_bench_retail/` directory contains the official integration of TAU-BENCH Retail benchmark into GEM. TAU-BENCH evaluates tool-augmented LLM agents on realistic customer service tasks in a retail environment. + +### Setup + +1. Clone the TAU-bench repository: +```bash +cd tau_bench_retail +git clone https://github.com/sierra-research/tau-bench.git +``` + +2. Set your API key: +```bash +export OPENAI_API_KEY="your-key-here" +``` + +3. Run the evaluation: +```bash +python run_eval.py +``` + +### Directory Structure + +``` +multiagent/ +└── tau_bench_retail/ + ├── tau_bench_env.py # GEM environment wrapper for TAU-bench + ├── tau_bench_agent.py # Agent with tool-calling capabilities + ├── run_eval.py # Evaluation script + └── tau-bench/ # Cloned TAU-bench repository (git ignored) + └── tau_bench/ + └── envs/ + └── retail/ # TAU-bench retail assets + ├── data/ # JSON data files + ├── tools/ # Tool implementations + ├── tasks_*.py # Task definitions + └── wiki.md # Agent policy +``` + +## Performance + +TAU-bench Retail: **78/115 (67.8%)** + +## Available Tools + +16 customer service tools including order management, user identification, information retrieval, and support functions. \ No newline at end of file diff --git a/examples/multiagent/tau_bench_retail/.gitignore b/examples/multiagent/tau_bench_retail/.gitignore new file mode 100644 index 0000000..4f97603 --- /dev/null +++ b/examples/multiagent/tau_bench_retail/.gitignore @@ -0,0 +1,5 @@ +experiments/results/ +*.pyc +__pycache__/ +.DS_Store +tau-bench/ diff --git a/examples/multiagent/tau_bench_retail/README.md b/examples/multiagent/tau_bench_retail/README.md new file mode 100644 index 0000000..6aaa8eb --- /dev/null +++ b/examples/multiagent/tau_bench_retail/README.md @@ -0,0 +1,66 @@ +# TAU-bench Retail - GEM MultiAgentEnv Integration + +Clean implementation of TAU-bench retail benchmark using GEM's MultiAgentEnv API. + +**Performance**: 78/115 (67.8%) - Exceeds target of 60.4% + +## Setup + +### 1. Clone TAU-bench Repository + +```bash +# Clone the official TAU-bench repository +git clone https://github.com/sierra-research/tau-bench.git + +# Option 1: Clone to the default location (within tau_bench_retail directory) +cd examples/multiagent/tau_bench_retail +git clone https://github.com/sierra-research/tau-bench.git + +# Option 2: Clone anywhere and set environment variable +git clone https://github.com/sierra-research/tau-bench.git /path/to/tau-bench +export TAU_BENCH_PATH=/path/to/tau-bench +``` + +### 2. Install Dependencies +```bash +# Install GEM +cd /path/to/gem/ +pip install -e . + +# Install TAU-bench +cd /path/to/gem/examples/multiagent/tau_bench_retail/tau-bench +pip install -e . +``` + +### 3. Set API Keys + +```bash +# Required for OpenAI models +export OPENAI_API_KEY="your-key" + +# Optional: For OpenRouter models (Gemini, Claude, DeepSeek) +export OPENROUTER_API_KEY="your-key" +``` + +### 4. Run Evaluation + +```bash +python run_eval.py +``` + +## Files + +- `tau_bench_env.py` - GEM MultiAgentEnv environment wrapper +- `tau_bench_agent.py` - Agent with OpenRouter-style tool calling +- `run_eval.py` - Evaluation runner (115 test tasks) + +## Model Support + +Supported models via `run_eval.py`: +- OpenAI: `gpt-4o` +- OpenRouter: `google/gemini-2.0-flash-001`, `deepseek/deepseek-chat`, `anthropic/claude-3.5-sonnet` + +For OpenRouter models: +```bash +export OPENROUTER_API_KEY="your-key" +``` diff --git a/examples/multiagent/tau_bench_retail/run_eval.py b/examples/multiagent/tau_bench_retail/run_eval.py new file mode 100644 index 0000000..9ea162e --- /dev/null +++ b/examples/multiagent/tau_bench_retail/run_eval.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python3 +import os +import sys +from concurrent.futures import ThreadPoolExecutor, as_completed + +sys.path.insert(0, os.path.dirname(__file__)) +from tau_bench_agent import TauBenchAgent +from tau_bench_env import TauBenchEnv + + +def eval_task(args): + task_idx, model, provider, user_model, user_provider = args + try: + env = TauBenchEnv( + task_split="test", user_model=user_model, user_provider=user_provider + ) + agent = TauBenchAgent(model=model, provider=provider, temperature=0.0) + result = agent.solve(env, task_index=task_idx) + return task_idx, result["reward"] + except Exception as e: + print(f"Task {task_idx} error: {e}") + return task_idx, 0.0 + + +if __name__ == "__main__": + # OpenAI: model="gpt-4o", provider="openai" + # Gemini: model="google/gemini-2.0-flash-001", provider="openrouter" + # DeepSeek: model="deepseek/deepseek-chat", provider="openrouter" + # Claude: model="anthropic/claude-3.5-sonnet", provider="openrouter" + + model = "gpt-4o" + provider = "openai" + user_model = "gpt-4o" + user_provider = "openai" + + print(f"Running 115 tasks with {model} via {provider}") + print(f"User simulator: {user_model} via {user_provider}") + print("=" * 60) + + tasks = [(i, model, provider, user_model, user_provider) for i in range(115)] + results = [] + passed = 0 + + with ThreadPoolExecutor(max_workers=32) as executor: + futures = {executor.submit(eval_task, args): args[0] for args in tasks} + + for future in as_completed(futures): + task_idx, reward = future.result() + results.append((task_idx, reward)) + + if reward > 0: + passed += 1 + + completed = len(results) + print( + f"Task {task_idx}: {'✓' if reward > 0 else '✗'} | " + f"{completed}/115 | Pass@1: {passed}/{completed} ({100*passed/completed:.1f}%)" + ) + + print(f"\n{'='*60}") + print(f"FINAL: {passed}/115 ({100*passed/115:.1f}%)") + print(f"Target: 60.4%") + print(f"{'='*60}") diff --git a/examples/multiagent/tau_bench_retail/tau_bench_agent.py b/examples/multiagent/tau_bench_retail/tau_bench_agent.py new file mode 100644 index 0000000..2bf3b7c --- /dev/null +++ b/examples/multiagent/tau_bench_retail/tau_bench_agent.py @@ -0,0 +1,87 @@ +#!/usr/bin/env python3 +import json +from typing import Any, Dict, List + +from litellm import completion + + +class TauBenchAgent: + """Agent using OpenRouter-style tool calling pattern""" + + def __init__( + self, model: str = "gpt-4o", provider: str = "openai", temperature: float = 0.0 + ): + self.model = model + self.provider = provider + self.temperature = temperature + + def solve( + self, env, task_index: int = 0, max_num_steps: int = 30 + ) -> Dict[str, Any]: + observations, infos = env.reset(task_index=task_index) + + messages: List[Dict[str, Any]] = [ + {"role": "system", "content": env.wiki}, + {"role": "user", "content": observations["assistant"]}, + ] + + reward = 0.0 + num_steps = 0 + + for _ in range(max_num_steps): + request = { + "model": self.model, + "messages": messages, + "tools": env.tool_definitions, + "temperature": self.temperature, + } + + response = completion(custom_llm_provider=self.provider, **request) + response_message = response.choices[0].message + messages.append(response_message.model_dump()) + + if hasattr(response_message, "tool_calls") and response_message.tool_calls: + for tool_call in response_message.tool_calls: + tool_name = tool_call.function.name + tool_args = json.loads(tool_call.function.arguments) + + action_json = json.dumps({"name": tool_name, "kwargs": tool_args}) + observations, rewards, terminations, truncations, env_infos = ( + env.step({"assistant": action_json}) + ) + + reward = rewards.get("assistant", 0.0) + messages.append( + { + "role": "tool", + "tool_call_id": tool_call.id, + "content": observations["assistant"], + } + ) + + num_steps += 1 + if terminations.get("assistant", False): + break + else: + content = response_message.content or "" + action_json = json.dumps( + {"name": "respond", "kwargs": {"content": content}} + ) + + observations, rewards, terminations, truncations, env_infos = env.step( + {"assistant": action_json} + ) + + reward = rewards.get("assistant", 0.0) + messages.append({"role": "user", "content": observations["assistant"]}) + num_steps += 1 + + if terminations.get("assistant", False): + break + + return { + "reward": reward, + "task_id": env.task.user_id, + "task_index": task_index, + "num_steps": num_steps, + } diff --git a/examples/multiagent/tau_bench_retail/tau_bench_env.py b/examples/multiagent/tau_bench_retail/tau_bench_env.py new file mode 100644 index 0000000..726e0ce --- /dev/null +++ b/examples/multiagent/tau_bench_retail/tau_bench_env.py @@ -0,0 +1,244 @@ +#!/usr/bin/env python3 +import json +import os +import sys +from hashlib import sha256 +from typing import Any, Dict, List, Optional, Tuple + +from litellm import completion +from pydantic import BaseModel + +from gem.envs.multiagent import MultiAgentEnv +from gem.envs.multiagent.multi_agent_env import AgentSelector + +TAU_BENCH_PATH = os.environ.get( + "TAU_BENCH_PATH", os.path.join(os.path.dirname(__file__), "tau-bench") +) +ASSETS_PATH = os.path.join(TAU_BENCH_PATH, "tau_bench/envs/retail") + +if not os.path.exists(ASSETS_PATH): + raise FileNotFoundError( + f"TAU-bench repository not found. Please either:\n" + f"1. Clone https://github.com/sierra-research/tau-bench to {TAU_BENCH_PATH}\n" + f"2. Set TAU_BENCH_PATH environment variable to the cloned repository path" + ) + +if ASSETS_PATH not in sys.path: + sys.path.insert(0, ASSETS_PATH) + +from data import load_data +from tools import ALL_TOOLS +from wiki import WIKI + + +class Action(BaseModel): + name: str + kwargs: Dict[str, Any] + + +class Task(BaseModel): + user_id: str + actions: List[Action] + instruction: str + outputs: List[str] + + +class TauBenchEnv(MultiAgentEnv): + """TAU-bench Retail environment using GEM MultiAgentEnv API""" + + def __init__( + self, + task_split: str = "test", + user_model: str = "gpt-4o", + user_provider: str = "openai", + ): + super().__init__() + self.task_split = task_split + self.user_model = user_model + self.user_provider = user_provider + + self.possible_agents = ["assistant"] + self.agent_selector = AgentSelector(self.possible_agents, mode="sequential") + + self.data = load_data() + self.wiki = WIKI + self.tool_definitions = [tool.get_info() for tool in ALL_TOOLS] + self.tools_map = { + tool.get_info()["function"]["name"]: tool for tool in ALL_TOOLS + } + self.tasks = self._load_tasks() + self.terminate_tools = ["transfer_to_human_agents"] + + self.task_index = 0 + self.task = None + self.user_messages = [] + self.actions_taken = [] + + def _load_tasks(self) -> List[Task]: + if self.task_split == "test": + from tasks_test import TASKS_TEST + + return TASKS_TEST + elif self.task_split == "train": + from tasks_train import TASKS_TRAIN + + return TASKS_TRAIN + else: + from tasks_dev import TASKS_DEV + + return TASKS_DEV + + def reset( + self, seed: Optional[int] = None, task_index: Optional[int] = None + ) -> Tuple[Dict[str, str], Dict[str, Any]]: + observations, infos = super().reset(seed=seed) + + self.task_index = task_index if task_index is not None else 0 + self.task = self.tasks[self.task_index] + self.data = load_data() + self.actions_taken = [] + + user_system_prompt = f"""You are a user interacting with an agent. + +Instruction: {self.task.instruction} + +Rules: +- Just generate one line at a time to simulate the user's message. +- Do not give away all the instruction at once. Only provide the information that is necessary for the current step. +- Do not hallucinate information that is not provided in the instruction. For example, if the agent asks for the order id but it is not mentioned in the instruction, do not make up an order id, just say you do not remember or have it. +- If the instruction goal is satisified, generate '###STOP###' as a standalone message without anything else to end the conversation. +- Do not repeat the exact instruction in the conversation. Instead, use your own words to convey the same information. +- Try to make the conversation as natural as possible, and stick to the personalities in the instruction.""" + + self.user_messages = [ + {"role": "system", "content": user_system_prompt}, + {"role": "user", "content": "Hi! How can I help you today?"}, + ] + + initial_user_obs = self._simulate_user() + observations["assistant"] = initial_user_obs + infos["assistant"] = {"task": self.task.model_dump()} + + return observations, infos + + def observe(self, agent: str) -> str: + if agent == "assistant" and self.user_messages and len(self.user_messages) > 2: + return self.user_messages[-1]["content"] + return "" + + def _simulate_user(self) -> str: + response = completion( + model=self.user_model, + custom_llm_provider=self.user_provider, + messages=self.user_messages, + ) + msg = response.choices[0].message + self.user_messages.append(msg.model_dump()) + return msg.content + + def _process_actions(self, actions: Dict[str, str]) -> Tuple[ + Dict[str, str], + Dict[str, float], + Dict[str, bool], + Dict[str, bool], + Dict[str, dict], + ]: + observations = {} + rewards = {"assistant": 0.0} + terminations = {"assistant": False} + truncations = {"assistant": False} + infos = {"assistant": {}} + + if "assistant" in actions: + action_str = actions["assistant"] + + try: + action_dict = json.loads(action_str) + action = Action(**action_dict) + except: + action = Action(name="respond", kwargs={"content": action_str}) + + self.actions_taken.append(action) + + if action.name == "respond": + self.user_messages.append( + {"role": "user", "content": action.kwargs["content"]} + ) + user_response = self._simulate_user() + + observations["assistant"] = user_response + infos["assistant"]["source"] = "user" + + if "###STOP###" in user_response: + terminations["assistant"] = True + rewards["assistant"] = self._calculate_reward() + + elif action.name in self.tools_map: + try: + observation = self.tools_map[action.name].invoke( + data=self.data, **action.kwargs + ) + except Exception as e: + observation = f"Error: {e}" + + observations["assistant"] = observation + infos["assistant"]["source"] = action.name + + if action.name in self.terminate_tools: + terminations["assistant"] = True + rewards["assistant"] = self._calculate_reward() + + else: + observations["assistant"] = f"Unknown action {action.name}" + infos["assistant"]["source"] = action.name + + return observations, rewards, terminations, truncations, infos + + def _calculate_reward(self) -> float: + def to_hashable(item): + if isinstance(item, dict): + return tuple( + (key, to_hashable(value)) for key, value in sorted(item.items()) + ) + elif isinstance(item, list): + return tuple(to_hashable(element) for element in item) + elif isinstance(item, set): + return tuple(sorted(to_hashable(element) for element in item)) + else: + return item + + def get_data_hash(): + return sha256(str(to_hashable(self.data)).encode("utf-8")).hexdigest() + + data_hash = get_data_hash() + self.data = load_data() + saved_user_messages = self.user_messages[:] + + for gt_action in self.task.actions: + if gt_action.name not in self.terminate_tools: + self.actions_taken.append(gt_action) + if gt_action.name != "respond" and gt_action.name in self.tools_map: + try: + self.tools_map[gt_action.name].invoke( + data=self.data, **gt_action.kwargs + ) + except: + pass + + self.user_messages = saved_user_messages + gt_data_hash = get_data_hash() + reward = 1.0 if data_hash == gt_data_hash else 0.0 + + if len(self.task.outputs) > 0: + for output in self.task.outputs: + found = any( + action.name == "respond" + and output.lower() + in action.kwargs.get("content", "").lower().replace(",", "") + for action in self.actions_taken + ) + if not found: + reward = 0.0 + break + + return reward diff --git a/gem/multiagent/__init__.py b/gem/envs/multiagent/__init__.py similarity index 87% rename from gem/multiagent/__init__.py rename to gem/envs/multiagent/__init__.py index a79d323..ac360eb 100644 --- a/gem/multiagent/__init__.py +++ b/gem/envs/multiagent/__init__.py @@ -12,9 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from gem.multiagent.multi_agent_env import AgentSelector, MultiAgentEnv +from gem.envs.multiagent.multi_agent_env import MultiAgentEnv __all__ = [ "MultiAgentEnv", - "AgentSelector", ] diff --git a/gem/multiagent/multi_agent_env.py b/gem/envs/multiagent/multi_agent_env.py similarity index 100% rename from gem/multiagent/multi_agent_env.py rename to gem/envs/multiagent/multi_agent_env.py