diff --git a/baselines/IPPO/ippo_cnn_overcooked.py b/baselines/IPPO/ippo_cnn_overcooked.py index 60cbb76a..6bde1f49 100644 --- a/baselines/IPPO/ippo_cnn_overcooked.py +++ b/baselines/IPPO/ippo_cnn_overcooked.py @@ -345,13 +345,13 @@ def _loss_fn(params, traj_batch, gae, targets): ), "batch size must be equal to number of steps * number of actors" permutation = jax.random.permutation(_rng, batch_size) batch = (traj_batch, advantages, targets) - batch = jax.tree_util.tree.map( + batch = jax.tree.map( lambda x: x.reshape((batch_size,) + x.shape[2:]), batch ) - shuffled_batch = jax.tree_util.tree.map( + shuffled_batch = jax.tree.map( lambda x: jnp.take(x, permutation, axis=0), batch ) - minibatches = jax.tree_util.tree.map( + minibatches = jax.tree.map( lambda x: jnp.reshape( x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:]) ), diff --git a/baselines/IPPO/ippo_ff_hanabi.py b/baselines/IPPO/ippo_ff_hanabi.py index 1f798c03..ea608bb3 100644 --- a/baselines/IPPO/ippo_ff_hanabi.py +++ b/baselines/IPPO/ippo_ff_hanabi.py @@ -258,11 +258,11 @@ def _loss_fn(params, traj_batch, gae, targets): batch = (traj_batch, advantages.squeeze(), targets.squeeze()) permutation = jax.random.permutation(_rng, config["NUM_ACTORS"]) - shuffled_batch = jax.tree_util.tree.map( + shuffled_batch = jax.tree.map( lambda x: jnp.take(x, permutation, axis=1), batch ) - minibatches = jax.tree_util.tree.map( + minibatches = jax.tree.map( lambda x: jnp.swapaxes( jnp.reshape( x, diff --git a/baselines/IPPO/ippo_ff_mabrax.py b/baselines/IPPO/ippo_ff_mabrax.py index 24fbdb34..e731298d 100644 --- a/baselines/IPPO/ippo_ff_mabrax.py +++ b/baselines/IPPO/ippo_ff_mabrax.py @@ -258,13 +258,13 @@ def _loss_fn(params, traj_batch, gae, targets): ), "batch size must be equal to number of steps * number of actors" permutation = jax.random.permutation(_rng, batch_size) batch = (traj_batch, advantages, targets) - batch = jax.tree_util.tree.map( + batch = jax.tree.map( lambda x: x.reshape((batch_size,) + x.shape[2:]), batch ) - shuffled_batch = jax.tree_util.tree.map( + shuffled_batch = jax.tree.map( lambda x: jnp.take(x, permutation, axis=0), batch ) - minibatches = jax.tree_util.tree.map( + minibatches = jax.tree.map( lambda x: jnp.reshape( x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:]) ), diff --git a/baselines/IPPO/ippo_ff_mpe.py b/baselines/IPPO/ippo_ff_mpe.py index 6df56c30..379c0c85 100644 --- a/baselines/IPPO/ippo_ff_mpe.py +++ b/baselines/IPPO/ippo_ff_mpe.py @@ -255,13 +255,13 @@ def _loss_fn(params, traj_batch, gae, targets): ), "batch size must be equal to number of steps * number of actors" permutation = jax.random.permutation(_rng, batch_size) batch = (traj_batch, advantages, targets) - batch = jax.tree_util.tree.map( + batch = jax.tree.map( lambda x: x.reshape((batch_size,) + x.shape[2:]), batch ) - shuffled_batch = jax.tree_util.tree.map( + shuffled_batch = jax.tree.map( lambda x: jnp.take(x, permutation, axis=0), batch ) - minibatches = jax.tree_util.tree.map( + minibatches = jax.tree.map( lambda x: jnp.reshape( x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:]) ), diff --git a/baselines/IPPO/ippo_ff_mpe_facmac.py b/baselines/IPPO/ippo_ff_mpe_facmac.py index 0b37f7f5..5cfc48a0 100644 --- a/baselines/IPPO/ippo_ff_mpe_facmac.py +++ b/baselines/IPPO/ippo_ff_mpe_facmac.py @@ -252,13 +252,13 @@ def _loss_fn(params, traj_batch, gae, targets): ), "batch size must be equal to number of steps * number of actors" permutation = jax.random.permutation(_rng, batch_size) batch = (traj_batch, advantages, targets) - batch = jax.tree_util.tree.map( + batch = jax.tree.map( lambda x: x.reshape((batch_size,) + x.shape[2:]), batch ) - shuffled_batch = jax.tree_util.tree.map( + shuffled_batch = jax.tree.map( lambda x: jnp.take(x, permutation, axis=0), batch ) - minibatches = jax.tree_util.tree.map( + minibatches = jax.tree.map( lambda x: jnp.reshape( x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:]) ), diff --git a/baselines/IPPO/ippo_ff_overcooked.py b/baselines/IPPO/ippo_ff_overcooked.py index d0ae5afb..33bef493 100644 --- a/baselines/IPPO/ippo_ff_overcooked.py +++ b/baselines/IPPO/ippo_ff_overcooked.py @@ -318,13 +318,13 @@ def _loss_fn(params, traj_batch, gae, targets): ), "batch size must be equal to number of steps * number of actors" permutation = jax.random.permutation(_rng, batch_size) batch = (traj_batch, advantages, targets) - batch = jax.tree_util.tree.map( + batch = jax.tree.map( lambda x: x.reshape((batch_size,) + x.shape[2:]), batch ) - shuffled_batch = jax.tree_util.tree.map( + shuffled_batch = jax.tree.map( lambda x: jnp.take(x, permutation, axis=0), batch ) - minibatches = jax.tree_util.tree.map( + minibatches = jax.tree.map( lambda x: jnp.reshape( x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:]) ), diff --git a/baselines/IPPO/ippo_ff_switch_riddle.py b/baselines/IPPO/ippo_ff_switch_riddle.py index 9804b61f..2a4fb97c 100644 --- a/baselines/IPPO/ippo_ff_switch_riddle.py +++ b/baselines/IPPO/ippo_ff_switch_riddle.py @@ -247,13 +247,13 @@ def _loss_fn(params, traj_batch, gae, targets): ), "batch size must be equal to number of steps * number of actors" permutation = jax.random.permutation(_rng, batch_size) batch = (traj_batch, advantages, targets) - batch = jax.tree_util.tree.map( + batch = jax.tree.map( lambda x: x.reshape((batch_size,) + x.shape[2:]), batch ) - shuffled_batch = jax.tree_util.tree.map( + shuffled_batch = jax.tree.map( lambda x: jnp.take(x, permutation, axis=0), batch ) - minibatches = jax.tree_util.tree.map( + minibatches = jax.tree.map( lambda x: jnp.reshape( x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:]) ), diff --git a/baselines/IPPO/ippo_rnn_hanabi.py b/baselines/IPPO/ippo_rnn_hanabi.py index ec9cbb12..d59f0197 100644 --- a/baselines/IPPO/ippo_rnn_hanabi.py +++ b/baselines/IPPO/ippo_rnn_hanabi.py @@ -312,11 +312,11 @@ def _loss_fn(params, init_hstate, traj_batch, gae, targets): batch = (init_hstate, traj_batch, advantages.squeeze(), targets.squeeze()) permutation = jax.random.permutation(_rng, config["NUM_ACTORS"]) - shuffled_batch = jax.tree_util.tree.map( + shuffled_batch = jax.tree.map( lambda x: jnp.take(x, permutation, axis=1), batch ) - minibatches = jax.tree_util.tree.map( + minibatches = jax.tree.map( lambda x: jnp.swapaxes( jnp.reshape( x, diff --git a/baselines/IPPO/ippo_rnn_mpe.py b/baselines/IPPO/ippo_rnn_mpe.py index fba387a0..a8ea020a 100644 --- a/baselines/IPPO/ippo_rnn_mpe.py +++ b/baselines/IPPO/ippo_rnn_mpe.py @@ -334,11 +334,11 @@ def _loss_fn(params, init_hstate, traj_batch, gae, targets): ) permutation = jax.random.permutation(_rng, config["NUM_ACTORS"]) - shuffled_batch = jax.tree_util.tree.map( + shuffled_batch = jax.tree.map( lambda x: jnp.take(x, permutation, axis=1), batch ) - minibatches = jax.tree_util.tree.map( + minibatches = jax.tree.map( lambda x: jnp.swapaxes( jnp.reshape( x, diff --git a/baselines/IPPO/ippo_rnn_smax.py b/baselines/IPPO/ippo_rnn_smax.py index eb1180cc..14ffd1b8 100644 --- a/baselines/IPPO/ippo_rnn_smax.py +++ b/baselines/IPPO/ippo_rnn_smax.py @@ -346,11 +346,11 @@ def _loss_fn(params, init_hstate, traj_batch, gae, targets): ) permutation = jax.random.permutation(_rng, config["NUM_ACTORS"]) - shuffled_batch = jax.tree_util.tree.map( + shuffled_batch = jax.tree.map( lambda x: jnp.take(x, permutation, axis=1), batch ) - minibatches = jax.tree_util.tree.map( + minibatches = jax.tree.map( lambda x: jnp.swapaxes( jnp.reshape( x, diff --git a/baselines/MAPPO/mappo_ff_hanabi.py b/baselines/MAPPO/mappo_ff_hanabi.py index 8a6e68ac..8c79d90f 100644 --- a/baselines/MAPPO/mappo_ff_hanabi.py +++ b/baselines/MAPPO/mappo_ff_hanabi.py @@ -406,11 +406,11 @@ def _critic_loss_fn(critic_params, traj_batch, targets): ) permutation = jax.random.permutation(_rng, config["NUM_ACTORS"]) - shuffled_batch = jax.tree_util.tree.map( + shuffled_batch = jax.tree.map( lambda x: jnp.take(x, permutation, axis=1), batch ) - minibatches = jax.tree_util.tree.map( + minibatches = jax.tree.map( lambda x: jnp.swapaxes( jnp.reshape( x, diff --git a/baselines/MAPPO/mappo_rnn_hanabi.py b/baselines/MAPPO/mappo_rnn_hanabi.py index ed38fe4b..63b106f4 100644 --- a/baselines/MAPPO/mappo_rnn_hanabi.py +++ b/baselines/MAPPO/mappo_rnn_hanabi.py @@ -460,11 +460,11 @@ def _critic_loss_fn(critic_params, init_hstate, traj_batch, targets): ) permutation = jax.random.permutation(_rng, config["NUM_ACTORS"]) - shuffled_batch = jax.tree_util.tree.map( + shuffled_batch = jax.tree.map( lambda x: jnp.take(x, permutation, axis=1), batch ) - minibatches = jax.tree_util.tree.map( + minibatches = jax.tree.map( lambda x: jnp.swapaxes( jnp.reshape( x, diff --git a/baselines/MAPPO/mappo_rnn_mpe.py b/baselines/MAPPO/mappo_rnn_mpe.py index 6a5d3bd5..f7a2b337 100644 --- a/baselines/MAPPO/mappo_rnn_mpe.py +++ b/baselines/MAPPO/mappo_rnn_mpe.py @@ -448,11 +448,11 @@ def _critic_loss_fn(critic_params, init_hstate, traj_batch, targets): ) permutation = jax.random.permutation(_rng, config["NUM_ACTORS"]) - shuffled_batch = jax.tree_util.tree.map( + shuffled_batch = jax.tree.map( lambda x: jnp.take(x, permutation, axis=1), batch ) - minibatches = jax.tree_util.tree.map( + minibatches = jax.tree.map( lambda x: jnp.swapaxes( jnp.reshape( x, diff --git a/baselines/MAPPO/mappo_rnn_smax.py b/baselines/MAPPO/mappo_rnn_smax.py index c8f7695d..e9210e6d 100644 --- a/baselines/MAPPO/mappo_rnn_smax.py +++ b/baselines/MAPPO/mappo_rnn_smax.py @@ -479,11 +479,11 @@ def _critic_loss_fn(critic_params, init_hstate, traj_batch, targets): ) permutation = jax.random.permutation(_rng, config["NUM_ACTORS"]) - shuffled_batch = jax.tree_util.tree.map( + shuffled_batch = jax.tree.map( lambda x: jnp.take(x, permutation, axis=1), batch ) - minibatches = jax.tree_util.tree.map( + minibatches = jax.tree.map( lambda x: jnp.swapaxes( jnp.reshape( x, diff --git a/baselines/QLearning/iql_cnn_overcooked.py b/baselines/QLearning/iql_cnn_overcooked.py index e7e4315f..38011c14 100644 --- a/baselines/QLearning/iql_cnn_overcooked.py +++ b/baselines/QLearning/iql_cnn_overcooked.py @@ -291,7 +291,7 @@ def _step_env(carry, _): ) # update timesteps count # BUFFER UPDATE - timesteps = jax.tree_util.tree.map( + timesteps = jax.tree.map( lambda x: x.reshape(-1, *x.shape[2:]), timesteps ) # (num_envs*num_steps, ...) buffer_state = buffer.add(buffer_state, timesteps) diff --git a/baselines/QLearning/iql_rnn.py b/baselines/QLearning/iql_rnn.py index c28ec1d5..f8f2facf 100644 --- a/baselines/QLearning/iql_rnn.py +++ b/baselines/QLearning/iql_rnn.py @@ -319,7 +319,7 @@ def _step_env(carry, _): ) # update timesteps count # BUFFER UPDATE - buffer_traj_batch = jax.tree_util.tree.map( + buffer_traj_batch = jax.tree.map( lambda x: jnp.swapaxes(x, 0, 1)[ :, np.newaxis ], # put the batch dim first and add a dummy sequence dim diff --git a/baselines/QLearning/pqn_vdn_cnn_overcooked.py b/baselines/QLearning/pqn_vdn_cnn_overcooked.py index 08c241bb..cffa656e 100644 --- a/baselines/QLearning/pqn_vdn_cnn_overcooked.py +++ b/baselines/QLearning/pqn_vdn_cnn_overcooked.py @@ -430,7 +430,7 @@ def preprocess_transition(x, rng): return x rng, _rng = jax.random.split(rng) - minibatches = jax.tree_util.tree.map( + minibatches = jax.tree.map( lambda x: preprocess_transition(x, _rng), transitions, ) # num_minibatches, num_agents, num_envs/num_minbatches ... diff --git a/baselines/QLearning/pqn_vdn_ff.py b/baselines/QLearning/pqn_vdn_ff.py index ff04f8f7..19267a6e 100644 --- a/baselines/QLearning/pqn_vdn_ff.py +++ b/baselines/QLearning/pqn_vdn_ff.py @@ -383,7 +383,7 @@ def preprocess_transition(x, rng): return x rng, _rng = jax.random.split(rng) - minibatches = jax.tree_util.tree.map( + minibatches = jax.tree.map( lambda x: preprocess_transition(x, _rng), transitions, ) # num_minibatches, num_agents, num_envs/num_minbatches ... diff --git a/baselines/QLearning/pqn_vdn_rnn.py b/baselines/QLearning/pqn_vdn_rnn.py index f96601dd..76d92e2d 100644 --- a/baselines/QLearning/pqn_vdn_rnn.py +++ b/baselines/QLearning/pqn_vdn_rnn.py @@ -336,7 +336,7 @@ def _learn_phase(carry, minibatch): minibatch.last_done, ) # batchify the agent input: num_agents*batch_size - agent_in = jax.tree_util.tree.map( + agent_in = jax.tree.map( lambda x: x.reshape(x.shape[0], -1, *x.shape[3:]), agent_in ) # (num_steps, num_agents*batch_size, ...) @@ -442,7 +442,7 @@ def preprocess_transition(x, rng): return x rng, _rng = jax.random.split(rng) - minibatches = jax.tree_util.tree.map( + minibatches = jax.tree.map( lambda x: preprocess_transition(x, _rng), memory_transitions, ) # num_minibatches, num_steps+memory_window, num_agents, batch_size/num_minbatches, num_agents, ... diff --git a/baselines/QLearning/qmix_rnn.py b/baselines/QLearning/qmix_rnn.py index 0561b725..7cd37cfc 100644 --- a/baselines/QLearning/qmix_rnn.py +++ b/baselines/QLearning/qmix_rnn.py @@ -407,7 +407,7 @@ def _step_env(carry, _): ) # update timesteps count # BUFFER UPDATE - buffer_traj_batch = jax.tree_util.tree.map( + buffer_traj_batch = jax.tree.map( lambda x: jnp.swapaxes(x, 0, 1)[ :, np.newaxis ], # put the batch dim first and add a dummy sequence dim diff --git a/baselines/QLearning/shaq.py b/baselines/QLearning/shaq.py index f34ff17b..2cbb2f95 100644 --- a/baselines/QLearning/shaq.py +++ b/baselines/QLearning/shaq.py @@ -454,7 +454,7 @@ def _env_step(step_state, unused): # get the q_values from the agent network hstate, q_vals = homogeneous_pass(params, hstate, obs_, dones_) # remove the dummy time_step dimension and index qs by the valid actions of each agent - valid_q_vals = jax.tree_util.tree.map(lambda q, valid_idx: q.squeeze(0)[..., valid_idx], q_vals, wrapped_env.valid_actions) + valid_q_vals = jax.tree.map(lambda q, valid_idx: q.squeeze(0)[..., valid_idx], q_vals, wrapped_env.valid_actions) # explore with epsilon greedy_exploration actions = explorer.choose_actions(valid_q_vals, t, key_a) @@ -488,7 +488,7 @@ def _env_step(step_state, unused): ) # BUFFER UPDATE: save the collected trajectory in the buffer - buffer_traj_batch = jax.tree_util.tree.map( + buffer_traj_batch = jax.tree.map( lambda x:jnp.swapaxes(x, 0, 1)[:, np.newaxis], # put the batch dim first and add a dummy sequence dim traj_batch ) # (num_envs, 1, time_steps, ...) @@ -519,7 +519,7 @@ def _loss_fn(params_agent, params_mixer, target_network_params_agent, target_net ) # get the target q value of the greedy actions for each agent - valid_q_vals = jax.tree_util.tree.map(lambda q, valid_idx: q[..., valid_idx], q_vals, wrapped_env.valid_actions) + valid_q_vals = jax.tree.map(lambda q, valid_idx: q[..., valid_idx], q_vals, wrapped_env.valid_actions) target_max_qvals = jax.tree.map( lambda t_q, q: q_of_action(t_q, jnp.argmax(q, axis=-1))[1:], # avoid first timestep target_q_vals, @@ -642,7 +642,7 @@ def _td_lambda_target(ret, values): 'timesteps': time_state['timesteps']*config['NUM_ENVS'], 'updates' : time_state['updates'], 'loss': loss, - 'rewards': jax.tree_util.tree.map(lambda x: jnp.sum(x, axis=0).mean(), traj_batch.rewards), + 'rewards': jax.tree.map(lambda x: jnp.sum(x, axis=0).mean(), traj_batch.rewards), 'eps': explorer.get_epsilon(time_state['timesteps']) } metrics['test_metrics'] = test_metrics # add the test metrics dictionary @@ -691,7 +691,7 @@ def _greedy_env_step(step_state, unused): obs_ = jax.tree.map(lambda x: x[np.newaxis, :], obs_) dones_ = jax.tree.map(lambda x: x[np.newaxis, :], last_dones) hstate, q_vals = homogeneous_pass(params, hstate, obs_, dones_) - actions = jax.tree_util.tree.map(lambda q, valid_idx: jnp.argmax(q.squeeze(0)[..., valid_idx], axis=-1), q_vals, test_env.valid_actions) + actions = jax.tree.map(lambda q, valid_idx: jnp.argmax(q.squeeze(0)[..., valid_idx], axis=-1), q_vals, test_env.valid_actions) obs, env_state, rewards, dones, infos = test_env.batch_step(key_s, env_state, actions) step_state = (params, env_state, obs, dones, hstate, rng) return step_state, (rewards, dones, infos) diff --git a/baselines/QLearning/transf_qmix.py b/baselines/QLearning/transf_qmix.py index 1ab6e30e..97bc9710 100644 --- a/baselines/QLearning/transf_qmix.py +++ b/baselines/QLearning/transf_qmix.py @@ -384,7 +384,7 @@ class Transition(NamedTuple): def tree_mean(tree): return jnp.array( - jax.tree_util.tree_leaves(jax.tree.map(lambda x: x.mean(), tree)) + jax.tree_leaves(jax.tree.map(lambda x: x.mean(), tree)) ).mean() @@ -487,8 +487,8 @@ def _env_sample_step(env_state, unused): network_stats = {'agent':agent_params['batch_stats'],'mixer':mixer_params['batch_stats']} # print number of params - agent_params = sum(x.size for x in jax.tree_util.tree_leaves(network_params['agent'])) - mixer_params = sum(x.size for x in jax.tree_util.tree_leaves(network_params['mixer'])) + agent_params = sum(x.size for x in jax.tree_leaves(network_params['agent'])) + mixer_params = sum(x.size for x in jax.tree_leaves(network_params['mixer'])) jax.debug.print("Number of agent params: {x}", x=agent_params) jax.debug.print("Number of mixer params: {x}", x=mixer_params) @@ -609,7 +609,7 @@ def _env_step(step_state, unused): # get the q_values from the agent netwoek _, hstate, q_vals = homogeneous_pass(env_params, env_batch_norm, hstate, obs_, dones_, train=False) # remove the dummy time_step dimension and index qs by the valid actions of each agent - valid_q_vals = jax.tree_util.tree.map(lambda q, valid_idx: q.squeeze(0)[..., valid_idx], q_vals, wrapped_env.valid_actions) + valid_q_vals = jax.tree.map(lambda q, valid_idx: q.squeeze(0)[..., valid_idx], q_vals, wrapped_env.valid_actions) # explore with epsilon greedy_exploration actions = explorer.choose_actions(valid_q_vals, t, key_a) @@ -641,7 +641,7 @@ def _env_step(step_state, unused): ) # BUFFER UPDATE: save the collected trajectory in the buffer - buffer_traj_batch = jax.tree_util.tree.map( + buffer_traj_batch = jax.tree.map( lambda x:jnp.swapaxes(x, 0, 1)[:, np.newaxis], # put the batch dim first and add a dummy sequence dim traj_batch ) # (num_envs, 1, time_steps, ...) @@ -702,7 +702,7 @@ def _loss_fn(params, init_hs, learn_traj): ) # get the target q value of the greedy actions for each agent - valid_q_vals = jax.tree_util.tree.map(lambda q, valid_idx: q[..., valid_idx], q_vals, wrapped_env.valid_actions) + valid_q_vals = jax.tree.map(lambda q, valid_idx: q[..., valid_idx], q_vals, wrapped_env.valid_actions) target_max_qvals = jax.tree.map( lambda t_q, q: q_of_action(t_q, jnp.argmax(q, axis=-1))[1:], # avoid first timestep target_q_vals, @@ -867,7 +867,7 @@ def _greedy_env_step(step_state, unused): obs_ = jax.tree.map(lambda x: x[np.newaxis, :], obs_) dones_ = jax.tree.map(lambda x: x[np.newaxis, :], last_dones) _, hstate, q_vals = homogeneous_pass(env_params, env_batch_norm, hstate, obs_, dones_, train=False) - actions = jax.tree_util.tree.map(lambda q, valid_idx: jnp.argmax(q.squeeze(0)[..., valid_idx], axis=-1), q_vals, test_env.valid_actions) + actions = jax.tree.map(lambda q, valid_idx: jnp.argmax(q.squeeze(0)[..., valid_idx], axis=-1), q_vals, test_env.valid_actions) obs, env_state, rewards, dones, infos = test_env.batch_step(key_s, env_state, actions) step_state = (env_state, obs, dones, hstate, rng) return step_state, (rewards, dones, infos) diff --git a/baselines/QLearning/vdn_cnn_overcooked.py b/baselines/QLearning/vdn_cnn_overcooked.py index af151baa..7efe3832 100644 --- a/baselines/QLearning/vdn_cnn_overcooked.py +++ b/baselines/QLearning/vdn_cnn_overcooked.py @@ -291,7 +291,7 @@ def _step_env(carry, _): ) # update timesteps count # BUFFER UPDATE - timesteps = jax.tree_util.tree.map( + timesteps = jax.tree.map( lambda x: x.reshape(-1, *x.shape[2:]), timesteps ) # (num_envs*num_steps, ...) buffer_state = buffer.add(buffer_state, timesteps) diff --git a/baselines/QLearning/vdn_ff.py b/baselines/QLearning/vdn_ff.py index 41bfec7d..2494c8b7 100644 --- a/baselines/QLearning/vdn_ff.py +++ b/baselines/QLearning/vdn_ff.py @@ -240,7 +240,7 @@ def _step_env(carry, _): ) # update timesteps count # BUFFER UPDATE - timesteps = jax.tree_util.tree.map( + timesteps = jax.tree.map( lambda x: x.reshape(-1, *x.shape[2:]), timesteps ) # (num_envs*num_steps, ...) buffer_state = buffer.add(buffer_state, timesteps) diff --git a/baselines/QLearning/vdn_rnn.py b/baselines/QLearning/vdn_rnn.py index 73bec893..59a0a3ef 100644 --- a/baselines/QLearning/vdn_rnn.py +++ b/baselines/QLearning/vdn_rnn.py @@ -319,7 +319,7 @@ def _step_env(carry, _): ) # update timesteps count # BUFFER UPDATE - buffer_traj_batch = jax.tree_util.tree.map( + buffer_traj_batch = jax.tree.map( lambda x: jnp.swapaxes(x, 0, 1)[ :, np.newaxis ], # put the batch dim first and add a dummy sequence dim diff --git a/jaxmarl/environments/hanabi/hanabi.py b/jaxmarl/environments/hanabi/hanabi.py index 21dd98ca..d2a4d68d 100644 --- a/jaxmarl/environments/hanabi/hanabi.py +++ b/jaxmarl/environments/hanabi/hanabi.py @@ -532,7 +532,7 @@ def _binarize_discard_pile(self, discard_pile: chex.Array): """Binarize the discard pile to reduce dimensionality.""" def binarize_ranks(n_ranks): - tree = jax.tree_util.tree.map( + tree = jax.tree.map( lambda n_rank_present, max_ranks: jnp.where( jnp.arange(max_ranks) >= n_rank_present, jnp.zeros(max_ranks), diff --git a/jaxmarl/wrappers/baselines.py b/jaxmarl/wrappers/baselines.py index 0614b677..c2779aa3 100644 --- a/jaxmarl/wrappers/baselines.py +++ b/jaxmarl/wrappers/baselines.py @@ -285,7 +285,7 @@ def batch_step(self, key, states, actions): def wrapped_reset(self, key): obs_, state = self._env.reset(key) if self.preprocess_obs: - obs = jax.tree_util.tree.map(self._preprocess_obs, {agent:obs_[agent] for agent in self.agents}, self.agents_one_hot) + obs = jax.tree.map(self._preprocess_obs, {agent:obs_[agent] for agent in self.agents}, self.agents_one_hot) else: obs = obs_ obs["__all__"] = self.global_state(obs_, state) @@ -295,8 +295,8 @@ def wrapped_reset(self, key): def wrapped_step(self, key, state, actions): obs_, state, reward, done, infos = self._env.step(key, state, actions) if self.preprocess_obs: - obs = jax.tree_util.tree.map(self._preprocess_obs, {agent:obs_[agent] for agent in self.agents}, self.agents_one_hot) - obs = jax.tree_util.tree.map(lambda d, o: jnp.where(d, 0., o), {agent:done[agent] for agent in self.agents}, obs) # ensure that the obs are 0s for done agents + obs = jax.tree.map(self._preprocess_obs, {agent:obs_[agent] for agent in self.agents}, self.agents_one_hot) + obs = jax.tree.map(lambda d, o: jnp.where(d, 0., o), {agent:done[agent] for agent in self.agents}, obs) # ensure that the obs are 0s for done agents else: obs = obs_ obs["__all__"] = self.global_state(obs_, state) diff --git a/tests/brax/test_brax_rand_acts.py b/tests/brax/test_brax_rand_acts.py index 43671532..71059dda 100644 --- a/tests/brax/test_brax_rand_acts.py +++ b/tests/brax/test_brax_rand_acts.py @@ -23,3 +23,7 @@ def test_random_rollout(): rng_act = jax.random.split(rng_act, env.num_agents) actions = {a: env.action_space(a).sample(rng_act[i]) for i, a in enumerate(env.agents)} _, state, _, _, _ = env.step(rng, state, actions) + + +if __name__ == "__main__": + test_random_rollout() \ No newline at end of file diff --git a/tests/mpe/mpe_policy_transfer.py b/tests/mpe/mpe_policy_transfer.py index a8dcea97..56c1ef8b 100644 --- a/tests/mpe/mpe_policy_transfer.py +++ b/tests/mpe/mpe_policy_transfer.py @@ -118,7 +118,7 @@ def _preprocess_obs(arr, extra_features): def obs_to_act(obs, dones, params=params): - obs = jax.tree_util.tree.map(_preprocess_obs, obs, agents_one_hot) + obs = jax.tree.map(_preprocess_obs, obs, agents_one_hot) # add a dummy temporal dimension obs_ = jax.tree.map(lambda x: x[np.newaxis, np.newaxis, :], obs) # add also a dummy batch dim to obs @@ -129,8 +129,8 @@ def obs_to_act(obs, dones, params=params): hstate, q_vals = agent.homogeneous_pass(params, hstate, obs_, dones_) # get actions from q vals - valid_q_vals = jax.tree_util.tree.map(lambda q, valid_idx: q.squeeze(0)[..., valid_idx], q_vals, valid_actions) - actions = jax.tree_util.tree.map(lambda q: jnp.argmax(q, axis=-1).squeeze(0), valid_q_vals) + valid_q_vals = jax.tree.map(lambda q, valid_idx: q.squeeze(0)[..., valid_idx], q_vals, valid_actions) + actions = jax.tree.map(lambda q: jnp.argmax(q, axis=-1).squeeze(0), valid_q_vals) return actions