diff --git a/smoke-tests/circle_ppo.py b/smoke-tests/circle_ppo.py index ae579014..f9a72c87 100644 --- a/smoke-tests/circle_ppo.py +++ b/smoke-tests/circle_ppo.py @@ -230,6 +230,8 @@ def train( minibatch_size: int = 128, n_rollout_steps: int = 1024, n_total_steps: int = 1024 * 1000, + n_sensors: int = 16, + sensor_length: float = 100.0, food_loc_fn: str = "gaussian", env_shape: str = "circle", reset_interval: Optional[int] = None, @@ -259,6 +261,8 @@ def train( angular_damping=angular_damping, max_force=max_force, min_force=min_force, + n_agent_sensors=n_sensors, + sensor_length=sensor_length, ) network = run_training( jax.random.PRNGKey(seed),