From 7bee5705f5a85ce0f21e407ec11d11c976e90d37 Mon Sep 17 00:00:00 2001 From: alex Date: Mon, 2 Sep 2024 13:48:08 +0100 Subject: [PATCH] fix --- jaxmarl/environments/jaxnav/jaxnav_env.py | 5 ++++- tests/jaxnav/test_jaxnav_rand_acts.py | 18 +++++++++++++----- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/jaxmarl/environments/jaxnav/jaxnav_env.py b/jaxmarl/environments/jaxnav/jaxnav_env.py index 65e97f08..bbcb6232 100644 --- a/jaxmarl/environments/jaxnav/jaxnav_env.py +++ b/jaxmarl/environments/jaxnav/jaxnav_env.py @@ -229,7 +229,10 @@ def step_env( old_goal_reached = agent_states.goal_reached old_move_term = agent_states.move_term map_collisions = jax.vmap(self._map_obj.check_agent_map_collision, in_axes=(0, 0, None))(new_pos, new_theta, agent_states.map_data)*(1-agent_states.done).astype(bool) - agent_collisions = self.map_obj.check_all_agent_agent_collisions(new_pos, new_theta)*(1- agent_states.done).astype(bool) + if self.num_agents > 1: + agent_collisions = self.map_obj.check_all_agent_agent_collisions(new_pos, new_theta)*(1- agent_states.done).astype(bool) + else: + agent_collisions = jnp.zeros((self.num_agents,), dtype=jnp.bool_) collisions = map_collisions | agent_collisions goal_reached = (self._check_goal_reached(new_pos, agent_states.goal)*(1-agent_states.done)).astype(bool) time_up = jnp.full((self.num_agents,), (step >= self.max_steps)) diff --git a/tests/jaxnav/test_jaxnav_rand_acts.py b/tests/jaxnav/test_jaxnav_rand_acts.py index ce8f350b..eef16cf0 100644 --- a/tests/jaxnav/test_jaxnav_rand_acts.py +++ b/tests/jaxnav/test_jaxnav_rand_acts.py @@ -3,14 +3,21 @@ TODO: replace this with proper unit tests. """ import jax -# import pytest +import pytest from jaxmarl.environments.jaxnav import JaxNav -env = JaxNav(4) - -def test_random_rollout(): +@pytest.mark.parametrize( + ("num_agents",), + [ + (1,), + (4,), + (9,), + ], +) +def test_random_rollout(num_agents: int): + env = JaxNav(num_agents=num_agents) rng = jax.random.PRNGKey(0) rng, rng_reset = jax.random.split(rng) @@ -22,6 +29,7 @@ def test_random_rollout(): actions = {a: env.action_space(a).sample(rng_act[i]) for i, a in enumerate(env.agents)} _, state, _, _, _ = env.step(rng, state, actions) -test_random_rollout() +test_random_rollout(1) +test_random_rollout(4) \ No newline at end of file