Skip to content

Commit

Permalink
Merge pull request #97 from automl/language_goals
Browse files Browse the repository at this point in the history
Goals for CARL
  • Loading branch information
TheEimer authored Feb 9, 2024
2 parents bc6ef39 + 0a7e4c9 commit 43f1d21
Show file tree
Hide file tree
Showing 19 changed files with 1,207 additions and 41 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ format-isort:
format: format-black format-isort

test:
$(PYTEST) test
$(PYTEST) --disable-warnings test

cov-report:
coverage html -d coverage_html
Expand Down
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
<img align="left" width="80" src="./docs/source/figures/CARL_logo.png" alt="CARL">

# – The Benchmark Library
[![PyPI Version](https://img.shields.io/pypi/v/carl-bench.svg)](https://pypi.python.org/pypi/carl-bench)
[![Test](https://github.com/automl/carl/actions/workflows/tests.yaml/badge.svg)](https://github.com/automl/carl/actions/workflows/tests.yaml)
[![Doc Status](https://github.com/automl/carl/actions/workflows/docs.yaml/badge.svg)](https://github.com/automl/carl/actions/workflows/docs.yaml)
[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)

CARL (context adaptive RL) provides highly configurable contextual extensions
to several well-known RL environments.
It's designed to test your agent's generalization capabilities
Expand Down
180 changes: 180 additions & 0 deletions carl/envs/brax/brax_walker_goal_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
import gym
import numpy as np
from brax.io import mjcf
from etils import epath

STATE_INDICES = {
"ant": [13, 14],
"humanoid": [22, 23],
"halfcheetah": [14, 15],
"hopper": [5, 6],
"walker2d": [8, 9],
}

DIRECTION_NAMES = {
1: "north",
3: "south",
2: "east",
4: "west",
12: "north east",
32: "south east",
14: "north west",
34: "south west",
112: "north north east",
332: "south south east",
114: "north north west",
334: "south south west",
212: "east north east",
232: "east south east",
414: "west north west",
434: "west south west",
}

directions = [
1, # north
3, # south
2, # east
4, # west
12,
32,
14,
34,
112,
332,
114,
334,
212,
232,
414,
434,
]


class BraxWalkerGoalWrapper(gym.Wrapper):
"""Adds a positional goal to brax walker envs"""

def __init__(self, env: gym.Env, env_name: str, asset_path: str) -> None:
super().__init__(env)
self.env_name = env_name
if (
self.env_name == "humanoid"
or self.env_name == "halfcheetah"
or self.env_name == "hopper"
or self.env_name == "walker2d"
):
self.env._forward_reward_weight = 0
self.context = None
self.position = None
self.goal_position = None
self.goal_radius = None
self.direction_values = {
3: [0, -1],
1: [0, 1],
2: [1, 0],
4: [-1, 0],
34: [-np.sqrt(0.5), -np.sqrt(0.5)],
14: [-np.sqrt(0.5), np.sqrt(0.5)],
32: [np.sqrt(0.5), -np.sqrt(0.5)],
12: [np.sqrt(0.5), np.sqrt(0.5)],
334: [
-np.cos(22.5 * np.pi / 180),
-np.sin(22.5 * np.pi / 180),
],
434: [
-np.sin(22.5 * np.pi / 180),
-np.cos(22.5 * np.pi / 180),
],
114: [
-np.cos(22.5 * np.pi / 180),
np.sin(22.5 * np.pi / 180),
],
414: [
-np.sin(22.5 * np.pi / 180),
np.cos(22.5 * np.pi / 180),
],
332: [
np.cos(22.5 * np.pi / 180),
-np.sin(22.5 * np.pi / 180),
],
232: [
np.sin(22.5 * np.pi / 180),
-np.cos(22.5 * np.pi / 180),
],
112: [
np.cos(22.5 * np.pi / 180),
np.sin(22.5 * np.pi / 180),
],
212: [np.sin(22.5 * np.pi / 180), np.cos(22.5 * np.pi / 180)],
}
path = epath.resource_path("brax") / asset_path
sys = mjcf.load(path)
self.dt = sys.dt

def reset(self, seed=None, options={}):
state, info = self.env.reset(seed=seed, options=options)
self.position = (0, 0)
self.goal_position = (
np.array(self.direction_values[self.context["target_direction"]])
* self.context["target_distance"]
)
self.goal_radius = self.context["target_radius"]
info["success"] = 0
return state, info

def step(self, action):
state, _, te, tr, info = self.env.step(action)
indices = STATE_INDICES[self.env_name]
new_position = (
np.array(list(self.position))
+ np.array([state[indices[0]], state[indices[1]]]) * self.dt
)
current_distance_to_goal = np.linalg.norm(self.goal_position - new_position)
previous_distance_to_goal = np.linalg.norm(self.goal_position - self.position)
direction_reward = max(0, previous_distance_to_goal - current_distance_to_goal)
self.position = new_position
if abs(current_distance_to_goal) <= self.goal_radius:
te = True
info["success"] = 1
else:
info["success"] = 0
return state, direction_reward, te, tr, info


class BraxLanguageWrapper(gym.Wrapper):
"""Translates the context features target distance and target radius into language"""

def __init__(self, env) -> None:
super().__init__(env)
self.context = None

def reset(self, seed=None, options={}):
self.env.context = self.context
state, info = self.env.reset(seed=seed, options=options)
goal_str = self.get_goal_desc(self.context)
if isinstance(state, dict):
state["goal"] = goal_str
else:
state = {"obs": state, "goal": goal_str}
return state, info

def step(self, action):
state, reward, te, tr, info = self.env.step(action)
goal_str = self.get_goal_desc(self.context)
if isinstance(state, dict):
state["goal"] = goal_str
else:
state = {"obs": state, "goal": goal_str}
return state, reward, te, tr, info

def get_goal_desc(self, context):
if "target_radius" in context.keys():
target_distance = context["target_distance"]
target_direction = context["target_direction"]
target_radius = context["target_radius"]
return f"""The distance to the goal is {target_distance}m
{DIRECTION_NAMES[target_direction]}.
Move within {target_radius} steps of the goal."""
else:
target_distance = context["target_distance"]
target_direction = context["target_direction"]
return f"Move {target_distance}m {DIRECTION_NAMES[target_direction]}."
16 changes: 15 additions & 1 deletion carl/envs/brax/carl_ant.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@

import numpy as np

from carl.context.context_space import ContextFeature, UniformFloatContextFeature
from carl.context.context_space import (
CategoricalContextFeature,
ContextFeature,
UniformFloatContextFeature,
)
from carl.envs.brax.brax_walker_goal_wrapper import directions
from carl.envs.brax.carl_brax_env import CARLBraxEnv


Expand Down Expand Up @@ -32,4 +37,13 @@ def get_context_features() -> dict[str, ContextFeature]:
"viscosity": UniformFloatContextFeature(
"viscosity", lower=0, upper=np.inf, default_value=0
),
"target_distance": UniformFloatContextFeature(
"target_distance", lower=0, upper=np.inf, default_value=100
),
"target_direction": CategoricalContextFeature(
"target_direction", choices=directions, default_value=1
),
"target_radius": UniformFloatContextFeature(
"target_radius", lower=0.1, upper=np.inf, default_value=5
),
}
86 changes: 85 additions & 1 deletion carl/envs/brax/carl_brax_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,13 @@
from jax import numpy as jp

from carl.context.selection import AbstractSelector
from carl.envs.brax.brax_walker_goal_wrapper import (
BraxLanguageWrapper,
BraxWalkerGoalWrapper,
)
from carl.envs.brax.wrappers import GymWrapper, VectorGymWrapper
from carl.envs.carl_env import CARLEnv
from carl.utils.types import Contexts
from carl.utils.types import Context, Contexts


def set_geom_attr(
Expand Down Expand Up @@ -152,6 +156,7 @@ def __init__(
obs_context_as_dict: bool = True,
context_selector: AbstractSelector | type[AbstractSelector] | None = None,
context_selector_kwargs: dict = None,
use_language_goals: bool = False,
**kwargs,
) -> None:
"""
Expand Down Expand Up @@ -204,6 +209,37 @@ def __init__(
dtype=np.float32,
)

if contexts is not None:
if (
"target_distance" in contexts[list(contexts.keys())[0]].keys()
or "target_direction" in contexts[list(contexts.keys())[0]].keys()
):
assert all(
[
"target_direction" in contexts[list(contexts.keys())[i]].keys()
for i in range(len(contexts))
]
), "All contexts must have a 'target_direction' key"
assert all(
[
"target_distance" in contexts[list(contexts.keys())[i]].keys()
for i in range(len(contexts))
]
), "All contexts must have a 'target_distance' key"
base_dir = contexts[list(contexts.keys())[0]]["target_direction"]
base_dist = contexts[list(contexts.keys())[0]]["target_distance"]
max_diff_dir = max(
[c["target_direction"] - base_dir for c in contexts.values()]
)
max_diff_dist = max(
[c["target_distance"] - base_dist for c in contexts.values()]
)
if max_diff_dir > 0.1 or max_diff_dist > 0.1:
env = BraxWalkerGoalWrapper(env, self.env_name, self.asset_path)
if use_language_goals:
env = BraxLanguageWrapper(env)
self.use_language_goals = use_language_goals

super().__init__(
env=env,
contexts=contexts,
Expand All @@ -213,6 +249,7 @@ def __init__(
context_selector_kwargs=context_selector_kwargs,
**kwargs,
)
self.env.context = self.context

def _update_context(self) -> None:
context = self.context
Expand All @@ -224,6 +261,9 @@ def _update_context(self) -> None:
"gravity",
"viscosity",
"elasticity",
"target_distance",
"target_direction",
"target_radius",
]
check_context(context, registered_cfs)

Expand Down Expand Up @@ -252,3 +292,47 @@ def _update_context(self) -> None:
sys = sys.replace(geoms=updated_geoms)

self.env.unwrapped.sys = sys

def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[Any, dict[str, Any]]:
"""Overwrites reset in super to update context in wrapper."""
last_context_id = self.context_id
self._progress_instance()
if self.context_id != last_context_id:
self._update_context()
self.env.context = self.context
state, info = self.env.reset(seed=seed, options=options)
state = self._add_context_to_state(state)
info["context_id"] = self.context_id
return state, info

@classmethod
def get_default_context(cls) -> Context:
"""Get the default context (without any goal features)
Returns
-------
Context
Default context.
"""
default_context = cls.get_context_space().get_default_context()
if "target_distance" in default_context:
del default_context["target_distance"]
if "target_direction" in default_context:
del default_context["target_direction"]
if "target_radius" in default_context:
del default_context["target_radius"]
return default_context

@classmethod
def get_default_goal_context(cls) -> Context:
"""Get the default context (with goal features)
Returns
-------
Context
Default context.
"""
default_context = cls.get_context_space().get_default_context()
return default_context
16 changes: 15 additions & 1 deletion carl/envs/brax/carl_halfcheetah.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@

import numpy as np

from carl.context.context_space import ContextFeature, UniformFloatContextFeature
from carl.context.context_space import (
CategoricalContextFeature,
ContextFeature,
UniformFloatContextFeature,
)
from carl.envs.brax.brax_walker_goal_wrapper import directions
from carl.envs.brax.carl_brax_env import CARLBraxEnv


Expand Down Expand Up @@ -50,4 +55,13 @@ def get_context_features() -> dict[str, ContextFeature]:
"mass_ffoot": UniformFloatContextFeature(
"mass_ffoot", lower=1e-6, upper=np.inf, default_value=0.8845188
),
"target_distance": UniformFloatContextFeature(
"target_distance", lower=0, upper=np.inf, default_value=100
),
"target_direction": CategoricalContextFeature(
"target_direction", choices=directions, default_value=1
),
"target_radius": UniformFloatContextFeature(
"target_radius", lower=0.1, upper=np.inf, default_value=5
),
}
Loading

0 comments on commit 43f1d21

Please sign in to comment.