Skip to content

Commit

Permalink
Update smax_env.py
Browse files Browse the repository at this point in the history
Start positions are currently hard-coded to assume that the map_width and height is 32. If the map width is set to 128, the units will this start in the top-left corner.

Make team_0 and team_1 start y coordinates halfway up the map_height (instead of hard-coded 16). The x coordinate will be a fourth of the map_width, and three fourth of the map_width respectively.
  • Loading branch information
syrkis authored Jun 12, 2024
1 parent e37547c commit b302198
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions jaxmarl/environments/smax/smax_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,12 +262,12 @@ def _get_obs_size(self):
def reset(self, key: chex.PRNGKey) -> Tuple[Dict[str, chex.Array], State]:
"""Environment-specific reset."""
key, team_0_key, team_1_key = jax.random.split(key, num=3)
team_0_start = jnp.stack([jnp.array([8.0, 16.0])] * self.num_allies)
team_0_start = jnp.stack([jnp.array([self.map_width / 4, self.map_height / 2])] * self.num_allies)
team_0_start_noise = jax.random.uniform(
team_0_key, shape=(self.num_allies, 2), minval=-2, maxval=2
)
team_0_start = team_0_start + team_0_start_noise
team_1_start = jnp.stack([jnp.array([24.0, 16.0])] * self.num_enemies)
team_1_start = jnp.stack([jnp.array([self.map_width / 4 * 3, self.map_height / 2])] * self.num_enemies)
team_1_start_noise = jax.random.uniform(
team_1_key, shape=(self.num_enemies, 2), minval=-2, maxval=2
)
Expand Down

0 comments on commit b302198

Please sign in to comment.