Skip to content

Commit

Permalink
fix: tests run again
Browse files Browse the repository at this point in the history
  • Loading branch information
TheEimer committed Feb 9, 2024
1 parent b0b9da1 commit a576d18
Show file tree
Hide file tree
Showing 11 changed files with 107 additions and 50 deletions.
11 changes: 8 additions & 3 deletions carl/envs/brax/brax_walker_goal_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
class BraxWalkerGoalWrapper(gym.Wrapper):
"""Adds a positional goal to brax walker envs"""

def __init__(self, env, env_name, asset_path) -> None:
def __init__(self, env: gym.Env, env_name: str, asset_path: str) -> None:
super().__init__(env)
self.env_name = env_name
if (
Expand All @@ -66,6 +66,7 @@ def __init__(self, env, env_name, asset_path) -> None:
self.context = None
self.position = None
self.goal_position = None
self.goal_radius = None
self.direction_values = {
3: [0, -1],
1: [0, 1],
Expand Down Expand Up @@ -116,6 +117,7 @@ def reset(self, seed=None, options={}):
np.array(self.direction_values[self.context["target_direction"]])
* self.context["target_distance"]
)
self.goal_radius = self.context["target_radius"]
info["success"] = 0
return state, info

Expand All @@ -130,7 +132,7 @@ def step(self, action):
previous_distance_to_goal = np.linalg.norm(self.goal_position - self.position)
direction_reward = max(0, previous_distance_to_goal - current_distance_to_goal)
self.position = new_position
if abs(current_distance_to_goal) <= 5:
if abs(current_distance_to_goal) <= self.goal_radius:
te = True
info["success"] = 1
else:
Expand Down Expand Up @@ -168,8 +170,11 @@ def step(self, action):
def get_goal_desc(self, context):
if "target_radius" in context.keys():
target_distance = context["target_distance"]
target_direction = context["target_direction"]
target_radius = context["target_radius"]
return f"The distance to the goal is {target_distance}m. Move within {target_radius} steps of the goal."
return f"""The distance to the goal is {target_distance}m
{DIRECTION_NAMES[target_direction]}.
Move within {target_radius} steps of the goal."""
else:
target_distance = context["target_distance"]
target_direction = context["target_direction"]
Expand Down
3 changes: 3 additions & 0 deletions carl/envs/brax/carl_ant.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,7 @@ def get_context_features() -> dict[str, ContextFeature]:
"target_direction": CategoricalContextFeature(
"target_direction", choices=directions, default_value=1
),
"target_radius": UniformFloatContextFeature(
"target_radius", lower=0.1, upper=np.inf, default_value=5
),
}
15 changes: 13 additions & 2 deletions carl/envs/brax/carl_brax_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,18 @@ def __init__(
"target_distance" in contexts[list(contexts.keys())[0]].keys()
or "target_direction" in contexts[list(contexts.keys())[0]].keys()
):
assert all(
[
"target_direction" in contexts[list(contexts.keys())[i]].keys()
for i in range(len(contexts))
]
), "All contexts must have a 'target_direction' key"
assert all(
[
"target_distance" in contexts[list(contexts.keys())[i]].keys()
for i in range(len(contexts))
]
), "All contexts must have a 'target_distance' key"
base_dir = contexts[list(contexts.keys())[0]]["target_direction"]
base_dist = contexts[list(contexts.keys())[0]]["target_distance"]
max_diff_dir = max(
Expand Down Expand Up @@ -251,6 +263,7 @@ def _update_context(self) -> None:
"elasticity",
"target_distance",
"target_direction",
"target_radius",
]
check_context(context, registered_cfs)

Expand Down Expand Up @@ -288,8 +301,6 @@ def reset(
self._progress_instance()
if self.context_id != last_context_id:
self._update_context()
# if self.use_language_goals:
# self.env.env.context = self.context
self.env.context = self.context
state, info = self.env.reset(seed=seed, options=options)
state = self._add_context_to_state(state)
Expand Down
3 changes: 3 additions & 0 deletions carl/envs/brax/carl_halfcheetah.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,7 @@ def get_context_features() -> dict[str, ContextFeature]:
"target_direction": CategoricalContextFeature(
"target_direction", choices=directions, default_value=1
),
"target_radius": UniformFloatContextFeature(
"target_radius", lower=0.1, upper=np.inf, default_value=5
),
}
3 changes: 3 additions & 0 deletions carl/envs/brax/carl_hopper.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,7 @@ def get_context_features() -> dict[str, ContextFeature]:
"target_direction": CategoricalContextFeature(
"target_direction", choices=directions, default_value=1
),
"target_radius": UniformFloatContextFeature(
"target_radius", lower=0.1, upper=np.inf, default_value=5
),
}
3 changes: 3 additions & 0 deletions carl/envs/brax/carl_humanoid.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,4 +79,7 @@ def get_context_features() -> dict[str, ContextFeature]:
"target_direction": CategoricalContextFeature(
"target_direction", choices=directions, default_value=1
),
"target_radius": UniformFloatContextFeature(
"target_radius", lower=0.1, upper=np.inf, default_value=5
),
}
9 changes: 8 additions & 1 deletion carl/envs/brax/carl_pusher.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from copy import deepcopy

