Skip to content

Commit

Permalink
Merge pull request #113 from FLAIROx/jaxnav-agent_agent_collision_fix
Browse files Browse the repository at this point in the history
fix bug in jaxnav's agent agent collisions for single agent case
  • Loading branch information
amacrutherford authored Sep 2, 2024
2 parents ded3239 + 7bee570 commit 03058ae
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 6 deletions.
5 changes: 4 additions & 1 deletion jaxmarl/environments/jaxnav/jaxnav_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,10 @@ def step_env(
old_goal_reached = agent_states.goal_reached
old_move_term = agent_states.move_term
map_collisions = jax.vmap(self._map_obj.check_agent_map_collision, in_axes=(0, 0, None))(new_pos, new_theta, agent_states.map_data)*(1-agent_states.done).astype(bool)
agent_collisions = self.map_obj.check_all_agent_agent_collisions(new_pos, new_theta)*(1- agent_states.done).astype(bool)
if self.num_agents > 1:
agent_collisions = self.map_obj.check_all_agent_agent_collisions(new_pos, new_theta)*(1- agent_states.done).astype(bool)
else:
agent_collisions = jnp.zeros((self.num_agents,), dtype=jnp.bool_)
collisions = map_collisions | agent_collisions
goal_reached = (self._check_goal_reached(new_pos, agent_states.goal)*(1-agent_states.done)).astype(bool)
time_up = jnp.full((self.num_agents,), (step >= self.max_steps))
Expand Down
18 changes: 13 additions & 5 deletions tests/jaxnav/test_jaxnav_rand_acts.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,21 @@
TODO: replace this with proper unit tests.
"""
import jax
# import pytest
import pytest

from jaxmarl.environments.jaxnav import JaxNav

env = JaxNav(4)

def test_random_rollout():

@pytest.mark.parametrize(
("num_agents",),
[
(1,),
(4,),
(9,),
],
)
def test_random_rollout(num_agents: int):
env = JaxNav(num_agents=num_agents)
rng = jax.random.PRNGKey(0)
rng, rng_reset = jax.random.split(rng)

Expand All @@ -22,6 +29,7 @@ def test_random_rollout():
actions = {a: env.action_space(a).sample(rng_act[i]) for i, a in enumerate(env.agents)}
_, state, _, _, _ = env.step(rng, state, actions)

test_random_rollout()
test_random_rollout(1)
test_random_rollout(4)


0 comments on commit 03058ae

Please sign in to comment.