Skip to content

Commit

Permalink
Fix Game of 24 environment
Browse files Browse the repository at this point in the history
  • Loading branch information
saridormi committed Mar 19, 2024
1 parent 8fe84ae commit 37b8b1e
Show file tree
Hide file tree
Showing 6 changed files with 369 additions and 789 deletions.
129 changes: 58 additions & 71 deletions environments/game_of_24/common/environment.py
Original file line number Diff line number Diff line change
@@ -1,92 +1,81 @@
from collections import defaultdict
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, Sequence

import gymnasium as gym
from gymnasium.core import SupportsFloat
from langchain_core.agents import AgentAction
from langchain_core.tools import BaseTool

from .simple_tools import AddTool, DivideTool, MultiplyTool, SubtractTool
from .tools import AddTool, MultiplyTool, SubtractTool, DivideTool
from planning_library.action_executors import DefaultActionExecutor


class GameOf24(gym.Env[str, AgentAction]):
AVAILABLE_ACTIONS: Dict[str, BaseTool] = {
"add": AddTool(), # type: ignore[call-arg]
"multiply": MultiplyTool(), # type: ignore[call-arg]
"subtract": SubtractTool(), # type: ignore[call-arg]
"divide": DivideTool(), # type: ignore[call-arg]
}
class GameOf24Env(gym.Env[str, AgentAction]):
def __init__(self, numbers: Optional[List[float | int]] = None):
self._action_executor = DefaultActionExecutor(
tools=[
AddTool(env=self), # type: ignore[call-arg]
MultiplyTool(env=self), # type: ignore[call-arg]
SubtractTool(env=self), # type: ignore[call-arg]
DivideTool(env=self), # type: ignore[call-arg]
]
)

def __init__(self, numbers: Optional[List[int]] = None):
self.numbers: Dict[float, int] = defaultdict(int)
self._numbers: Dict[float, int] = defaultdict(int)
if numbers:
for number in numbers:
self.numbers[number] += 1
self._numbers[float(number)] += 1

def __str__(self):
@property
def numbers(self) -> str:
return " ".join(
[str(key) for key, value in self.numbers.items() for _ in range(value)]
[str(key) for key, value in self._numbers.items() for _ in range(value)]
)

def _add_number(self, number: float) -> None:
self.numbers[number] += 1
@numbers.setter
def numbers(self, numbers: List[float | int]):
self._numbers = defaultdict(int)
if numbers:
for number in numbers:
self._numbers[float(number)] += 1

@property
def tools(self) -> Sequence[BaseTool]:
return self._action_executor.tools

def is_success(self) -> bool:
return self._numbers == {24.0: 1}

def _remove_number(self, number: float) -> None:
if number not in self.numbers:
def is_terminated(self) -> bool:
return len(self._numbers) == 1

def add_number(self, number: float) -> None:
self._numbers[number] += 1

def remove_number(self, number: float) -> None:
if number not in self._numbers:
return

self.numbers[number] -= 1
if self.numbers[number] == 0:
del self.numbers[number]
self._numbers[number] -= 1
if self._numbers[number] == 0:
del self._numbers[number]

def _verify_arguments(self, number1: float, number2: float) -> bool:
def verify_arguments(self, number1: float, number2: float) -> bool:
if number1 == number2:
return number1 in self.numbers and self.numbers[number1] >= 2
return number1 in self._numbers and self._numbers[number1] >= 2

return (
number1 in self.numbers
and self.numbers[number1] >= 1
and number2 in self.numbers
and self.numbers[number2] >= 1
number1 in self._numbers
and self._numbers[number1] >= 1
and number2 in self._numbers
and self._numbers[number2] >= 1
)

def step(
self, action: AgentAction
) -> Tuple[str, SupportsFloat, bool, bool, Dict[str, Any]]:
observation, reward, terminated, truncated, info = (
None,
0,
False,
False,
{"numbers": str(self)},
)

assert isinstance(action.tool_input, dict)
number1, number2 = (
float(action.tool_input["number1"]),
float(action.tool_input["number2"]),
)

if not self._verify_arguments(number1=number1, number2=number2):
observation = "Wrong arguments: not all numbers given as arguments to a tool call are available."
return observation, reward, terminated, truncated, info

if action.tool not in self.AVAILABLE_ACTIONS:
observation = f"Unknown tool. Currently available tools: {list(GameOf24.AVAILABLE_ACTIONS.keys())}."
return observation, reward, terminated, truncated, info

result = self.AVAILABLE_ACTIONS[action.tool]._run(number1, number2)

self._remove_number(number1)
self._remove_number(number2)
self._add_number(result)

observation, info = (
f"Calling {action.tool} with {number1} and {number2} leads to {result}.",
{"numbers": str(self)},
)

return observation, reward, terminated, truncated, info
result = self._action_executor.execute(action)
return result.observation

