Skip to content

Commit

Permalink
Merge pull request #112 from FLAIROx/jaxnav-docs
Browse files Browse the repository at this point in the history
Jaxnav docs & improvements
  • Loading branch information
amacrutherford authored Aug 29, 2024
2 parents 29fbca7 + 17594ee commit ded3239
Show file tree
Hide file tree
Showing 11 changed files with 331 additions and 275 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ For more details, take a look at our [blog post](https://blog.foersterlab.com/ja
| 🎆 Hanabi | [Paper](https://arxiv.org/abs/1902.00506) | [Source](https://github.com/FLAIROx/JaxMARL/tree/main/jaxmarl/environments/hanabi) | Fully-cooperative partially-observable multiplayer card game |
| 👾 SMAX | Novel | [Source](https://github.com/FLAIROx/JaxMARL/tree/main/jaxmarl/environments/smax) | Simplified cooperative StarCraft micro-management environment |
| 🧮 STORM: Spatial-Temporal Representations of Matrix Games | [Paper](https://openreview.net/forum?id=54F8woU8vhq) | [Source](https://github.com/FLAIROx/JaxMARL/tree/main/jaxmarl/environments/storm) | Matrix games represented as grid world scenarios
| 🧭 JaxNav | Paper coming | [Source](https://github.com/FLAIROx/JaxMARL/tree/main/jaxmarl/environments/jaxnav) | 2D geometric navigation for differential drive robots
| 🧭 JaxNav | [Paper](https://www.arxiv.org/abs/2408.15099) | [Source](https://github.com/FLAIROx/JaxMARL/tree/main/jaxmarl/environments/jaxnav) | 2D geometric navigation for differential drive robots
| 🪙 Coin Game | [Paper](https://arxiv.org/abs/1802.09640) | [Source](https://github.com/FLAIROx/JaxMARL/tree/main/jaxmarl/environments/coin_game) | Two-player grid world environment which emulates social dilemmas
| 💡 Switch Riddle | [Paper](https://proceedings.neurips.cc/paper_files/paper/2016/hash/c7635bfd99248a2cdef8249ef7bfbef4-Abstract.html) | [Source](https://github.com/FLAIROx/JaxMARL/tree/main/jaxmarl/environments/switch_riddle) | Simple cooperative communication game included for debugging

Expand Down
70 changes: 68 additions & 2 deletions jaxmarl/environments/jaxnav/README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,71 @@
# 🧭 JaxNav

2D geometric navigation for differential drive robots. Using distances readings to nearby obstacles (mimicing LiDAR readings), the direction to their goal and their current velocity, robots must navigate to their goal without colliding with obstacles.
2D geometric navigation for differential drive robots. Using distances readings to nearby obstacles (mimicing LiDAR readings), the direction to their goal and their current velocity, robots must navigate to their goal without colliding with obstacles.

MORE TO COME.
## Environment Details

JaxNav was first introduced in ["No Regrets: Investigating and Improving Regret Approximations for Curriculum Discovery"](https://www.arxiv.org/abs/2408.15099) with an in-detail specification given in the Appendix.

### Map Types
The default map is square robots of width 0.5m moving within a world with grid based obstacled, with cells of size 1m x 1m. Map cell size can be varied to produce obstacles of higher fidelty or robot strucutre can be changed into any polygon or a circle.

We also include a map which uses polygon obstacles, but note we have not used this code in a while so there may well be issues with it.

### Observation space
By default, each robot receives 200 range readings from a 360-degree arc centered on their forward axis. These range readings have a max range of 6m but no minimum range and are discretised with a resultion of 0.05 m. Alongside these range readings, each robot receives their current linear and angular velocities along with the direction to their goal. Their goal direction is given by a vector in polar form where the distance is either the max lidar range if the goal is beyond their "line of sight" or the actual distance if the goal is within their lidar range. There is no communication between agents.

### Action Space
The environments default action space is a 2D continuous action, where the first dimension is the desired linear velocity and the second the desired angular velocity. Discrete actions are also supported, where the possible combination of linear and angular velocities are discretised into 15 options.

### Reward function
By default, the reward function contains a sparse outcome based reward alongside a dense shaping term.

## Visulisation
Visualiser contained within `jaxnav_viz.py`, with an example below:

```python
from jaxmarl.environments.jaxnav.jaxnav_env import JaxNav
from jaxmarl.environments.jaxnav.jaxnav_viz import JaxNavVisualizer
import jax

env = JaxNav(num_agents=4)

rng = jax.random.PRNGKey(0)
rng, _rng = jax.random.split(rng)

obs, env_state = env.reset(_rng)

obs_list = [obs]
env_state_list = [env_state]

for _ in range(10):
rng, act_rng, step_rng = jax.random.split(rng, 3)
act_rngs = jax.random.split(act_rng, env.num_agents)
actions = {a: env.action_space(a).sample(act_rngs[i]) for i, a in enumerate(env.action_spaces.keys())}
obs, env_state, _, _, _ = env.step(step_rng, env_state, actions)
obs_list.append(obs)
env_state_list.append(env_state)

viz = JaxNavVisualizer(env, obs_list, env_state_list)
viz.animate("test.gif")
```

## TODOs:
- remove `self.rad` dependence for non circular agents
- more unit tests
- add tests for non-square agents

## Citation
JaxNav was introduced by the following paper, if you use JaxNav in your work please cite it as:

```bibtex
@misc{rutherford2024noregrets,
title={No Regrets: Investigating and Improving Regret Approximations for Curriculum Discovery},
author={Alexander Rutherford and Michael Beukman and Timon Willi and Bruno Lacerda and Nick Hawes and Jakob Foerster},
year={2024},
eprint={2408.15099},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2408.15099},
}
```
3 changes: 2 additions & 1 deletion jaxmarl/environments/jaxnav/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .jaxnav_env import JaxNav
from .jaxnav_singletons import make_jaxnav_singleton, make_jaxnav_singleton_collection, JaxNavSingleton
from .jaxnav_singletons import make_jaxnav_singleton, make_jaxnav_singleton_collection, JaxNavSingleton
from .jaxnav_viz import JaxNavVisualizer
84 changes: 33 additions & 51 deletions jaxmarl/environments/jaxnav/jaxnav_env.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,20 @@
"""
Rob sim that follows the JaxMARL interface
2D robot navigation simulator that follows the JaxMARL interface
"""

import jax
import jax.numpy as jnp
from jax import random, jit, vmap
import numpy as np
from functools import partial
import chex
from flax import struct
from typing import Tuple, Dict
#from gymnax.environments import spaces
import os, pathlib
import matplotlib.pyplot as plt
import matplotlib.axes._axes as axes

from jaxmarl.environments import MultiAgentEnv
from jaxmarl.environments.spaces import Box, Discrete

from .maps import make_map, Map
from .jaxnav_utils import pol2cart, wrap, unitvec, cart2pol
import jaxmarl.environments.jaxnav.jaxnav_graph_utils as _graph_utils

from .jaxnav_utils import wrap, cart2pol

NUM_REWARD_COMPONENTS = 2
REWARD_COMPONENT_SPARSE = 0
Expand Down Expand Up @@ -98,15 +91,19 @@ def discrete_act_map(action: int) -> jnp.ndarray:

## ---- Environment ----
class JaxNav(MultiAgentEnv):
"""
Current assumptions:
- homogenous agents
"""

def __init__(self,
num_agents: int, # Number of agents
act_type="Continuous", # Action type, either Continuous or Discrete
normalise_obs=True,
rad=0.3, # Agent radius
evaporating=False, # Whether agents evaporate (dissapeare) when they reach the goal
map_id="Grid-Rand-Poly", # Map type
map_params=MAP_PARAMS, # Map parameters
rad=0.3, # Agent radius, TODO remove dependency on this
evaporating=False, # Whether agents evaporate (dissapeare) when they reach the goal
map_id="Grid-Rand-Poly", # Map type
map_params=MAP_PARAMS, # Map parameters
lidar_num_beams=200,
lidar_range_resolution=0.05,
lidar_max_range=6.0,
Expand Down Expand Up @@ -153,12 +150,11 @@ def __init__(self,
self.lidar_num_beams = lidar_num_beams
self.lidar_max_range = lidar_max_range
self.lidar_min_range = lidar_min_range
assert self.lidar_min_range == 0.0, "lidar_min_range must be 0.0 FOR NOW"
assert self.lidar_min_range == 0.0, "lidar_min_range must be 0.0" # TODO for now
self.lidar_range_resolution = lidar_range_resolution
self.lidar_angle_factor = lidar_angle_factor
self.lidar_max_angle = jnp.pi * self.lidar_angle_factor
self.lidar_angles = jnp.linspace(-jnp.pi * self.lidar_angle_factor, jnp.pi * self.lidar_angle_factor, self.lidar_num_beams)
#self.lidar_ranges = jnp.arange(self.lidar_min_range, self.lidar_max_range, self.lidar_range_resolution)
num_lidar_samples = int((self.lidar_max_range - self.lidar_min_range) / self.lidar_range_resolution)
self.lidar_ranges = jnp.linspace(self.lidar_min_range, self.lidar_max_range, num_lidar_samples)

Expand Down Expand Up @@ -232,8 +228,8 @@ def step_env(
# 2) Check collisions, goal and time
old_goal_reached = agent_states.goal_reached
old_move_term = agent_states.move_term
map_collisions = self._check_map_collisions(new_pos, new_theta, agent_states.map_data)*(1-agent_states.done).astype(bool)
agent_collisions = self._check_agent_collisions(jnp.arange(agent_states.pos.shape[0]), new_pos, agent_states.done)*(1- agent_states.done).astype(bool)
map_collisions = jax.vmap(self._map_obj.check_agent_map_collision, in_axes=(0, 0, None))(new_pos, new_theta, agent_states.map_data)*(1-agent_states.done).astype(bool)
agent_collisions = self.map_obj.check_all_agent_agent_collisions(new_pos, new_theta)*(1- agent_states.done).astype(bool)
collisions = map_collisions | agent_collisions
goal_reached = (self._check_goal_reached(new_pos, agent_states.goal)*(1-agent_states.done)).astype(bool)
time_up = jnp.full((self.num_agents,), (step >= self.max_steps))
Expand Down Expand Up @@ -293,7 +289,6 @@ def step_env(
rew_batch = self.rew_lambda * rew_individual + (1 - self.rew_lambda) * shared_rew

rew = {a: rew_batch[i] for i, a in enumerate(self.agents)}

obs = {a: obs_batch[i] for i, a in enumerate(self.agents)}

if self.evaporating:
Expand Down Expand Up @@ -325,10 +320,8 @@ def step_env(
"AgentC": agent_c,
"MapC": map_c,
"TimeO": time_o,
# reward
"Return": rew_info,
# whether action was valid
"terminated": term,
"Return": rew_info, # reward
"terminated": term, # whether action was valid
}
if self.do_sep_reward:
raise NotImplementedError("Separate reward not implemented")
Expand Down Expand Up @@ -419,22 +412,14 @@ def sample_lambda(self, key):
rew_lambda = jax.random.uniform(key, (1,), minval=self.lambda_range[0], maxval=self.lambda_range[1])
return rew_lambda

@partial(vmap, in_axes=(None, 0, 0, None))
@partial(jax.vmap, in_axes=(None, 0, 0, None))
def _check_map_collisions(self, pos: chex.Array, theta: chex.Array, map_data: chex.Array) -> bool:
return self._map_obj.check_agent_map_collision(pos, theta, map_data)

@partial(vmap, in_axes=(None, 0, 0))
@partial(jax.vmap, in_axes=(None, 0, 0))
def _check_goal_reached(self, pos: chex.Array, goal_pos: chex.Array) -> bool:
return jnp.sqrt(jnp.sum((pos - goal_pos)**2)) <= self.goal_radius

@partial(vmap, in_axes=(None, 0, None, None))
def _check_agent_collisions(self, agent_idx: int, agent_positions: chex.Array, dones: chex.Array) -> bool:
# TODO this function is a little clunky FIX
z = jnp.zeros(agent_positions.shape)
z = z.at[agent_idx,:].set(jnp.ones(2)*self.rad*2.1)
x = agent_positions + z
return jnp.any(jnp.sqrt(jnp.sum((x - agent_positions[agent_idx,:])**2, axis=1)) <= self.rad*2)


@partial(jax.jit, static_argnums=[0])
def get_obs(self, state: State) -> chex.Array:
obs_batch = self._get_obs(state)
Expand Down Expand Up @@ -485,7 +470,7 @@ def get_world_state(self, state: State) -> chex.Array:

return jnp.concatenate([agent_idx, concat, obs], axis=1)

@partial(vmap, in_axes=(None, 0, 0, 0, 0, 0))
@partial(jax.vmap, in_axes=(None, 0, 0, 0, 0, 0))
def update_state(self, pos: chex.Array, theta: float, speed: chex.Array, action: chex.Array, done: chex.Array) -> chex.Array:
""" Update agent's state, if `done` the current position and velocity are returned"""
if self.evaporating:
Expand Down Expand Up @@ -672,29 +657,26 @@ def get_env_metrics(self, state: State) -> dict:
# n_walls = state.map_data.sum() - state.map_data.shape[0]*2 - state.map_data.shape[1]*2 + 4
inside = state.map_data.astype(jnp.bool_)[1:-1, 1:-1]
n_walls = jnp.sum(inside)
passability = jax.vmap(
self.map_obj.passable_check,
in_axes=(0, 0, None)

passable, path_len = jax.vmap(
self.map_obj.dikstra_path,
in_axes=(None, 0, 0)
)(
state.map_data,
state.pos,
state.goal,
state.map_data,
)

# shortest_path_lengths = jax.vmap( # BUG in the minimax code somewhere
# _graph_util.shortest_path_len,
# in_axes=(None, 0, 0),
# )(
# inside.astype(jnp.bool_),
# jnp.floor(state.pos-1).astype(jnp.int32),
# jnp.floor(state.goal-1).astype(jnp.int32),
# )


shortest_path_lengths_stderr = jax.lax.select(
jnp.sum(passable) > 0,
jnp.std(path_len, where=passable)/jnp.sqrt(jnp.sum(passable)),
0.0
)
return dict(
n_walls=n_walls,
# shortest_path_length_mean=jnp.mean(shortest_path_lengths),
# shortest_path_lengths_stderr=jnp.std(shortest_path_lengths)/jnp.sqrt(self.num_agents),
passable=jnp.mean(passability),
shortest_path_length_mean=jnp.mean(path_len, where=passable),
shortest_path_lengths_stderr=shortest_path_lengths_stderr,
passable=jnp.mean(passable),
)

### === VISULISATION === ###
Expand Down
3 changes: 2 additions & 1 deletion jaxmarl/environments/jaxnav/jaxnav_graph_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""
""" Lifted from MiniMax
Copyright (c) Meta Platforms, Inc. and affiliates.
All rights reserved.
Expand Down
10 changes: 1 addition & 9 deletions jaxmarl/environments/jaxnav/jaxnav_singletons.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import jax
""" Hand crafted test cases for JaxNav """
import jax.numpy as jnp
import chex
from typing import NamedTuple, List, Tuple
Expand Down Expand Up @@ -708,11 +708,6 @@ def make_jaxnav_singleton_collection(collection_id: str, **env_kwargs) -> Tuple[
"NarrowChicane2b",
"Chicane4",
],
"new": [
"NarrowChicane2a",
"NarrowChicane2b",
"Chicane4",
],
"corridor": [
"BlankCross4",
"LongCorridor2",
Expand All @@ -722,8 +717,5 @@ def make_jaxnav_singleton_collection(collection_id: str, **env_kwargs) -> Tuple[
"just-long-corridor": [
"LongCorridor2",
],
"just-single2": [
"SingleNav2",
],
}

Loading

0 comments on commit ded3239

Please sign in to comment.