-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #101 from automl/dev_rebase
Enhancement: Environment Reconciliation
- Loading branch information
Showing
26 changed files
with
713 additions
and
307 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
# Copyright 2023 The Brax Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Wrappers to convert brax envs to gym envs.""" | ||
from typing import ClassVar, Optional | ||
|
||
from brax.envs import Env | ||
import gym | ||
from gym import spaces | ||
from gym.vector import utils | ||
import jax | ||
import numpy as np | ||
from functools import partial | ||
|
||
|
||
class GymWrapper(gym.Env): | ||
"""A wrapper that converts Brax Env to one that follows Gym API.""" | ||
|
||
# Flag that prevents `gym.register` from misinterpreting the `_step` and | ||
# `_reset` as signs of a deprecated gym Env API. | ||
_gym_disable_underscore_compat: ClassVar[bool] = True | ||
|
||
def __init__(self, | ||
env: Env, | ||
seed: int = 0, | ||
backend: Optional[str] = None): | ||
self._env = env | ||
self.metadata = { | ||
'render.modes': ['human', 'rgb_array'], | ||
'video.frames_per_second': 1 / self._env.dt | ||
} | ||
self.seed(seed) | ||
self.backend = backend | ||
self._state = None | ||
|
||
obs = np.inf * np.ones(self._env.observation_size, dtype='float32') | ||
self.observation_space = spaces.Box(-obs, obs, dtype='float32') | ||
|
||
action = np.ones(self._env.action_size, dtype='float32') | ||
self.action_space = spaces.Box(-action, action, dtype='float32') | ||
|
||
def reset(key): | ||
key1, key2 = jax.random.split(key) | ||
state = self._env.reset(key2) | ||
return state, state.obs, key1 | ||
|
||
self._reset = partial(reset) | ||
|
||
def step(state, action): | ||
state = self._env.step(state, action) | ||
info = {**state.metrics, **state.info} | ||
return state, state.obs, state.reward, state.done, info | ||
|
||
self._step = partial(step) | ||
|
||
def reset(self): | ||
self._state, obs, self._key = self._reset(self._key) | ||
# We return device arrays for pytorch users. | ||
return obs | ||
|
||
def step(self, action): | ||
self._state, obs, reward, done, info = self._step(self._state, action) | ||
# We return device arrays for pytorch users. | ||
return obs, reward, done, info | ||
|
||
def seed(self, seed: int = 0): | ||
self._key = jax.random.PRNGKey(seed) | ||
|
||
def render(self, mode='human'): | ||
return super().render(mode=mode) # just raise an exception | ||
|
||
|
||
class VectorGymWrapper(gym.vector.VectorEnv): | ||
"""A wrapper that converts batched Brax Env to one that follows Gym VectorEnv API.""" | ||
|
||
# Flag that prevents `gym.register` from misinterpreting the `_step` and | ||
# `_reset` as signs of a deprecated gym Env API. | ||
_gym_disable_underscore_compat: ClassVar[bool] = True | ||
|
||
def __init__(self, | ||
env: Env, | ||
seed: int = 0, | ||
backend: Optional[str] = None): | ||
self._env = env | ||
self.metadata = { | ||
'render.modes': ['human', 'rgb_array'], | ||
'video.frames_per_second': 1 / self._env.dt | ||
} | ||
if not hasattr(self._env, 'batch_size'): | ||
raise ValueError('underlying env must be batched') | ||
|
||
self.num_envs = self._env.batch_size | ||
self.seed(seed) | ||
self.backend = backend | ||
self._state = None | ||
|
||
obs = np.inf * np.ones(self._env.observation_size, dtype='float32') | ||
obs_space = spaces.Box(-obs, obs, dtype='float32') | ||
self.observation_space = utils.batch_space(obs_space, self.num_envs) | ||
|
||
action = np.ones(self._env.action_size, dtype='float32') | ||
action_space = spaces.Box(-action, action, dtype='float32') | ||
self.action_space = utils.batch_space(action_space, self.num_envs) | ||
|
||
def reset(key): | ||
key1, key2 = jax.random.split(key) | ||
state = self._env.reset(key2) | ||
return state, state.obs, key1 | ||
|
||
self._reset = partial(reset) | ||
|
||
def step(state, action): | ||
state = self._env.step(state, action) | ||
info = {**state.metrics, **state.info} | ||
return state, state.obs, state.reward, state.done, info | ||
|
||
self._step = partial(step) | ||
|
||
def reset(self): | ||
self._state, obs, self._key = self._reset(self._key) | ||
return obs | ||
|
||
def step(self, action): | ||
self._state, obs, reward, done, info = self._step(self._state, action) | ||
return obs, reward, done, info | ||
|
||
def seed(self, seed: int = 0): | ||
self._key = jax.random.PRNGKey(seed) | ||
|
||
def render(self, mode='human'): | ||
if mode == 'rgb_array': | ||
sys, state = self._env.sys, self._state | ||
if state is None: | ||
raise RuntimeError('must call reset or step before rendering') | ||
return image.render_array(sys, state.state.take(0), 256, 256) | ||
else: | ||
return super().render(mode=mode) # just raise an exception |
Oops, something went wrong.