37
37
ExperimentOutput ,
38
38
LearnerFn ,
39
39
MarlEnv ,
40
+ Metrics ,
40
41
TimeStep ,
41
42
)
42
43
from mava .utils import make_env as environments
43
44
from mava .utils .checkpointing import Checkpointer
44
45
from mava .utils .config import check_total_timesteps
45
- from mava .utils .jax_utils import (
46
- merge_leading_dims ,
47
- unreplicate_batch_dim ,
48
- unreplicate_n_dims ,
49
- )
46
+ from mava .utils .jax_utils import merge_leading_dims , unreplicate_batch_dim , unreplicate_n_dims
50
47
from mava .utils .logger import LogEvent , MavaLogger
51
48
from mava .utils .network_utils import get_action_head
52
49
from mava .utils .training import make_learning_rate
@@ -83,51 +80,35 @@ def _update_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, Tup
83
80
_ (Any): The current metrics info.
84
81
"""
85
82
86
- def _env_step (learner_state : LearnerState , _ : Any ) -> Tuple [LearnerState , PPOTransition ]:
83
+ def _env_step (
84
+ learner_state : LearnerState , _ : Any
85
+ ) -> Tuple [LearnerState , Tuple [PPOTransition , Metrics ]]:
87
86
"""Step the environment."""
88
87
params , opt_state , key , env_state , last_timestep = learner_state
89
88
90
- # SELECT ACTION
89
+ # Select action
91
90
key , policy_key = jax .random .split (key )
92
91
action , log_prob , value = actor_action_select_fn ( # type: ignore
93
92
params ,
94
93
last_timestep .observation ,
95
94
policy_key ,
96
95
)
97
- # STEP ENVIRONMENT
96
+ # Step environment
98
97
env_state , timestep = jax .vmap (env .step , in_axes = (0 , 0 ))(env_state , action )
99
98
100
- # LOG EPISODE METRICS
101
- # Repeat along the agent dimension. This is needed to handle the
102
- # shuffling along the agent dimension during training.
103
- info = tree .map (
104
- lambda x : jnp .repeat (x [..., jnp .newaxis ], config .system .num_agents , axis = - 1 ),
105
- timestep .extras ["episode_metrics" ],
106
- )
107
-
108
- # SET TRANSITION
109
- done = tree .map (
110
- lambda x : jnp .repeat (x , config .system .num_agents ).reshape (config .arch .num_envs , - 1 ),
111
- timestep .last (),
112
- )
99
+ done = timestep .last ().repeat (env .num_agents ).reshape (config .arch .num_envs , - 1 )
113
100
transition = PPOTransition (
114
- done ,
115
- action ,
116
- value ,
117
- timestep .reward ,
118
- log_prob ,
119
- last_timestep .observation ,
120
- info ,
101
+ done , action , value , timestep .reward , log_prob , last_timestep .observation
121
102
)
122
103
learner_state = LearnerState (params , opt_state , key , env_state , timestep )
123
- return learner_state , transition
104
+ return learner_state , ( transition , timestep . extras [ "episode_metrics" ])
124
105
125
- # STEP ENVIRONMENT FOR ROLLOUT LENGTH
126
- learner_state , traj_batch = jax .lax .scan (
106
+ # Step environment for rollout length
107
+ learner_state , ( traj_batch , episode_metrics ) = jax .lax .scan (
127
108
_env_step , learner_state , None , config .system .rollout_length
128
109
)
129
110
130
- # CALCULATE ADVANTAGE
111
+ # Calculate advantage
131
112
params , opt_state , key , env_state , last_timestep = learner_state
132
113
133
114
key , last_val_key = jax .random .split (key )
@@ -171,8 +152,6 @@ def _update_epoch(update_state: Tuple, _: Any) -> Tuple:
171
152
172
153
def _update_minibatch (train_state : Tuple , batch_info : Tuple ) -> Tuple :
173
154
"""Update the network for a single minibatch."""
174
-
175
- # UNPACK TRAIN STATE AND BATCH INFO
176
155
params , opt_state , key = train_state
177
156
traj_batch , advantages , targets = batch_info
178
157
@@ -184,52 +163,47 @@ def _loss_fn(
184
163
entropy_key : chex .PRNGKey ,
185
164
) -> Tuple :
186
165
"""Calculate the actor loss."""
187
- # RERUN NETWORK
188
-
166
+ # Rerun network
189
167
log_prob , value , entropy = actor_apply_fn ( # type: ignore
190
168
params ,
191
169
traj_batch .obs ,
192
170
traj_batch .action ,
193
171
entropy_key ,
194
172
)
195
173
196
- # CALCULATE ACTOR LOSS
174
+ # Calculate actor loss
197
175
ratio = jnp .exp (log_prob - traj_batch .log_prob )
198
-
199
176
# Nomalise advantage at minibatch level
200
177
gae = (gae - gae .mean ()) / (gae .std () + 1e-8 )
201
-
202
- loss_actor1 = ratio * gae
203
- loss_actor2 = (
178
+ actor_loss1 = ratio * gae
179
+ actor_loss2 = (
204
180
jnp .clip (
205
181
ratio ,
206
182
1.0 - config .system .clip_eps ,
207
183
1.0 + config .system .clip_eps ,
208
184
)
209
185
* gae
210
186
)
211
- loss_actor = - jnp .minimum (loss_actor1 , loss_actor2 )
212
- loss_actor = loss_actor .mean ()
187
+ actor_loss = - jnp .minimum (actor_loss1 , actor_loss2 )
188
+ actor_loss = actor_loss .mean ()
213
189
entropy = entropy .mean ()
214
190
215
- # CALCULATE VALUE LOSS
191
+ # Clipped MSE loss
216
192
value_pred_clipped = traj_batch .value + (value - traj_batch .value ).clip (
217
193
- config .system .clip_eps , config .system .clip_eps
218
194
)
219
-
220
- # MSE LOSS
221
195
value_losses = jnp .square (value - value_targets )
222
196
value_losses_clipped = jnp .square (value_pred_clipped - value_targets )
223
197
value_loss = 0.5 * jnp .maximum (value_losses , value_losses_clipped ).mean ()
224
198
225
199
total_loss = (
226
- loss_actor
200
+ actor_loss
227
201
- config .system .ent_coef * entropy
228
202
+ config .system .vf_coef * value_loss
229
203
)
230
- return total_loss , (loss_actor , entropy , value_loss )
204
+ return total_loss , (actor_loss , entropy , value_loss )
231
205
232
- # CALCULATE ACTOR LOSS
206
+ # Calculate loss
233
207
key , entropy_key = jax .random .split (key )
234
208
actor_grad_fn = jax .value_and_grad (_loss_fn , has_aux = True )
235
209
actor_loss_info , actor_grads = actor_grad_fn (
@@ -248,15 +222,11 @@ def _loss_fn(
248
222
(actor_grads , actor_loss_info ), axis_name = "device"
249
223
)
250
224
251
- # UPDATE ACTOR PARAMS AND OPTIMISER STATE
225
+ # Update params and optimiser state
252
226
actor_updates , new_opt_state = actor_update_fn (actor_grads , opt_state )
253
227
new_params = optax .apply_updates (params , actor_updates )
254
228
255
- # PACK LOSS INFO
256
- total_loss = actor_loss_info [0 ]
257
- value_loss = actor_loss_info [1 ][2 ]
258
- actor_loss = actor_loss_info [1 ][0 ]
259
- entropy = actor_loss_info [1 ][1 ]
229
+ total_loss , (actor_loss , entropy , value_loss ) = actor_loss_info
260
230
loss_info = {
261
231
"total_loss" : total_loss ,
262
232
"value_loss" : value_loss ,
@@ -269,7 +239,7 @@ def _loss_fn(
269
239
params , opt_state , traj_batch , advantages , targets , key = update_state
270
240
key , batch_shuffle_key , agent_shuffle_key , entropy_key = jax .random .split (key , 4 )
271
241
272
- # SHUFFLE MINIBATCHES
242
+ # Shuffle minibatches
273
243
batch_size = config .system .rollout_length * config .arch .num_envs
274
244
permutation = jax .random .permutation (batch_shuffle_key , batch_size )
275
245
@@ -286,7 +256,7 @@ def _loss_fn(
286
256
shuffled_batch ,
287
257
)
288
258
289
- # UPDATE MINIBATCHES
259
+ # Update minibatches
290
260
(params , opt_state , entropy_key ), loss_info = jax .lax .scan (
291
261
_update_minibatch , (params , opt_state , entropy_key ), minibatches
292
262
)
@@ -296,17 +266,15 @@ def _loss_fn(
296
266
297
267
update_state = params , opt_state , traj_batch , advantages , targets , key
298
268
299
- # UPDATE EPOCHS
269
+ # Update epochs
300
270
update_state , loss_info = jax .lax .scan (
301
271
_update_epoch , update_state , None , config .system .ppo_epochs
302
272
)
303
273
304
274
params , opt_state , traj_batch , advantages , targets , key = update_state
305
275
learner_state = LearnerState (params , opt_state , key , env_state , last_timestep )
306
276
307
- metric = traj_batch .info
308
-
309
- return learner_state , (metric , loss_info )
277
+ return learner_state , (episode_metrics , loss_info )
310
278
311
279
def learner_fn (learner_state : LearnerState ) -> ExperimentOutput [LearnerState ]:
312
280
"""Learner function.
@@ -351,7 +319,7 @@ def learner_setup(
351
319
# PRNG keys.
352
320
key , actor_net_key = keys
353
321
354
- # Initialise observation: Obs for all agents .
322
+ # Get mock inputs to initialise network .
355
323
init_x = env .observation_spec ().generate_value ()
356
324
init_x = tree .map (lambda x : x [None , ...], init_x )
357
325
0 commit comments