Skip to content

Commit

Permalink
Merge pull request #52 from FLAIROx/qlearning
Browse files Browse the repository at this point in the history
better use of flashbax - solve smax typo in qlearning scritps - use logwrappers to log info
  • Loading branch information
amacrutherford authored Dec 8, 2023
2 parents eb83d25 + 3ddb2f9 commit 5b9bd55
Show file tree
Hide file tree
Showing 19 changed files with 340 additions and 176 deletions.
7 changes: 5 additions & 2 deletions baselines/QLearning/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ pip install -r requirements/requirements-qlearning.txt
❗The implementations were tested in the following environments:
- MPE
- SMAX
- Hanabi
```

## 🔎 Implementation Details
Expand Down Expand Up @@ -57,9 +58,11 @@ If you have cloned JaxMARL and you are in the repository root, you can run the a
python baselines/QLearning/iql.py +alg=iql_mpe +env=mpe_speaker_listener
# VDN with MPE spread
python baselines/QLearning/vdn.py +alg=vdn_mpe +env=mpe_spread
# QMIX with SMAX
# QMix with SMAX
python baselines/QLearning/qmix.py +alg=qmix_smax +env=smax
# QMIX against pretrained agents
# QMix with hanabi
python baselines/QLearning/qmix.py +alg=qmix_hanabi +env=hanabi
# QMix against pretrained agents
python baselines/QLearning/qmix_pretrained.py +alg=qmix_mpe +env=mpe_tag_pretrained
```

Expand Down
5 changes: 2 additions & 3 deletions baselines/QLearning/config/alg/iql_mpe.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"ENV_NAME": "MPE_simple_spread_v3"
"NUM_ENVS": 8
"NUM_STEPS": 25
"BUFFER_SIZE": 5000
"BUFFER_BATCH_SIZE": 32
"TOTAL_TIMESTEPS": 2050000
Expand All @@ -20,5 +20,4 @@
"VERBOSE": False
"WANDB_ONLINE_REPORT": True
"NUM_TEST_EPISODES": 32
"TEST_INTERVAL": 50000
"ALG_NAME": iql
"TEST_INTERVAL": 50000
2 changes: 1 addition & 1 deletion baselines/QLearning/config/alg/iql_smax.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"NUM_ENVS": 8
"NUM_STEPS": 128
"BUFFER_SIZE": 3000
"BUFFER_BATCH_SIZE": 32
"TOTAL_TIMESTEPS": 20000000
Expand All @@ -20,4 +21,3 @@
"WANDB_ONLINE_REPORT": True
"NUM_TEST_EPISODES": 32
"TEST_INTERVAL": 100000
"ALG_NAME": iql
27 changes: 27 additions & 0 deletions baselines/QLearning/config/alg/qlearn_hanabi.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"NUM_ENVS": 8
"NUM_STEPS": 21
"BUFFER_SIZE": 3000
"BUFFER_BATCH_SIZE": 32
"TOTAL_TIMESTEPS": 2e8
"AGENT_HIDDEN_DIM": 64
"AGENT_INIT_SCALE": 2.
"PARAMETERS_SHARING": True
"EPSILON_START": 1.0
"EPSILON_FINISH": 0.05
"EPSILON_ANNEAL_TIME": 100000
"MIXER_EMBEDDING_DIM": 32
"MIXER_HYPERNET_HIDDEN_DIM": 64
"MIXER_INIT_SCALE": 0.00001
"MAX_GRAD_NORM": 10
"TARGET_UPDATE_INTERVAL": 200
"LR": 0.005
"LR_LINEAR_DECAY": True
"EPS_ADAM": 0.001
"WEIGHT_DECAY_ADAM": 0.00001
"TD_LAMBDA_LOSS": True
"TD_LAMBDA": 0.6
"GAMMA": 0.99
"VERBOSE": False
"WANDB_ONLINE_REPORT": True
"NUM_TEST_EPISODES": 32
"TEST_INTERVAL": 50000
27 changes: 27 additions & 0 deletions baselines/QLearning/config/alg/qlearn_overcooked.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"NUM_ENVS": 32
"NUM_STEPS": 256
"BUFFER_SIZE": 1024
"BUFFER_BATCH_SIZE": 32
"TOTAL_TIMESTEPS": 10e7
"AGENT_HIDDEN_DIM": 64
"AGENT_INIT_SCALE": 2.
"PARAMETERS_SHARING": True
"EPSILON_START": 1.0
"EPSILON_FINISH": 0.1
"EPSILON_ANNEAL_TIME": 1e6
"MIXER_EMBEDDING_DIM": 32
"MIXER_HYPERNET_HIDDEN_DIM": 64
"MIXER_INIT_SCALE": 0.0001
"MAX_GRAD_NORM": 10
"TARGET_UPDATE_INTERVAL": 200
"LR": 0.0001
"LR_LINEAR_DECAY": False
"EPS_ADAM": 0.00001
"WEIGHT_DECAY_ADAM": 0.00001
"TD_LAMBDA_LOSS": False
"TD_LAMBDA": 0.6
"GAMMA": 0.99
"VERBOSE": False
"WANDB_ONLINE_REPORT": True
"NUM_TEST_EPISODES": 32
"TEST_INTERVAL": 50000
4 changes: 2 additions & 2 deletions baselines/QLearning/config/alg/qmix_mpe.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"NUM_ENVS": 8
"NUM_STEPS": 25
"BUFFER_SIZE": 5000
"BUFFER_BATCH_SIZE": 32
"TOTAL_TIMESTEPS": 2050000
Expand All @@ -23,5 +24,4 @@
"VERBOSE": False
"WANDB_ONLINE_REPORT": True
"NUM_TEST_EPISODES": 32
"TEST_INTERVAL": 50000
"ALG_NAME": qmix
"TEST_INTERVAL": 50000
4 changes: 2 additions & 2 deletions baselines/QLearning/config/alg/qmix_smax.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"NUM_ENVS": 8
"NUM_STEPS": 128
"BUFFER_SIZE": 3000
"BUFFER_BATCH_SIZE": 32
"TOTAL_TIMESTEPS": 20000000
Expand All @@ -23,5 +24,4 @@
"VERBOSE": False
"WANDB_ONLINE_REPORT": True
"NUM_TEST_EPISODES": 32
"TEST_INTERVAL": 100000
"ALG_NAME": qmix
"TEST_INTERVAL": 100000
4 changes: 2 additions & 2 deletions baselines/QLearning/config/alg/vdn_mpe.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"NUM_ENVS": 8
"NUM_STEPS": 25
"BUFFER_SIZE": 5000
"BUFFER_BATCH_SIZE": 32
"TOTAL_TIMESTEPS": 2050000
Expand All @@ -19,5 +20,4 @@
"VERBOSE": False
"WANDB_ONLINE_REPORT": True
"NUM_TEST_EPISODES": 32
"TEST_INTERVAL": 50000
"ALG_NAME": vdn
"TEST_INTERVAL": 50000
4 changes: 2 additions & 2 deletions baselines/QLearning/config/alg/vdn_smax.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"NUM_ENVS": 8
"NUM_STEPS": 128
"BUFFER_SIZE": 3000
"BUFFER_BATCH_SIZE": 32
"TOTAL_TIMESTEPS": 20000000
Expand All @@ -19,5 +20,4 @@
"VERBOSE": False
"WANDB_ONLINE_REPORT": True
"NUM_TEST_EPISODES": 32
"TEST_INTERVAL": 100000
"ALG_NAME": vdn
"TEST_INTERVAL": 100000
2 changes: 1 addition & 1 deletion baselines/QLearning/config/config.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# experiment params
"NUM_SEEDS": 10
"NUM_SEEDS": 2
"SEED": 30

