Skip to content

Commit

Permalink
(optionaly) Randomly initialize angle at initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
kngwyu committed Jun 6, 2024
1 parent f2e0a30 commit 92f59e4
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 4 deletions.
17 changes: 13 additions & 4 deletions src/emevo/environments/circle_foraging.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,7 @@ def __init__(
n_physics_iter: int = 5,
max_place_attempts: int = 10,
n_max_food_regen: int = 20,
random_angle: bool = True, # False when debugging/testing
# Only for CircleForagingWithSmell, but placed here to keep config class simple
smell_decay_factor: float = 0.01,
smell_diff_max: float = 1.0,
Expand Down Expand Up @@ -525,6 +526,7 @@ def __init__(
self._n_max_foods = n_max_foods
self._max_place_attempts = max_place_attempts
self._n_max_food_regen = n_max_food_regen
self._random_angle = random_angle
# Physics
if isinstance(obstacles, str):
obs_list = Obstacle(obstacles).as_list(self._x_range, self._y_range)
Expand Down Expand Up @@ -918,8 +920,11 @@ def activate(
# To use .at[].add, append (0, 0) to sampled xy
new_xy_with_sentinel = jnp.concatenate((new_xy, jnp.zeros((1, 2))))
xy = circle.p.xy.at[replaced_indices].add(new_xy_with_sentinel[parent_indices])
new_angle = jax.random.uniform(keys[1]) * jnp.pi * 2.0
angle = jnp.where(is_replaced, new_angle, circle.p.angle)
if self._random_angle:
new_angle = jax.random.uniform(keys[1]) * jnp.pi * 2.0
angle = jnp.where(is_replaced, new_angle, circle.p.angle)
else:
angle = jnp.where(is_replaced, 0.0, circle.p.angle)
p = Position(angle=angle, xy=xy)
is_active = jnp.logical_or(is_replaced, circle.is_active)
physics = replace(
Expand Down Expand Up @@ -1023,7 +1028,7 @@ def _initialize_physics_state(
"static_circle.p.xy",
jnp.ones_like(stated.static_circle.p.xy) * NOWHERE,
)
keys = jax.random.split(key, self._n_initial_agents + self._n_food_sources)
keys = jax.random.split(key, self._n_initial_agents + self._n_food_sources + 1)
agent_failed = 0
agentloc_state = self._initial_agentloc_state
for i, key in enumerate(keys[: self._n_initial_agents]):
Expand All @@ -1046,10 +1051,14 @@ def _initialize_physics_state(
if agent_failed > 0:
warnings.warn(f"Failed to place {agent_failed} agents!", stacklevel=1)

if self._random_angle:
angle = jax.random.uniform(key, shape=stated.circle.p.angle.shape)
stated = stated.nested_replace("circle.p.angle", angle)

food_failed = 0
foodloc_states = [s for s in self._initial_foodloc_states]
foodnum_states = [s for s in self._initial_foodnum_states]
for i, key in enumerate(keys[self._n_initial_agents :]):
for i, key in enumerate(keys[self._n_initial_agents + 1 :]):
n_initial = self._food_num_fns[i].initial
xy, ok = self._place_food_fns[i](
loc_state=foodloc_states[i],
Expand Down
1 change: 1 addition & 0 deletions tests/test_observe.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def reset_env(key: chex.PRNGKey) -> tuple[CircleForaging, CFState, TimeStep[CFOb
foodloc_interval=20,
agent_radius=AGENT_RADIUS,
food_radius=FOOD_RADIUS,
random_angle=False,
)
state, timestep = env.reset(key)
return typing.cast(CircleForaging, env), state, timestep
Expand Down

0 comments on commit 92f59e4

Please sign in to comment.