Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhancement: Environment Reconciliation #100

Closed
wants to merge 717 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
717 commits
Select commit Hold shift + click to select a range
477096f
Update exp name and use SAC
Nov 13, 2022
7de0844
New path
Nov 13, 2022
9540e45
Update README.md
Nov 13, 2022
47fe7bc
Update
Nov 13, 2022
2117e6c
Always save agent at the end of training
Nov 14, 2022
c7b6616
Update paths
Nov 14, 2022
02f2ef1
Create script
Nov 14, 2022
e4ddcfe
Can set snap_dir
Nov 14, 2022
00327e5
Add config
Nov 14, 2022
820cba8
Current
Nov 14, 2022
7cf165b
Always use absolute paths
Nov 14, 2022
e28fb94
Add context efficiency experiment
Nov 14, 2022
c8a81f1
Fix commands
Nov 14, 2022
7ee14a9
Ignore more
Nov 14, 2022
aa9d466
Fix paths
Nov 14, 2022
85962ff
Fix paths
Nov 15, 2022
105c95e
Increase wandb table length
Nov 15, 2022
e29a4b9
Add opt gap exp
Nov 15, 2022
1c0f223
Opt gap
Nov 15, 2022
34cfb1a
Fix number of table entries
Nov 15, 2022
2b3390a
switch dmc gravity to positive values
sebidoe Nov 16, 2022
5941af6
Merge branch 'train' of github.com:automl/CARL into train
sebidoe Nov 16, 2022
b264ca2
Add all visibilities
Nov 16, 2022
ca10726
Find all run subdirs from hydra
Nov 17, 2022
1ed8198
Don't log to wandb
Nov 17, 2022
6224817
Increase default network width 32 -> 256
Nov 17, 2022
fb33e2a
Fix path
Nov 17, 2022
bed148a
Add warning if context ids can't be merged
Nov 17, 2022
66b6203
Fix path building
Nov 17, 2022
926c595
Merge branch 'train' of https://github.com/automl/CARL into train
Nov 17, 2022
01e9a6c
Fix save dirs
Nov 18, 2022
19cf2f6
Fix commands
Nov 18, 2022
8e3a953
Add overrides
Nov 21, 2022
992539e
Add runcommand
Nov 21, 2022
86d1ffb
Remove custom network widths
Nov 21, 2022
31b112b
Consider n samples if contexts from path
Nov 21, 2022
8324396
Update bash
Nov 21, 2022
ed9b6f6
Update to new brax legacy spring
Nov 22, 2022
bb455a5
Update commands
Nov 22, 2022
d45da80
Add wandb.debug=true to training
Nov 22, 2022
5745472
Consider n samples
Nov 22, 2022
e31c3d2
Train for 1M steps
Nov 22, 2022
20fce4c
Print context set length
Nov 22, 2022
a8a4fcd
Add optgap specific loading
Nov 23, 2022
330c1bf
fix sac training
frederikschubert Nov 23, 2022
e01220e
Update Readme.md
benjamc Nov 24, 2022
6d015ee
Fix context branch width
Dec 11, 2022
25c5008
Fix n atoms
Dec 11, 2022
b1f184e
Fix network width
Dec 11, 2022
fd6224d
Current
Dec 13, 2022
e23acdc
Merge branch 'train' of https://github.com/automl/CARL into train
Dec 13, 2022
d180fee
Current
Dec 13, 2022
5634401
Fix sampling of uniform bounds if mean <= 0
Dec 15, 2022
c3cd94e
Fix import
Dec 15, 2022
328f3b4
Fix arg for carl env
Dec 15, 2022
ba0b261
update default context for halfcheetah brax env
frederikschubert Dec 20, 2022
04bdf42
fix checkpointing
frederikschubert Dec 20, 2022
34c8fae
Add hidden as default
Dec 21, 2022
739d7fe
Add ppo
Dec 21, 2022
93cba97
Add exp
Dec 21, 2022
ce52b42
fix serialization of policies and q functions
frederikschubert Dec 21, 2022
ccb22a9
Update ppo
Dec 21, 2022
f5a1297
Merge branch 'train' of https://github.com/automl/CARL into train
Dec 21, 2022
7e3f49d
Increase exp counter
Dec 21, 2022
d3b208a
DDPG draft
Dec 21, 2022
37c151e
Merge branch 'train' of https://github.com/automl/CARL into train
Dec 21, 2022
3360236
Add ddpg to algorithm choices
Dec 21, 2022
bf7bb02
Missing Hps
Dec 21, 2022
fdfa086
fix experience replay
frederikschubert Dec 21, 2022
6f00128
Merge branch 'train' of https://github.com/automl/CARL into train
Dec 21, 2022
6e4d675
Merge branch 'train' of https://github.com/automl/CARL into train
Dec 21, 2022
5034c8a
Add exp
Dec 21, 2022
bff68c6
Fix loading (TODO: all algos)
Dec 21, 2022
ef57599
Current
Dec 21, 2022
2086723
Current
Dec 22, 2022
f6bd3dc
Add new plots
Jan 3, 2023
bacd89c
Change n final eval episodes
Jan 4, 2023
34de70d
Fix var
Jan 5, 2023
9af18fa
Fix pi loading
Jan 5, 2023
bdbe611
Add todo
Jan 6, 2023
771ef13
Increase num frames
Jan 6, 2023
dd738c1
Add command
Jan 6, 2023
08d593f
Current
Jan 6, 2023
11779b3
Add hidden on variations experiment
Jan 7, 2023
1301f45
Fix filter
Jan 7, 2023
718092f
Add utility function
Jan 17, 2023
9afc0a7
Update
Jan 17, 2023
e1e73b4
Check wandb tags
Jan 17, 2023
ec9c0c4
Update
Jan 17, 2023
9b5a7c0
Current
Jan 17, 2023
292a1fe
Trial
Jan 17, 2023
5b080c3
Add loading policy for SAC
Feb 7, 2023
94e63ef
New nb
Feb 7, 2023
26cf124
Check more
Feb 8, 2023
79259bf
Add info about rendering
Feb 8, 2023
2a6cb2e
Update
Feb 8, 2023
2de9359
Adapt to all envs
Feb 14, 2023
b3e6789
Rename file
Feb 14, 2023
826bb5f
Add Pendulum infos
Feb 14, 2023
e288b5d
Current
Feb 14, 2023
8f6b55d
Fix r string
Feb 14, 2023
c60fc37
Fix loading
Feb 14, 2023
9a43344
Current
Feb 14, 2023
6097f75
Add new kirk ep experiment
Feb 16, 2023
0584469
Update
Feb 16, 2023
431a133
Current
Feb 19, 2023
d2abc2a
Current
Feb 20, 2023
696a966
Current
Feb 20, 2023
6b56c61
Update plotting
Feb 21, 2023
c9b79d5
Ignore more
Feb 21, 2023
5d14ff8
Some stuff
Feb 21, 2023
59fa98d
Squashed commit of the following:
Feb 22, 2023
489badf
Merge main into train
Feb 22, 2023
f187f1f
Print agent hps as latex table
Feb 22, 2023
7191879
Update kirk eval
Feb 22, 2023
f73efe5
Add optgap experiment setup
Feb 22, 2023
2aa50f8
Set default for Pendulum
Feb 22, 2023
a32f01c
Update commands
Feb 22, 2023
8dc58ec
Set default agent
Feb 23, 2023
2b61909
Update
Feb 24, 2023
4069b8a
remove spaces
TheEimer Feb 28, 2023
d43be6e
less bash errors
TheEimer Mar 1, 2023
ba81c71
Update README.md
benjamc Mar 12, 2023
349c7bb
Update
Apr 3, 2023
6421069
Update
Apr 3, 2023
74656dd
Update
Apr 3, 2023
8719756
Make dir
Apr 3, 2023
e42081d
Current
Apr 3, 2023
cc06efe
Update
Apr 3, 2023
b815854
Merge branch 'train' of https://github.com/automl/CARL into train
Apr 3, 2023
9043cd9
Current
Apr 4, 2023
4ffdd63
Merge branch 'main' into train
Apr 4, 2023
9b7f238
Update
Apr 14, 2023
8984c63
Current
Apr 14, 2023
5c21e68
Add proper statistic eval
Apr 20, 2023
bc47e2f
Current
Apr 20, 2023
888eeb6
correct walker data
TheEimer Apr 20, 2023
ce237de
some hps + vehicle architecture
TheEimer May 5, 2023
f3e265c
some reshaping, car racing conv works now
TheEimer May 5, 2023
600f4ad
fixed brax issues, rendering for car racing
TheEimer May 8, 2023
b0fd36c
settings for brax + carracing
TheEimer May 10, 2023
4ef683c
Update
May 11, 2023
917f88c
Add base hpo config
May 11, 2023
fb1dc91
Current
May 11, 2023
876046e
Merge branch 'train' of https://github.com/automl/CARL into train
May 11, 2023
7ae90a3
brax context fix
TheEimer May 11, 2023
1ffa020
improve CARLMario setup and align with current API
frederikschubert May 11, 2023
dccc783
update gitignore
frederikschubert May 11, 2023
e328400
clipping as hp
TheEimer May 12, 2023
8baf1a5
Adjust loading for PPO
May 12, 2023
ba5d11e
Update Bipedalwalker ppo HPs
May 12, 2023
b17e4f3
Switch order of wandb debug check
May 12, 2023
f12c7bc
Update eval dir for hidden on variations
May 12, 2023
be634c2
Current
May 12, 2023
7f8dff4
Remove imports
May 12, 2023
405bdab
Update requirements
May 12, 2023
4714a81
Add pip freeze of a working env
May 12, 2023
6ae3714
Update README
May 12, 2023
c213795
Merge branch 'train' of https://github.com/automl/CARL into train
May 12, 2023
ae5fba2
make carlmario work with new api
frederikschubert May 12, 2023
b89f4e5
revert setup.py
frederikschubert May 12, 2023
15f1666
carlmario fixes
frederikschubert May 12, 2023
4506599
Rename brax -> braxenvs
May 12, 2023
dc72818
Add CARL Brax env
May 12, 2023
a326fcb
Current
May 12, 2023
103d6f2
Merge branch 'train' of https://github.com/automl/CARL into train
May 12, 2023
7e58a73
brax ppo version
TheEimer May 12, 2023
7c5a68c
Merge branch 'train' of https://github.com/automl/CARL into train
TheEimer May 12, 2023
7474f6f
First fixes
May 12, 2023
e4a066c
Merge branch 'train' of https://github.com/automl/CARL into train
TheEimer May 12, 2023
b4c4791
improve context configuration for carl mario
frederikschubert May 12, 2023
980a99b
Fix brax imports
May 12, 2023
73219e2
several fixes
frederikschubert May 13, 2023
4107514
fix carlmario ppo training
frederikschubert May 13, 2023
cd8afdd
trainig updates
TheEimer May 15, 2023
6b4b275
Merge branch 'train' of https://github.com/automl/CARL into train
TheEimer May 15, 2023
e46d215
Clean notebook
May 15, 2023
dd2c2d1
Format
May 15, 2023
b1d915e
Adjust outdir + write data
May 15, 2023
3de40a8
Fix paths
May 15, 2023
87eb633
Fix colors and aspect ratio
May 15, 2023
864a3e9
Add latex str
May 15, 2023
e167ce5
Save figure tex
May 15, 2023
7c070c8
Current
May 15, 2023
821859c
Current
May 15, 2023
9001658
Current
May 15, 2023
4555a34
integrate virtual display into mario env
frederikschubert May 16, 2023
907f8ac
brax update
TheEimer May 16, 2023
33dbfea
Merge branch 'train' of https://github.com/automl/CARL into train
TheEimer May 16, 2023
123374a
brax env defaults, md train eval
TheEimer May 17, 2023
0739698
ppos coexist now
TheEimer May 17, 2023
fd1135b
ppos coexist now
TheEimer May 17, 2023
2866d1d
Add optuna config
May 19, 2023
c333ea9
Add base config for SMAC
May 19, 2023
c772588
Add config notes + increase budget
May 19, 2023
3495c9d
Increase time limit
May 19, 2023
3fe5127
Add option to add delta seed
May 19, 2023
3efcfe3
Current
May 19, 2023
be1b990
Add commands
May 19, 2023
8a22c73
Add note about HPO
May 19, 2023
6e34c8c
Fix Ant
May 19, 2023
41d34cb
Make env v3
May 19, 2023
111db77
Ax config (ax does not work on my setup)
May 19, 2023
7ac3beb
Merge branch 'train' of github.com:automl/CARL into train
May 19, 2023
136ef2f
add plot notebook for mario
frederikschubert May 22, 2023
ed5f16f
Do not use HydraConfig
May 22, 2023
6431494
Fix training
May 23, 2023
58d96c3
Add SAC Hps
May 23, 2023
bacd16e
Merge branch 'train' of github.com:automl/CARL into train
May 23, 2023
060fbcf
register bipedal walker as gym env
frederikschubert May 23, 2023
c5eb8b8
return context_id with gym info
frederikschubert May 23, 2023
f5fd324
return context key instead of id
frederikschubert May 24, 2023
06b3696
Add swig dependency for box2d
May 25, 2023
d4f6b5b
Current
May 25, 2023
08ac7d0
Update
May 25, 2023
430559b
Merge branch 'train' of github.com:automl/CARL into train
May 25, 2023
fd6fa75
Fix annotations
May 25, 2023
48ba0a3
Fix retrieval of dt for brax env
May 25, 2023
08ecfc3
Revert "Fix retrieval of dt for brax env"
May 25, 2023
36c30c0
Always hide context per default
May 26, 2023
2c8e8ca
Fix batch size
May 26, 2023
9bd5ce1
Add super brax env
May 26, 2023
070fc7f
Brax envs inherit from brax super env
May 26, 2023
5332866
Make format
May 26, 2023
1760706
Revert "Make format"
May 26, 2023
8e5fc57
Comment brax draft
May 26, 2023
8e3753a
Rename
May 26, 2023
1c85272
Reset env to test update context
May 26, 2023
eb3dfb8
Merge branch 'train' of github.com:automl/CARL into train
May 26, 2023
60ef3da
Current
May 30, 2023
cb6b55f
renamed braxenvs to brax
amsks Jun 16, 2023
20e983a
updated name to brax names in environments
amsks Jun 16, 2023
3b373f6
readded MARIO and TOAD-GUI to gitmodules
amsks Jul 13, 2023
4904a16
Added citations and removed ___carl_brax_env.py
amsks Jul 13, 2023
b411c42
Changed braxenvs to brax
amsks Jul 13, 2023
90f9b5f
removed requirements clone
amsks Jul 13, 2023
673e87f
removed notes
amsks Jul 13, 2023
4501bad
removed experiments
amsks Jul 13, 2023
f0dfbfd
readded docs
amsks Jul 13, 2023
36ffb0f
removed mario from main. Should be instaled with submodule
amsks Jul 13, 2023
2c25647
Changed Citations to Contextualize Me
amsks Jul 14, 2023
bb3d92e
removed wokrshop paper
amsks Jul 14, 2023
1fcf02f
Merge branch 'main' into train
amsks Jul 14, 2023
cc822d9
removed workshop branch reference
amsks Jul 14, 2023
ea06019
Added fetch
amsks Jul 14, 2023
2b39530
added grasp
amsks Jul 14, 2023
582d516
added ur5e
amsks Jul 14, 2023
7f8a48e
added docs
amsks Jul 14, 2023
3778c6f
added mario
amsks Jul 14, 2023
b13542b
Merge branch 'development' into train
amsks Jul 14, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
/dataSources.local.xml
/httpRequests/
tmp/
slurm.sh
slurm.shssssssssssssssssssssssssssssssssssssssssssssssssssss
carl/runscripts/generated
docs/html
docs/apidoc
Expand All @@ -20,12 +20,19 @@ carl.egg-info
exp_sweep
multirun
outputs
testvenv
*.egg-info
runs
*.tex
*.png
*.pdf
*.csv
*.json
*.pickle
*.egg-info
*code-workspace
*.ipynb_checkpoints
*optgap*
*smac3*
*.json
generated
core
*.tex
build
target
2 changes: 1 addition & 1 deletion .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@
url = https://github.com/Mawiszus/TOAD-GUI
[submodule "src/envs/mario/Mario-AI-Framework"]
path = src/envs/mario/Mario-AI-Framework
url = https://github.com/frederikschubert/Mario-AI-Framework
url = https://github.com/frederikschubert/Mario-AI-Framework
2 changes: 1 addition & 1 deletion CITATION.bib
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ @inproceedings { BenEim2023a
title = {Contextualize Me - The Case for Context in Reinforcement Learning},
journal = {Transactions on Machine Learning Research},
year = {2023},
}
}
16 changes: 7 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ pip install .

