Skip to content

Commit

Permalink
Add a Gymnasium env for Game of 24 and Gymnasium action executor
Browse files Browse the repository at this point in the history
  • Loading branch information
saridormi committed Mar 12, 2024
1 parent e558541 commit 5dd040a
Show file tree
Hide file tree
Showing 6 changed files with 638 additions and 138 deletions.
97 changes: 97 additions & 0 deletions environments/game_of_24/common/environment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from collections import defaultdict
from typing import Any, Dict, List, Optional, Tuple

import gymnasium as gym
from gymnasium.core import SupportsFloat
from langchain_core.agents import AgentAction
from langchain_core.tools import BaseTool

from .simple_tools import AddTool, DivideTool, MultiplyTool, SubtractTool


class GameOf24(gym.Env[str, AgentAction]):
AVAILABLE_ACTIONS: Dict[str, BaseTool] = {
"add": AddTool(), # type: ignore[call-arg]
"multiply": MultiplyTool(), # type: ignore[call-arg]
"subtract": SubtractTool(), # type: ignore[call-arg]
"divide": DivideTool(), # type: ignore[call-arg]
}

def __init__(self, numbers: Optional[List[int]] = None):
self.numbers: Dict[float, int] = defaultdict(int)
if numbers:
for number in numbers:
self.numbers[number] += 1

def __str__(self):
return " ".join([str(key) for key, value in self.numbers.items() for _ in range(value)])

def _add_number(self, number: float) -> None:
self.numbers[number] += 1

def _remove_number(self, number: float) -> None:
if number not in self.numbers:
return

self.numbers[number] -= 1
if self.numbers[number] == 0:
del self.numbers[number]

def _verify_arguments(self, number1: float, number2: float) -> bool:
if number1 == number2:
return number1 in self.numbers and self.numbers[number1] >= 2

return (
number1 in self.numbers
and self.numbers[number1] >= 1
and number2 in self.numbers
and self.numbers[number2] >= 1
)

def step(self, action: AgentAction) -> Tuple[str, SupportsFloat, bool, bool, Dict[str, Any]]:
observation, reward, terminated, truncated, info = None, 0, False, False, {"numbers": str(self)}

assert isinstance(action.tool_input, dict)
number1, number2 = float(action.tool_input["number1"]), float(action.tool_input["number2"])

if not self._verify_arguments(number1=number1, number2=number2):
observation = "Wrong arguments: not all numbers given as arguments to a tool call are available."
return observation, reward, terminated, truncated, info

if action.tool not in self.AVAILABLE_ACTIONS:
observation = f"Unknown tool. Currently available tools: {list(GameOf24.AVAILABLE_ACTIONS.keys())}."
return observation, reward, terminated, truncated, info

result = self.AVAILABLE_ACTIONS[action.tool]._run(number1, number2)

self._remove_number(number1)
self._remove_number(number2)
self._add_number(result)

observation, info = (
f"Calling {action.tool} with {number1} and {number2} leads to {result}.",
{"numbers": str(self)},
)

return observation, reward, terminated, truncated, info

def reset(
self,
*,
seed: int | None = None,
options: Dict[str, Any] | None = None,
) -> Tuple[str, Dict[str, Any]]:
super().reset(seed=seed)

numbers = options.get("numbers", []) if options else []
self.numbers = defaultdict(int)
for number in numbers:
self.numbers[number] += 1

if options is None or "trajectory" not in options:
return "Reset environment.", {"numbers": str(self)}

trajectory: List[Tuple[AgentAction, str]] = options["trajectory"]
for action, observation in trajectory:
self.step(action)
return "Reset environment.", {"numbers": str(self)}
56 changes: 39 additions & 17 deletions environments/game_of_24/common/simple_tools.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,47 @@
from langchain_core.tools import tool
from typing import Optional, Type

from langchain.callbacks.manager import CallbackManagerForToolRun
from langchain.pydantic_v1 import BaseModel, Field
from langchain.tools import BaseTool, StructuredTool, tool

@tool
def add(number1: float, number2: float) -> float:
"""Adds number1 with number2."""
return number1 + number2

class CalculatorInput(BaseModel):
number1: float = Field(description="first number")
number2: float = Field(description="second number")

@tool
def subtract(number1: float, number2: float) -> float:
"""Subtracts number2 from number 1."""
return number1 - number2

class AddTool(BaseTool):
name = "add"
description = "Adds two numbers."
args_schema: Type[BaseModel] = CalculatorInput

@tool
def multiply(number1: float, number2: float) -> float:
"""Multiplies number1 by number2."""
return number1 * number2
def _run(self, number1: float, number2: float, run_manager: Optional[CallbackManagerForToolRun] = None) -> float:
"""Use the tool."""
return number1 + number2


@tool
def divide(number1: float, number2: float) -> float:
"""Divides number1 by number2."""
return number1 / number2
class SubtractTool(BaseTool):
name = "subtract"
description = "Subtracts number1 from number2."
args_schema: Type[BaseModel] = CalculatorInput

def _run(self, number1: float, number2: float, run_manager: Optional[CallbackManagerForToolRun] = None) -> float:
return number1 - number2


class MultiplyTool(BaseTool):
name = "multiply"
description = "Multiplies two numbers."
args_schema: Type[BaseModel] = CalculatorInput

def _run(self, number1: float, number2: float, run_manager: Optional[CallbackManagerForToolRun] = None) -> float:
return number1 * number2


class DivideTool(BaseTool):
name = "divide"
description = "Divides number1 by number2."
args_schema: Type[BaseModel] = CalculatorInput

def _run(self, number1: float, number2: float, run_manager: Optional[CallbackManagerForToolRun] = None) -> float:
return number1 / number2
Loading

0 comments on commit 5dd040a

Please sign in to comment.