Skip to content

Commit 80554ef

Browse files
authored
Merge pull request #1124 from instadeepai/chore/ppo-system-cleanup
chore: PPO system cleanup
2 parents 2762b3d + 81c108d commit 80554ef

File tree

17 files changed

+323
-561
lines changed

17 files changed

+323
-561
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ Additionally, we also have a [Quickstart notebook][quickstart] that can be used
186186

187187
## Advanced Usage 👽
188188

189-
Mava can be used in a wide array of advanced systems. As an example, we demonstrate recording experience data from one of our PPO systems into a [Flashbax](https://github.com/instadeepai/flashbax) `Vault`. This vault can then easily be integrated into offline MARL systems, such as those found in [OG-MARL](https://github.com/instadeepai/og-marl). See the [Advanced README](./mava/advanced_usage/) for more information.
189+
Mava can be used in a wide array of advanced systems. As an example, we demonstrate recording experience data from one of our PPO systems into a [Flashbax](https://github.com/instadeepai/flashbax) `Vault`. This vault can then easily be integrated into offline MARL systems, such as those found in [OG-MARL](https://github.com/instadeepai/og-marl). See the [Advanced README](./examples/advanced_usage/README.md) for more information.
190190

191191
## Contributing 🤝
192192

examples/Quickstart.ipynb

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -413,8 +413,6 @@
413413
" )\n",
414414
"\n",
415415
" # Compute the parallel mean (pmean) over the batch.\n",
416-
" # This calculation is inspired by the Anakin architecture demo notebook.\n",
417-
" # available at https://tinyurl.com/26tdzs5x\n",
418416
" # This pmean could be a regular mean as the batch axis is on the same device.\n",
419417
" actor_grads, actor_loss_info = jax.lax.pmean(\n",
420418
" (actor_grads, actor_loss_info), axis_name=\"batch\"\n",
File renamed without changes.

mava/advanced_usage/ff_ippo_store_experience.py renamed to examples/advanced_usage/ff_ippo_store_experience.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# type: ignore
12
# Copyright 2022 InstaDeep Ltd. All rights reserved.
23
#
34
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -225,8 +226,6 @@ def _critic_loss_fn(
225226
)
226227

227228
# Compute the parallel mean (pmean) over the batch.
228-
# This calculation is inspired by the Anakin architecture demo notebook.
229-
# available at https://tinyurl.com/26tdzs5x
230229
# This pmean could be a regular mean as the batch axis is on the same device.
231230
actor_grads, actor_loss_info = jax.lax.pmean(
232231
(actor_grads, actor_loss_info), axis_name="batch"

mava/evaluator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ def _episode(key: PRNGKey) -> Tuple[PRNGKey, Metrics]:
293293
# find the first instance of done to get the metrics at that timestep, we don't
294294
# care about subsequent steps because we only the results from the first episode
295295
done_idx = np.argmax(timesteps.last(), axis=0)
296-
metrics = jax.tree_map(lambda m: m[done_idx, np.arange(n_parallel_envs)], metrics)
296+
metrics = tree.map(lambda m: m[done_idx, np.arange(n_parallel_envs)], metrics)
297297
del metrics["is_terminal_step"] # uneeded for logging
298298

299299
return key, metrics
@@ -307,7 +307,7 @@ def _episode(key: PRNGKey) -> Tuple[PRNGKey, Metrics]:
307307
metrics_array.append(metric)
308308

309309
# flatten metrics
310-
metrics: Metrics = jax.tree_map(lambda *x: np.array(x).reshape(-1), *metrics_array)
310+
metrics: Metrics = tree.map(lambda *x: np.array(x).reshape(-1), *metrics_array)
311311
return metrics
312312

313313
def timed_eval_fn(params: FrozenDict, key: PRNGKey, init_act_state: ActorState) -> Metrics:

mava/systems/mat/anakin/mat.py

Lines changed: 30 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -37,16 +37,13 @@
3737
ExperimentOutput,
3838
LearnerFn,
3939
MarlEnv,
40+
Metrics,
4041
TimeStep,
4142
)
4243
from mava.utils import make_env as environments
4344
from mava.utils.checkpointing import Checkpointer
4445
from mava.utils.config import check_total_timesteps
45-
from mava.utils.jax_utils import (
46-
merge_leading_dims,
47-
unreplicate_batch_dim,
48-
unreplicate_n_dims,
49-
)
46+
from mava.utils.jax_utils import merge_leading_dims, unreplicate_batch_dim, unreplicate_n_dims
5047
from mava.utils.logger import LogEvent, MavaLogger
5148
from mava.utils.network_utils import get_action_head
5249
from mava.utils.training import make_learning_rate
@@ -83,51 +80,35 @@ def _update_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, Tup
8380
_ (Any): The current metrics info.
8481
"""
8582

86-
def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, PPOTransition]:
83+
def _env_step(
84+
learner_state: LearnerState, _: Any
85+
) -> Tuple[LearnerState, Tuple[PPOTransition, Metrics]]:
8786
"""Step the environment."""
8887
params, opt_state, key, env_state, last_timestep = learner_state
8988

90-
# SELECT ACTION
89+
# Select action
9190
key, policy_key = jax.random.split(key)
9291
action, log_prob, value = actor_action_select_fn( # type: ignore
9392
params,
9493
last_timestep.observation,
9594
policy_key,
9695
)
97-
# STEP ENVIRONMENT
96+
# Step environment
9897
env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, action)
9998

100-
# LOG EPISODE METRICS
101-
# Repeat along the agent dimension. This is needed to handle the
102-
# shuffling along the agent dimension during training.
103-
info = tree.map(
104-
lambda x: jnp.repeat(x[..., jnp.newaxis], config.system.num_agents, axis=-1),
105-
timestep.extras["episode_metrics"],
106-
)
107-
108-
# SET TRANSITION
109-
done = tree.map(
110-
lambda x: jnp.repeat(x, config.system.num_agents).reshape(config.arch.num_envs, -1),
111-
timestep.last(),
112-
)
99+
done = timestep.last().repeat(env.num_agents).reshape(config.arch.num_envs, -1)
113100
transition = PPOTransition(
114-
done,
115-
action,
116-
value,
117-
timestep.reward,
118-
log_prob,
119-
last_timestep.observation,
120-
info,
101+
done, action, value, timestep.reward, log_prob, last_timestep.observation
121102
)
122103
learner_state = LearnerState(params, opt_state, key, env_state, timestep)
123-
return learner_state, transition
104+
return learner_state, (transition, timestep.extras["episode_metrics"])
124105

125-
# STEP ENVIRONMENT FOR ROLLOUT LENGTH
126-
learner_state, traj_batch = jax.lax.scan(
106+
# Step environment for rollout length
107+
learner_state, (traj_batch, episode_metrics) = jax.lax.scan(
127108
_env_step, learner_state, None, config.system.rollout_length
128109
)
129110

130-
# CALCULATE ADVANTAGE
111+
# Calculate advantage
131112
params, opt_state, key, env_state, last_timestep = learner_state
132113

133114
key, last_val_key = jax.random.split(key)
@@ -171,8 +152,6 @@ def _update_epoch(update_state: Tuple, _: Any) -> Tuple:
171152

172153
def _update_minibatch(train_state: Tuple, batch_info: Tuple) -> Tuple:
173154
"""Update the network for a single minibatch."""
174-
175-
# UNPACK TRAIN STATE AND BATCH INFO
176155
params, opt_state, key = train_state
177156
traj_batch, advantages, targets = batch_info
178157

@@ -184,52 +163,47 @@ def _loss_fn(
184163
entropy_key: chex.PRNGKey,
185164
) -> Tuple:
186165
"""Calculate the actor loss."""
187-
# RERUN NETWORK
188-
166+
# Rerun network
189167
log_prob, value, entropy = actor_apply_fn( # type: ignore
190168
params,
191169
traj_batch.obs,
192170
traj_batch.action,
193171
entropy_key,
194172
)
195173

196-
# CALCULATE ACTOR LOSS
174+
# Calculate actor loss
197175
ratio = jnp.exp(log_prob - traj_batch.log_prob)
198-
199176
# Nomalise advantage at minibatch level
200177
gae = (gae - gae.mean()) / (gae.std() + 1e-8)
201-
202-
loss_actor1 = ratio * gae
203-
loss_actor2 = (
178+
actor_loss1 = ratio * gae
179+
actor_loss2 = (
204180
jnp.clip(
205181
ratio,
206182
1.0 - config.system.clip_eps,
207183
1.0 + config.system.clip_eps,
208184
)
209185
* gae
210186
)
211-
loss_actor = -jnp.minimum(loss_actor1, loss_actor2)
212-
loss_actor = loss_actor.mean()
187+
actor_loss = -jnp.minimum(actor_loss1, actor_loss2)
188+
actor_loss = actor_loss.mean()
213189
entropy = entropy.mean()
214190

215-
# CALCULATE VALUE LOSS
191+
# Clipped MSE loss
216192
value_pred_clipped = traj_batch.value + (value - traj_batch.value).clip(
217193
-config.system.clip_eps, config.system.clip_eps
218194
)
219-
220-
# MSE LOSS
221195
value_losses = jnp.square(value - value_targets)
222196
value_losses_clipped = jnp.square(value_pred_clipped - value_targets)
223197
value_loss = 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean()
224198

225199
total_loss = (
226-
loss_actor
200+
actor_loss
227201
- config.system.ent_coef * entropy
228202
+ config.system.vf_coef * value_loss
229203
)
230-
return total_loss, (loss_actor, entropy, value_loss)
204+
return total_loss, (actor_loss, entropy, value_loss)
231205

232-
# CALCULATE ACTOR LOSS
206+
# Calculate loss
233207
key, entropy_key = jax.random.split(key)
234208
actor_grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
235209
actor_loss_info, actor_grads = actor_grad_fn(
@@ -248,15 +222,11 @@ def _loss_fn(
248222
(actor_grads, actor_loss_info), axis_name="device"
249223
)
250224

251-
# UPDATE ACTOR PARAMS AND OPTIMISER STATE
225+
# Update params and optimiser state
252226
actor_updates, new_opt_state = actor_update_fn(actor_grads, opt_state)
253227
new_params = optax.apply_updates(params, actor_updates)
254228

255-
# PACK LOSS INFO
256-
total_loss = actor_loss_info[0]
257-
value_loss = actor_loss_info[1][2]
258-
actor_loss = actor_loss_info[1][0]
259-
entropy = actor_loss_info[1][1]
229+
total_loss, (actor_loss, entropy, value_loss) = actor_loss_info
260230
loss_info = {
261231
"total_loss": total_loss,
262232
"value_loss": value_loss,
@@ -269,7 +239,7 @@ def _loss_fn(
269239
params, opt_state, traj_batch, advantages, targets, key = update_state
270240
key, batch_shuffle_key, agent_shuffle_key, entropy_key = jax.random.split(key, 4)
271241

272-
# SHUFFLE MINIBATCHES
242+
# Shuffle minibatches
273243
batch_size = config.system.rollout_length * config.arch.num_envs
274244
permutation = jax.random.permutation(batch_shuffle_key, batch_size)
275245

@@ -286,7 +256,7 @@ def _loss_fn(
286256
shuffled_batch,
287257
)
288258

289-
# UPDATE MINIBATCHES
259+
# Update minibatches
290260
(params, opt_state, entropy_key), loss_info = jax.lax.scan(
291261
_update_minibatch, (params, opt_state, entropy_key), minibatches
292262
)
@@ -296,17 +266,15 @@ def _loss_fn(
296266

297267
update_state = params, opt_state, traj_batch, advantages, targets, key
298268

299-
# UPDATE EPOCHS
269+
# Update epochs
300270
update_state, loss_info = jax.lax.scan(
301271
_update_epoch, update_state, None, config.system.ppo_epochs
302272
)
303273

304274
params, opt_state, traj_batch, advantages, targets, key = update_state
305275
learner_state = LearnerState(params, opt_state, key, env_state, last_timestep)
306276

307-
metric = traj_batch.info
308-
309-
return learner_state, (metric, loss_info)
277+
return learner_state, (episode_metrics, loss_info)
310278

311279
def learner_fn(learner_state: LearnerState) -> ExperimentOutput[LearnerState]:
312280
"""Learner function.
@@ -351,7 +319,7 @@ def learner_setup(
351319
# PRNG keys.
352320
key, actor_net_key = keys
353321

354-
# Initialise observation: Obs for all agents.
322+
# Get mock inputs to initialise network.
355323
init_x = env.observation_spec().generate_value()
356324
init_x = tree.map(lambda x: x[None, ...], init_x)
357325

0 commit comments

Comments
 (0)