This will only install the basic classic control environments, which should run on most operating systems. For the full set of environments, use the install options:
```bash
pip install -e .[box2d, brax, mario, dm_control]
pip install -e .[box2d,brax,mario,dm_control]
```

These may not be compatible with Windows systems. Box2D environment may need to be installed via conda on MacOS systems:
Expand All @@ -68,17 +68,15 @@ Different instiations can be achieved by setting the context features to differe
## Cite Us
If you use CARL in your research, please cite our paper on the benchmark:
```bibtex
@inproceedings{BenEim2021a,
title = {CARL: A Benchmark for Contextual and Adaptive Reinforcement Learning},
author = {Carolin Benjamins and Theresa Eimer and Frederik Schubert and André Biedenkapp and Bodo Rosenhahn and Frank Hutter and Marius Lindauer},
booktitle = {NeurIPS 2021 Workshop on Ecological Theory of Reinforcement Learning},
year = {2021},
month = dec
@inproceedings{Benjamins2023,
title = {Contextualize Me -- The Case for Context in Reinforcement Learning},
author = {Carolin Benjamins and Theresa Eimer and Frederik Schubert and Aditya Mohan and Sebastian Döhler and André Biedenkapp and Bodo Rosenhan and Frank Hutter and Marius Lindauer},
booktitle = {Transactions on Machine Learning Research},
year = {2023},
month = Apr
}
```