import numpy as np

from carl.context.context_space import ContextFeature, UniformFloatContextFeature
Expand Down Expand Up @@ -89,8 +91,13 @@ def get_context_features() -> dict[str, ContextFeature]:
}

def _update_context(self) -> None:
super()._update_context()
goal_x = self.context["goal_position_x"]
goal_y = self.context["goal_position_y"]
goal_z = self.context["goal_position_z"]
context = deepcopy(self.context)
del self.context["goal_position_x"]
del self.context["goal_position_y"]
del self.context["goal_position_z"]
super()._update_context()
self.env._goal_pos = np.array([goal_x, goal_y, goal_z])
self.context = context
3 changes: 3 additions & 0 deletions carl/envs/brax/carl_walker2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,7 @@ def get_context_features() -> dict[str, ContextFeature]:
"target_direction": CategoricalContextFeature(
"target_direction", choices=directions, default_value=1
),
"target_radius": UniformFloatContextFeature(
"target_radius", lower=0.1, upper=np.inf, default_value=5
),
}
8 changes: 2 additions & 6 deletions carl/envs/dmc/dmc_tasks/pointmass.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,9 +235,7 @@ def easy_pointmass(
xml_string, assets = get_model_and_assets()
xml_string = make_model(**context)
if context != {}:
xml_string = adapt_context(
xml_string=xml_string, context=context, context_mask=context_mask
)
xml_string = adapt_context(xml_string=xml_string, context=context)
physics = Physics.from_xml_string(xml_string, assets)
task = ContextualPointMass(randomize_gains=False, random=random)
environment_kwargs = environment_kwargs or {}
Expand All @@ -261,9 +259,7 @@ def hard_pointmass(
xml_string, assets = get_model_and_assets()
xml_string = make_model(**context)
if context != {}:
xml_string = adapt_context(
xml_string=xml_string, context=context, context_mask=context_mask
)
xml_string = adapt_context(xml_string=xml_string, context=context)
physics = Physics.from_xml_string(xml_string, assets)
task = ContextualPointMass(randomize_gains=True, random=random)
environment_kwargs = environment_kwargs or {}
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def read_file(filepath: str) -> str:
"brax": [
"brax==0.9.1",
"protobuf>=3.17.3",
"mujoco==3.0.1"
],
"dm_control": [
"dm_control>=1.0.3",
Expand Down
98 changes: 60 additions & 38 deletions test/test_language_goals.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,19 @@ def test_uniform_sampling(self):
)
contexts = context_sampler.sample_contexts(n_contexts=10)
assert len(contexts.keys()) == 10
assert "target_distance" in contexts[0].keys()
assert "target_direction" in contexts[0].keys()
assert all([contexts[i]["target_direction"] in DIRECTIONS for i in range(10)])
assert all([contexts[i]["target_distance"] <= 200 for i in range(10)])
assert all([contexts[i]["target_distance"] >= 4 for i in range(10)])
assert "target_distance" in contexts[0].keys(), "target_distance not in context"
assert (
"target_direction" in contexts[0].keys()
), "target_direction not in context"
assert all(
[contexts[i]["target_direction"] in DIRECTIONS for i in range(10)]
), "Not all directions are valid."
assert all(
[contexts[i]["target_distance"] <= 200 for i in range(10)]
), "Not all distances are valid (too large)."
assert all(
[contexts[i]["target_distance"] >= 4 for i in range(10)]
), "Not all distances are valid (too small)."

def test_normal_sampling(self):
context_distributions = [
Expand All @@ -61,12 +69,22 @@ def test_normal_sampling(self):
seed=0,
)
contexts = context_sampler.sample_contexts(n_contexts=10)
assert len(contexts.keys()) == 10
assert "target_distance" in contexts[0].keys()
assert "target_direction" in contexts[0].keys()
assert all([contexts[i]["target_direction"] in DIRECTIONS for i in range(10)])
assert all([contexts[i]["target_distance"] <= 200 for i in range(10)])
assert all([contexts[i]["target_distance"] >= 4 for i in range(10)])
assert (
len(contexts.keys()) == 10
), "Number of sampled contexts does not match the requested number."
assert "target_distance" in contexts[0].keys(), "target_distance not in context"
assert (
"target_direction" in contexts[0].keys()
), "target_direction not in context"
assert all(
[contexts[i]["target_direction"] in DIRECTIONS for i in range(10)]
), "Not all directions are valid."
assert all(
[contexts[i]["target_distance"] <= 200 for i in range(10)]
), "Not all distances are valid (too large)."
assert all(
[contexts[i]["target_distance"] >= 4 for i in range(10)]
), "Not all distances are valid (too small)."


class TestGoalWrapper(unittest.TestCase):
Expand All @@ -84,12 +102,12 @@ def test_reset(self):
env = CARLBraxAnt(contexts=contexts)

assert isinstance(env.env, BraxWalkerGoalWrapper)
assert env.position is None
assert env.position is None, "Position set before reset."

state, info = env.reset()
assert state is not None
assert info is not None
assert env.position is not None
assert state is not None, "No state returned."
assert info is not None, "No info returned."
assert env.position is not None, "Position not set."

context_distributions = [
NormalFloatContextFeature("target_distance", mu=9.8, sigma=1),
Expand All @@ -103,13 +121,13 @@ def test_reset(self):
contexts = context_sampler.sample_contexts(n_contexts=10)
env = CARLBraxHalfcheetah(contexts=contexts, use_language_goals=True)

assert isinstance(env.env, BraxLanguageWrapper)
assert env.position is None
assert isinstance(env.env, BraxLanguageWrapper), "Language wrapper not used."
assert env.position is None, "Position set before reset."

state, info = env.reset()
assert state is not None
assert info is not None
assert env.position is not None
assert state is not None, "No state returned."
assert info is not None, "No info returned."
assert env.position is not None, "Position not set."

def test_reward_scale(self):
context_distributions = [
Expand All @@ -129,7 +147,7 @@ def test_reward_scale(self):
for _ in range(10):
action = env.action_space.sample()
_, wrapped_reward, _, _, _ = env.step(action)
assert wrapped_reward >= 0
assert wrapped_reward >= 0, "Negative reward."

context_distributions = [
NormalFloatContextFeature("target_distance", mu=9.8, sigma=1),
Expand All @@ -148,7 +166,7 @@ def test_reward_scale(self):
for _ in range(10):
action = env.action_space.sample()
_, wrapped_reward, _, _, _ = env.step(action)
assert wrapped_reward >= 0
assert wrapped_reward >= 0, "Negative reward."


class TestLanguageWrapper(unittest.TestCase):
Expand All @@ -165,13 +183,15 @@ def test_reset(self) -> None:
contexts = context_sampler.sample_contexts(n_contexts=10)
env = CARLBraxAnt(contexts=contexts, use_language_goals=True)
state, info = env.reset()
assert type(state) is dict
assert "obs" in state.keys()
assert "goal" in state["obs"].keys()
assert type(state["obs"]["goal"]) is str
assert str(env.context["target_distance"]) in state["obs"]["goal"]
assert "north north east" in state["obs"]["goal"]
assert info is not None
assert type(state) is dict, "State is not a dictionary."
assert "obs" in state.keys(), "Observation not in state."
assert "goal" in state["obs"].keys(), "Goal not in observation."
assert type(state["obs"]["goal"]) is str, "Goal is not a string."
assert (
str(env.context["target_distance"]) in state["obs"]["goal"]
), "Distance not in goal."
assert "north north east" in state["obs"]["goal"], "Direction not in goal."
assert info is not None, "No info returned."

def test_step(self):
context_distributions = [
Expand All @@ -189,12 +209,14 @@ def test_step(self):
for _ in range(10):
action = env.action_space.sample()
state, _, _, _, _ = env.step(action)
assert type(state) is dict
assert "obs" in state.keys()
assert "goal" in state["obs"].keys()
assert type(state["obs"]["goal"]) is str
assert "north north east" in state["obs"]["goal"]
assert str(env.context["target_distance"]) in state["obs"]["goal"]
assert type(state) is dict, "State is not a dictionary."
assert "obs" in state.keys(), "Observation not in state."
assert "goal" in state["obs"].keys(), "Goal not in observation."
assert type(state["obs"]["goal"]) is str, "Goal is not a string."
assert "north north east" in state["obs"]["goal"], "Direction not in goal."
assert (
str(env.context["target_distance"]) in state["obs"]["goal"]
), "Distance not in goal."

context_distributions = [
NormalFloatContextFeature("target_distance", mu=9.8, sigma=1),
Expand All @@ -211,6 +233,6 @@ def test_step(self):
for _ in range(10):
action = env.action_space.sample()
state, _, _, _, _ = env.step(action)
assert type(state) is dict
assert "obs" in state.keys()
assert "goal" not in state.keys()
assert type(state) is dict, "State is not a dictionary."
assert "obs" in state.keys(), "Observation not in state."
assert "goal" not in state.keys(), "Goal in observation."

0 comments on commit a576d18

Please sign in to comment.