Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix bug in jaxnav's agent agent collisions for single agent case #113

Merged
merged 1 commit into from
Sep 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion jaxmarl/environments/jaxnav/jaxnav_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,10 @@ def step_env(
old_goal_reached = agent_states.goal_reached
old_move_term = agent_states.move_term
map_collisions = jax.vmap(self._map_obj.check_agent_map_collision, in_axes=(0, 0, None))(new_pos, new_theta, agent_states.map_data)*(1-agent_states.done).astype(bool)
agent_collisions = self.map_obj.check_all_agent_agent_collisions(new_pos, new_theta)*(1- agent_states.done).astype(bool)
if self.num_agents > 1:
agent_collisions = self.map_obj.check_all_agent_agent_collisions(new_pos, new_theta)*(1- agent_states.done).astype(bool)
else:
agent_collisions = jnp.zeros((self.num_agents,), dtype=jnp.bool_)
collisions = map_collisions | agent_collisions
goal_reached = (self._check_goal_reached(new_pos, agent_states.goal)*(1-agent_states.done)).astype(bool)
time_up = jnp.full((self.num_agents,), (step >= self.max_steps))
Expand Down
18 changes: 13 additions & 5 deletions tests/jaxnav/test_jaxnav_rand_acts.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,21 @@
TODO: replace this with proper unit tests.
"""
import jax
# import pytest
import pytest

from jaxmarl.environments.jaxnav import JaxNav

env = JaxNav(4)

def test_random_rollout():

@pytest.mark.parametrize(
("num_agents",),
[
(1,),
(4,),
(9,),
],
)
def test_random_rollout(num_agents: int):
env = JaxNav(num_agents=num_agents)
rng = jax.random.PRNGKey(0)
rng, rng_reset = jax.random.split(rng)

Expand All @@ -22,6 +29,7 @@ def test_random_rollout():
actions = {a: env.action_space(a).sample(rng_act[i]) for i, a in enumerate(env.agents)}
_, state, _, _, _ = env.step(rng, state, actions)

test_random_rollout()
test_random_rollout(1)
test_random_rollout(4)


Loading