# wandb params
Expand Down
2 changes: 2 additions & 0 deletions baselines/QLearning/config/env/hanabi.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
"ENV_NAME": "hanabi"
"ENV_KWARGS": {}
3 changes: 3 additions & 0 deletions baselines/QLearning/config/env/overcooked.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"ENV_NAME": "overcooked"
"ENV_KWARGS":
"layout" : "cramped_room"
2 changes: 1 addition & 1 deletion baselines/QLearning/config/env/smax.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"ENV_NAME": "HeuristicEnemySMAX"
"MAP_NAME": "smacv2_5_units"
"MAP_NAME": "5m_vs_6m"
"ENV_KWARGS":
"see_enemy_actions": True
"walls_cause_death": True
Expand Down
85 changes: 55 additions & 30 deletions baselines/QLearning/iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,9 @@
from flax.traverse_util import flatten_dict

from jaxmarl import make
from jaxmarl.wrappers.baselines import CTRolloutManager
from jaxmarl.wrappers.baselines import LogWrapper, SMAXLogWrapper, CTRolloutManager
from jaxmarl.environments.smax import map_name_to_scenario
from jaxmarl.environments.overcooked import overcooked_layouts


class ScannedRNN(nn.Module):
Expand Down Expand Up @@ -107,6 +108,7 @@ class Transition(NamedTuple):
actions: dict
rewards: dict
dones: dict
infos: dict


