Skip to content

Commit

Permalink
Modify circle_ppo to support aversive learning
Browse files Browse the repository at this point in the history
  • Loading branch information
kngwyu committed Nov 22, 2024
1 parent 69731dc commit d22c28d
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 8 deletions.
37 changes: 37 additions & 0 deletions config/env/20241122-small-poison.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
n_initial_agents = 20
n_max_agents = 40
n_max_foods = 40
food_num_fn = [
["linear", 10, 0.1, 20],
["linear", 10, 0.1, 20],
]
food_loc_fn = ["uniform", "uniform"]
agent_loc_fn = "uniform"
food_color = [[254, 2, 162, 255], [2, 254, 162, 255]]
food_energy_coef = [1.0, -0.5]
n_food_sources = 2
observe_food_label = true
xlim = [0.0, 360.0]
ylim = [0.0, 240.0]
env_radius = 120.0
env_shape = "square"
neighbor_stddev = 40.0
n_agent_sensors = 16
sensor_length = 100.0
sensor_range = "wide"
agent_radius = 10.0
food_radius = 4.0
foodloc_interval = 1000
dt = 0.1
linear_damping = 0.8
angular_damping = 0.6
max_force = 40.0
min_force = -20.0
init_energy = 40.0
energy_capacity = 100.0
force_energy_consumption = 0.00025
energy_share_ratio = 0.4
n_velocity_iter = 6
n_position_iter = 2
n_physics_iter = 5
max_place_attempts = 10
30 changes: 30 additions & 0 deletions config/env/20241122-small-uniform.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
n_initial_agents = 20
n_max_agents = 100
n_max_foods = 40
food_num_fn = ["logistic", 20, 0.01, 40]
food_loc_fn = "gaussian"
agent_loc_fn = "uniform"
xlim = [0.0, 360.0]
ylim = [0.0, 240.0]
env_radius = 120.0
env_shape = "square"
neighbor_stddev = 40.0
n_agent_sensors = 16
sensor_length = 100.0
sensor_range = "wide"
agent_radius = 10.0
food_radius = 4.0
foodloc_interval = 1000
dt = 0.1
linear_damping = 0.8
angular_damping = 0.6
max_force = 40.0
min_force = -20.0
init_energy = 40.0
energy_capacity = 100.0
force_energy_consumption = 0.00025
energy_share_ratio = 0.4
n_velocity_iter = 6
n_position_iter = 2
n_physics_iter = 5
max_place_attempts = 10
29 changes: 21 additions & 8 deletions smoke-tests/circle_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import dataclasses
from pathlib import Path
from typing import Optional

import chex
import equinox as eqx
Expand Down Expand Up @@ -80,10 +79,13 @@ def exec_rollout(
prng_key: jax.Array,
n_rollout_steps: int,
action_reward_coef: float,
poison_reward_coef: float,
) -> tuple[State, Rollout, Obs, jax.Array]:
max_norm = jnp.sqrt(jnp.sum(env.act_space.high**2, axis=-1))
food_coef = jnp.array([[1.0, poison_reward_coef]])

def normalize_action(action: jax.Array) -> jax.Array:
scaled = env.act_space.sigmoid_scale(action)
max_norm = jnp.sqrt(jnp.sum(env.act_space.high**2, axis=-1, keepdims=True))
norm = jnp.sqrt(jnp.sum(scaled**2, axis=-1, keepdims=True))
return norm / max_norm

Expand All @@ -96,8 +98,12 @@ def step_rollout(
net_out = vmap_apply(network, obs_t_array)
actions = net_out.policy().sample(seed=key)
state_t1, timestep = env.step(state_t, env.act_space.sigmoid_scale(actions))
food_rewards = obs_t.collision[:, 1].astype(jnp.float32).reshape(-1, 1)
rewards = food_rewards - action_reward_coef * normalize_action(actions)
n_ate = timestep.info["n_ate_food"]
act_rewards = action_reward_coef * normalize_action(actions)
if n_ate.shape[1] == 2:
rewards = jnp.sum(n_ate * food_coef, axis=1, keepdims=True) - act_rewards
else:
rewards = n_ate - act_rewards
rollout = Rollout(
observations=obs_t_array,
actions=actions,
Expand Down Expand Up @@ -134,6 +140,7 @@ def training_step(
n_optim_epochs: int,
reset: jax.Array,
action_reward_coef: float,
poison_reward_coef: float,
entropy_weight: float,
) -> tuple[State, Obs, jax.Array, optax.OptState, NormalPPONet]:
keys = jax.random.split(prng_key, N_MAX_AGENTS + 1)
Expand All @@ -145,6 +152,7 @@ def training_step(
keys[0],
n_rollout_steps,
action_reward_coef,
poison_reward_coef,
)
rollout = rollout.replace(terminations=rollout.terminations.at[-1].set(reset))
batch = vmap_batch(rollout, next_value, gamma, gae_lambda)
Expand Down Expand Up @@ -174,15 +182,16 @@ def run_training(
n_rollout_steps: int,
n_total_steps: int,
action_reward_coef: float,
poison_reward_coef: float,
entropy_weight: float,
figsize: tuple[float, float],
reset_interval: int | None = None,
debug_vis: bool = False,
) -> tuple[NormalPPONet, jax.Array]:
key, net_key, reset_key = jax.random.split(key, 3)
obs_space = env.obs_space.flatten()
input_size = np.prod(obs_space.shape)
act_size = np.prod(env.act_space.shape)
input_size = int(np.prod(obs_space.shape))
act_size = int(np.prod(env.act_space.shape))
pponet = vmap_net(
input_size,
64,
Expand Down Expand Up @@ -218,14 +227,16 @@ def run_training(
n_optim_epochs,
jnp.array(reset),
action_reward_coef,
poison_reward_coef,
entropy_weight,
)
ri = jnp.sum(jnp.squeeze(rewards_i, axis=-1), axis=0)
ri = np.array(jnp.sum(rewards_i, axis=0))
rewards = rewards + ri
if visualizer is not None:
visualizer.render(env_state.physics) # type: ignore
visualizer.show()
print(f"Rewards: {[x.item() for x in ri[: n_agents]]}")
for i in range(ri.shape[1]):
print(f"Rewards for {i + 1}: {ri[:n_agents, i]}")
if reset:
env_state, timestep = env.reset(key)
obs = timestep.obs
Expand All @@ -251,6 +262,7 @@ def train(
n_rollout_steps: int = 1024,
n_total_steps: int = 1024 * 1000,
action_reward_coef: float = 1e-3,
poison_reward_coef: float = -1.0,
entropy_weight: float = 1e-4,
cfconfig_path: Path = PROJECT_ROOT / "config/env/20231214-square.toml",
env_override: str = "",
Expand Down Expand Up @@ -282,6 +294,7 @@ def train(
n_rollout_steps=n_rollout_steps,
n_total_steps=n_total_steps,
action_reward_coef=action_reward_coef,
poison_reward_coef=poison_reward_coef,
entropy_weight=entropy_weight,
figsize=(xsize, ysize),
reset_interval=reset_interval,
Expand Down

0 comments on commit d22c28d

Please sign in to comment.