Skip to content

Commit abd1b1d

Browse files
viz example
1 parent 61d4e29 commit abd1b1d

File tree

3 files changed

+40
-17
lines changed

3 files changed

+40
-17
lines changed

jaxmarl/environments/jaxnav/README.md

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
# 🧭 JaxNav
22

3-
2D geometric navigation for differential drive robots. Using distances readings to nearby obstacles (mimicing LiDAR readings), the direction to their goal and their current velocity, robots must navigate to their goal without colliding with obstacles.
3+
2D geometric navigation for differential drive robots. Using distances readings to nearby obstacles (mimicing LiDAR readings), the direction to their goal and their current velocity, robots must navigate to their goal without colliding with obstacles.
44

55
## Environment Details
66

7+
JaxNav was first introduced in "No Regrets: Investigating and Improving Regret Approximations for Curriculum Discovery" with an in-detail specification given in the Appendix.
8+
79
### Map Types
810
The default map is square robots of width 0.5m moving within a world with grid based obstacled, with cells of size 1m x 1m. Map cell size can be varied to produce obstacles of higher fidelty or robot strucutre can be changed into any polygon or a circle.
911

@@ -19,13 +21,41 @@ The environments default action space is a 2D continuous action, where the first
1921
By default, the reward function contains a sparse outcome based reward alongside a dense shaping term.
2022

2123
## Visulisation
24+
Visualiser contained within `jaxnav_viz.py`, with an example below:
25+
26+
```python
27+
from jaxmarl.environments.jaxnav.jaxnav_env import JaxNav
28+
from jaxmarl.environments.jaxnav.jaxnav_viz import JaxNavVisualizer
29+
import jax
30+
31+
env = JaxNav(num_agents=4)
32+
33+
rng = jax.random.PRNGKey(0)
34+
rng, _rng = jax.random.split(rng)
35+
36+
obs, env_state = env.reset(_rng)
37+
38+
obs_list = [obs]
39+
env_state_list = [env_state]
40+
41+
for _ in range(10):
42+
rng, act_rng, step_rng = jax.random.split(rng, 3)
43+
act_rngs = jax.random.split(act_rng, env.num_agents)
44+
actions = {a: env.action_space(a).sample(act_rngs[i]) for i, a in enumerate(env.action_spaces.keys())}
45+
obs, env_state, _, _, _ = env.step(step_rng, env_state, actions)
46+
obs_list.append(obs)
47+
env_state_list.append(env_state)
48+
49+
viz = JaxNavVisualizer(env, obs_list, env_state_list)
50+
viz.animate("test.gif")
51+
```
2252

2353
## TODOs:
24-
- remove self.rad dependence
54+
- remove `self.rad` dependence for non circular agents
2555

2656
## Citation
2757
JaxNav was introduced by the following paper, if you use JaxNav in your work please cite it as:
2858

29-
'''bibtex
59+
```bibtex
3060
TODO
31-
'''
61+
```

jaxmarl/environments/jaxnav/jaxnav_env.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -426,15 +426,7 @@ def _check_map_collisions(self, pos: chex.Array, theta: chex.Array, map_data: ch
426426
@partial(vmap, in_axes=(None, 0, 0))
427427
def _check_goal_reached(self, pos: chex.Array, goal_pos: chex.Array) -> bool:
428428
return jnp.sqrt(jnp.sum((pos - goal_pos)**2)) <= self.goal_radius
429-
430-
@partial(vmap, in_axes=(None, 0, None, None))
431-
def _check_agent_collisions(self, agent_idx: int, agent_positions: chex.Array, dones: chex.Array) -> bool:
432-
# TODO this function is a little clunky FIX
433-
z = jnp.zeros(agent_positions.shape)
434-
z = z.at[agent_idx,:].set(jnp.ones(2)*self.rad*2.1)
435-
x = agent_positions + z
436-
return jnp.any(jnp.sqrt(jnp.sum((x - agent_positions[agent_idx,:])**2, axis=1)) <= self.rad*2)
437-
429+
438430
@partial(jax.jit, static_argnums=[0])
439431
def get_obs(self, state: State) -> chex.Array:
440432
obs_batch = self._get_obs(state)

jaxmarl/environments/jaxnav/jaxnav_viz.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@
44
import matplotlib.animation as animation
55
from typing import Optional, List
66

7-
from .jaxnav_env import JaxNav
7+
from .jaxnav_env import JaxNav, State
88
import jax.numpy as jnp
99

1010
class JaxNavVisualizer(object):
1111
def __init__(self,
1212
env: JaxNav,
1313
obs_seq: List,
14-
state_seq: List,
14+
state_seq: List[State],
1515
reward_seq: List=None,
1616
done_frames=None,
1717
title_text: str=None,
@@ -77,8 +77,9 @@ def update(self, frame):
7777
if self.plot_path:
7878
for a in range(self.env.num_agents):
7979
plot_frame = frame
80-
if self.done_frames[a] < frame:
81-
plot_frame = self.done_frames[a]
80+
if self.done_frames is not None:
81+
if (self.done_frames[a] < frame):
82+
plot_frame = self.done_frames[a]
8283
self.env.map_obj.plot_agent_path(self.ax, self.path_seq[:plot_frame, a, 0], self.path_seq[:plot_frame, a, 1])
8384
# self.ax.plot(self.path_seq[:frame, 0], self.path_seq[:frame, 1], color='b', linewidth=2.0, zorder=1)
8485
self.env.init_render(self.ax, self.state_seq[frame], self.obs_seq[frame], lidar=self.plot_lidar, agent=self.plot_agent)

0 commit comments

Comments
 (0)