Skip to content

Commit

Permalink
Merge pull request #101 from automl/dev_rebase
Browse files Browse the repository at this point in the history
Enhancement: Environment Reconciliation
  • Loading branch information
TheEimer authored Jul 17, 2023
2 parents 818c262 + 6a991fd commit d68bcda
Show file tree
Hide file tree
Showing 26 changed files with 713 additions and 307 deletions.
15 changes: 11 additions & 4 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,19 @@ carl.egg-info
exp_sweep
multirun
outputs
testvenv
*.egg-info
runs
*.tex
*.png
*.pdf
*.csv
*.json
*.pickle
*.egg-info
*code-workspace
*.ipynb_checkpoints
*optgap*
*smac3*
*.json
generated
core
*.tex
build
target
3 changes: 0 additions & 3 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
[submodule "src/envs/rna/learna"]
path = src/envs/rna/learna
url = https://github.com/automl/learna.git
[submodule "src/envs/mario/TOAD-GUI"]
path = src/envs/mario/TOAD-GUI
url = https://github.com/Mawiszus/TOAD-GUI
Expand Down
22 changes: 14 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,16 +68,22 @@ Different instiations can be achieved by setting the context features to differe
## Cite Us
If you use CARL in your research, please cite our paper on the benchmark:
```bibtex
@inproceedings{BenEim2021a,
title = {CARL: A Benchmark for Contextual and Adaptive Reinforcement Learning},
author = {Carolin Benjamins and Theresa Eimer and Frederik Schubert and André Biedenkapp and Bodo Rosenhahn and Frank Hutter and Marius Lindauer},
booktitle = {NeurIPS 2021 Workshop on Ecological Theory of Reinforcement Learning},
year = {2021},
month = dec
@inproceedings { BenEim2023a,
author = {Carolin Benjamins and
Theresa Eimer and
Frederik Schubert and
Aditya Mohan and
Sebastian Döhler and
André Biedenkapp and
Bodo Rosenhahn and
Frank Hutter and
Marius Lindauer},
title = {Contextualize Me - The Case for Context in Reinforcement Learning},
journal = {Transactions on Machine Learning Research},
year = {2023},
}
```
You can find the code and experiments for this paper in the `neurips_ecorl_workshop_2021` branch.
```

## References
[OpenAI gym, Brockman et al., 2016. arXiv preprint arXiv:1606.01540](https://arxiv.org/pdf/1606.01540.pdf)
Expand Down
15 changes: 15 additions & 0 deletions carl/envs/brax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
from carl.envs.brax.carl_ant import CONTEXT_BOUNDS as CARLAnt_bounds
from carl.envs.brax.carl_ant import DEFAULT_CONTEXT as CARLAnt_defaults
from carl.envs.brax.carl_ant import CARLAnt
from carl.envs.brax.carl_double_pendulum import CONTEXT_BOUNDS as CARLInvertedDoublePendulum_bounds
from carl.envs.brax.carl_double_pendulum import DEFAULT_CONTEXT as CARLInvertedDoublePendulum_defaults
from carl.envs.brax.carl_double_pendulum import CARLInvertedDoublePendulum
from carl.envs.brax.carl_fetch import CONTEXT_BOUNDS as CARLFetch_bounds
from carl.envs.brax.carl_fetch import DEFAULT_CONTEXT as CARLFetch_defaults
from carl.envs.brax.carl_fetch import CARLFetch
Expand All @@ -12,9 +15,21 @@
from carl.envs.brax.carl_halfcheetah import CONTEXT_BOUNDS as CARLHalfcheetah_bounds
from carl.envs.brax.carl_halfcheetah import DEFAULT_CONTEXT as CARLHalfcheetah_defaults
from carl.envs.brax.carl_halfcheetah import CARLHalfcheetah
from carl.envs.brax.carl_hopper import CONTEXT_BOUNDS as CARLHopper_bounds
from carl.envs.brax.carl_hopper import DEFAULT_CONTEXT as CARLHopper_defaults
from carl.envs.brax.carl_hopper import CARLHopper
from carl.envs.brax.carl_humanoid import CONTEXT_BOUNDS as CARLHumanoid_bounds
from carl.envs.brax.carl_humanoid import DEFAULT_CONTEXT as CARLHumanoid_defaults
from carl.envs.brax.carl_humanoid import CARLHumanoid
from carl.envs.brax.carl_pusher import CONTEXT_BOUNDS as CARLPusher_bounds
from carl.envs.brax.carl_pusher import DEFAULT_CONTEXT as CARLPusher_defaults
from carl.envs.brax.carl_pusher import CARLPusher
from carl.envs.brax.carl_reacher import CONTEXT_BOUNDS as CARLReacher_bounds
from carl.envs.brax.carl_reacher import DEFAULT_CONTEXT as CARLReacher_defaults
from carl.envs.brax.carl_reacher import CARLReacher
from carl.envs.brax.carl_ur5e import CONTEXT_BOUNDS as CARLUr5e_bounds
from carl.envs.brax.carl_ur5e import DEFAULT_CONTEXT as CARLUr5e_defaults
from carl.envs.brax.carl_ur5e import CARLUr5e



148 changes: 148 additions & 0 deletions carl/envs/brax/brax_wrappers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
# Copyright 2023 The Brax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Wrappers to convert brax envs to gym envs."""
from typing import ClassVar, Optional

from brax.envs import Env
import gym
from gym import spaces
from gym.vector import utils
import jax
import numpy as np
from functools import partial


class GymWrapper(gym.Env):
"""A wrapper that converts Brax Env to one that follows Gym API."""

# Flag that prevents `gym.register` from misinterpreting the `_step` and
# `_reset` as signs of a deprecated gym Env API.
_gym_disable_underscore_compat: ClassVar[bool] = True

def __init__(self,
env: Env,
seed: int = 0,
backend: Optional[str] = None):
self._env = env
self.metadata = {
'render.modes': ['human', 'rgb_array'],
'video.frames_per_second': 1 / self._env.dt
}
self.seed(seed)
self.backend = backend
self._state = None

obs = np.inf * np.ones(self._env.observation_size, dtype='float32')
self.observation_space = spaces.Box(-obs, obs, dtype='float32')

action = np.ones(self._env.action_size, dtype='float32')
self.action_space = spaces.Box(-action, action, dtype='float32')

def reset(key):
key1, key2 = jax.random.split(key)
state = self._env.reset(key2)
return state, state.obs, key1

self._reset = partial(reset)

def step(state, action):
state = self._env.step(state, action)
info = {**state.metrics, **state.info}
return state, state.obs, state.reward, state.done, info

self._step = partial(step)

def reset(self):
self._state, obs, self._key = self._reset(self._key)
# We return device arrays for pytorch users.
return obs

def step(self, action):
self._state, obs, reward, done, info = self._step(self._state, action)
# We return device arrays for pytorch users.
return obs, reward, done, info

def seed(self, seed: int = 0):
self._key = jax.random.PRNGKey(seed)

def render(self, mode='human'):
return super().render(mode=mode) # just raise an exception


class VectorGymWrapper(gym.vector.VectorEnv):
"""A wrapper that converts batched Brax Env to one that follows Gym VectorEnv API."""

# Flag that prevents `gym.register` from misinterpreting the `_step` and
# `_reset` as signs of a deprecated gym Env API.
_gym_disable_underscore_compat: ClassVar[bool] = True

def __init__(self,
env: Env,
seed: int = 0,
backend: Optional[str] = None):
self._env = env
self.metadata = {
'render.modes': ['human', 'rgb_array'],
'video.frames_per_second': 1 / self._env.dt
}
if not hasattr(self._env, 'batch_size'):
raise ValueError('underlying env must be batched')

self.num_envs = self._env.batch_size
self.seed(seed)
self.backend = backend
self._state = None

obs = np.inf * np.ones(self._env.observation_size, dtype='float32')
obs_space = spaces.Box(-obs, obs, dtype='float32')
self.observation_space = utils.batch_space(obs_space, self.num_envs)

action = np.ones(self._env.action_size, dtype='float32')
action_space = spaces.Box(-action, action, dtype='float32')
self.action_space = utils.batch_space(action_space, self.num_envs)

def reset(key):
key1, key2 = jax.random.split(key)
state = self._env.reset(key2)
return state, state.obs, key1

self._reset = partial(reset)

def step(state, action):
state = self._env.step(state, action)
info = {**state.metrics, **state.info}
return state, state.obs, state.reward, state.done, info

self._step = partial(step)

def reset(self):
self._state, obs, self._key = self._reset(self._key)
return obs

def step(self, action):
self._state, obs, reward, done, info = self._step(self._state, action)
return obs, reward, done, info

def seed(self, seed: int = 0):
self._key = jax.random.PRNGKey(seed)

def render(self, mode='human'):
if mode == 'rgb_array':
sys, state = self._env.sys, self._state
if state is None:
raise RuntimeError('must call reset or step before rendering')
return image.render_array(sys, state.state.take(0), 256, 256)
else:
return super().render(mode=mode) # just raise an exception
Loading

0 comments on commit d68bcda

Please sign in to comment.