Skip to content

Commit

Permalink
add support for gym environments (#31)
Browse files Browse the repository at this point in the history
* add support for gym environments, fixes #5

* revert
  • Loading branch information
maxpumperla authored Nov 3, 2021
1 parent a41a4d9 commit f7d123c
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 4 deletions.
69 changes: 66 additions & 3 deletions pathmind/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
import numpy as np
import prettytable
import yaml
from gym import Env
from gym.spaces import Box as GymContinuous
from gym.spaces import Discrete as GymDiscrete
from or_gym import Env as OrEnv
from prettytable import PrettyTable

__all__ = ["Discrete", "Continuous", "Simulation"]
Expand Down Expand Up @@ -125,8 +129,6 @@ def run(
agents = range(self.number_of_agents())
table, summary = _define_tables(self, agents)

reward_terms = [0 for n in summary.field_names if not n == "Episode"]

for episode in range(num_episodes):

step = 0
Expand Down Expand Up @@ -157,7 +159,6 @@ def run(
terms = [
v for agent_id in agents for v in self.get_reward(agent_id).values()
]

summary.add_row([episode] + terms)

if debug_mode:
Expand Down Expand Up @@ -272,3 +273,65 @@ def _define_tables(simulation, agents):
]

return table, summary


def from_gym(gym_instance: Union[Env, OrEnv]) -> Simulation:
"""
:param gym_instance: gym or OR-gym environment
:return: A pathmind environment
"""

class GymSimulation(Simulation):
def __init__(self, gym_instance: Union[Env, OrEnv], *args, **kwargs):
super().__init__(*args, **kwargs)
self.env = gym_instance
self.observations: Dict[str, float] = {}
self.rewards: Dict[str, float] = {"reward": 0}
self.done: bool = False

def number_of_agents(self) -> int:
return 1

def action_space(self, agent_id: int) -> Union[Continuous, Discrete]:
gym_space = self.env.action_space
if isinstance(gym_space, GymDiscrete):
# TODO take care of MultiDiscrete.
space = Discrete(choices=gym_space.n)
elif isinstance(gym_space, GymContinuous):
space = Continuous(
shape=gym_space.shape, low=gym_space.low, high=gym_space.high
)
else:
raise ValueError(
f"Unsupported gym.spaces type {type(gym_space)}. Pathmind currently only allows"
f"gym.spaces.Discrete and gym.spaces.Box as valid action spaces."
)
return space

def step(self) -> None:
# This assumes "choices=1"
action = self.action[0][0]
obs, rew, done, _ = self.env.step(action)
self.observations = {f"obs_{i}": o for i, o in enumerate(obs)}
self.rewards = {"reward": rew}
self.done = done

def reset(self) -> None:
obs = self.env.reset()
self.observations = {f"obs_{i}": o for i, o in enumerate(obs)}
self.done = False

def get_reward(self, agent_id: int) -> Dict[str, float]:
return self.rewards

def get_observation(
self, agent_id: int
) -> Dict[str, Union[float, List[float]]]:
return self.observations

def is_done(self, agent_id: int) -> bool:
return self.done

sim = GymSimulation(gym_instance=gym_instance)
return sim
9 changes: 8 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,14 @@
download_url="https://github.com/PathmindAPI/pathmind_api/tarball/0.3",
author="Max Pumperla",
author_email="max@pathmind.com",
install_requires=["pyyaml", "tensorflow", "requests", "prettytable"],
install_requires=[
"pyyaml",
"tensorflow",
"requests",
"prettytable",
"gym",
"or-gym",
],
extras_require={
"tests": ["pytest", "flake8", "flake8-debugger", "pre-commit", "pandas"]
},
Expand Down
15 changes: 15 additions & 0 deletions tests/test_run_simulation.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import os
import pathlib

import gym
import numpy as np
import or_gym
import pandas as pd
import pytest
from examples.mouse.mouse_env_pathmind import MouseAndCheese
from examples.mouse.multi_mouse_env_pathmind import MultiMouseAndCheese

from pathmind.policy import Local, Random, Server
from pathmind.simulation import from_gym

PATH = pathlib.Path(__file__).parent.resolve()

Expand Down Expand Up @@ -114,3 +117,15 @@ def test_policy_predictions():
action = server.get_actions(simulation)
simulation.set_action(action)
simulation.step()


def test_from_gym():
env = gym.make("CartPole-v0")
sim = from_gym(env)
sim.run(Random())


def test_from_or_gym():
env = or_gym.make("Knapsack-v0")
sim = from_gym(env)
sim.run(Random())
16 changes: 16 additions & 0 deletions tests/test_train_simulation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import gym
import or_gym
import pytest
from examples.mouse.mouse_env_pathmind import MouseAndCheese
from examples.mouse.multi_mouse_env_pathmind import MultiMouseAndCheese

from pathmind.simulation import from_gym


def test_training():
simulation = MouseAndCheese()
Expand All @@ -12,3 +16,15 @@ def test_multi_training():
pytest.skip("Needs multi-agent training to work on web app")
simulation = MultiMouseAndCheese()
simulation.train()


def test_from_gym():
env = gym.make("CartPole-v0")
sim = from_gym(env)
sim.train()


def test_from_or_gym():
env = or_gym.make("Knapsack-v0")
sim = from_gym(env)
sim.train()

0 comments on commit f7d123c

Please sign in to comment.