def reset(
self,
Expand All @@ -96,15 +85,13 @@ def reset(
) -> Tuple[str, Dict[str, Any]]:
super().reset(seed=seed)

numbers = options.get("numbers", []) if options else []
self.numbers = defaultdict(int)
for number in numbers:
self.numbers[number] += 1
self.numbers = options.get("numbers", []) if options else []

observation, info = "", {"numbers": self.numbers}

if options is None or "trajectory" not in options:
return "Reset environment.", {"numbers": str(self)}
if options is not None and "trajectory" in options:
for action in options["trajectory"]:
assert isinstance(action, AgentAction)
observation, reward, terminated, truncated, info = self.step(action)

trajectory: List[Tuple[AgentAction, str]] = options["trajectory"]
for action, observation in trajectory:
self.step(action)
return "Reset environment.", {"numbers": str(self)}
return observation, info
67 changes: 0 additions & 67 deletions environments/game_of_24/common/simple_tools.py

This file was deleted.

125 changes: 125 additions & 0 deletions environments/game_of_24/common/tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
from textwrap import dedent
from typing import Any, Tuple, Type, Dict

from langchain.pydantic_v1 import BaseModel, Field
from langchain.tools import BaseTool
from gymnasium.core import SupportsFloat
import gymnasium as gym
from abc import ABC, abstractmethod


class BaseGameof24Tool(BaseModel, ABC):
"""Base tool for a Game of 24 environment.
Environment is present as a field, but it won't be shown to models."""

env: gym.Env = Field(exclude=True)

class Config(BaseTool.Config):
pass

@abstractmethod
def _operation(self, number1: float, number2: float) -> float: ...

def _run(
self,
number1: int,
number2: int,
*args: Any,
**kwargs: Any,
) -> Tuple[str, SupportsFloat, bool, bool, Dict[str, Any]]:
reward, terminated, truncated, info = (
0,
self.env.is_terminated(), # type: ignore[attr-defined]
False,
{"numbers": self.env.numbers}, # type: ignore[attr-defined]
)
if terminated:
observation = "The environment has already been terminated."
return observation, reward, terminated, truncated, info

if not self.env.verify_arguments(number1=number1, number2=number2): # type: ignore[attr-defined]
observation = f"Wrong arguments: not all numbers given as arguments to a tool call are available (arguments: {number1}, {number2}, available numbers: {self.env.numbers}." # type: ignore[attr-defined]
return observation, reward, terminated, truncated, info

result = self._operation(number1=number1, number2=number2)

self.env.remove_number(number1) # type: ignore[attr-defined]
self.env.remove_number(number2) # type: ignore[attr-defined]
self.env.add_number(result) # type: ignore[attr-defined]

observation = f"result of current arithmetical operation on {number1} and {number2} is {result}"
reward = int(self.env.is_success()) # type: ignore[attr-defined]
terminated = self.env.is_terminated() # type: ignore[attr-defined]
info = {"numbers": self.env.numbers} # type: ignore[attr-defined]

return observation, reward, terminated, truncated, info


class CalculatorInput(BaseModel):
number1: float = Field(
description="The first argument in an arithmetical operation."
)
number2: float = Field(
description="The second argument in an arithmetical operation."
)


class AddTool(BaseGameof24Tool, BaseTool):
name = "add"
description = dedent("""
Adds two numbers. Returns the following:
* observation: the result of the addition;
* reward: 1 when the goal is reached (24 is obtained), 0 otherwise;
* terminated: if True, the game has ended: there's no possible actions anymore;
* truncated: if True, the time limit has been exceeded;
* info: the remaining numbers""")
args_schema: Type[BaseModel] = CalculatorInput

def _operation(self, number1: float, number2: float) -> float:
return number1 + number2


class SubtractTool(BaseGameof24Tool, BaseTool):
name = "subtract"
description = dedent("""
Subtracts the second number from the first one. Returns the following:
* observation: the result of the subtraction;
* reward: 1 when the goal is reached (24 is obtained), 0 otherwise;
* terminated: if True, the game has ended: there's no possible actions anymore;
* truncated: if True, the time limit has been exceeded;
* info: the remaining numbers""")
args_schema: Type[BaseModel] = CalculatorInput

def _operation(self, number1: float, number2: float) -> float:
return number1 - number2


class MultiplyTool(BaseGameof24Tool, BaseTool):
name = "multiply"
description = dedent("""
Multiplies two numbers. Returns the following:
* observation: the result of the multiplication;
* reward: 1 when the goal is reached (24 is obtained), 0 otherwise;
* terminated: if True, the game has ended: there's no possible actions anymore;
* truncated: if True, the time limit has been exceeded;
* info: the remaining numbers""")
args_schema: Type[BaseModel] = CalculatorInput

def _operation(self, number1: float, number2: float) -> float:
return number1 * number2


class DivideTool(BaseGameof24Tool, BaseTool):
name = "divide"
description = dedent("""
Divides the first number by the second one. Returns the following:
* observation: the result of the division;
* reward: 1 when the goal is reached (24 is obtained), 0 otherwise;
* terminated: if True, the game has ended: there's no possible actions anymore;
* truncated: if True, the time limit has been exceeded;
* info: the remaining numbers""")
args_schema: Type[BaseModel] = CalculatorInput

def _operation(self, number1: float, number2: float) -> float:
return number1 / number2
Loading

0 comments on commit 37b8b1e

Please sign in to comment.