diff --git a/carl/envs/brax/brax_walker_goal_wrapper.py b/carl/envs/brax/brax_walker_goal_wrapper.py index ee912c5f..efdc9654 100644 --- a/carl/envs/brax/brax_walker_goal_wrapper.py +++ b/carl/envs/brax/brax_walker_goal_wrapper.py @@ -53,7 +53,7 @@ class BraxWalkerGoalWrapper(gym.Wrapper): """Adds a positional goal to brax walker envs""" - def __init__(self, env, env_name, asset_path) -> None: + def __init__(self, env: gym.Env, env_name: str, asset_path: str) -> None: super().__init__(env) self.env_name = env_name if ( @@ -66,6 +66,7 @@ def __init__(self, env, env_name, asset_path) -> None: self.context = None self.position = None self.goal_position = None + self.goal_radius = None self.direction_values = { 3: [0, -1], 1: [0, 1], @@ -116,6 +117,7 @@ def reset(self, seed=None, options={}): np.array(self.direction_values[self.context["target_direction"]]) * self.context["target_distance"] ) + self.goal_radius = self.context["target_radius"] info["success"] = 0 return state, info @@ -130,7 +132,7 @@ def step(self, action): previous_distance_to_goal = np.linalg.norm(self.goal_position - self.position) direction_reward = max(0, previous_distance_to_goal - current_distance_to_goal) self.position = new_position - if abs(current_distance_to_goal) <= 5: + if abs(current_distance_to_goal) <= self.goal_radius: te = True info["success"] = 1 else: @@ -168,8 +170,11 @@ def step(self, action): def get_goal_desc(self, context): if "target_radius" in context.keys(): target_distance = context["target_distance"] + target_direction = context["target_direction"] target_radius = context["target_radius"] - return f"The distance to the goal is {target_distance}m. Move within {target_radius} steps of the goal." + return f"""The distance to the goal is {target_distance}m + {DIRECTION_NAMES[target_direction]}. + Move within {target_radius} steps of the goal.""" else: target_distance = context["target_distance"] target_direction = context["target_direction"] diff --git a/carl/envs/brax/carl_ant.py b/carl/envs/brax/carl_ant.py index 68dca775..38711b43 100644 --- a/carl/envs/brax/carl_ant.py +++ b/carl/envs/brax/carl_ant.py @@ -43,4 +43,7 @@ def get_context_features() -> dict[str, ContextFeature]: "target_direction": CategoricalContextFeature( "target_direction", choices=directions, default_value=1 ), + "target_radius": UniformFloatContextFeature( + "target_radius", lower=0.1, upper=np.inf, default_value=5 + ), } diff --git a/carl/envs/brax/carl_brax_env.py b/carl/envs/brax/carl_brax_env.py index 8d970a5f..2c99d3d0 100644 --- a/carl/envs/brax/carl_brax_env.py +++ b/carl/envs/brax/carl_brax_env.py @@ -214,6 +214,18 @@ def __init__( "target_distance" in contexts[list(contexts.keys())[0]].keys() or "target_direction" in contexts[list(contexts.keys())[0]].keys() ): + assert all( + [ + "target_direction" in contexts[list(contexts.keys())[i]].keys() + for i in range(len(contexts)) + ] + ), "All contexts must have a 'target_direction' key" + assert all( + [ + "target_distance" in contexts[list(contexts.keys())[i]].keys() + for i in range(len(contexts)) + ] + ), "All contexts must have a 'target_distance' key" base_dir = contexts[list(contexts.keys())[0]]["target_direction"] base_dist = contexts[list(contexts.keys())[0]]["target_distance"] max_diff_dir = max( @@ -251,6 +263,7 @@ def _update_context(self) -> None: "elasticity", "target_distance", "target_direction", + "target_radius", ] check_context(context, registered_cfs) @@ -288,8 +301,6 @@ def reset( self._progress_instance() if self.context_id != last_context_id: self._update_context() - # if self.use_language_goals: - # self.env.env.context = self.context self.env.context = self.context state, info = self.env.reset(seed=seed, options=options) state = self._add_context_to_state(state) diff --git a/carl/envs/brax/carl_halfcheetah.py b/carl/envs/brax/carl_halfcheetah.py index a4c249a4..b5855cd6 100644 --- a/carl/envs/brax/carl_halfcheetah.py +++ b/carl/envs/brax/carl_halfcheetah.py @@ -61,4 +61,7 @@ def get_context_features() -> dict[str, ContextFeature]: "target_direction": CategoricalContextFeature( "target_direction", choices=directions, default_value=1 ), + "target_radius": UniformFloatContextFeature( + "target_radius", lower=0.1, upper=np.inf, default_value=5 + ), } diff --git a/carl/envs/brax/carl_hopper.py b/carl/envs/brax/carl_hopper.py index 759e08db..d9cdaf1a 100644 --- a/carl/envs/brax/carl_hopper.py +++ b/carl/envs/brax/carl_hopper.py @@ -52,4 +52,7 @@ def get_context_features() -> dict[str, ContextFeature]: "target_direction": CategoricalContextFeature( "target_direction", choices=directions, default_value=1 ), + "target_radius": UniformFloatContextFeature( + "target_radius", lower=0.1, upper=np.inf, default_value=5 + ), } diff --git a/carl/envs/brax/carl_humanoid.py b/carl/envs/brax/carl_humanoid.py index 763c5a08..ad4af4cf 100644 --- a/carl/envs/brax/carl_humanoid.py +++ b/carl/envs/brax/carl_humanoid.py @@ -79,4 +79,7 @@ def get_context_features() -> dict[str, ContextFeature]: "target_direction": CategoricalContextFeature( "target_direction", choices=directions, default_value=1 ), + "target_radius": UniformFloatContextFeature( + "target_radius", lower=0.1, upper=np.inf, default_value=5 + ), } diff --git a/carl/envs/brax/carl_pusher.py b/carl/envs/brax/carl_pusher.py index 2c5ec32c..19cdec86 100644 --- a/carl/envs/brax/carl_pusher.py +++ b/carl/envs/brax/carl_pusher.py @@ -1,5 +1,7 @@ from __future__ import annotations +from copy import deepcopy + import numpy as np from carl.context.context_space import ContextFeature, UniformFloatContextFeature @@ -89,8 +91,13 @@ def get_context_features() -> dict[str, ContextFeature]: } def _update_context(self) -> None: - super()._update_context() goal_x = self.context["goal_position_x"] goal_y = self.context["goal_position_y"] goal_z = self.context["goal_position_z"] + context = deepcopy(self.context) + del self.context["goal_position_x"] + del self.context["goal_position_y"] + del self.context["goal_position_z"] + super()._update_context() self.env._goal_pos = np.array([goal_x, goal_y, goal_z]) + self.context = context diff --git a/carl/envs/brax/carl_walker2d.py b/carl/envs/brax/carl_walker2d.py index 6b94b998..db08dbe2 100644 --- a/carl/envs/brax/carl_walker2d.py +++ b/carl/envs/brax/carl_walker2d.py @@ -61,4 +61,7 @@ def get_context_features() -> dict[str, ContextFeature]: "target_direction": CategoricalContextFeature( "target_direction", choices=directions, default_value=1 ), + "target_radius": UniformFloatContextFeature( + "target_radius", lower=0.1, upper=np.inf, default_value=5 + ), } diff --git a/carl/envs/dmc/dmc_tasks/pointmass.py b/carl/envs/dmc/dmc_tasks/pointmass.py index 2c076293..1eb80db5 100644 --- a/carl/envs/dmc/dmc_tasks/pointmass.py +++ b/carl/envs/dmc/dmc_tasks/pointmass.py @@ -235,9 +235,7 @@ def easy_pointmass( xml_string, assets = get_model_and_assets() xml_string = make_model(**context) if context != {}: - xml_string = adapt_context( - xml_string=xml_string, context=context, context_mask=context_mask - ) + xml_string = adapt_context(xml_string=xml_string, context=context) physics = Physics.from_xml_string(xml_string, assets) task = ContextualPointMass(randomize_gains=False, random=random) environment_kwargs = environment_kwargs or {} @@ -261,9 +259,7 @@ def hard_pointmass( xml_string, assets = get_model_and_assets() xml_string = make_model(**context) if context != {}: - xml_string = adapt_context( - xml_string=xml_string, context=context, context_mask=context_mask - ) + xml_string = adapt_context(xml_string=xml_string, context=context) physics = Physics.from_xml_string(xml_string, assets) task = ContextualPointMass(randomize_gains=True, random=random) environment_kwargs = environment_kwargs or {} diff --git a/setup.py b/setup.py index 1a84d934..d19673cc 100644 --- a/setup.py +++ b/setup.py @@ -27,6 +27,7 @@ def read_file(filepath: str) -> str: "brax": [ "brax==0.9.1", "protobuf>=3.17.3", + "mujoco==3.0.1" ], "dm_control": [ "dm_control>=1.0.3", diff --git a/test/test_language_goals.py b/test/test_language_goals.py index d22a2382..0ad5d09b 100644 --- a/test/test_language_goals.py +++ b/test/test_language_goals.py @@ -44,11 +44,19 @@ def test_uniform_sampling(self): ) contexts = context_sampler.sample_contexts(n_contexts=10) assert len(contexts.keys()) == 10 - assert "target_distance" in contexts[0].keys() - assert "target_direction" in contexts[0].keys() - assert all([contexts[i]["target_direction"] in DIRECTIONS for i in range(10)]) - assert all([contexts[i]["target_distance"] <= 200 for i in range(10)]) - assert all([contexts[i]["target_distance"] >= 4 for i in range(10)]) + assert "target_distance" in contexts[0].keys(), "target_distance not in context" + assert ( + "target_direction" in contexts[0].keys() + ), "target_direction not in context" + assert all( + [contexts[i]["target_direction"] in DIRECTIONS for i in range(10)] + ), "Not all directions are valid." + assert all( + [contexts[i]["target_distance"] <= 200 for i in range(10)] + ), "Not all distances are valid (too large)." + assert all( + [contexts[i]["target_distance"] >= 4 for i in range(10)] + ), "Not all distances are valid (too small)." def test_normal_sampling(self): context_distributions = [ @@ -61,12 +69,22 @@ def test_normal_sampling(self): seed=0, ) contexts = context_sampler.sample_contexts(n_contexts=10) - assert len(contexts.keys()) == 10 - assert "target_distance" in contexts[0].keys() - assert "target_direction" in contexts[0].keys() - assert all([contexts[i]["target_direction"] in DIRECTIONS for i in range(10)]) - assert all([contexts[i]["target_distance"] <= 200 for i in range(10)]) - assert all([contexts[i]["target_distance"] >= 4 for i in range(10)]) + assert ( + len(contexts.keys()) == 10 + ), "Number of sampled contexts does not match the requested number." + assert "target_distance" in contexts[0].keys(), "target_distance not in context" + assert ( + "target_direction" in contexts[0].keys() + ), "target_direction not in context" + assert all( + [contexts[i]["target_direction"] in DIRECTIONS for i in range(10)] + ), "Not all directions are valid." + assert all( + [contexts[i]["target_distance"] <= 200 for i in range(10)] + ), "Not all distances are valid (too large)." + assert all( + [contexts[i]["target_distance"] >= 4 for i in range(10)] + ), "Not all distances are valid (too small)." class TestGoalWrapper(unittest.TestCase): @@ -84,12 +102,12 @@ def test_reset(self): env = CARLBraxAnt(contexts=contexts) assert isinstance(env.env, BraxWalkerGoalWrapper) - assert env.position is None + assert env.position is None, "Position set before reset." state, info = env.reset() - assert state is not None - assert info is not None - assert env.position is not None + assert state is not None, "No state returned." + assert info is not None, "No info returned." + assert env.position is not None, "Position not set." context_distributions = [ NormalFloatContextFeature("target_distance", mu=9.8, sigma=1), @@ -103,13 +121,13 @@ def test_reset(self): contexts = context_sampler.sample_contexts(n_contexts=10) env = CARLBraxHalfcheetah(contexts=contexts, use_language_goals=True) - assert isinstance(env.env, BraxLanguageWrapper) - assert env.position is None + assert isinstance(env.env, BraxLanguageWrapper), "Language wrapper not used." + assert env.position is None, "Position set before reset." state, info = env.reset() - assert state is not None - assert info is not None - assert env.position is not None + assert state is not None, "No state returned." + assert info is not None, "No info returned." + assert env.position is not None, "Position not set." def test_reward_scale(self): context_distributions = [ @@ -129,7 +147,7 @@ def test_reward_scale(self): for _ in range(10): action = env.action_space.sample() _, wrapped_reward, _, _, _ = env.step(action) - assert wrapped_reward >= 0 + assert wrapped_reward >= 0, "Negative reward." context_distributions = [ NormalFloatContextFeature("target_distance", mu=9.8, sigma=1), @@ -148,7 +166,7 @@ def test_reward_scale(self): for _ in range(10): action = env.action_space.sample() _, wrapped_reward, _, _, _ = env.step(action) - assert wrapped_reward >= 0 + assert wrapped_reward >= 0, "Negative reward." class TestLanguageWrapper(unittest.TestCase): @@ -165,13 +183,15 @@ def test_reset(self) -> None: contexts = context_sampler.sample_contexts(n_contexts=10) env = CARLBraxAnt(contexts=contexts, use_language_goals=True) state, info = env.reset() - assert type(state) is dict - assert "obs" in state.keys() - assert "goal" in state["obs"].keys() - assert type(state["obs"]["goal"]) is str - assert str(env.context["target_distance"]) in state["obs"]["goal"] - assert "north north east" in state["obs"]["goal"] - assert info is not None + assert type(state) is dict, "State is not a dictionary." + assert "obs" in state.keys(), "Observation not in state." + assert "goal" in state["obs"].keys(), "Goal not in observation." + assert type(state["obs"]["goal"]) is str, "Goal is not a string." + assert ( + str(env.context["target_distance"]) in state["obs"]["goal"] + ), "Distance not in goal." + assert "north north east" in state["obs"]["goal"], "Direction not in goal." + assert info is not None, "No info returned." def test_step(self): context_distributions = [ @@ -189,12 +209,14 @@ def test_step(self): for _ in range(10): action = env.action_space.sample() state, _, _, _, _ = env.step(action) - assert type(state) is dict - assert "obs" in state.keys() - assert "goal" in state["obs"].keys() - assert type(state["obs"]["goal"]) is str - assert "north north east" in state["obs"]["goal"] - assert str(env.context["target_distance"]) in state["obs"]["goal"] + assert type(state) is dict, "State is not a dictionary." + assert "obs" in state.keys(), "Observation not in state." + assert "goal" in state["obs"].keys(), "Goal not in observation." + assert type(state["obs"]["goal"]) is str, "Goal is not a string." + assert "north north east" in state["obs"]["goal"], "Direction not in goal." + assert ( + str(env.context["target_distance"]) in state["obs"]["goal"] + ), "Distance not in goal." context_distributions = [ NormalFloatContextFeature("target_distance", mu=9.8, sigma=1), @@ -211,6 +233,6 @@ def test_step(self): for _ in range(10): action = env.action_space.sample() state, _, _, _, _ = env.step(action) - assert type(state) is dict - assert "obs" in state.keys() - assert "goal" not in state.keys() + assert type(state) is dict, "State is not a dictionary." + assert "obs" in state.keys(), "Observation not in state." + assert "goal" not in state.keys(), "Goal in observation."