-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add a Gymnasium env for Game of 24 and Gymnasium action executor
- Loading branch information
Showing
6 changed files
with
638 additions
and
138 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
from collections import defaultdict | ||
from typing import Any, Dict, List, Optional, Tuple | ||
|
||
import gymnasium as gym | ||
from gymnasium.core import SupportsFloat | ||
from langchain_core.agents import AgentAction | ||
from langchain_core.tools import BaseTool | ||
|
||
from .simple_tools import AddTool, DivideTool, MultiplyTool, SubtractTool | ||
|
||
|
||
class GameOf24(gym.Env[str, AgentAction]): | ||
AVAILABLE_ACTIONS: Dict[str, BaseTool] = { | ||
"add": AddTool(), # type: ignore[call-arg] | ||
"multiply": MultiplyTool(), # type: ignore[call-arg] | ||
"subtract": SubtractTool(), # type: ignore[call-arg] | ||
"divide": DivideTool(), # type: ignore[call-arg] | ||
} | ||
|
||
def __init__(self, numbers: Optional[List[int]] = None): | ||
self.numbers: Dict[float, int] = defaultdict(int) | ||
if numbers: | ||
for number in numbers: | ||
self.numbers[number] += 1 | ||
|
||
def __str__(self): | ||
return " ".join([str(key) for key, value in self.numbers.items() for _ in range(value)]) | ||
|
||
def _add_number(self, number: float) -> None: | ||
self.numbers[number] += 1 | ||
|
||
def _remove_number(self, number: float) -> None: | ||
if number not in self.numbers: | ||
return | ||
|
||
self.numbers[number] -= 1 | ||
if self.numbers[number] == 0: | ||
del self.numbers[number] | ||
|
||
def _verify_arguments(self, number1: float, number2: float) -> bool: | ||
if number1 == number2: | ||
return number1 in self.numbers and self.numbers[number1] >= 2 | ||
|
||
return ( | ||
number1 in self.numbers | ||
and self.numbers[number1] >= 1 | ||
and number2 in self.numbers | ||
and self.numbers[number2] >= 1 | ||
) | ||
|
||
def step(self, action: AgentAction) -> Tuple[str, SupportsFloat, bool, bool, Dict[str, Any]]: | ||
observation, reward, terminated, truncated, info = None, 0, False, False, {"numbers": str(self)} | ||
|
||
assert isinstance(action.tool_input, dict) | ||
number1, number2 = float(action.tool_input["number1"]), float(action.tool_input["number2"]) | ||
|
||
if not self._verify_arguments(number1=number1, number2=number2): | ||
observation = "Wrong arguments: not all numbers given as arguments to a tool call are available." | ||
return observation, reward, terminated, truncated, info | ||
|
||
if action.tool not in self.AVAILABLE_ACTIONS: | ||
observation = f"Unknown tool. Currently available tools: {list(GameOf24.AVAILABLE_ACTIONS.keys())}." | ||
return observation, reward, terminated, truncated, info | ||
|
||
result = self.AVAILABLE_ACTIONS[action.tool]._run(number1, number2) | ||
|
||
self._remove_number(number1) | ||
self._remove_number(number2) | ||
self._add_number(result) | ||
|
||
observation, info = ( | ||
f"Calling {action.tool} with {number1} and {number2} leads to {result}.", | ||
{"numbers": str(self)}, | ||
) | ||
|
||
return observation, reward, terminated, truncated, info | ||
|
||
def reset( | ||
self, | ||
*, | ||
seed: int | None = None, | ||
options: Dict[str, Any] | None = None, | ||
) -> Tuple[str, Dict[str, Any]]: | ||
super().reset(seed=seed) | ||
|
||
numbers = options.get("numbers", []) if options else [] | ||
self.numbers = defaultdict(int) | ||
for number in numbers: | ||
self.numbers[number] += 1 | ||
|
||
if options is None or "trajectory" not in options: | ||
return "Reset environment.", {"numbers": str(self)} | ||
|
||
trajectory: List[Tuple[AgentAction, str]] = options["trajectory"] | ||
for action, observation in trajectory: | ||
self.step(action) | ||
return "Reset environment.", {"numbers": str(self)} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,25 +1,47 @@ | ||
from langchain_core.tools import tool | ||
from typing import Optional, Type | ||
|
||
from langchain.callbacks.manager import CallbackManagerForToolRun | ||
from langchain.pydantic_v1 import BaseModel, Field | ||
from langchain.tools import BaseTool, StructuredTool, tool | ||
|
||
@tool | ||
def add(number1: float, number2: float) -> float: | ||
"""Adds number1 with number2.""" | ||
return number1 + number2 | ||
|
||
class CalculatorInput(BaseModel): | ||
number1: float = Field(description="first number") | ||
number2: float = Field(description="second number") | ||
|
||
@tool | ||
def subtract(number1: float, number2: float) -> float: | ||
"""Subtracts number2 from number 1.""" | ||
return number1 - number2 | ||
|
||
class AddTool(BaseTool): | ||
name = "add" | ||
description = "Adds two numbers." | ||
args_schema: Type[BaseModel] = CalculatorInput | ||
|
||
@tool | ||
def multiply(number1: float, number2: float) -> float: | ||
"""Multiplies number1 by number2.""" | ||
return number1 * number2 | ||
def _run(self, number1: float, number2: float, run_manager: Optional[CallbackManagerForToolRun] = None) -> float: | ||
"""Use the tool.""" | ||
return number1 + number2 | ||
|
||
|
||
@tool | ||
def divide(number1: float, number2: float) -> float: | ||
"""Divides number1 by number2.""" | ||
return number1 / number2 | ||
class SubtractTool(BaseTool): | ||
name = "subtract" | ||
description = "Subtracts number1 from number2." | ||
args_schema: Type[BaseModel] = CalculatorInput | ||
|
||
def _run(self, number1: float, number2: float, run_manager: Optional[CallbackManagerForToolRun] = None) -> float: | ||
return number1 - number2 | ||
|
||
|
||
class MultiplyTool(BaseTool): | ||
name = "multiply" | ||
description = "Multiplies two numbers." | ||
args_schema: Type[BaseModel] = CalculatorInput | ||
|
||
def _run(self, number1: float, number2: float, run_manager: Optional[CallbackManagerForToolRun] = None) -> float: | ||
return number1 * number2 | ||
|
||
|
||
class DivideTool(BaseTool): | ||
name = "divide" | ||
description = "Divides number1 by number2." | ||
args_schema: Type[BaseModel] = CalculatorInput | ||
|
||
def _run(self, number1: float, number2: float, run_manager: Optional[CallbackManagerForToolRun] = None) -> float: | ||
return number1 / number2 |
Oops, something went wrong.