class AgentRNN(nn.Module):
Expand Down Expand Up @@ -151,20 +153,21 @@ def _env_sample_step(env_state, unused):
key_a = jax.random.split(key_a, env.num_agents)
actions = {agent: wrapped_env.batch_sample(key_a[i], agent) for i, agent in enumerate(env.agents)}
obs, env_state, rewards, dones, infos = wrapped_env.batch_step(key_s, env_state, actions)
transition = Transition(obs, actions, rewards, dones)
transition = Transition(obs, actions, rewards, dones, infos)
return env_state, transition
_, sample_traj = jax.lax.scan(
_env_sample_step, env_state, None, config["NUM_STEPS"]
)
sample_traj_unbatched = jax.tree_map(lambda x: x[:, 0], sample_traj) # remove the NUM_ENV dim
buffer = fbx.make_flat_buffer(
max_length=config['BUFFER_SIZE'],
min_length=config['BUFFER_BATCH_SIZE'],
buffer = fbx.make_trajectory_buffer(
max_length_time_axis=config['BUFFER_SIZE']//config['NUM_ENVS'],
min_length_time_axis=config['BUFFER_BATCH_SIZE'],
sample_batch_size=config['BUFFER_BATCH_SIZE'],
add_sequences=True,
add_batch_size=None,
add_batch_size=config['NUM_ENVS'],
sample_sequence_length=1,
period=1,
)
buffer_state = buffer.init(sample_traj_unbatched)
buffer_state = buffer.init(sample_traj_unbatched)

# INIT NETWORK
agent = AgentRNN(action_dim=wrapped_env.max_action_space, hidden_dim=config["AGENT_HIDDEN_DIM"], init_scale=config['AGENT_INIT_SCALE'])
Expand Down Expand Up @@ -264,7 +267,7 @@ def _env_step(step_state, unused):

# STEP ENV
obs, env_state, rewards, dones, infos = wrapped_env.batch_step(key_s, env_state, actions)
transition = Transition(last_obs, actions, rewards, dones)
transition = Transition(last_obs, actions, rewards, dones, infos)

step_state = (params, env_state, obs, dones, hstate, rng, t+1)
return step_state, transition
Expand Down Expand Up @@ -292,7 +295,10 @@ def _env_step(step_state, unused):
)

# BUFFER UPDATE: save the collected trajectory in the buffer
buffer_traj_batch = jax.tree_util.tree_map(lambda x:jnp.swapaxes(x, 0, 1), traj_batch) # put the batch size (num envs) in first axis
buffer_traj_batch = jax.tree_util.tree_map(
lambda x:jnp.swapaxes(x, 0, 1)[:, np.newaxis], # put the batch dim first and add a dummy sequence dim
traj_batch
) # (num_envs, 1, time_steps, ...)
buffer_state = buffer.add(buffer_state, buffer_traj_batch)

# LEARN PHASE
Expand Down Expand Up @@ -368,8 +374,11 @@ def _loss_fn(params, target_agent_params, init_hs, learn_traj):

# sample a batched trajectory from the buffer and set the time step dim in first axis
rng, _rng = jax.random.split(rng)
learn_traj = buffer.sample(buffer_state, _rng).experience.first # (batch_size, max_time_steps, ...)
learn_traj = jax.tree_map(lambda x: jnp.swapaxes(x, 0, 1), learn_traj) # (max_time_steps, batch_size, ...)
learn_traj = buffer.sample(buffer_state, _rng).experience # (batch_size, 1, max_time_steps, ...)
learn_traj = jax.tree_map(
lambda x: jnp.swapaxes(x[:, 0], 0, 1), # remove the dummy sequence dim (1) and swap batch and temporal dims
learn_traj
) # (max_time_steps, batch_size, ...)

# for iql the loss must be computed differently with or without parameters sharing
if config.get('PARAMETERS_SHARING', True):
Expand Down Expand Up @@ -441,20 +450,25 @@ def _loss_fn(params, target_params, init_hs, obs, dones, actions, valid_actions,
'loss': loss,
'rewards': jax.tree_util.tree_map(lambda x: jnp.sum(x, axis=0).mean(), traj_batch.rewards),
}
metrics.update(test_metrics) # add the test metrics dictionary
metrics['test_metrics'] = test_metrics # add the test metrics dictionary

if config.get('WANDB_ONLINE_REPORT', False):
def callback(metrics):
def callback(metrics, infos):
info_metrics = {
k:v[...,0][infos["returned_episode"][..., 0]].mean()
for k,v in infos.items() if k!="returned_episode"
}
wandb.log(
{
"returns": metrics['rewards']['__all__'].mean(),
"test_returns": metrics['test_returns']['__all__'].mean(),
"timestep": metrics['timesteps'],
"updates": metrics['updates'],
"loss": metrics['loss'],
**{'return_'+k:v.mean() for k, v in metrics['rewards'].items()},
**info_metrics,
**{k:v.mean() for k, v in metrics['test_metrics'].items()}
}
)
jax.debug.callback(callback, metrics)
jax.debug.callback(callback, metrics, traj_batch.infos)

