Skip to content

Commit

Permalink
Merge pull request #95 from automl/train
Browse files Browse the repository at this point in the history
enhancement: Environment Reconciliation
  • Loading branch information
TheEimer committed Jul 14, 2023
2 parents af97301 + 1fcf02f commit 29dd0fd
Show file tree
Hide file tree
Showing 99 changed files with 1,080 additions and 2,499 deletions.
16 changes: 13 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,19 @@ carl.egg-info
exp_sweep
multirun
outputs
experiments
testvenv
*.egg-info
runs
*.png
*.pdf
*.csv
*.pickle
*.ipynb_checkpoints
*optgap*
*smac3*
*.json
generated
*egg*
core
*.png
*.tex
build
target
2 changes: 1 addition & 1 deletion .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@
url = https://github.com/Mawiszus/TOAD-GUI
[submodule "src/envs/mario/Mario-AI-Framework"]
path = src/envs/mario/Mario-AI-Framework
url = https://github.com/frederikschubert/Mario-AI-Framework
url = https://github.com/frederikschubert/Mario-AI-Framework
2 changes: 1 addition & 1 deletion CITATION.bib
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ @inproceedings { BenEim2023a
title = {Contextualize Me - The Case for Context in Reinforcement Learning},
journal = {Transactions on Machine Learning Research},
year = {2023},
}
}
14 changes: 7 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ pip install .

This will only install the basic classic control environments, which should run on most operating systems. For the full set of environments, use the install options:
```bash
pip install -e .[box2d, brax, mario, dm_control]
pip install -e .[box2d,brax,mario,dm_control]
```

These may not be compatible with Windows systems. Box2D environment may need to be installed via conda on MacOS systems:
Expand All @@ -68,12 +68,12 @@ Different instiations can be achieved by setting the context features to differe
## Cite Us
If you use CARL in your research, please cite our paper on the benchmark:
```bibtex
@inproceedings{BenEim2021a,
title = {CARL: A Benchmark for Contextual and Adaptive Reinforcement Learning},
author = {Carolin Benjamins and Theresa Eimer and Frederik Schubert and André Biedenkapp and Bodo Rosenhahn and Frank Hutter and Marius Lindauer},
booktitle = {NeurIPS 2021 Workshop on Ecological Theory of Reinforcement Learning},
year = {2021},
month = dec
@inproceedings{Benjamins2023,
title = {Contextualize Me -- The Case for Context in Reinforcement Learning},
author = {Carolin Benjamins and Theresa Eimer and Frederik Schubert and Aditya Mohan and Sebastian Döhler and André Biedenkapp and Bodo Rosenhan and Frank Hutter and Marius Lindauer},
booktitle = {Transactions on Machine Learning Research},
year = {2023},
month = Apr
}
```

Expand Down
79 changes: 58 additions & 21 deletions carl/context/sampling.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,27 @@
# flake8: noqa: W605
from typing import Any, Dict, List, Tuple
from __future__ import annotations
from typing import Any, Dict, List, Optional, Tuple, Union

import importlib

import numpy as np
from scipy.stats import norm
from scipy.stats import norm, rv_continuous, uniform

from carl import envs
import carl.envs
from carl.utils.types import Context, Contexts


def get_default_context_and_bounds(
env_name: str,
) -> Tuple[Dict[Any, Any], Dict[Any, Any]]:
) -> Tuple[
Context,
Dict[
str,
Union[
Tuple[Any, Any, Union[type, Tuple[type, type]]], Tuple[Any, Any, str, list]
],
],
]:
"""
Get context feature defaults and bounds for environment.
Expand All @@ -35,11 +46,11 @@ def get_default_context_and_bounds(
categorical context features:
``"VEHICLE": (None, None, "categorical", np.arange(0, len(PARKING_GARAGE)))``
"""
# TODO make less hacky / make explicit
env_defaults = getattr(envs, f"{env_name}_defaults")
env_bounds = getattr(envs, f"{env_name}_bounds")