You can find the code and experiments for this paper in the `neurips_ecorl_workshop_2021` branch.

## References
[OpenAI gym, Brockman et al., 2016. arXiv preprint arXiv:1606.01540](https://arxiv.org/pdf/1606.01540.pdf)

Expand Down
22 changes: 20 additions & 2 deletions carl/context/sampling.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# flake8: noqa: W605
from __future__ import annotations
from typing import Any, Dict, List, Optional, Tuple, Union

import importlib

import numpy as np
from scipy.stats import norm, rv_continuous
from scipy.stats import norm, rv_continuous, uniform

import carl.envs
from carl.utils.types import Context, Contexts
Expand Down Expand Up @@ -59,6 +60,8 @@ def sample_contexts(
default_sample_std_percentage: float = 0.05,
fallback_sample_std: float = 0.1,
seed: Optional[int] = None,
uniform_distribution: bool = False,
uniform_bounds_rel: tuple(float, float) | None = None
) -> Dict[int, Dict[str, Any]]:
"""
Sample contexts.
Expand Down Expand Up @@ -158,7 +161,21 @@ def sample_contexts(
# the sample mean. Therefore we use a fallback sample standard deviation.
sample_std = fallback_sample_std # TODO change this back to sample_std

random_variable = norm(loc=sample_mean, scale=sample_std)
if not uniform_distribution:
random_variable = norm(loc=sample_mean, scale=sample_std)
else:
# bounds defined as [loc, loc+scale]
if sample_mean == 0:
# relative bounds are centered around 1 so subtract here for the percentages
loc = uniform_bounds_rel[0] - 1
scale = uniform_bounds_rel[1] - uniform_bounds_rel[0]
elif sample_mean < 0:
loc = uniform_bounds_rel[1] * sample_mean
scale = uniform_bounds_rel[0] * sample_mean - loc
else:
loc = uniform_bounds_rel[0] * sample_mean
scale = uniform_bounds_rel[1] * sample_mean - loc
random_variable = uniform(loc=loc, scale=scale)
context_feature_type = env_bounds[context_feature_name][2]
sample_dists[context_feature_name] = (random_variable, context_feature_type)

Expand All @@ -173,6 +190,7 @@ def sample_contexts(
random_variable = sample_dists[k][0]
context_feature_type = sample_dists[k][1]
lower_bound, upper_bound = env_bounds[k][0], env_bounds[k][1]
assert lower_bound <= upper_bound, f"context variable {k}: lower bound [{lower_bound}] is higher than upper bound [{upper_bound}]!"
if context_feature_type == list:
length = rng.integers(
5e5
Expand Down
2 changes: 1 addition & 1 deletion carl/context/selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def context_key(self) -> Any | None:
Any | None
The key of the current context or None
"""
if self.context_id:
if self.context_id is not None:
key = self.contexts_keys[self.context_id]
else:
key = None
Expand Down
40 changes: 30 additions & 10 deletions carl/envs/box2d/__init__.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,42 @@
# flake8: noqa: F401
from carl.envs.box2d.carl_bipedal_walker import (
CONTEXT_BOUNDS as CARLBipedalWalkerEnv_bounds,
)
from carl.envs.box2d.carl_bipedal_walker import (
DEFAULT_CONTEXT as CARLBipedalWalkerEnv_defaults,
)
from carl.envs.box2d.carl_bipedal_walker import CARLBipedalWalkerEnv

# Contextenvs.s and bounds by name
from carl.envs.box2d.carl_lunarlander import CONTEXT_BOUNDS as CARLLunarLanderEnv_bounds
from functools import partial
import warnings

import gym
from carl.envs.box2d.carl_lunarlander import CARLLunarLanderEnv
from carl.envs.box2d.carl_lunarlander import (
DEFAULT_CONTEXT as CARLLunarLanderEnv_defaults,
)
from carl.envs.box2d.carl_lunarlander import CARLLunarLanderEnv
from carl.envs.box2d.carl_vehicle_racing import (
CONTEXT_BOUNDS as CARLVehicleRacingEnv_bounds,
)

from carl.envs.box2d.carl_vehicle_racing import CARLVehicleRacingEnv
from carl.envs.box2d.carl_vehicle_racing import (
DEFAULT_CONTEXT as CARLVehicleRacingEnv_defaults,
)
from carl.envs.box2d.carl_vehicle_racing import CARLVehicleRacingEnv
from carl.envs.box2d.carl_vehicle_racing import (
CONTEXT_BOUNDS as CARLVehicleRacingEnv_bounds,
)

from carl.envs.box2d.carl_bipedal_walker import CARLBipedalWalkerEnv
from carl.envs.box2d.carl_bipedal_walker import (
DEFAULT_CONTEXT as CARLBipedalWalkerEnv_defaults,
)
from carl.envs.box2d.carl_bipedal_walker import (
CONTEXT_BOUNDS as CARLBipedalWalkerEnv_bounds,
)

try:
from carl.envs.box2d.carl_bipedal_walker import CARLBipedalWalkerEnv
from gym.envs.registration import register

def make_env(**kwargs):
return CARLBipedalWalkerEnv(**kwargs)
register("CARLBipedalWalkerEnv-v0", entry_point=make_env)
register("CARLBipedalWalkerHardcoreEnv-v0", entry_point=partial(make_env, env=gym.make("BipedalWalkerHardcore-v3")))
except Exception as e:
warnings.warn(
f"Could not load CARLMarioEnv which is probably not installed ({e}).")
3 changes: 2 additions & 1 deletion carl/envs/box2d/carl_bipedal_walker.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ def __init__(
instance_mode: str, optional
"""
if env is None:
env = bipedal_walker.BipedalWalker()
# env = bipedal_walker.BipedalWalker()
env = gym.make(id="BipedalWalker-v3")
if not contexts:
contexts = {0: DEFAULT_CONTEXT}
super().__init__(
Expand Down
39 changes: 21 additions & 18 deletions carl/envs/brax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
# flake8: noqa: F401
# Contexts and bounds by name
from carl.envs.brax.carl_ant import CONTEXT_BOUNDS as CARLAnt_bounds
from carl.envs.brax.carl_ant import DEFAULT_CONTEXT as CARLAnt_defaults
from carl.envs.brax.carl_ant import CARLAnt
from carl.envs.brax.carl_fetch import CONTEXT_BOUNDS as CARLFetch_bounds
from carl.envs.brax.carl_fetch import DEFAULT_CONTEXT as CARLFetch_defaults
from carl.envs.brax.carl_fetch import CARLFetch
from carl.envs.brax.carl_grasp import CONTEXT_BOUNDS as CARLGrasp_bounds
from carl.envs.brax.carl_grasp import DEFAULT_CONTEXT as CARLGrasp_defaults
from carl.envs.brax.carl_grasp import CARLGrasp
from carl.envs.brax.carl_halfcheetah import CONTEXT_BOUNDS as CARLHalfcheetah_bounds
from carl.envs.brax.carl_halfcheetah import DEFAULT_CONTEXT as CARLHalfcheetah_defaults
from carl.envs.brax.carl_halfcheetah import CARLHalfcheetah
from carl.envs.brax.carl_humanoid import CONTEXT_BOUNDS as CARLHumanoid_bounds
from carl.envs.brax.carl_humanoid import DEFAULT_CONTEXT as CARLHumanoid_defaults
from carl.envs.brax.carl_humanoid import CARLHumanoid
from carl.envs.brax.carl_ur5e import CONTEXT_BOUNDS as CARLUr5e_bounds
from carl.envs.brax.carl_ur5e import DEFAULT_CONTEXT as CARLUr5e_defaults
from carl.envs.brax.carl_ur5e import CARLUr5e
from carl.envs.braxenvs.carl_ant import CONTEXT_BOUNDS as CARLAnt_bounds
from carl.envs.braxenvs.carl_ant import DEFAULT_CONTEXT as CARLAnt_defaults
from carl.envs.braxenvs.carl_ant import CARLAnt
from carl.envs.braxenvs.carl_halfcheetah import CONTEXT_BOUNDS as CARLHalfcheetah_bounds
from carl.envs.braxenvs.carl_halfcheetah import DEFAULT_CONTEXT as CARLHalfcheetah_defaults
from carl.envs.braxenvs.carl_halfcheetah import CARLHalfcheetah
from carl.envs.braxenvs.carl_humanoid import CONTEXT_BOUNDS as CARLHumanoid_bounds
from carl.envs.braxenvs.carl_humanoid import DEFAULT_CONTEXT as CARLHumanoid_defaults
from carl.envs.braxenvs.carl_humanoid import CARLHumanoid
from carl.envs.braxenvs.carl_hopper import CONTEXT_BOUNDS as CARLHopper_bounds
from carl.envs.braxenvs.carl_hopper import DEFAULT_CONTEXT as CARLHopper_defaults
from carl.envs.braxenvs.carl_hopper import CARLHopper
from carl.envs.braxenvs.carl_reacher import CONTEXT_BOUNDS as CARLReacher_bounds
from carl.envs.braxenvs.carl_reacher import DEFAULT_CONTEXT as CARLReacher_defaults
from carl.envs.braxenvs.carl_reacher import CARLReacher
from carl.envs.braxenvs.carl_pusher import CONTEXT_BOUNDS as CARLPusher_bounds
from carl.envs.braxenvs.carl_pusher import DEFAULT_CONTEXT as CARLPusher_defaults
from carl.envs.braxenvs.carl_pusher import CARLPusher
from carl.envs.braxenvs.carl_double_pendulum import CONTEXT_BOUNDS as CARLInvertedDoublePendulum_bounds
from carl.envs.braxenvs.carl_double_pendulum import DEFAULT_CONTEXT as CARLInvertedDoublePendulum_defaults
from carl.envs.braxenvs.carl_double_pendulum import CARLInvertedDoublePendulum
148 changes: 148 additions & 0 deletions carl/envs/brax/brax_wrappers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
# Copyright 2023 The Brax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Wrappers to convert brax envs to gym envs."""
from typing import ClassVar, Optional

from brax.envs import Env
import gym
from gym import spaces
from gym.vector import utils
import jax
import numpy as np
from functools import partial


class GymWrapper(gym.Env):
"""A wrapper that converts Brax Env to one that follows Gym API."""

# Flag that prevents `gym.register` from misinterpreting the `_step` and
# `_reset` as signs of a deprecated gym Env API.
_gym_disable_underscore_compat: ClassVar[bool] = True

def __init__(self,
env: Env,
seed: int = 0,
backend: Optional[str] = None):
self._env = env
self.metadata = {
'render.modes': ['human', 'rgb_array'],
'video.frames_per_second': 1 / self._env.dt
}
self.seed(seed)
self.backend = backend
self._state = None

obs = np.inf * np.ones(self._env.observation_size, dtype='float32')
self.observation_space = spaces.Box(-obs, obs, dtype='float32')

action = np.ones(self._env.action_size, dtype='float32')
self.action_space = spaces.Box(-action, action, dtype='float32')

def reset(key):
key1, key2 = jax.random.split(key)
state = self._env.reset(key2)
return state, state.obs, key1

self._reset = partial(reset)

def step(state, action):
state = self._env.step(state, action)
info = {**state.metrics, **state.info}
return state, state.obs, state.reward, state.done, info

self._step = partial(step)

def reset(self):
self._state, obs, self._key = self._reset(self._key)
# We return device arrays for pytorch users.
return obs

def step(self, action):
self._state, obs, reward, done, info = self._step(self._state, action)
# We return device arrays for pytorch users.
return obs, reward, done, info

def seed(self, seed: int = 0):
self._key = jax.random.PRNGKey(seed)

def render(self, mode='human'):
return super().render(mode=mode) # just raise an exception


class VectorGymWrapper(gym.vector.VectorEnv):
"""A wrapper that converts batched Brax Env to one that follows Gym VectorEnv API."""

# Flag that prevents `gym.register` from misinterpreting the `_step` and
# `_reset` as signs of a deprecated gym Env API.
_gym_disable_underscore_compat: ClassVar[bool] = True

def __init__(self,
env: Env,
seed: int = 0,
backend: Optional[str] = None):
self._env = env
self.metadata = {
'render.modes': ['human', 'rgb_array'],
'video.frames_per_second': 1 / self._env.dt
}
if not hasattr(self._env, 'batch_size'):
raise ValueError('underlying env must be batched')

self.num_envs = self._env.batch_size
self.seed(seed)
self.backend = backend
self._state = None

obs = np.inf * np.ones(self._env.observation_size, dtype='float32')
obs_space = spaces.Box(-obs, obs, dtype='float32')
self.observation_space = utils.batch_space(obs_space, self.num_envs)

action = np.ones(self._env.action_size, dtype='float32')
action_space = spaces.Box(-action, action, dtype='float32')
self.action_space = utils.batch_space(action_space, self.num_envs)

def reset(key):
key1, key2 = jax.random.split(key)
state = self._env.reset(key2)
return state, state.obs, key1

self._reset = partial(reset)

def step(state, action):
state = self._env.step(state, action)
info = {**state.metrics, **state.info}
return state, state.obs, state.reward, state.done, info

self._step = partial(step)

def reset(self):
self._state, obs, self._key = self._reset(self._key)
return obs

def step(self, action):
self._state, obs, reward, done, info = self._step(self._state, action)
return obs, reward, done, info

def seed(self, seed: int = 0):
self._key = jax.random.PRNGKey(seed)

def render(self, mode='human'):
if mode == 'rgb_array':
sys, state = self._env.sys, self._state
if state is None:
raise RuntimeError('must call reset or step before rendering')
return image.render_array(sys, state.state.take(0), 256, 256)
else:
return super().render(mode=mode) # just raise an exception
Loading