runner_state = (
train_state,
Expand All @@ -479,10 +493,10 @@ def _greedy_env_step(step_state, unused):
obs_ = jax.tree_map(lambda x: x[np.newaxis, :], obs_)
dones_ = jax.tree_map(lambda x: x[np.newaxis, :], last_dones)
hstate, q_vals = homogeneous_pass(params, hstate, obs_, dones_)
actions = jax.tree_util.tree_map(lambda q, valid_idx: jnp.argmax(q.squeeze(0)[..., valid_idx], axis=-1), q_vals, wrapped_env.valid_actions)
actions = jax.tree_util.tree_map(lambda q, valid_idx: jnp.argmax(q.squeeze(0)[..., valid_idx], axis=-1), q_vals, test_env.valid_actions)
obs, env_state, rewards, dones, infos = test_env.batch_step(key_s, env_state, actions)
step_state = (params, env_state, obs, dones, hstate, rng)
return step_state, (rewards, dones)
return step_state, (rewards, dones, infos)
rng, _rng = jax.random.split(rng)
init_obs, env_state = test_env.batch_reset(_rng)
init_dones = {agent:jnp.zeros((config["NUM_TEST_EPISODES"]), dtype=bool) for agent in env.agents+['__all__']}
Expand All @@ -499,23 +513,25 @@ def _greedy_env_step(step_state, unused):
hstate,
_rng,
)
step_state, rews_dones = jax.lax.scan(
step_state, (rewards, dones, infos) = jax.lax.scan(
_greedy_env_step, step_state, None, config["NUM_STEPS"]
)
# compute the episode returns of the first episode that is done for each parallel env
# compute the metrics of the first episode that is done for each parallel env
def first_episode_returns(rewards, dones):
first_done = jax.lax.select(jnp.argmax(dones)==0., dones.size, jnp.argmax(dones))
first_episode_mask = jnp.where(jnp.arange(dones.size) <= first_done, True, False)
return jnp.where(first_episode_mask, rewards, 0.).sum()
all_dones = rews_dones[1]['__all__']
returns = jax.tree_map(lambda r: jax.vmap(first_episode_returns, in_axes=1)(r, all_dones), rews_dones[0])
all_dones = dones['__all__']
first_returns = jax.tree_map(lambda r: jax.vmap(first_episode_returns, in_axes=1)(r, all_dones), rewards)
first_infos = jax.tree_map(lambda i: jax.vmap(first_episode_returns, in_axes=1)(i[..., 0], all_dones), infos)
metrics = {
'test_returns': returns # episode returns
**{'test_returns_'+k:v.mean() for k, v in first_returns.items()},
**{'test_'+k:v for k,v in first_infos.items()}
}
if config.get('VERBOSE', False):
def callback(timestep, val):
print(f"Timestep: {timestep}, return: {val}")
jax.debug.callback(callback, time_state['timesteps']*config['NUM_ENVS'], returns['__all__'].mean())
jax.debug.callback(callback, time_state['timesteps']*config['NUM_ENVS'], first_returns['__all__'].mean())
return metrics

time_state = {
Expand Down Expand Up @@ -555,12 +571,21 @@ def main(config):
alg_name = f'iql_{"ps" if config["alg"].get("PARAMETERS_SHARING", True) else "ns"}'

# smac init neeeds a scenario
if 'SMAC' in env_name:
if 'smax' in env_name.lower():
config['env']['ENV_KWARGS']['scenario'] = map_name_to_scenario(config['env']['MAP_NAME'])
env_name = 'jaxmarl_'+config['env']['MAP_NAME']

env = make(config["env"]["ENV_NAME"], **config['env']['ENV_KWARGS'])
config["alg"]["NUM_STEPS"] = config["alg"].get("NUM_STEPS", env.max_steps) # default steps defined by the env
env_name = f"{config['env']['ENV_NAME']}_{config['env']['MAP_NAME']}"
env = make(config["env"]["ENV_NAME"], **config['env']['ENV_KWARGS'])
env = SMAXLogWrapper(env)
# overcooked needs a layout
elif 'overcooked' in env_name.lower():
config['env']["ENV_KWARGS"]["layout"] = overcooked_layouts[config['env']["ENV_KWARGS"]["layout"]]
env = make(config["env"]["ENV_NAME"], **config['env']['ENV_KWARGS'])
env = LogWrapper(env)
else:
env = make(config["env"]["ENV_NAME"], **config['env']['ENV_KWARGS'])
env = LogWrapper(env)

#config["alg"]["NUM_STEPS"] = config["alg"].get("NUM_STEPS", env.max_steps) # default steps defined by the env

wandb.init(
entity=config["ENTITY"],
Expand Down
Loading

0 comments on commit 5b9bd55

Please sign in to comment.