diff --git a/Dockerfile b/Dockerfile index c4cb6440..94986d2a 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM nvcr.io/nvidia/jax:23.10-py3 +FROM nvcr.io/nvidia/jax:24.10-py3 # Create user ARG UID diff --git a/baselines/IPPO/ippo_cnn_overcooked.py b/baselines/IPPO/ippo_cnn_overcooked.py index 8d47e37d..6bde1f49 100644 --- a/baselines/IPPO/ippo_cnn_overcooked.py +++ b/baselines/IPPO/ippo_cnn_overcooked.py @@ -237,9 +237,9 @@ def _env_step(runner_state, unused): shaped_reward = info.pop("shaped_reward") current_timestep = update_step*config["NUM_STEPS"]*config["NUM_ENVS"] - reward = jax.tree_map(lambda x,y: x+y*rew_shaping_anneal(current_timestep), reward, shaped_reward) + reward = jax.tree.map(lambda x,y: x+y*rew_shaping_anneal(current_timestep), reward, shaped_reward) - info = jax.tree_map(lambda x: x.reshape((config["NUM_ACTORS"])), info) + info = jax.tree.map(lambda x: x.reshape((config["NUM_ACTORS"])), info) transition = Transition( batchify(done, env.agents, config["NUM_ACTORS"]).squeeze(), action, @@ -345,13 +345,13 @@ def _loss_fn(params, traj_batch, gae, targets): ), "batch size must be equal to number of steps * number of actors" permutation = jax.random.permutation(_rng, batch_size) batch = (traj_batch, advantages, targets) - batch = jax.tree_util.tree_map( + batch = jax.tree.map( lambda x: x.reshape((batch_size,) + x.shape[2:]), batch ) - shuffled_batch = jax.tree_util.tree_map( + shuffled_batch = jax.tree.map( lambda x: jnp.take(x, permutation, axis=0), batch ) - minibatches = jax.tree_util.tree_map( + minibatches = jax.tree.map( lambda x: jnp.reshape( x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:]) ), @@ -375,7 +375,7 @@ def callback(metric): wandb.log(metric) update_step = update_step + 1 - metric = jax.tree_map(lambda x: x.mean(), metric) + metric = jax.tree.map(lambda x: x.mean(), metric) metric["update_step"] = update_step metric["env_step"] = update_step * config["NUM_STEPS"] * config["NUM_ENVS"] jax.debug.callback(callback, metric) @@ -413,7 +413,7 @@ def single_run(config): print("** Saving Results **") filename = f'{config["ENV_NAME"]}_{layout_name}_seed{config["SEED"]}' - train_state = jax.tree_map(lambda x: x[0], out["runner_state"][0]) + train_state = jax.tree.map(lambda x: x[0], out["runner_state"][0]) state_seq = get_rollout(train_state.params, config) viz = OvercookedVisualizer() # agent_view_size is hardcoded as it determines the padding around the layout. diff --git a/baselines/IPPO/ippo_ff_hanabi.py b/baselines/IPPO/ippo_ff_hanabi.py index cf937cbf..ea608bb3 100644 --- a/baselines/IPPO/ippo_ff_hanabi.py +++ b/baselines/IPPO/ippo_ff_hanabi.py @@ -138,7 +138,7 @@ def _env_step(runner_state, unused): action = pi.sample(seed=_rng) log_prob = pi.log_prob(action) env_act = unbatchify(action, env.agents, config["NUM_ENVS"], env.num_agents) - env_act = jax.tree_map(lambda x: x.squeeze(), env_act) + env_act = jax.tree.map(lambda x: x.squeeze(), env_act) # STEP ENV rng, _rng = jax.random.split(rng) @@ -146,7 +146,7 @@ def _env_step(runner_state, unused): obsv, env_state, reward, done, info = jax.vmap(env.step, in_axes=(0, 0, 0))( rng_step, env_state, env_act ) - info = jax.tree_map(lambda x: x.reshape((config["NUM_ACTORS"])), info) + info = jax.tree.map(lambda x: x.reshape((config["NUM_ACTORS"])), info) done_batch = batchify(done, env.agents, config["NUM_ACTORS"]).squeeze() transition = Transition( done_batch, @@ -258,11 +258,11 @@ def _loss_fn(params, traj_batch, gae, targets): batch = (traj_batch, advantages.squeeze(), targets.squeeze()) permutation = jax.random.permutation(_rng, config["NUM_ACTORS"]) - shuffled_batch = jax.tree_util.tree_map( + shuffled_batch = jax.tree.map( lambda x: jnp.take(x, permutation, axis=1), batch ) - minibatches = jax.tree_util.tree_map( + minibatches = jax.tree.map( lambda x: jnp.swapaxes( jnp.reshape( x, diff --git a/baselines/IPPO/ippo_ff_mabrax.py b/baselines/IPPO/ippo_ff_mabrax.py index f0b05f06..e731298d 100644 --- a/baselines/IPPO/ippo_ff_mabrax.py +++ b/baselines/IPPO/ippo_ff_mabrax.py @@ -143,7 +143,7 @@ def _env_step(runner_state, unused): rng_step, env_state, env_act, ) - info = jax.tree_map(lambda x: x.reshape((config["NUM_ACTORS"])), info) + info = jax.tree.map(lambda x: x.reshape((config["NUM_ACTORS"])), info) transition = Transition( batchify(done, env.agents, config["NUM_ACTORS"]).squeeze(), action, @@ -258,13 +258,13 @@ def _loss_fn(params, traj_batch, gae, targets): ), "batch size must be equal to number of steps * number of actors" permutation = jax.random.permutation(_rng, batch_size) batch = (traj_batch, advantages, targets) - batch = jax.tree_util.tree_map( + batch = jax.tree.map( lambda x: x.reshape((batch_size,) + x.shape[2:]), batch ) - shuffled_batch = jax.tree_util.tree_map( + shuffled_batch = jax.tree.map( lambda x: jnp.take(x, permutation, axis=0), batch ) - minibatches = jax.tree_util.tree_map( + minibatches = jax.tree.map( lambda x: jnp.reshape( x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:]) ), @@ -292,8 +292,8 @@ def callback(metric): update_count = update_count + 1 r0 = {"ratio0": loss_info["ratio"][0,0].mean()} - loss_info = jax.tree_map(lambda x: x.mean(), loss_info) - metric = jax.tree_map(lambda x: x.mean(), metric) + loss_info = jax.tree.map(lambda x: x.mean(), loss_info) + metric = jax.tree.map(lambda x: x.mean(), metric) metric["update_step"] = update_count metric["env_step"] = update_count * config["NUM_STEPS"] * config["NUM_ENVS"] metric = {**metric, **loss_info, **r0} diff --git a/baselines/IPPO/ippo_ff_mpe.py b/baselines/IPPO/ippo_ff_mpe.py index fd86072e..379c0c85 100644 --- a/baselines/IPPO/ippo_ff_mpe.py +++ b/baselines/IPPO/ippo_ff_mpe.py @@ -140,7 +140,7 @@ def _env_step(runner_state, unused): rng_step, env_state, env_act, ) - info = jax.tree_map(lambda x: x.reshape((config["NUM_ACTORS"])), info) + info = jax.tree.map(lambda x: x.reshape((config["NUM_ACTORS"])), info) transition = Transition( batchify(done, env.agents, config["NUM_ACTORS"]).squeeze(), action, @@ -255,13 +255,13 @@ def _loss_fn(params, traj_batch, gae, targets): ), "batch size must be equal to number of steps * number of actors" permutation = jax.random.permutation(_rng, batch_size) batch = (traj_batch, advantages, targets) - batch = jax.tree_util.tree_map( + batch = jax.tree.map( lambda x: x.reshape((batch_size,) + x.shape[2:]), batch ) - shuffled_batch = jax.tree_util.tree_map( + shuffled_batch = jax.tree.map( lambda x: jnp.take(x, permutation, axis=0), batch ) - minibatches = jax.tree_util.tree_map( + minibatches = jax.tree.map( lambda x: jnp.reshape( x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:]) ), @@ -288,8 +288,8 @@ def callback(metric): r0 = {"ratio0": loss_info["ratio"][0,0].mean()} # jax.debug.print('ratio0 {x}', x=r0["ratio0"]) - loss_info = jax.tree_map(lambda x: x.mean(), loss_info) - metric = jax.tree_map(lambda x: x.mean(), metric) + loss_info = jax.tree.map(lambda x: x.mean(), loss_info) + metric = jax.tree.map(lambda x: x.mean(), metric) metric = {**metric, **loss_info, **r0} jax.experimental.io_callback(callback, None, metric) runner_state = (train_state, env_state, last_obs, rng) diff --git a/baselines/IPPO/ippo_ff_mpe_facmac.py b/baselines/IPPO/ippo_ff_mpe_facmac.py index d5262fdd..5cfc48a0 100644 --- a/baselines/IPPO/ippo_ff_mpe_facmac.py +++ b/baselines/IPPO/ippo_ff_mpe_facmac.py @@ -146,7 +146,7 @@ def _env_step(runner_state, unused): rng_step, env_state, env_act, ) - info = jax.tree_map(lambda x: x.reshape((config["NUM_ACTORS"])), info) + info = jax.tree.map(lambda x: x.reshape((config["NUM_ACTORS"])), info) transition = Transition( batchify(done, env.agents, config["NUM_ACTORS"]).squeeze(), action, @@ -252,13 +252,13 @@ def _loss_fn(params, traj_batch, gae, targets): ), "batch size must be equal to number of steps * number of actors" permutation = jax.random.permutation(_rng, batch_size) batch = (traj_batch, advantages, targets) - batch = jax.tree_util.tree_map( + batch = jax.tree.map( lambda x: x.reshape((batch_size,) + x.shape[2:]), batch ) - shuffled_batch = jax.tree_util.tree_map( + shuffled_batch = jax.tree.map( lambda x: jnp.take(x, permutation, axis=0), batch ) - minibatches = jax.tree_util.tree_map( + minibatches = jax.tree.map( lambda x: jnp.reshape( x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:]) ), diff --git a/baselines/IPPO/ippo_ff_overcooked.py b/baselines/IPPO/ippo_ff_overcooked.py index 116c0273..33bef493 100644 --- a/baselines/IPPO/ippo_ff_overcooked.py +++ b/baselines/IPPO/ippo_ff_overcooked.py @@ -212,7 +212,7 @@ def _env_step(runner_state, unused): info["reward"] = reward["agent_0"] current_timestep = update_step*config["NUM_STEPS"]*config["NUM_ENVS"] - reward = jax.tree_map(lambda x,y: x+y*rew_shaping_anneal(current_timestep), reward, info["shaped_reward"]) + reward = jax.tree.map(lambda x,y: x+y*rew_shaping_anneal(current_timestep), reward, info["shaped_reward"]) transition = Transition( batchify(done, env.agents, config["NUM_ACTORS"]).squeeze(), @@ -318,13 +318,13 @@ def _loss_fn(params, traj_batch, gae, targets): ), "batch size must be equal to number of steps * number of actors" permutation = jax.random.permutation(_rng, batch_size) batch = (traj_batch, advantages, targets) - batch = jax.tree_util.tree_map( + batch = jax.tree.map( lambda x: x.reshape((batch_size,) + x.shape[2:]), batch ) - shuffled_batch = jax.tree_util.tree_map( + shuffled_batch = jax.tree.map( lambda x: jnp.take(x, permutation, axis=0), batch ) - minibatches = jax.tree_util.tree_map( + minibatches = jax.tree.map( lambda x: jnp.reshape( x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:]) ), @@ -353,7 +353,7 @@ def callback(metric): metric ) update_step = update_step + 1 - metric = jax.tree_map(lambda x: x.mean(), metric) + metric = jax.tree.map(lambda x: x.mean(), metric) metric["update_step"] = update_step metric["env_step"] = update_step*config["NUM_STEPS"]*config["NUM_ENVS"] jax.debug.callback(callback, metric) @@ -393,7 +393,7 @@ def main(config): out = jax.vmap(train_jit)(rngs) filename = f'{config["ENV_NAME"]}_{layout_name}' - train_state = jax.tree_map(lambda x: x[0], out["runner_state"][0]) + train_state = jax.tree.map(lambda x: x[0], out["runner_state"][0]) state_seq = get_rollout(train_state, config) viz = OvercookedVisualizer() # agent_view_size is hardcoded as it determines the padding around the layout. @@ -415,7 +415,7 @@ def main(config): plt.savefig(f'{filename}.png') # animate first seed - train_state = jax.tree_map(lambda x: x[0], out["runner_state"][0]) + train_state = jax.tree.map(lambda x: x[0], out["runner_state"][0]) state_seq = get_rollout(train_state, config) viz = OvercookedVisualizer() # agent_view_size is hardcoded as it determines the padding around the layout. diff --git a/baselines/IPPO/ippo_ff_switch_riddle.py b/baselines/IPPO/ippo_ff_switch_riddle.py index 2d4b65a7..2a4fb97c 100644 --- a/baselines/IPPO/ippo_ff_switch_riddle.py +++ b/baselines/IPPO/ippo_ff_switch_riddle.py @@ -141,7 +141,7 @@ def _env_step(runner_state, unused): rng_step, env_state, env_act, ) - info = jax.tree_map(lambda x: x.reshape((config["NUM_ACTORS"])), info) + info = jax.tree.map(lambda x: x.reshape((config["NUM_ACTORS"])), info) transition = Transition( batchify(done, env.agents, config["NUM_ACTORS"]).squeeze(), action, @@ -247,13 +247,13 @@ def _loss_fn(params, traj_batch, gae, targets): ), "batch size must be equal to number of steps * number of actors" permutation = jax.random.permutation(_rng, batch_size) batch = (traj_batch, advantages, targets) - batch = jax.tree_util.tree_map( + batch = jax.tree.map( lambda x: x.reshape((batch_size,) + x.shape[2:]), batch ) - shuffled_batch = jax.tree_util.tree_map( + shuffled_batch = jax.tree.map( lambda x: jnp.take(x, permutation, axis=0), batch ) - minibatches = jax.tree_util.tree_map( + minibatches = jax.tree.map( lambda x: jnp.reshape( x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:]) ), diff --git a/baselines/IPPO/ippo_rnn_hanabi.py b/baselines/IPPO/ippo_rnn_hanabi.py index 0d5f382c..d59f0197 100644 --- a/baselines/IPPO/ippo_rnn_hanabi.py +++ b/baselines/IPPO/ippo_rnn_hanabi.py @@ -183,14 +183,14 @@ def _env_step(runner_state, unused): action = pi.sample(seed=_rng) log_prob = pi.log_prob(action) env_act = unbatchify(action, env.agents, config["NUM_ENVS"], env.num_agents) - env_act = jax.tree_map(lambda x: x.squeeze(), env_act) + env_act = jax.tree.map(lambda x: x.squeeze(), env_act) # STEP ENV rng, _rng = jax.random.split(rng) rng_step = jax.random.split(_rng, config["NUM_ENVS"]) obsv, env_state, reward, done, info = jax.vmap(env.step, in_axes=(0, 0, 0))( rng_step, env_state, env_act ) - info = jax.tree_map(lambda x: x.reshape((config["NUM_ACTORS"])), info) + info = jax.tree.map(lambda x: x.reshape((config["NUM_ACTORS"])), info) done_batch = batchify(done, env.agents, config["NUM_ACTORS"]).squeeze() transition = Transition( jnp.tile(done["__all__"], env.num_agents), @@ -312,11 +312,11 @@ def _loss_fn(params, init_hstate, traj_batch, gae, targets): batch = (init_hstate, traj_batch, advantages.squeeze(), targets.squeeze()) permutation = jax.random.permutation(_rng, config["NUM_ACTORS"]) - shuffled_batch = jax.tree_util.tree_map( + shuffled_batch = jax.tree.map( lambda x: jnp.take(x, permutation, axis=1), batch ) - minibatches = jax.tree_util.tree_map( + minibatches = jax.tree.map( lambda x: jnp.swapaxes( jnp.reshape( x, @@ -342,7 +342,7 @@ def _loss_fn(params, init_hstate, traj_batch, gae, targets): train_state = update_state[0] metric = traj_batch.info ratio_0 = loss_info[1][3].at[0,0].get().mean() - loss_info = jax.tree_map(lambda x: x.mean(), loss_info) + loss_info = jax.tree.map(lambda x: x.mean(), loss_info) metric["loss"] = { "total_loss": loss_info[0], "value_loss": loss_info[1][0], diff --git a/baselines/IPPO/ippo_rnn_mpe.py b/baselines/IPPO/ippo_rnn_mpe.py index 1837c2ed..a8ea020a 100644 --- a/baselines/IPPO/ippo_rnn_mpe.py +++ b/baselines/IPPO/ippo_rnn_mpe.py @@ -198,7 +198,7 @@ def _env_step(runner_state, unused): obsv, env_state, reward, done, info = jax.vmap( env.step, in_axes=(0, 0, 0) )(rng_step, env_state, env_act) - info = jax.tree_map(lambda x: x.reshape((config["NUM_ACTORS"])), info) + info = jax.tree.map(lambda x: x.reshape((config["NUM_ACTORS"])), info) done_batch = batchify(done, env.agents, config["NUM_ACTORS"]).squeeze() transition = Transition( jnp.tile(done["__all__"], env.num_agents), @@ -334,11 +334,11 @@ def _loss_fn(params, init_hstate, traj_batch, gae, targets): ) permutation = jax.random.permutation(_rng, config["NUM_ACTORS"]) - shuffled_batch = jax.tree_util.tree_map( + shuffled_batch = jax.tree.map( lambda x: jnp.take(x, permutation, axis=1), batch ) - minibatches = jax.tree_util.tree_map( + minibatches = jax.tree.map( lambda x: jnp.swapaxes( jnp.reshape( x, @@ -377,14 +377,14 @@ def _loss_fn(params, init_hstate, traj_batch, gae, targets): ) train_state = update_state[0] metric = traj_batch.info - metric = jax.tree_map( + metric = jax.tree.map( lambda x: x.reshape( (config["NUM_STEPS"], config["NUM_ENVS"], env.num_agents) ), traj_batch.info, ) ratio_0 = loss_info[1][3].at[0,0].get().mean() - loss_info = jax.tree_map(lambda x: x.mean(), loss_info) + loss_info = jax.tree.map(lambda x: x.mean(), loss_info) metric["loss"] = { "total_loss": loss_info[0], "value_loss": loss_info[1][0], diff --git a/baselines/IPPO/ippo_rnn_smax.py b/baselines/IPPO/ippo_rnn_smax.py index 5d52470d..14ffd1b8 100644 --- a/baselines/IPPO/ippo_rnn_smax.py +++ b/baselines/IPPO/ippo_rnn_smax.py @@ -204,7 +204,7 @@ def _env_step(runner_state, unused): obsv, env_state, reward, done, info = jax.vmap( env.step, in_axes=(0, 0, 0) )(rng_step, env_state, env_act) - info = jax.tree_map(lambda x: x.reshape((config["NUM_ACTORS"])), info) + info = jax.tree.map(lambda x: x.reshape((config["NUM_ACTORS"])), info) done_batch = batchify(done, env.agents, config["NUM_ACTORS"]).squeeze() transition = Transition( jnp.tile(done["__all__"], env.num_agents), @@ -346,11 +346,11 @@ def _loss_fn(params, init_hstate, traj_batch, gae, targets): ) permutation = jax.random.permutation(_rng, config["NUM_ACTORS"]) - shuffled_batch = jax.tree_util.tree_map( + shuffled_batch = jax.tree.map( lambda x: jnp.take(x, permutation, axis=1), batch ) - minibatches = jax.tree_util.tree_map( + minibatches = jax.tree.map( lambda x: jnp.swapaxes( jnp.reshape( x, @@ -389,14 +389,14 @@ def _loss_fn(params, init_hstate, traj_batch, gae, targets): ) train_state = update_state[0] metric = traj_batch.info - metric = jax.tree_map( + metric = jax.tree.map( lambda x: x.reshape( (config["NUM_STEPS"], config["NUM_ENVS"], env.num_agents) ), traj_batch.info, ) ratio_0 = loss_info[1][3].at[0,0].get().mean() - loss_info = jax.tree_map(lambda x: x.mean(), loss_info) + loss_info = jax.tree.map(lambda x: x.mean(), loss_info) metric["loss"] = { "total_loss": loss_info[0], "value_loss": loss_info[1][0], diff --git a/baselines/MAPPO/mappo_ff_hanabi.py b/baselines/MAPPO/mappo_ff_hanabi.py index 61da33e1..8c79d90f 100644 --- a/baselines/MAPPO/mappo_ff_hanabi.py +++ b/baselines/MAPPO/mappo_ff_hanabi.py @@ -239,7 +239,7 @@ def _env_step(runner_state, unused): env_act = unbatchify( action, env.agents, config["NUM_ENVS"], env.num_agents ) - env_act = jax.tree_map(lambda x: x.squeeze(), env_act) + env_act = jax.tree.map(lambda x: x.squeeze(), env_act) # VALUE world_state = last_obs["world_state"].swapaxes(0,1) @@ -252,7 +252,7 @@ def _env_step(runner_state, unused): obsv, env_state, reward, done, info = jax.vmap( env.step, in_axes=(0, 0, 0) )(rng_step, env_state, env_act) - info = jax.tree_map(lambda x: x.reshape((config["NUM_ACTORS"])), info) + info = jax.tree.map(lambda x: x.reshape((config["NUM_ACTORS"])), info) done_batch = batchify(done, env.agents, config["NUM_ACTORS"]).squeeze() transition = Transition( jnp.tile(done["__all__"], env.num_agents), @@ -406,11 +406,11 @@ def _critic_loss_fn(critic_params, traj_batch, targets): ) permutation = jax.random.permutation(_rng, config["NUM_ACTORS"]) - shuffled_batch = jax.tree_util.tree_map( + shuffled_batch = jax.tree.map( lambda x: jnp.take(x, permutation, axis=1), batch ) - minibatches = jax.tree_util.tree_map( + minibatches = jax.tree.map( lambda x: jnp.swapaxes( jnp.reshape( x, @@ -449,7 +449,7 @@ def _critic_loss_fn(critic_params, traj_batch, targets): train_states = update_state[0] metric = traj_batch.info loss_info["ratio_0"] = loss_info["ratio"].at[0,0].get() - loss_info = jax.tree_map(lambda x: x.mean(), loss_info) + loss_info = jax.tree.map(lambda x: x.mean(), loss_info) metric["loss"] = loss_info rng = update_state[-1] diff --git a/baselines/MAPPO/mappo_rnn_hanabi.py b/baselines/MAPPO/mappo_rnn_hanabi.py index 734ccc87..63b106f4 100644 --- a/baselines/MAPPO/mappo_rnn_hanabi.py +++ b/baselines/MAPPO/mappo_rnn_hanabi.py @@ -275,7 +275,7 @@ def _env_step(runner_state, unused): env_act = unbatchify( action, env.agents, config["NUM_ENVS"], env.num_agents ) - env_act = jax.tree_map(lambda x: x.squeeze(), env_act) + env_act = jax.tree.map(lambda x: x.squeeze(), env_act) # VALUE # output of wrapper is (num_envs, num_agents, world_state_size) @@ -294,7 +294,7 @@ def _env_step(runner_state, unused): obsv, env_state, reward, done, info = jax.vmap( env.step, in_axes=(0, 0, 0) )(rng_step, env_state, env_act) - info = jax.tree_map(lambda x: x.reshape((config["NUM_ACTORS"])), info) + info = jax.tree.map(lambda x: x.reshape((config["NUM_ACTORS"])), info) done_batch = batchify(done, env.agents, config["NUM_ACTORS"]).squeeze() transition = Transition( jnp.tile(done["__all__"], env.num_agents), @@ -447,7 +447,7 @@ def _critic_loss_fn(critic_params, init_hstate, traj_batch, targets): ) = update_state rng, _rng = jax.random.split(rng) - init_hstates = jax.tree_map(lambda x: jnp.reshape( + init_hstates = jax.tree.map(lambda x: jnp.reshape( x, (1, config["NUM_ACTORS"], -1) ), initial_hstates) @@ -460,11 +460,11 @@ def _critic_loss_fn(critic_params, init_hstate, traj_batch, targets): ) permutation = jax.random.permutation(_rng, config["NUM_ACTORS"]) - shuffled_batch = jax.tree_util.tree_map( + shuffled_batch = jax.tree.map( lambda x: jnp.take(x, permutation, axis=1), batch ) - minibatches = jax.tree_util.tree_map( + minibatches = jax.tree.map( lambda x: jnp.swapaxes( jnp.reshape( x, @@ -482,7 +482,7 @@ def _critic_loss_fn(critic_params, init_hstate, traj_batch, targets): ) update_state = ( train_states, - jax.tree_map(lambda x: x.squeeze(), init_hstates), + jax.tree.map(lambda x: x.squeeze(), init_hstates), traj_batch, advantages, targets, @@ -502,7 +502,7 @@ def _critic_loss_fn(critic_params, init_hstate, traj_batch, targets): _update_epoch, update_state, None, config["UPDATE_EPOCHS"] ) loss_info["ratio_0"] = loss_info["ratio"].at[0,0].get() - loss_info = jax.tree_map(lambda x: x.mean(), loss_info) + loss_info = jax.tree.map(lambda x: x.mean(), loss_info) train_states = update_state[0] metric = traj_batch.info diff --git a/baselines/MAPPO/mappo_rnn_mpe.py b/baselines/MAPPO/mappo_rnn_mpe.py index 96141731..f7a2b337 100644 --- a/baselines/MAPPO/mappo_rnn_mpe.py +++ b/baselines/MAPPO/mappo_rnn_mpe.py @@ -283,7 +283,7 @@ def _env_step(runner_state, unused): obsv, env_state, reward, done, info = jax.vmap( env.step, in_axes=(0, 0, 0) )(rng_step, env_state, env_act) - info = jax.tree_map(lambda x: x.reshape((config["NUM_ACTORS"])), info) + info = jax.tree.map(lambda x: x.reshape((config["NUM_ACTORS"])), info) done_batch = batchify(done, env.agents, config["NUM_ACTORS"]).squeeze() transition = Transition( jnp.tile(done["__all__"], env.num_agents), @@ -435,7 +435,7 @@ def _critic_loss_fn(critic_params, init_hstate, traj_batch, targets): ) = update_state rng, _rng = jax.random.split(rng) - init_hstates = jax.tree_map(lambda x: jnp.reshape( + init_hstates = jax.tree.map(lambda x: jnp.reshape( x, (1, config["NUM_ACTORS"], -1) ), init_hstates) @@ -448,11 +448,11 @@ def _critic_loss_fn(critic_params, init_hstate, traj_batch, targets): ) permutation = jax.random.permutation(_rng, config["NUM_ACTORS"]) - shuffled_batch = jax.tree_util.tree_map( + shuffled_batch = jax.tree.map( lambda x: jnp.take(x, permutation, axis=1), batch ) - minibatches = jax.tree_util.tree_map( + minibatches = jax.tree.map( lambda x: jnp.swapaxes( jnp.reshape( x, @@ -470,7 +470,7 @@ def _critic_loss_fn(critic_params, init_hstate, traj_batch, targets): ) update_state = ( train_states, - jax.tree_map(lambda x: x.squeeze(), init_hstates), + jax.tree.map(lambda x: x.squeeze(), init_hstates), traj_batch, advantages, targets, @@ -490,7 +490,7 @@ def _critic_loss_fn(critic_params, init_hstate, traj_batch, targets): _update_epoch, update_state, None, config["UPDATE_EPOCHS"] ) loss_info["ratio_0"] = loss_info["ratio"].at[0,0].get() - loss_info = jax.tree_map(lambda x: x.mean(), loss_info) + loss_info = jax.tree.map(lambda x: x.mean(), loss_info) train_states = update_state[0] metric = traj_batch.info diff --git a/baselines/MAPPO/mappo_rnn_smax.py b/baselines/MAPPO/mappo_rnn_smax.py index 8cd73199..e9210e6d 100644 --- a/baselines/MAPPO/mappo_rnn_smax.py +++ b/baselines/MAPPO/mappo_rnn_smax.py @@ -314,7 +314,7 @@ def _env_step(runner_state, unused): obsv, env_state, reward, done, info = jax.vmap( env.step, in_axes=(0, 0, 0) )(rng_step, env_state, env_act) - info = jax.tree_map(lambda x: x.reshape((config["NUM_ACTORS"])), info) + info = jax.tree.map(lambda x: x.reshape((config["NUM_ACTORS"])), info) done_batch = batchify(done, env.agents, config["NUM_ACTORS"]).squeeze() transition = Transition( jnp.tile(done["__all__"], env.num_agents), @@ -466,7 +466,7 @@ def _critic_loss_fn(critic_params, init_hstate, traj_batch, targets): ) = update_state rng, _rng = jax.random.split(rng) - init_hstates = jax.tree_map(lambda x: jnp.reshape( + init_hstates = jax.tree.map(lambda x: jnp.reshape( x, (1, config["NUM_ACTORS"], -1) ), init_hstates) @@ -479,11 +479,11 @@ def _critic_loss_fn(critic_params, init_hstate, traj_batch, targets): ) permutation = jax.random.permutation(_rng, config["NUM_ACTORS"]) - shuffled_batch = jax.tree_util.tree_map( + shuffled_batch = jax.tree.map( lambda x: jnp.take(x, permutation, axis=1), batch ) - minibatches = jax.tree_util.tree_map( + minibatches = jax.tree.map( lambda x: jnp.swapaxes( jnp.reshape( x, @@ -502,7 +502,7 @@ def _critic_loss_fn(critic_params, init_hstate, traj_batch, targets): ) update_state = ( train_states, - jax.tree_map(lambda x: x.squeeze(), init_hstates), + jax.tree.map(lambda x: x.squeeze(), init_hstates), traj_batch, advantages, targets, @@ -522,11 +522,11 @@ def _critic_loss_fn(critic_params, init_hstate, traj_batch, targets): _update_epoch, update_state, None, config["UPDATE_EPOCHS"] ) loss_info["ratio_0"] = loss_info["ratio"].at[0,0].get() - loss_info = jax.tree_map(lambda x: x.mean(), loss_info) + loss_info = jax.tree.map(lambda x: x.mean(), loss_info) train_states = update_state[0] metric = traj_batch.info - metric = jax.tree_map( + metric = jax.tree.map( lambda x: x.reshape( (config["NUM_STEPS"], config["NUM_ENVS"], env.num_agents) ), diff --git a/baselines/QLearning/iql_cnn_overcooked.py b/baselines/QLearning/iql_cnn_overcooked.py index bddadec1..38011c14 100644 --- a/baselines/QLearning/iql_cnn_overcooked.py +++ b/baselines/QLearning/iql_cnn_overcooked.py @@ -224,7 +224,7 @@ def create_agent(rng): rewards=_rewards, dones=_dones, ) - _tiemstep_unbatched = jax.tree_map( + _tiemstep_unbatched = jax.tree.map( lambda x: x[0], _timestep ) # remove the NUM_ENV dim buffer_state = buffer.init(_tiemstep_unbatched) @@ -260,7 +260,7 @@ def _step_env(carry, _): # add shaped reward shaped_reward = infos.pop("shaped_reward") shaped_reward["__all__"] = batchify(shaped_reward).sum(axis=0) - rewards = jax.tree_map( + rewards = jax.tree.map( lambda x, y: x + y * rew_shaping_anneal(train_state.timesteps), rewards, shaped_reward, @@ -291,7 +291,7 @@ def _step_env(carry, _): ) # update timesteps count # BUFFER UPDATE - timesteps = jax.tree_util.tree_map( + timesteps = jax.tree.map( lambda x: x.reshape(-1, *x.shape[2:]), timesteps ) # (num_envs*num_steps, ...) buffer_state = buffer.add(buffer_state, timesteps) @@ -385,7 +385,7 @@ def _loss_fn(params): "loss": loss.mean(), "qvals": qvals.mean(), } - metrics.update(jax.tree_map(lambda x: x.mean(), infos)) + metrics.update(jax.tree.map(lambda x: x.mean(), infos)) if config.get("TEST_DURING_TRAINING", True): rng, _rng = jax.random.split(rng) @@ -545,7 +545,7 @@ def single_run(config): ) for i, rng in enumerate(rngs): - params = jax.tree_map(lambda x: x[i], model_state.params) + params = jax.tree.map(lambda x: x[i], model_state.params) save_path = os.path.join( save_dir, f'{alg_name}_{env_name}_seed{config["SEED"]}_vmap{i}.safetensors', diff --git a/baselines/QLearning/iql_rnn.py b/baselines/QLearning/iql_rnn.py index 83aaca94..f8f2facf 100644 --- a/baselines/QLearning/iql_rnn.py +++ b/baselines/QLearning/iql_rnn.py @@ -233,7 +233,7 @@ def _env_sample_step(env_state, unused): _, sample_traj = jax.lax.scan( _env_sample_step, _env_state, None, config["NUM_STEPS"] ) - sample_traj_unbatched = jax.tree_map( + sample_traj_unbatched = jax.tree.map( lambda x: x[:, 0], sample_traj ) # remove the NUM_ENV dim buffer = fbx.make_trajectory_buffer( @@ -288,7 +288,7 @@ def _step_env(carry, _): timestep = Timestep( obs=last_obs, actions=actions, - rewards=jax.tree_map(lambda x:config.get("REW_SCALE", 1)*x, rewards), + rewards=jax.tree.map(lambda x:config.get("REW_SCALE", 1)*x, rewards), dones=dones, avail_actions=avail_actions, ) @@ -319,7 +319,7 @@ def _step_env(carry, _): ) # update timesteps count # BUFFER UPDATE - buffer_traj_batch = jax.tree_util.tree_map( + buffer_traj_batch = jax.tree.map( lambda x: jnp.swapaxes(x, 0, 1)[ :, np.newaxis ], # put the batch dim first and add a dummy sequence dim @@ -333,7 +333,7 @@ def _learn_phase(carry, _): train_state, rng = carry rng, _rng = jax.random.split(rng) minibatch = buffer.sample(buffer_state, _rng).experience - minibatch = jax.tree_map( + minibatch = jax.tree.map( lambda x: jnp.swapaxes( x[:, 0], 0, 1 ), # remove the dummy sequence dim (1) and swap batch and temporal dims @@ -455,10 +455,10 @@ def _loss_fn(params): "loss": loss.mean(), "qvals": qvals.mean(), } - metrics.update(jax.tree_map(lambda x: x.mean(), infos)) + metrics.update(jax.tree.map(lambda x: x.mean(), infos)) if config.get("LOG_AGENTS_SEPARATELY", False): for i, a in enumerate(env.agents): - m = jax.tree_map( + m = jax.tree.map( lambda x: x[..., i].mean(), infos, ) @@ -544,7 +544,7 @@ def _greedy_env_step(step_state, unused): if config.get("LOG_AGENTS_SEPARATELY", False): metrics = {} for i, a in enumerate(env.agents): - m = jax.tree_map( + m = jax.tree.map( lambda x: jnp.nanmean( jnp.where( infos["returned_episode"][..., i], @@ -557,7 +557,7 @@ def _greedy_env_step(step_state, unused): m = {k + f"_{a}": v for k, v in m.items()} metrics.update(m) else: - metrics = jax.tree_map( + metrics = jax.tree.map( lambda x: jnp.nanmean( jnp.where( infos["returned_episode"], @@ -652,7 +652,7 @@ def single_run(config): ) for i, rng in enumerate(rngs): - params = jax.tree_map(lambda x: x[i], model_state.params) + params = jax.tree.map(lambda x: x[i], model_state.params) save_path = os.path.join( save_dir, f'{alg_name}_{env_name}_seed{config["SEED"]}_vmap{i}.safetensors', diff --git a/baselines/QLearning/pqn_vdn_cnn_overcooked.py b/baselines/QLearning/pqn_vdn_cnn_overcooked.py index b64e7b42..cffa656e 100644 --- a/baselines/QLearning/pqn_vdn_cnn_overcooked.py +++ b/baselines/QLearning/pqn_vdn_cnn_overcooked.py @@ -261,7 +261,7 @@ def _step_env(carry, _): # add shaped reward shaped_reward = info.pop("shaped_reward") shaped_reward["__all__"] = batchify(shaped_reward).sum(axis=0) - reward = jax.tree_map( + reward = jax.tree.map( lambda x, y: x + y * rew_shaping_anneal(train_state.timesteps), reward, shaped_reward, @@ -337,7 +337,7 @@ def _get_target(lambda_returns_and_next_q, rew_q_done): _, targets = jax.lax.scan( _get_target, (lambda_returns, last_q), - jax.tree_map(lambda x: x[:-1], (reward, q_vals, done)), + jax.tree.map(lambda x: x[:-1], (reward, q_vals, done)), reverse=True, ) targets = jnp.concatenate((targets, lambda_returns[np.newaxis])) @@ -430,7 +430,7 @@ def preprocess_transition(x, rng): return x rng, _rng = jax.random.split(rng) - minibatches = jax.tree_util.tree_map( + minibatches = jax.tree.map( lambda x: preprocess_transition(x, _rng), transitions, ) # num_minibatches, num_agents, num_envs/num_minbatches ... @@ -457,7 +457,7 @@ def preprocess_transition(x, rng): "loss": loss.mean(), "qvals": qvals.mean(), } - metrics.update(jax.tree_map(lambda x: x.mean(), infos)) + metrics.update(jax.tree.map(lambda x: x.mean(), infos)) if config.get("TEST_DURING_TRAINING", True): rng, _rng = jax.random.split(rng) @@ -624,7 +624,7 @@ def single_run(config): ) for i, rng in enumerate(rngs): - params = jax.tree_map(lambda x: x[i], model_state.params) + params = jax.tree.map(lambda x: x[i], model_state.params) save_path = os.path.join( save_dir, f'{alg_name}_{env_name}_seed{config["SEED"]}_vmap{i}.safetensors', diff --git a/baselines/QLearning/pqn_vdn_ff.py b/baselines/QLearning/pqn_vdn_ff.py index cacbe86a..19267a6e 100644 --- a/baselines/QLearning/pqn_vdn_ff.py +++ b/baselines/QLearning/pqn_vdn_ff.py @@ -291,7 +291,7 @@ def _get_target(lambda_returns_and_next_q, rew_q_done): _, targets = jax.lax.scan( _get_target, (lambda_returns, last_q), - jax.tree_map(lambda x: x[:-1], (reward, q_vals, done)), + jax.tree.map(lambda x: x[:-1], (reward, q_vals, done)), reverse=True, ) targets = jnp.concatenate((targets, lambda_returns[np.newaxis])) @@ -383,7 +383,7 @@ def preprocess_transition(x, rng): return x rng, _rng = jax.random.split(rng) - minibatches = jax.tree_util.tree_map( + minibatches = jax.tree.map( lambda x: preprocess_transition(x, _rng), transitions, ) # num_minibatches, num_agents, num_envs/num_minbatches ... @@ -410,7 +410,7 @@ def preprocess_transition(x, rng): "loss": loss.mean(), "qvals": qvals.mean(), } - metrics.update(jax.tree_map(lambda x: x.mean(), infos)) + metrics.update(jax.tree.map(lambda x: x.mean(), infos)) if config.get("TEST_DURING_TRAINING", True): rng, _rng = jax.random.split(rng) @@ -485,7 +485,7 @@ def _greedy_env_step(step_state, unused): step_state, (rewards, dones, infos) = jax.lax.scan( _greedy_env_step, step_state, None, config["TEST_NUM_STEPS"] ) - metrics = jax.tree_map( + metrics = jax.tree.map( lambda x: jnp.nanmean( jnp.where( infos["returned_episode"], @@ -588,7 +588,7 @@ def single_run(config): ) for i, rng in enumerate(rngs): - params = jax.tree_map(lambda x: x[i], model_state.params) + params = jax.tree.map(lambda x: x[i], model_state.params) save_path = os.path.join( save_dir, f'{alg_name}_{env_name}_seed{config["SEED"]}_vmap{i}.safetensors', diff --git a/baselines/QLearning/pqn_vdn_rnn.py b/baselines/QLearning/pqn_vdn_rnn.py index c85cd4e2..76d92e2d 100644 --- a/baselines/QLearning/pqn_vdn_rnn.py +++ b/baselines/QLearning/pqn_vdn_rnn.py @@ -312,7 +312,7 @@ def _step_env(carry, _): ) # update timesteps count # insert the transitions into the memory - memory_transitions = jax.tree_map( + memory_transitions = jax.tree.map( lambda x, y: jnp.concatenate([x[config["NUM_STEPS"] :], y], axis=0), memory_transitions, transitions, @@ -336,7 +336,7 @@ def _learn_phase(carry, minibatch): minibatch.last_done, ) # batchify the agent input: num_agents*batch_size - agent_in = jax.tree_util.tree_map( + agent_in = jax.tree.map( lambda x: x.reshape(x.shape[0], -1, *x.shape[3:]), agent_in ) # (num_steps, num_agents*batch_size, ...) @@ -363,7 +363,7 @@ def _get_target(lambda_returns_and_next_q, rew_q_done): _, targets = jax.lax.scan( _get_target, (lambda_returns, last_q), - jax.tree_map(lambda x: x[:-1], (reward, q_vals, done)), + jax.tree.map(lambda x: x[:-1], (reward, q_vals, done)), reverse=True, ) targets = jnp.concatenate([targets, lambda_returns[np.newaxis]]) @@ -442,7 +442,7 @@ def preprocess_transition(x, rng): return x rng, _rng = jax.random.split(rng) - minibatches = jax.tree_util.tree_map( + minibatches = jax.tree.map( lambda x: preprocess_transition(x, _rng), memory_transitions, ) # num_minibatches, num_steps+memory_window, num_agents, batch_size/num_minbatches, num_agents, ... @@ -467,7 +467,7 @@ def preprocess_transition(x, rng): "loss": loss.mean(), "qvals": qvals.mean(), } - metrics.update(jax.tree_map(lambda x: x.mean(), infos)) + metrics.update(jax.tree.map(lambda x: x.mean(), infos)) if config.get("TEST_DURING_TRAINING", True): rng, _rng = jax.random.split(rng) @@ -558,7 +558,7 @@ def _greedy_env_step(step_state, unused): step_state, (rewards, dones, infos) = jax.lax.scan( _greedy_env_step, step_state, None, config["TEST_NUM_STEPS"] ) - metrics = jax.tree_map( + metrics = jax.tree.map( lambda x: jnp.nanmean( jnp.where( infos["returned_episode"], @@ -716,7 +716,7 @@ def single_run(config): ) for i, rng in enumerate(rngs): - params = jax.tree_map(lambda x: x[i], model_state.params) + params = jax.tree.map(lambda x: x[i], model_state.params) save_path = os.path.join( save_dir, f'{alg_name}_{env_name}_seed{config["SEED"]}_vmap{i}.safetensors', diff --git a/baselines/QLearning/qmix_rnn.py b/baselines/QLearning/qmix_rnn.py index a9ebef54..7cd37cfc 100644 --- a/baselines/QLearning/qmix_rnn.py +++ b/baselines/QLearning/qmix_rnn.py @@ -260,7 +260,7 @@ def _env_sample_step(env_state, unused): _, sample_traj = jax.lax.scan( _env_sample_step, _env_state, None, config["NUM_STEPS"] ) - sample_traj_unbatched = jax.tree_map( + sample_traj_unbatched = jax.tree.map( lambda x: x[:, 0], sample_traj ) # remove the NUM_ENV dim @@ -376,7 +376,7 @@ def _step_env(carry, _): timestep = Timestep( obs=last_obs, actions=actions, - rewards=jax.tree_map(lambda x:config.get("REW_SCALE", 1)*x, rewards), + rewards=jax.tree.map(lambda x:config.get("REW_SCALE", 1)*x, rewards), dones=dones, avail_actions=avail_actions, ) @@ -407,7 +407,7 @@ def _step_env(carry, _): ) # update timesteps count # BUFFER UPDATE - buffer_traj_batch = jax.tree_util.tree_map( + buffer_traj_batch = jax.tree.map( lambda x: jnp.swapaxes(x, 0, 1)[ :, np.newaxis ], # put the batch dim first and add a dummy sequence dim @@ -421,7 +421,7 @@ def _learn_phase(carry, _): train_state, rng = carry rng, _rng = jax.random.split(rng) minibatch = buffer.sample(buffer_state, _rng).experience - minibatch = jax.tree_map( + minibatch = jax.tree.map( lambda x: jnp.swapaxes( x[:, 0], 0, 1 ), # remove the dummy sequence dim (1) and swap batch and temporal dims @@ -544,7 +544,7 @@ def _loss_fn(params): "loss": loss.mean(), "qvals": qvals.mean(), } - metrics.update(jax.tree_map(lambda x: x.mean(), infos)) + metrics.update(jax.tree.map(lambda x: x.mean(), infos)) # update the test metrics if config.get("TEST_DURING_TRAINING", True): @@ -623,7 +623,7 @@ def _greedy_env_step(step_state, unused): step_state, (rewards, dones, infos) = jax.lax.scan( _greedy_env_step, step_state, None, config["TEST_NUM_STEPS"] ) - metrics = jax.tree_map( + metrics = jax.tree.map( lambda x: jnp.nanmean( jnp.where( infos["returned_episode"], @@ -718,7 +718,7 @@ def single_run(config): ) for i, rng in enumerate(rngs): - params = jax.tree_map(lambda x: x[i], model_state.params) + params = jax.tree.map(lambda x: x[i], model_state.params) save_path = os.path.join( save_dir, f'{alg_name}_{env_name}_seed{config["SEED"]}_vmap{i}.safetensors', diff --git a/baselines/QLearning/shaq.py b/baselines/QLearning/shaq.py index 0c260fa7..2cbb2f95 100644 --- a/baselines/QLearning/shaq.py +++ b/baselines/QLearning/shaq.py @@ -291,7 +291,7 @@ def explore(q, eps, key): eps = self.get_epsilon(t) keys = dict(zip(q_vals.keys(), jax.random.split(rng, len(q_vals)))) # get a key for each agent - chosen_actions = jax.tree_map(lambda q, k: explore(q, eps, k), q_vals, keys) + chosen_actions = jax.tree.map(lambda q, k: explore(q, eps, k), q_vals, keys) return chosen_actions class Transition(NamedTuple): @@ -329,7 +329,7 @@ def _env_sample_step(env_state, unused): _, sample_traj = jax.lax.scan( _env_sample_step, env_state, None, config["NUM_STEPS"] ) - sample_traj_unbatched = jax.tree_map(lambda x: x[:, 0], sample_traj) # remove the NUM_ENV dim + sample_traj_unbatched = jax.tree.map(lambda x: x[:, 0], sample_traj) # remove the NUM_ENV dim buffer = fbx.make_trajectory_buffer( max_length_time_axis=config['BUFFER_SIZE']//config['NUM_ENVS'], min_length_time_axis=config['BUFFER_BATCH_SIZE'], @@ -395,8 +395,8 @@ def linear_schedule(count): ) # target network params - target_network_params_agent = jax.tree_map(lambda x: jnp.copy(x), train_state_agent.params) - target_network_params_mixer = jax.tree_map(lambda x: jnp.copy(x), train_state_mixer.params) + target_network_params_agent = jax.tree.map(lambda x: jnp.copy(x), train_state_agent.params) + target_network_params_mixer = jax.tree.map(lambda x: jnp.copy(x), train_state_mixer.params) # INIT EXPLORATION STRATEGY explorer = EpsilonGreedy( @@ -449,12 +449,12 @@ def _env_step(step_state, unused): # SELECT ACTION # add a dummy time_step dimension to the agent input obs_ = {a:last_obs[a] for a in env.agents} # ensure to not pass the global state (obs["__all__"]) to the network - obs_ = jax.tree_map(lambda x: x[np.newaxis, :], obs_) - dones_ = jax.tree_map(lambda x: x[np.newaxis, :], last_dones) + obs_ = jax.tree.map(lambda x: x[np.newaxis, :], obs_) + dones_ = jax.tree.map(lambda x: x[np.newaxis, :], last_dones) # get the q_values from the agent network hstate, q_vals = homogeneous_pass(params, hstate, obs_, dones_) # remove the dummy time_step dimension and index qs by the valid actions of each agent - valid_q_vals = jax.tree_util.tree_map(lambda q, valid_idx: q.squeeze(0)[..., valid_idx], q_vals, wrapped_env.valid_actions) + valid_q_vals = jax.tree.map(lambda q, valid_idx: q.squeeze(0)[..., valid_idx], q_vals, wrapped_env.valid_actions) # explore with epsilon greedy_exploration actions = explorer.choose_actions(valid_q_vals, t, key_a) @@ -488,7 +488,7 @@ def _env_step(step_state, unused): ) # BUFFER UPDATE: save the collected trajectory in the buffer - buffer_traj_batch = jax.tree_util.tree_map( + buffer_traj_batch = jax.tree.map( lambda x:jnp.swapaxes(x, 0, 1)[:, np.newaxis], # put the batch dim first and add a dummy sequence dim traj_batch ) # (num_envs, 1, time_steps, ...) @@ -512,22 +512,22 @@ def _loss_fn(params_agent, params_mixer, target_network_params_agent, target_net _, target_q_vals = homogeneous_pass(target_network_params_agent, init_hstate, obs_, learn_traj.dones) # get the q_vals of the taken actions (with exploration) for each agent - chosen_action_qvals = jax.tree_map( + chosen_action_qvals = jax.tree.map( lambda q, u: q_of_action(q, u)[:-1], # avoid last timestep q_vals, learn_traj.actions ) # get the target q value of the greedy actions for each agent - valid_q_vals = jax.tree_util.tree_map(lambda q, valid_idx: q[..., valid_idx], q_vals, wrapped_env.valid_actions) - target_max_qvals = jax.tree_map( + valid_q_vals = jax.tree.map(lambda q, valid_idx: q[..., valid_idx], q_vals, wrapped_env.valid_actions) + target_max_qvals = jax.tree.map( lambda t_q, q: q_of_action(t_q, jnp.argmax(q, axis=-1))[1:], # avoid first timestep target_q_vals, jax.lax.stop_gradient(valid_q_vals) ) # get the max_filters - max_filters = jax.tree_map( + max_filters = jax.tree.map( lambda q, u: get_max_filter(q, u)[:-1], q_vals, learn_traj.actions @@ -588,7 +588,7 @@ def _td_lambda_target(ret, values): # sample a batched trajectory from the buffer and set the time step dim in first axis rng, _rng = jax.random.split(rng) learn_traj = buffer.sample(buffer_state, _rng).experience # (batch_size, 1, max_time_steps, ...) - learn_traj = jax.tree_map( + learn_traj = jax.tree.map( lambda x: jnp.swapaxes(x[:, 0], 0, 1), # remove the dummy sequence dim (1) and swap batch and temporal dims learn_traj ) # (max_time_steps, batch_size, ...) @@ -617,13 +617,13 @@ def _td_lambda_target(ret, values): # update the target network if necessary target_network_params_agent = jax.lax.cond( time_state['updates'] % config['TARGET_UPDATE_INTERVAL'] == 0, - lambda _: jax.tree_map(lambda x: jnp.copy(x), train_state_agent.params), + lambda _: jax.tree.map(lambda x: jnp.copy(x), train_state_agent.params), lambda _: target_network_params_agent, operand=None ) target_network_params_mixer = jax.lax.cond( time_state['updates'] % config['TARGET_UPDATE_INTERVAL'] == 0, - lambda _: jax.tree_map(lambda x: jnp.copy(x), train_state_mixer.params), + lambda _: jax.tree.map(lambda x: jnp.copy(x), train_state_mixer.params), lambda _: target_network_params_mixer, operand=None ) @@ -642,7 +642,7 @@ def _td_lambda_target(ret, values): 'timesteps': time_state['timesteps']*config['NUM_ENVS'], 'updates' : time_state['updates'], 'loss': loss, - 'rewards': jax.tree_util.tree_map(lambda x: jnp.sum(x, axis=0).mean(), traj_batch.rewards), + 'rewards': jax.tree.map(lambda x: jnp.sum(x, axis=0).mean(), traj_batch.rewards), 'eps': explorer.get_epsilon(time_state['timesteps']) } metrics['test_metrics'] = test_metrics # add the test metrics dictionary @@ -688,10 +688,10 @@ def _greedy_env_step(step_state, unused): params, env_state, last_obs, last_dones, hstate, rng = step_state rng, key_s = jax.random.split(rng) obs_ = {a:last_obs[a] for a in env.agents} - obs_ = jax.tree_map(lambda x: x[np.newaxis, :], obs_) - dones_ = jax.tree_map(lambda x: x[np.newaxis, :], last_dones) + obs_ = jax.tree.map(lambda x: x[np.newaxis, :], obs_) + dones_ = jax.tree.map(lambda x: x[np.newaxis, :], last_dones) hstate, q_vals = homogeneous_pass(params, hstate, obs_, dones_) - actions = jax.tree_util.tree_map(lambda q, valid_idx: jnp.argmax(q.squeeze(0)[..., valid_idx], axis=-1), q_vals, test_env.valid_actions) + actions = jax.tree.map(lambda q, valid_idx: jnp.argmax(q.squeeze(0)[..., valid_idx], axis=-1), q_vals, test_env.valid_actions) obs, env_state, rewards, dones, infos = test_env.batch_step(key_s, env_state, actions) step_state = (params, env_state, obs, dones, hstate, rng) return step_state, (rewards, dones, infos) @@ -720,8 +720,8 @@ def first_episode_returns(rewards, dones): first_episode_mask = jnp.where(jnp.arange(dones.size) <= first_done, True, False) return jnp.where(first_episode_mask, rewards, 0.).sum() all_dones = dones['__all__'] - first_returns = jax.tree_map(lambda r: jax.vmap(first_episode_returns, in_axes=1)(r, all_dones), rewards) - first_infos = jax.tree_map(lambda i: jax.vmap(first_episode_returns, in_axes=1)(i[..., 0], all_dones), infos) + first_returns = jax.tree.map(lambda r: jax.vmap(first_episode_returns, in_axes=1)(r, all_dones), rewards) + first_infos = jax.tree.map(lambda i: jax.vmap(first_episode_returns, in_axes=1)(i[..., 0], all_dones), infos) metrics = { 'test_returns': first_returns['__all__'],# episode returns **{'test_'+k:v for k,v in first_infos.items()} @@ -811,7 +811,7 @@ def save_params(params: Dict, filename: Union[str, os.PathLike]) -> None: save_file(flattened_dict, filename) model_state = outs['runner_state'][0] - params = jax.tree_map(lambda x: x[0], model_state.params) # save only params of the firt run + params = jax.tree.map(lambda x: x[0], model_state.params) # save only params of the firt run save_dir = os.path.join(config['SAVE_PATH'], env_name) os.makedirs(save_dir, exist_ok=True) save_params(params, f'{save_dir}/{alg_name}.safetensors') diff --git a/baselines/QLearning/transf_qmix.py b/baselines/QLearning/transf_qmix.py index a34dda33..97bc9710 100644 --- a/baselines/QLearning/transf_qmix.py +++ b/baselines/QLearning/transf_qmix.py @@ -370,7 +370,7 @@ def explore(q, eps, key): eps = self.get_epsilon(t) keys = dict(zip(q_vals.keys(), jax.random.split(rng, len(q_vals)))) # get a key for each agent - chosen_actions = jax.tree_map(lambda q, k: explore(q, eps, k), q_vals, keys) + chosen_actions = jax.tree.map(lambda q, k: explore(q, eps, k), q_vals, keys) return chosen_actions @@ -384,7 +384,7 @@ class Transition(NamedTuple): def tree_mean(tree): return jnp.array( - jax.tree_util.tree_leaves(jax.tree_map(lambda x: x.mean(), tree)) + jax.tree_leaves(jax.tree.map(lambda x: x.mean(), tree)) ).mean() @@ -415,7 +415,7 @@ def _env_sample_step(env_state, unused): _, sample_traj = jax.lax.scan( _env_sample_step, env_state, None, config["NUM_STEPS"] ) - sample_traj_unbatched = jax.tree_map(lambda x: x[:, 0], sample_traj) # remove the NUM_ENV dim + sample_traj_unbatched = jax.tree.map(lambda x: x[:, 0], sample_traj) # remove the NUM_ENV dim buffer = fbx.make_trajectory_buffer( max_length_time_axis=config['BUFFER_SIZE']//config['NUM_ENVS'], min_length_time_axis=config['BUFFER_BATCH_SIZE'], @@ -487,8 +487,8 @@ def _env_sample_step(env_state, unused): network_stats = {'agent':agent_params['batch_stats'],'mixer':mixer_params['batch_stats']} # print number of params - agent_params = sum(x.size for x in jax.tree_util.tree_leaves(network_params['agent'])) - mixer_params = sum(x.size for x in jax.tree_util.tree_leaves(network_params['mixer'])) + agent_params = sum(x.size for x in jax.tree_leaves(network_params['agent'])) + mixer_params = sum(x.size for x in jax.tree_leaves(network_params['mixer'])) jax.debug.print("Number of agent params: {x}", x=agent_params) jax.debug.print("Number of mixer params: {x}", x=mixer_params) @@ -532,7 +532,7 @@ class TrainState_(TrainState): tx=tx, ) # target network params - copy_tree = lambda tree: jax.tree_map(lambda x: jnp.copy(x), tree) + copy_tree = lambda tree: jax.tree.map(lambda x: jnp.copy(x), tree) target_network_state = {'params':copy_tree(train_state.params), 'batch_stats':copy_tree(train_state.batch_stats)} # INIT EXPLORATION STRATEGY @@ -604,19 +604,19 @@ def _env_step(step_state, unused): # SELECT ACTION # add a dummy time_step dimension to the agent input obs_ = {a:last_obs[a] for a in env.agents} # ensure to not pass the global state (obs["__all__"]) to the network - obs_ = jax.tree_map(lambda x: x[np.newaxis, :], obs_) - dones_ = jax.tree_map(lambda x: x[np.newaxis, :], last_dones) + obs_ = jax.tree.map(lambda x: x[np.newaxis, :], obs_) + dones_ = jax.tree.map(lambda x: x[np.newaxis, :], last_dones) # get the q_values from the agent netwoek _, hstate, q_vals = homogeneous_pass(env_params, env_batch_norm, hstate, obs_, dones_, train=False) # remove the dummy time_step dimension and index qs by the valid actions of each agent - valid_q_vals = jax.tree_util.tree_map(lambda q, valid_idx: q.squeeze(0)[..., valid_idx], q_vals, wrapped_env.valid_actions) + valid_q_vals = jax.tree.map(lambda q, valid_idx: q.squeeze(0)[..., valid_idx], q_vals, wrapped_env.valid_actions) # explore with epsilon greedy_exploration actions = explorer.choose_actions(valid_q_vals, t, key_a) # STEP ENV obs, env_state, rewards, dones, infos = wrapped_env.batch_step(key_s, env_state, actions) # reward scaling - rewards = jax.tree_map(lambda x:config.get("REW_SCALE", 1)*x, rewards) + rewards = jax.tree.map(lambda x:config.get("REW_SCALE", 1)*x, rewards) transition = Transition(last_obs, actions, rewards, dones, infos) step_state = (env_state, obs, dones, hstate, rng, t+1) @@ -641,7 +641,7 @@ def _env_step(step_state, unused): ) # BUFFER UPDATE: save the collected trajectory in the buffer - buffer_traj_batch = jax.tree_util.tree_map( + buffer_traj_batch = jax.tree.map( lambda x:jnp.swapaxes(x, 0, 1)[:, np.newaxis], # put the batch dim first and add a dummy sequence dim traj_batch ) # (num_envs, 1, time_steps, ...) @@ -661,7 +661,7 @@ def _network_update(carry, unused): # sample a batched trajectory from the buffer and set the time step dim in first axis rng, _rng = jax.random.split(rng) learn_traj = buffer.sample(buffer_state, _rng).experience # (batch_size, 1, max_time_steps, ...) - learn_traj = jax.tree_map( + learn_traj = jax.tree.map( lambda x: jnp.swapaxes(x[:, 0], 0, 1), # remove the dummy sequence dim (1) and swap batch and temporal dims learn_traj ) # (max_time_steps, batch_size, ...) @@ -695,15 +695,15 @@ def _loss_fn(params, init_hs, learn_traj): hs_target_agents = jax.lax.stop_gradient(hs_target_agents) # get the q_vals of the taken actions (with exploration) for each agent - chosen_action_qvals = jax.tree_map( + chosen_action_qvals = jax.tree.map( lambda q, u: q_of_action(q, u)[:-1], # avoid last timestep q_vals, learn_traj.actions ) # get the target q value of the greedy actions for each agent - valid_q_vals = jax.tree_util.tree_map(lambda q, valid_idx: q[..., valid_idx], q_vals, wrapped_env.valid_actions) - target_max_qvals = jax.tree_map( + valid_q_vals = jax.tree.map(lambda q, valid_idx: q[..., valid_idx], q_vals, wrapped_env.valid_actions) + target_max_qvals = jax.tree.map( lambda t_q, q: q_of_action(t_q, jnp.argmax(q, axis=-1))[1:], # avoid first timestep target_q_vals, jax.lax.stop_gradient(valid_q_vals) @@ -864,10 +864,10 @@ def _greedy_env_step(step_state, unused): env_state, last_obs, last_dones, hstate, rng = step_state rng, key_s = jax.random.split(rng) obs_ = {a:last_obs[a] for a in env.agents} - obs_ = jax.tree_map(lambda x: x[np.newaxis, :], obs_) - dones_ = jax.tree_map(lambda x: x[np.newaxis, :], last_dones) + obs_ = jax.tree.map(lambda x: x[np.newaxis, :], obs_) + dones_ = jax.tree.map(lambda x: x[np.newaxis, :], last_dones) _, hstate, q_vals = homogeneous_pass(env_params, env_batch_norm, hstate, obs_, dones_, train=False) - actions = jax.tree_util.tree_map(lambda q, valid_idx: jnp.argmax(q.squeeze(0)[..., valid_idx], axis=-1), q_vals, test_env.valid_actions) + actions = jax.tree.map(lambda q, valid_idx: jnp.argmax(q.squeeze(0)[..., valid_idx], axis=-1), q_vals, test_env.valid_actions) obs, env_state, rewards, dones, infos = test_env.batch_step(key_s, env_state, actions) step_state = (env_state, obs, dones, hstate, rng) return step_state, (rewards, dones, infos) @@ -892,8 +892,8 @@ def first_episode_returns(rewards, dones): first_episode_mask = jnp.where(jnp.arange(dones.size) <= first_done, True, False) return jnp.where(first_episode_mask, rewards, 0.).sum() all_dones = dones['__all__'] - first_returns = jax.tree_map(lambda r: jax.vmap(first_episode_returns, in_axes=1)(r, all_dones), rewards) - first_infos = jax.tree_map(lambda i: jax.vmap(first_episode_returns, in_axes=1)(i[..., 0], all_dones), infos) + first_returns = jax.tree.map(lambda r: jax.vmap(first_episode_returns, in_axes=1)(r, all_dones), rewards) + first_infos = jax.tree.map(lambda i: jax.vmap(first_episode_returns, in_axes=1)(i[..., 0], all_dones), infos) metrics = { 'test_returns': first_returns['__all__'],# episode returns **{'test_'+k:v for k,v in first_infos.items()} @@ -989,7 +989,7 @@ def single_run(config): ) for i, rng in enumerate(rngs): - params = jax.tree_map(lambda x: x[i], model_state.params) + params = jax.tree.map(lambda x: x[i], model_state.params) save_path = os.path.join( save_dir, f'{alg_name}_{env_name}_seed{config["SEED"]}_vmap{i}.safetensors', diff --git a/baselines/QLearning/vdn_cnn_overcooked.py b/baselines/QLearning/vdn_cnn_overcooked.py index 96e09511..7efe3832 100644 --- a/baselines/QLearning/vdn_cnn_overcooked.py +++ b/baselines/QLearning/vdn_cnn_overcooked.py @@ -224,7 +224,7 @@ def create_agent(rng): rewards=_rewards, dones=_dones, ) - _tiemstep_unbatched = jax.tree_map( + _tiemstep_unbatched = jax.tree.map( lambda x: x[0], _timestep ) # remove the NUM_ENV dim buffer_state = buffer.init(_tiemstep_unbatched) @@ -260,7 +260,7 @@ def _step_env(carry, _): # add shaped reward shaped_reward = infos.pop("shaped_reward") shaped_reward["__all__"] = batchify(shaped_reward).sum(axis=0) - rewards = jax.tree_map( + rewards = jax.tree.map( lambda x, y: x + y * rew_shaping_anneal(train_state.timesteps), rewards, shaped_reward, @@ -291,7 +291,7 @@ def _step_env(carry, _): ) # update timesteps count # BUFFER UPDATE - timesteps = jax.tree_util.tree_map( + timesteps = jax.tree.map( lambda x: x.reshape(-1, *x.shape[2:]), timesteps ) # (num_envs*num_steps, ...) buffer_state = buffer.add(buffer_state, timesteps) @@ -385,7 +385,7 @@ def _loss_fn(params): "loss": loss.mean(), "qvals": qvals.mean(), } - metrics.update(jax.tree_map(lambda x: x.mean(), infos)) + metrics.update(jax.tree.map(lambda x: x.mean(), infos)) if config.get("TEST_DURING_TRAINING", True): rng, _rng = jax.random.split(rng) @@ -545,7 +545,7 @@ def single_run(config): ) for i, rng in enumerate(rngs): - params = jax.tree_map(lambda x: x[i], model_state.params) + params = jax.tree.map(lambda x: x[i], model_state.params) save_path = os.path.join( save_dir, f'{alg_name}_{env_name}_seed{config["SEED"]}_vmap{i}.safetensors', diff --git a/baselines/QLearning/vdn_ff.py b/baselines/QLearning/vdn_ff.py index a789cbf3..2494c8b7 100644 --- a/baselines/QLearning/vdn_ff.py +++ b/baselines/QLearning/vdn_ff.py @@ -182,7 +182,7 @@ def create_agent(rng): rewards=_rewards, dones=_dones, ) - _tiemstep_unbatched = jax.tree_map( + _tiemstep_unbatched = jax.tree.map( lambda x: x[0], _timestep ) # remove the NUM_ENV dim buffer_state = buffer.init(_tiemstep_unbatched) @@ -240,7 +240,7 @@ def _step_env(carry, _): ) # update timesteps count # BUFFER UPDATE - timesteps = jax.tree_util.tree_map( + timesteps = jax.tree.map( lambda x: x.reshape(-1, *x.shape[2:]), timesteps ) # (num_envs*num_steps, ...) buffer_state = buffer.add(buffer_state, timesteps) @@ -336,7 +336,7 @@ def _loss_fn(params): "loss": loss.mean(), "qvals": qvals.mean(), } - metrics.update(jax.tree_map(lambda x: x.mean(), infos)) + metrics.update(jax.tree.map(lambda x: x.mean(), infos)) if config.get("TEST_DURING_TRAINING", True): rng, _rng = jax.random.split(rng) @@ -403,7 +403,7 @@ def _greedy_env_step(step_state, unused): step_state, (rewards, dones, infos) = jax.lax.scan( _greedy_env_step, step_state, None, config["TEST_NUM_STEPS"] ) - metrics = jax.tree_map( + metrics = jax.tree.map( lambda x: jnp.nanmean( jnp.where( infos["returned_episode"], @@ -498,7 +498,7 @@ def single_run(config): ) for i, rng in enumerate(rngs): - params = jax.tree_map(lambda x: x[i], model_state.params) + params = jax.tree.map(lambda x: x[i], model_state.params) save_path = os.path.join( save_dir, f'{alg_name}_{env_name}_seed{config["SEED"]}_vmap{i}.safetensors', diff --git a/baselines/QLearning/vdn_rnn.py b/baselines/QLearning/vdn_rnn.py index 6eba8381..59a0a3ef 100644 --- a/baselines/QLearning/vdn_rnn.py +++ b/baselines/QLearning/vdn_rnn.py @@ -233,7 +233,7 @@ def _env_sample_step(env_state, unused): _, sample_traj = jax.lax.scan( _env_sample_step, _env_state, None, config["NUM_STEPS"] ) - sample_traj_unbatched = jax.tree_map( + sample_traj_unbatched = jax.tree.map( lambda x: x[:, 0], sample_traj ) # remove the NUM_ENV dim buffer = fbx.make_trajectory_buffer( @@ -288,7 +288,7 @@ def _step_env(carry, _): timestep = Timestep( obs=last_obs, actions=actions, - rewards=jax.tree_map(lambda x:config.get("REW_SCALE", 1)*x, rewards), + rewards=jax.tree.map(lambda x:config.get("REW_SCALE", 1)*x, rewards), dones=dones, avail_actions=avail_actions, ) @@ -319,7 +319,7 @@ def _step_env(carry, _): ) # update timesteps count # BUFFER UPDATE - buffer_traj_batch = jax.tree_util.tree_map( + buffer_traj_batch = jax.tree.map( lambda x: jnp.swapaxes(x, 0, 1)[ :, np.newaxis ], # put the batch dim first and add a dummy sequence dim @@ -333,7 +333,7 @@ def _learn_phase(carry, _): train_state, rng = carry rng, _rng = jax.random.split(rng) minibatch = buffer.sample(buffer_state, _rng).experience - minibatch = jax.tree_map( + minibatch = jax.tree.map( lambda x: jnp.swapaxes( x[:, 0], 0, 1 ), # remove the dummy sequence dim (1) and swap batch and temporal dims @@ -456,7 +456,7 @@ def _loss_fn(params): "loss": loss.mean(), "qvals": qvals.mean(), } - metrics.update(jax.tree_map(lambda x: x.mean(), infos)) + metrics.update(jax.tree.map(lambda x: x.mean(), infos)) if config.get("TEST_DURING_TRAINING", True): rng, _rng = jax.random.split(rng) @@ -534,7 +534,7 @@ def _greedy_env_step(step_state, unused): step_state, (rewards, dones, infos) = jax.lax.scan( _greedy_env_step, step_state, None, config["TEST_NUM_STEPS"] ) - metrics = jax.tree_map( + metrics = jax.tree.map( lambda x: jnp.nanmean( jnp.where( infos["returned_episode"], @@ -629,7 +629,7 @@ def single_run(config): ) for i, rng in enumerate(rngs): - params = jax.tree_map(lambda x: x[i], model_state.params) + params = jax.tree.map(lambda x: x[i], model_state.params) save_path = os.path.join( save_dir, f'{alg_name}_{env_name}_seed{config["SEED"]}_vmap{i}.safetensors', diff --git a/jaxmarl/__init__.py b/jaxmarl/__init__.py index 962e720b..e29b9baf 100644 --- a/jaxmarl/__init__.py +++ b/jaxmarl/__init__.py @@ -1,4 +1,4 @@ from .registration import make, registered_envs __all__ = ["make", "registered_envs"] -__version__ = "0.0.6" +__version__ = "0.0.7" diff --git a/jaxmarl/environments/hanabi/hanabi.py b/jaxmarl/environments/hanabi/hanabi.py index bc47d88c..d2a4d68d 100644 --- a/jaxmarl/environments/hanabi/hanabi.py +++ b/jaxmarl/environments/hanabi/hanabi.py @@ -532,7 +532,7 @@ def _binarize_discard_pile(self, discard_pile: chex.Array): """Binarize the discard pile to reduce dimensionality.""" def binarize_ranks(n_ranks): - tree = jax.tree_util.tree_map( + tree = jax.tree.map( lambda n_rank_present, max_ranks: jnp.where( jnp.arange(max_ranks) >= n_rank_present, jnp.zeros(max_ranks), diff --git a/jaxmarl/environments/hanabi/pretrained/obl_r2d2_agent.py b/jaxmarl/environments/hanabi/pretrained/obl_r2d2_agent.py index 278e5eee..1e108f4e 100644 --- a/jaxmarl/environments/hanabi/pretrained/obl_r2d2_agent.py +++ b/jaxmarl/environments/hanabi/pretrained/obl_r2d2_agent.py @@ -26,7 +26,7 @@ def __call__(self, carry, inputs): new_cs = [] for l in range(self.num_layers): new_carry, y = nn.LSTMCell(self.features, name=f"l{l}")( - jax.tree_map(lambda x: x[l], carry), inputs + jax.tree.map(lambda x: x[l], carry), inputs ) new_cs.append(new_carry[0]) new_hs.append(new_carry[1]) diff --git a/jaxmarl/environments/jaxnav/jaxnav_env.py b/jaxmarl/environments/jaxnav/jaxnav_env.py index bbcb6232..a48bf947 100644 --- a/jaxmarl/environments/jaxnav/jaxnav_env.py +++ b/jaxmarl/environments/jaxnav/jaxnav_env.py @@ -494,7 +494,7 @@ def update_state(self, pos: chex.Array, theta: float, speed: chex.Array, action: theta = wrap(theta + w*self.dt) out = (pos, theta, jnp.array([v, w], dtype=jnp.float32)) - return jax.tree_map(lambda x, y: jax.lax.select(done, x, y), out_done, out) + return jax.tree.map(lambda x, y: jax.lax.select(done, x, y), out_done, out) @partial(jax.vmap, in_axes=(None, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0)) def compute_reward( @@ -598,10 +598,10 @@ def step_plr( key, state, actions ) obs_re, state_re = self.reset_to_level(level) # todo maybe should be set state depending on PLR code - state = jax.tree_map( + state = jax.tree.map( lambda x, y: jax.lax.select(state_st.ep_done, x, y), state_re, state_st ) - obs = jax.tree_map( + obs = jax.tree.map( lambda x, y: jax.lax.select(state_st.ep_done, x, y), obs_re, obs_st ) #obs = jax.lax.select(state_st.ep_done, obs_re, obs_st) diff --git a/jaxmarl/environments/jaxnav/jaxnav_ued_utils.py b/jaxmarl/environments/jaxnav/jaxnav_ued_utils.py index 0c4fb6c1..e5633532 100644 --- a/jaxmarl/environments/jaxnav/jaxnav_ued_utils.py +++ b/jaxmarl/environments/jaxnav/jaxnav_ued_utils.py @@ -40,11 +40,11 @@ def _apply(rng, state): is_flip_wall = jnp.equal(mutation, Mutations.FLIP_WALL.value) mutated_state = flip_wall(arng, map, state) - next_state = jax.tree_map(lambda x,y: jax.lax.select(is_flip_wall, x, y), mutated_state, state) + next_state = jax.tree.map(lambda x,y: jax.lax.select(is_flip_wall, x, y), mutated_state, state) is_move_goal = jnp.equal(mutation, Mutations.MOVE_GOAL.value) mutated_state = move_goal(brng, map, state) - next_state = jax.tree_map(lambda x,y: jax.lax.select(is_move_goal, x, y), mutated_state, next_state) + next_state = jax.tree.map(lambda x,y: jax.lax.select(is_move_goal, x, y), mutated_state, next_state) return next_state diff --git a/jaxmarl/environments/jaxnav/maps/grid_map.py b/jaxmarl/environments/jaxnav/maps/grid_map.py index 0e12fccb..9d3e35f1 100644 --- a/jaxmarl/environments/jaxnav/maps/grid_map.py +++ b/jaxmarl/environments/jaxnav/maps/grid_map.py @@ -981,7 +981,7 @@ def __init__(self, with open(filepath, "rb") as f: tc = pickle.load(f) print('tc c', tc) - test_cases = jax.tree_map(lambda x, y: jnp.concatenate((x, y), axis=0), test_cases, tc) + test_cases = jax.tree.map(lambda x, y: jnp.concatenate((x, y), axis=0), test_cases, tc) self.test_cases = test_cases self.num_test_cases = test_cases[0].shape[0] print('test cases', test_cases) @@ -993,7 +993,7 @@ def __init__(self, def sample_scenario(self, key): print('-- sampling scenarios -- ') idx = jax.random.randint(key, (1,), minval=0, maxval=self.num_test_cases)[0] - tc = jax.tree_map(lambda x: x[idx], self.test_cases) + tc = jax.tree.map(lambda x: x[idx], self.test_cases) print('tc ', tc) map_data = tc[0] print('map data', map_data.shape) diff --git a/jaxmarl/environments/multi_agent_env.py b/jaxmarl/environments/multi_agent_env.py index e7bc19be..75ca3c2f 100644 --- a/jaxmarl/environments/multi_agent_env.py +++ b/jaxmarl/environments/multi_agent_env.py @@ -59,10 +59,10 @@ def step( obs_re = self.get_obs(states_re) # Auto-reset environment based on termination - states = jax.tree_map( + states = jax.tree.map( lambda x, y: jax.lax.select(dones["__all__"], x, y), states_re, states_st ) - obs = jax.tree_map( + obs = jax.tree.map( lambda x, y: jax.lax.select(dones["__all__"], x, y), obs_re, obs_st ) return obs, states, rewards, dones, infos diff --git a/jaxmarl/environments/smax/heuristic_enemy_smax_env.py b/jaxmarl/environments/smax/heuristic_enemy_smax_env.py index c0342865..de4a54f3 100644 --- a/jaxmarl/environments/smax/heuristic_enemy_smax_env.py +++ b/jaxmarl/environments/smax/heuristic_enemy_smax_env.py @@ -186,7 +186,7 @@ def __init__(self, enemy_shoots=True, attack_mode="closest", **env_kwargs): ) def get_enemy_policy_initial_state(self, key): - return jax.tree_map( + return jax.tree.map( lambda *xs: jnp.stack(xs), *([get_heuristic_policy_initial_state()] * self.num_enemies), ) diff --git a/jaxmarl/environments/smax/speed.py b/jaxmarl/environments/smax/speed.py index c4dd2295..6d46fa34 100644 --- a/jaxmarl/environments/smax/speed.py +++ b/jaxmarl/environments/smax/speed.py @@ -142,7 +142,7 @@ def env_step(runner_state, unused): obsv, env_state, _, _, info = jax.vmap(env.step)( rng_step, env_state, env_act ) - info = jax.tree_map(lambda x: x.reshape((config["NUM_ACTORS"])), info) + info = jax.tree.map(lambda x: x.reshape((config["NUM_ACTORS"])), info) runner_state = (params, env_state, obsv, rng) return runner_state, None diff --git a/jaxmarl/environments/storm/storm.py b/jaxmarl/environments/storm/storm.py index ff676067..fdf36c97 100644 --- a/jaxmarl/environments/storm/storm.py +++ b/jaxmarl/environments/storm/storm.py @@ -1569,7 +1569,7 @@ def coin_matcher(p: jnp.ndarray) -> jnp.ndarray: # soft reset for anyone who has just been mobilised state_sft_re = _soft_reset_state(key, state, old_freeze) - state = jax.tree_map( + state = jax.tree.map( lambda x, y: jax.lax.select( jnp.any( jnp.logical_and( @@ -1604,7 +1604,7 @@ def coin_matcher(p: jnp.ndarray) -> jnp.ndarray: state_re = _reset_state(key) state_re = state_re.replace(outer_t=outer_t + 1) - state = jax.tree_map( + state = jax.tree.map( lambda x, y: jnp.where(reset_inner, x, y), state_re, state_nxt, diff --git a/jaxmarl/environments/storm/storm_2p.py b/jaxmarl/environments/storm/storm_2p.py index c2615156..e7ccd227 100644 --- a/jaxmarl/environments/storm/storm_2p.py +++ b/jaxmarl/environments/storm/storm_2p.py @@ -727,7 +727,7 @@ def _step( state.freeze > 0, state.freeze - 1, state.freeze )) state_sft_re = _soft_reset_state(key, state) - state = jax.tree_map( + state = jax.tree.map( lambda x, y: jnp.where(state.freeze == 0, x, y), state_sft_re, state, @@ -755,7 +755,7 @@ def _step( # if inner episode is done, return start state for next game state_re = _reset_state(key) state_re = state_re.replace(outer_t=outer_t + 1) - state = jax.tree_map( + state = jax.tree.map( lambda x, y: jax.lax.select(reset_inner, x, y), state_re, state_nxt, diff --git a/jaxmarl/environments/storm/storm_env.py b/jaxmarl/environments/storm/storm_env.py index 26a93325..79b04f31 100644 --- a/jaxmarl/environments/storm/storm_env.py +++ b/jaxmarl/environments/storm/storm_env.py @@ -891,7 +891,7 @@ def update_timers(coop_coin_timer, defect_coin_timer, new_coop_coin_timer, new_d # # # if inner episode is done, return start state for next game state_re = _reset_state(key) state_re = state_re.replace(outer_t=outer_t + 1) - state = jax.tree_map( + state = jax.tree.map( lambda x, y: jax.lax.select(reset_inner, x, y), state_re, state_nxt, diff --git a/jaxmarl/gridworld/env.py b/jaxmarl/gridworld/env.py index b4750730..7f260414 100644 --- a/jaxmarl/gridworld/env.py +++ b/jaxmarl/gridworld/env.py @@ -46,10 +46,10 @@ def step( obs_re, state_re = self.reset_env(key_reset) # Auto-reset environment based on termination - state = jax.tree_map( + state = jax.tree.map( lambda x, y: jax.lax.select(done, x, y), state_re, state_st ) - obs = jax.tree_map( + obs = jax.tree.map( lambda x, y: jax.lax.select(done, x, y), obs_re, obs_st ) diff --git a/jaxmarl/gridworld/tabular_q.py b/jaxmarl/gridworld/tabular_q.py index be689ee3..24cf725d 100644 --- a/jaxmarl/gridworld/tabular_q.py +++ b/jaxmarl/gridworld/tabular_q.py @@ -74,7 +74,7 @@ def flatten_obs(obs): return jnp.concatenate((img, jnp.expand_dims(obs['agent_dir'], 0)), axis=-1) def extract_obs(obs, idx): - return jax.tree_map(lambda x: x[idx], obs) + return jax.tree.map(lambda x: x[idx], obs) def sample_random_action(key): return jax.random.randint(key, shape=(), minval=0, maxval=n_actions) diff --git a/jaxmarl/wrappers/baselines.py b/jaxmarl/wrappers/baselines.py index c94f35bf..c2779aa3 100644 --- a/jaxmarl/wrappers/baselines.py +++ b/jaxmarl/wrappers/baselines.py @@ -115,7 +115,7 @@ def step( obs, env_state, reward, done, info = self._env.step( key, state.env_state, action ) - rewardlog = jax.tree_map(lambda x: x*self._env.num_agents, reward) # As per on-policy codebase + rewardlog = jax.tree.map(lambda x: x*self._env.num_agents, reward) # As per on-policy codebase ep_done = done["__all__"] new_episode_return = state.episode_returns + self._batchify_floats(rewardlog) new_episode_length = state.episode_lengths + 1 @@ -285,7 +285,7 @@ def batch_step(self, key, states, actions): def wrapped_reset(self, key): obs_, state = self._env.reset(key) if self.preprocess_obs: - obs = jax.tree_util.tree_map(self._preprocess_obs, {agent:obs_[agent] for agent in self.agents}, self.agents_one_hot) + obs = jax.tree.map(self._preprocess_obs, {agent:obs_[agent] for agent in self.agents}, self.agents_one_hot) else: obs = obs_ obs["__all__"] = self.global_state(obs_, state) @@ -295,8 +295,8 @@ def wrapped_reset(self, key): def wrapped_step(self, key, state, actions): obs_, state, reward, done, infos = self._env.step(key, state, actions) if self.preprocess_obs: - obs = jax.tree_util.tree_map(self._preprocess_obs, {agent:obs_[agent] for agent in self.agents}, self.agents_one_hot) - obs = jax.tree_util.tree_map(lambda d, o: jnp.where(d, 0., o), {agent:done[agent] for agent in self.agents}, obs) # ensure that the obs are 0s for done agents + obs = jax.tree.map(self._preprocess_obs, {agent:obs_[agent] for agent in self.agents}, self.agents_one_hot) + obs = jax.tree.map(lambda d, o: jnp.where(d, 0., o), {agent:done[agent] for agent in self.agents}, obs) # ensure that the obs are 0s for done agents else: obs = obs_ obs["__all__"] = self.global_state(obs_, state) diff --git a/pyproject.toml b/pyproject.toml index a7c5a865..fe02f950 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,8 +31,8 @@ classifiers = [ "License :: OSI Approved :: Apache Software License", ] dependencies = [ - "jax>=0.4.16.0,<=0.4.25", - "jaxlib>=0.4.16.0,<=0.4.25", + "jax", + "jaxlib", "flax", "safetensors", "chex", diff --git a/tests/brax/test_brax_rand_acts.py b/tests/brax/test_brax_rand_acts.py index 43671532..71059dda 100644 --- a/tests/brax/test_brax_rand_acts.py +++ b/tests/brax/test_brax_rand_acts.py @@ -23,3 +23,7 @@ def test_random_rollout(): rng_act = jax.random.split(rng_act, env.num_agents) actions = {a: env.action_space(a).sample(rng_act[i]) for i, a in enumerate(env.agents)} _, state, _, _, _ = env.step(rng, state, actions) + + +if __name__ == "__main__": + test_random_rollout() \ No newline at end of file diff --git a/tests/mpe/mpe_policy_transfer.py b/tests/mpe/mpe_policy_transfer.py index d7dda941..56c1ef8b 100644 --- a/tests/mpe/mpe_policy_transfer.py +++ b/tests/mpe/mpe_policy_transfer.py @@ -118,19 +118,19 @@ def _preprocess_obs(arr, extra_features): def obs_to_act(obs, dones, params=params): - obs = jax.tree_util.tree_map(_preprocess_obs, obs, agents_one_hot) + obs = jax.tree.map(_preprocess_obs, obs, agents_one_hot) # add a dummy temporal dimension - obs_ = jax.tree_map(lambda x: x[np.newaxis, np.newaxis, :], obs) # add also a dummy batch dim to obs - dones_ = jax.tree_map(lambda x: x[np.newaxis, :], dones) + obs_ = jax.tree.map(lambda x: x[np.newaxis, np.newaxis, :], obs) # add also a dummy batch dim to obs + dones_ = jax.tree.map(lambda x: x[np.newaxis, :], dones) # pass in one with homogeneous pass hstate = ScannedRNN.initialize_carry(agent_hidden_dim, len(env_jax.agents)) hstate, q_vals = agent.homogeneous_pass(params, hstate, obs_, dones_) # get actions from q vals - valid_q_vals = jax.tree_util.tree_map(lambda q, valid_idx: q.squeeze(0)[..., valid_idx], q_vals, valid_actions) - actions = jax.tree_util.tree_map(lambda q: jnp.argmax(q, axis=-1).squeeze(0), valid_q_vals) + valid_q_vals = jax.tree.map(lambda q, valid_idx: q.squeeze(0)[..., valid_idx], q_vals, valid_actions) + actions = jax.tree.map(lambda q: jnp.argmax(q, axis=-1).squeeze(0), valid_q_vals) return actions @@ -184,7 +184,7 @@ def obs_to_act(obs, dones, params=params): acts = obs_to_act(obs_jax, done_jax) #print('acts', acts) obs_jax, state, rew_jax, done_jax, _ = env_jax.step(key_s, state, acts) - done_jax = jax.tree_map(lambda x: x[None], done_jax) + done_jax = jax.tree.map(lambda x: x[None], done_jax) rew_batch = np.array([rew_jax[a] for a in env_jax.agents]) rew_tallys_jax[j] = rew_batch diff --git a/tests/test_jaxmarl_api.py b/tests/test_jaxmarl_api.py index c19d8105..22c6b577 100644 --- a/tests/test_jaxmarl_api.py +++ b/tests/test_jaxmarl_api.py @@ -35,4 +35,4 @@ def _test_leaf(x, y, outcome=True): actions = {a: env.action_space(a).sample(rng_act[i]) for i, a in enumerate(env.agents)} _, next_state, _, dones, _ = env.step(rng, state1, actions, reset_state=state2) assert dones["__all__"] - jax.tree_map(_test_leaf, state2, next_state) + jax.tree.map(_test_leaf, state2, next_state)