return env_defaults, env_bounds
env_cls = getattr(carl.envs, env_name)
env_module = importlib.import_module(env_cls.__module__)
context_def = getattr(env_module, "DEFAULT_CONTEXT")
context_bounds = getattr(env_module, "CONTEXT_BOUNDS")
return context_def, context_bounds


def sample_contexts(
Expand All @@ -48,6 +59,9 @@ def sample_contexts(
num_contexts: int,
default_sample_std_percentage: float = 0.05,
fallback_sample_std: float = 0.1,
seed: Optional[int] = None,
uniform_distribution: bool = False,
uniform_bounds_rel: tuple(float, float) | None = None
) -> Dict[int, Dict[str, Any]]:
"""
Sample contexts.
Expand Down Expand Up @@ -102,6 +116,8 @@ def sample_contexts(
0.05.
fallback_sample_std: float, optional
The fallback relative standard deviation. Defaults to 0.1.
seed: int, optional
The seed for the sampling of the random variables.
Returns
-------
Expand All @@ -110,11 +126,15 @@ def sample_contexts(
names as keys and context feature values as values, e.g.,
"""
rng = np.random.default_rng(seed=seed)

# Get default context features and bounds
env_defaults, env_bounds = get_default_context_and_bounds(env_name=env_name)

# Create sample distributions/rules
sample_dists = {}
sample_dists: Dict[
str, Tuple[rv_continuous, Union[str, type, Tuple[type, type]]]
] = {}
for context_feature_name in env_defaults.keys():
if context_feature_name in context_feature_args:
if f"{context_feature_name}_mean" in context_feature_args:
Expand All @@ -141,7 +161,21 @@ def sample_contexts(
# the sample mean. Therefore we use a fallback sample standard deviation.
sample_std = fallback_sample_std # TODO change this back to sample_std

random_variable = norm(loc=sample_mean, scale=sample_std)
if not uniform_distribution:
random_variable = norm(loc=sample_mean, scale=sample_std)
else:
# bounds defined as [loc, loc+scale]
if sample_mean == 0:
# relative bounds are centered around 1 so subtract here for the percentages
loc = uniform_bounds_rel[0] - 1
scale = uniform_bounds_rel[1] - uniform_bounds_rel[0]
elif sample_mean < 0:
loc = uniform_bounds_rel[1] * sample_mean
scale = uniform_bounds_rel[0] * sample_mean - loc
else:
loc = uniform_bounds_rel[0] * sample_mean
scale = uniform_bounds_rel[1] * sample_mean - loc
random_variable = uniform(loc=loc, scale=scale)
context_feature_type = env_bounds[context_feature_name][2]
sample_dists[context_feature_name] = (random_variable, context_feature_type)

Expand All @@ -156,27 +190,30 @@ def sample_contexts(
random_variable = sample_dists[k][0]
context_feature_type = sample_dists[k][1]
lower_bound, upper_bound = env_bounds[k][0], env_bounds[k][1]
assert lower_bound <= upper_bound, f"context variable {k}: lower bound [{lower_bound}] is higher than upper bound [{upper_bound}]!"
if context_feature_type == list:
length = np.random.randint(
500000
) # TODO should we allow lists to be this long? or should we parametrize this?
arg_class = sample_dists[k][1][1]
context_list = random_variable.rvs(size=length)
arg_class = sample_dists[k][1][1] # type: ignore [index]
context_list = random_variable.rvs(size=length, random_state=rng)
context_list = np.clip(context_list, lower_bound, upper_bound)
c[k] = [arg_class(c) for c in context_list]
c[k] = [arg_class(c) for c in context_list] # type: ignore [operator]
elif context_feature_type == "categorical":
choices = env_bounds[k][3]
choice = np.random.choice(choices)
choices = env_bounds[k][3] # type: ignore [misc]
choice = rng.choice(choices)
c[k] = choice
elif context_feature_type == "conditional":
condition = env_bounds[k][4]
choices = env_bounds[k][3][condition]
choice = np.random.choice(choices)
condition = env_bounds[k][4] # type: ignore [misc]
choices = env_bounds[k][3][condition] # type: ignore [misc]
choice = rng.choice(choices)
c[k] = choice
else:
c[k] = random_variable.rvs(size=1)[0] # sample variable
c[k] = random_variable.rvs(size=1, random_state=rng)[
0
] # sample variable
c[k] = np.clip(c[k], lower_bound, upper_bound) # check bounds
c[k] = context_feature_type(c[k]) # cast to given type
c[k] = context_feature_type(c[k]) # type: ignore [operator] # cast to given type
else:
# No special sampling rule for context feature k, use the default context feature value
c[k] = env_defaults[k]
Expand Down
2 changes: 1 addition & 1 deletion carl/context/selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def context_key(self) -> Any | None:
Any | None
The key of the current context or None
"""
if self.context_id:
if self.context_id is not None:
key = self.contexts_keys[self.context_id]
else:
key = None
Expand Down
40 changes: 30 additions & 10 deletions carl/envs/box2d/__init__.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,42 @@
# flake8: noqa: F401
from carl.envs.box2d.carl_bipedal_walker import (
CONTEXT_BOUNDS as CARLBipedalWalkerEnv_bounds,
)
from carl.envs.box2d.carl_bipedal_walker import (
DEFAULT_CONTEXT as CARLBipedalWalkerEnv_defaults,
)
from carl.envs.box2d.carl_bipedal_walker import CARLBipedalWalkerEnv

# Contextenvs.s and bounds by name
from carl.envs.box2d.carl_lunarlander import CONTEXT_BOUNDS as CARLLunarLanderEnv_bounds
from functools import partial
import warnings

import gym
from carl.envs.box2d.carl_lunarlander import CARLLunarLanderEnv
from carl.envs.box2d.carl_lunarlander import (
DEFAULT_CONTEXT as CARLLunarLanderEnv_defaults,
)
from carl.envs.box2d.carl_lunarlander import CARLLunarLanderEnv
from carl.envs.box2d.carl_vehicle_racing import (
CONTEXT_BOUNDS as CARLVehicleRacingEnv_bounds,
)

from carl.envs.box2d.carl_vehicle_racing import CARLVehicleRacingEnv
from carl.envs.box2d.carl_vehicle_racing import (
DEFAULT_CONTEXT as CARLVehicleRacingEnv_defaults,
)
from carl.envs.box2d.carl_vehicle_racing import CARLVehicleRacingEnv
from carl.envs.box2d.carl_vehicle_racing import (
CONTEXT_BOUNDS as CARLVehicleRacingEnv_bounds,
)

from carl.envs.box2d.carl_bipedal_walker import CARLBipedalWalkerEnv
from carl.envs.box2d.carl_bipedal_walker import (
DEFAULT_CONTEXT as CARLBipedalWalkerEnv_defaults,
)
from carl.envs.box2d.carl_bipedal_walker import (
CONTEXT_BOUNDS as CARLBipedalWalkerEnv_bounds,
)

try:
from carl.envs.box2d.carl_bipedal_walker import CARLBipedalWalkerEnv
from gym.envs.registration import register

def make_env(**kwargs):
return CARLBipedalWalkerEnv(**kwargs)
register("CARLBipedalWalkerEnv-v0", entry_point=make_env)
register("CARLBipedalWalkerHardcoreEnv-v0", entry_point=partial(make_env, env=gym.make("BipedalWalkerHardcore-v3")))
except Exception as e:
warnings.warn(
f"Could not load CARLMarioEnv which is probably not installed ({e}).")
4 changes: 3 additions & 1 deletion carl/envs/box2d/carl_bipedal_walker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np
from Box2D.b2 import edgeShape, fixtureDef, polygonShape
import gym
from gym.envs.box2d import bipedal_walker
from gym.envs.box2d import bipedal_walker as bpw

Expand Down Expand Up @@ -105,7 +106,8 @@ def __init__(
instance_mode: str, optional
"""
if env is None:
env = bipedal_walker.BipedalWalker()
# env = bipedal_walker.BipedalWalker()
env = gym.make(id="BipedalWalker-v3")
if not contexts:
contexts = {0: DEFAULT_CONTEXT}
super().__init__(
Expand Down
39 changes: 21 additions & 18 deletions carl/envs/brax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
# flake8: noqa: F401
# Contexts and bounds by name
from carl.envs.brax.carl_ant import CONTEXT_BOUNDS as CARLAnt_bounds
from carl.envs.brax.carl_ant import DEFAULT_CONTEXT as CARLAnt_defaults
from carl.envs.brax.carl_ant import CARLAnt
from carl.envs.brax.carl_fetch import CONTEXT_BOUNDS as CARLFetch_bounds
from carl.envs.brax.carl_fetch import DEFAULT_CONTEXT as CARLFetch_defaults
from carl.envs.brax.carl_fetch import CARLFetch
from carl.envs.brax.carl_grasp import CONTEXT_BOUNDS as CARLGrasp_bounds
from carl.envs.brax.carl_grasp import DEFAULT_CONTEXT as CARLGrasp_defaults
from carl.envs.brax.carl_grasp import CARLGrasp
from carl.envs.brax.carl_halfcheetah import CONTEXT_BOUNDS as CARLHalfcheetah_bounds
from carl.envs.brax.carl_halfcheetah import DEFAULT_CONTEXT as CARLHalfcheetah_defaults
from carl.envs.brax.carl_halfcheetah import CARLHalfcheetah
from carl.envs.brax.carl_humanoid import CONTEXT_BOUNDS as CARLHumanoid_bounds
from carl.envs.brax.carl_humanoid import DEFAULT_CONTEXT as CARLHumanoid_defaults
from carl.envs.brax.carl_humanoid import CARLHumanoid
from carl.envs.brax.carl_ur5e import CONTEXT_BOUNDS as CARLUr5e_bounds
from carl.envs.brax.carl_ur5e import DEFAULT_CONTEXT as CARLUr5e_defaults
from carl.envs.brax.carl_ur5e import CARLUr5e
from carl.envs.braxenvs.carl_ant import CONTEXT_BOUNDS as CARLAnt_bounds
from carl.envs.braxenvs.carl_ant import DEFAULT_CONTEXT as CARLAnt_defaults
from carl.envs.braxenvs.carl_ant import CARLAnt
from carl.envs.braxenvs.carl_halfcheetah import CONTEXT_BOUNDS as CARLHalfcheetah_bounds
from carl.envs.braxenvs.carl_halfcheetah import DEFAULT_CONTEXT as CARLHalfcheetah_defaults
from carl.envs.braxenvs.carl_halfcheetah import CARLHalfcheetah
from carl.envs.braxenvs.carl_humanoid import CONTEXT_BOUNDS as CARLHumanoid_bounds
from carl.envs.braxenvs.carl_humanoid import DEFAULT_CONTEXT as CARLHumanoid_defaults
from carl.envs.braxenvs.carl_humanoid import CARLHumanoid
from carl.envs.braxenvs.carl_hopper import CONTEXT_BOUNDS as CARLHopper_bounds
from carl.envs.braxenvs.carl_hopper import DEFAULT_CONTEXT as CARLHopper_defaults
from carl.envs.braxenvs.carl_hopper import CARLHopper
from carl.envs.braxenvs.carl_reacher import CONTEXT_BOUNDS as CARLReacher_bounds
from carl.envs.braxenvs.carl_reacher import DEFAULT_CONTEXT as CARLReacher_defaults
from carl.envs.braxenvs.carl_reacher import CARLReacher
from carl.envs.braxenvs.carl_pusher import CONTEXT_BOUNDS as CARLPusher_bounds
from carl.envs.braxenvs.carl_pusher import DEFAULT_CONTEXT as CARLPusher_defaults
from carl.envs.braxenvs.carl_pusher import CARLPusher
from carl.envs.braxenvs.carl_double_pendulum import CONTEXT_BOUNDS as CARLInvertedDoublePendulum_bounds
from carl.envs.braxenvs.carl_double_pendulum import DEFAULT_CONTEXT as CARLInvertedDoublePendulum_defaults
from carl.envs.braxenvs.carl_double_pendulum import CARLInvertedDoublePendulum
Loading

0 comments on commit 29dd0fd

Please sign in to comment.