Skip to content

Commit

Permalink
Merge pull request #1134 from instadeepai/feat/update_juamnji
Browse files Browse the repository at this point in the history
Feat: Support latest Jumanji version
  • Loading branch information
WiemKhlifi authored Dec 6, 2024
2 parents 80554ef + 32a4e1a commit ae736ff
Show file tree
Hide file tree
Showing 39 changed files with 184 additions and 138 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ jobs:
linters:
name: "Python ${{ matrix.python-version }} on ubuntu-latest"
runs-on: ubuntu-latest
timeout-minutes: 5
timeout-minutes: 10

strategy:
matrix:
Expand Down
14 changes: 5 additions & 9 deletions examples/Quickstart.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"metadata": {
"id": "eWjNSGvZ7ALw"
},
Expand Down Expand Up @@ -571,7 +571,7 @@
" )\n",
"\n",
" # Initialise observation with obs of all agents.\n",
" obs = env.observation_spec().generate_value()\n",
" obs = env.observation_spec.generate_value()\n",
" init_x = tree.map(lambda x: x[jnp.newaxis, ...], obs)\n",
"\n",
" # Initialise actor params and optimiser state.\n",
Expand Down Expand Up @@ -1111,7 +1111,8 @@
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"display_name": "mava",
"language": "python",
"name": "python3"
},
"language_info": {
Expand All @@ -1124,12 +1125,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.18"
},
"vscode": {
"interpreter": {
"hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
}
"version": "3.12.4"
}
},
"nbformat": 4,
Expand Down
2 changes: 1 addition & 1 deletion examples/advanced_usage/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ dummy_flashbax_transition = {
"observation": jnp.zeros(
(
config.system.num_agents,
env.observation_spec().agents_view.shape[1],
env.observation_spec.agents_view.shape[1],
),
dtype=jnp.float32,
),
Expand Down
6 changes: 3 additions & 3 deletions examples/advanced_usage/ff_ippo_store_experience.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ def learner_setup(

# Define network and optimiser.
actor_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso)
action_head, _ = get_action_head(env.action_spec())
action_head, _ = get_action_head(env.action_spec)
actor_action_head = hydra.utils.instantiate(action_head, action_dim=env.action_dim)
critic_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso)

Expand All @@ -377,7 +377,7 @@ def learner_setup(
)

# Initialise observation with obs of all agents.
obs = env.observation_spec().generate_value()
obs = env.observation_spec.generate_value()
init_x = tree.map(lambda x: x[jnp.newaxis, ...], obs)

# Initialise actor params and optimiser state.
Expand Down Expand Up @@ -507,7 +507,7 @@ def run_experiment(_config: DictConfig) -> None:
"observation": jnp.zeros(
(
config.system.num_agents,
env.observation_spec().agents_view.shape[1],
env.observation_spec.agents_view.shape[1],
),
dtype=jnp.float32,
),
Expand Down
2 changes: 1 addition & 1 deletion mava/configs/default/ff_ippo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ defaults:
- arch: anakin
- system: ppo/ff_ippo
- network: mlp # [mlp, cnn]
- env: rware # [cleaner, connector, gigastep, lbf, mabrax, matrax, rware, smax, mpe]
- env: rware # [cleaner, connector, vector-connector, gigastep, lbf, mabrax, matrax, rware, smax, mpe]
- _self_

hydra:
Expand Down
4 changes: 2 additions & 2 deletions mava/configs/default/ff_mappo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ defaults:
- logger: logger
- arch: anakin
- system: ppo/ff_mappo
- network: mlp # [mlp, cnn]
- env: rware # [cleaner, connector, gigastep, lbf, mabrax, matrax, rware, smax, mpe]
- network: mlp # [mlp, cnn]
- env: rware # [cleaner, connector, vector-connector, gigastep, lbf, mabrax, matrax, rware, smax, mpe]
- _self_

hydra:
Expand Down
2 changes: 1 addition & 1 deletion mava/configs/default/ff_sable.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ defaults:
- arch: anakin
- system: sable/ff_sable
- network: ff_retention
- env: rware # [cleaner, connector, gigastep, lbf, rware, smax, mpe]
- env: rware # [cleaner, connector, vector-connector, gigastep, lbf, matrax, rware, smax, mpe]
- _self_

hydra:
Expand Down
2 changes: 1 addition & 1 deletion mava/configs/default/mat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ defaults:
- arch: anakin
- system: mat/mat
- network: transformer
- env: rware # [gigastep, lbf, mabrax, matrax, rware, smax, mpe]
- env: rware # [cleaner, connector, vector-connector, gigastep, lbf, mabrax, matrax, rware, smax, mpe]
- _self_

hydra:
Expand Down
4 changes: 2 additions & 2 deletions mava/configs/default/rec_ippo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ defaults:
- logger: logger
- arch: anakin
- system: ppo/rec_ippo
- network: rnn # [rnn, rcnn]
- env: rware # [cleaner, connector, gigastep, lbf, mabrax, matrax, rware, smax, mpe]
- network: rnn # [rnn, rcnn]
- env: rware # [cleaner, connector, vector-connector, gigastep, lbf, mabrax, matrax, rware, smax, mpe]
- _self_

hydra:
Expand Down
4 changes: 2 additions & 2 deletions mava/configs/default/rec_iql.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ defaults:
- logger: logger
- arch: anakin
- system: q_learning/rec_iql
- network: rnn # [rnn, rcnn]
- env: smax # [cleaner, connector, gigastep, lbf, matrax, rware, smax]
- network: rnn # [rnn, rcnn]
- env: smax # [cleaner, connector, vector-connector, gigastep, lbf, matrax, rware, smax]

hydra:
searchpath:
Expand Down
4 changes: 2 additions & 2 deletions mava/configs/default/rec_mappo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ defaults:
- logger: logger
- arch: anakin
- system: ppo/rec_mappo
- network: rnn # [rnn, rcnn]
- env: rware # [cleaner, connector, gigastep, lbf, mabrax, matrax, rware, smax, mpe]
- network: rnn # [rnn, rcnn]
- env: rware # [cleaner, connector, vector-connector, gigastep, lbf, mabrax, matrax, rware, smax, mpe]
- _self_

hydra:
Expand Down
2 changes: 1 addition & 1 deletion mava/configs/default/rec_sable.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ defaults:
- arch: anakin
- system: sable/rec_sable
- network: rec_retention
- env: rware # [cleaner, connector, gigastep, lbf, rware, smax, mabrax, mpe]
- env: rware # [cleaner, connector, vector-connector, gigastep, lbf, matrax, rware, smax, mpe]
- _self_

hydra:
Expand Down
5 changes: 4 additions & 1 deletion mava/configs/env/connector.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ defaults:
- scenario: con-5x5x3a # [con-5x5x3a, con-7x7x5a, con-10x10x10a, con-15x15x23a]
# Further environment config details in "con-10x10x5a" file.

env_name: MaConnector # Used for logging purposes.
env_name: Connector # Used for logging purposes.

# Choose whether to aggregate individual rewards into a shared team reward or not.
aggregate_rewards: True

# Defines the metric that will be used to evaluate the performance of the agent.
# This metric is returned at the end of an experiment and can be used for hyperparameter tuning.
Expand Down
4 changes: 2 additions & 2 deletions mava/configs/env/lbf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ defaults:

env_name: LevelBasedForaging # Used for logging purposes.

# Choose whether to aggregate the list of individual rewards and use the team reward (default setting) OR use_individual_rewards=True.
use_individual_rewards: False # If True, use the list of individual rewards.
# Choose whether to aggregate individual rewards into a shared team reward or not.
aggregate_rewards: True

# Defines the metric that will be used to evaluate the performance of the agent.
# This metric is returned at the end of an experiment and can be used for hyperparameter tuning.
Expand Down
2 changes: 1 addition & 1 deletion mava/configs/env/scenario/con-10x10x10a.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# The config of the 10x10x10a scenario
name: MaConnector-v2
name: Connector-v2
task_name: con-10x10x10a

task_config:
Expand Down
2 changes: 1 addition & 1 deletion mava/configs/env/scenario/con-15x15x23a.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# The config of the 15x15x23a scenario
name: MaConnector-v2
name: Connector-v2
task_name: con-15x15x23a

task_config:
Expand Down
2 changes: 1 addition & 1 deletion mava/configs/env/scenario/con-5x5x3a.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# The config of the 5x5x3a scenario
name: MaConnector-v2
name: Connector-v2
task_name: con-5x5x3a

task_config:
Expand Down
2 changes: 1 addition & 1 deletion mava/configs/env/scenario/con-7x7x5a.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# The config of the 7x7x5a scenario
name: MaConnector-v2
name: Connector-v2
task_name: con-7x7x5a

task_config:
Expand Down
5 changes: 4 additions & 1 deletion mava/configs/env/vector-connector.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ defaults:
- scenario: con-5x5x3a # [con-5x5x3a, con-7x7x5a, con-10x10x10a, con-15x15x23a]
# Further environment config details in "con-10x10x5a" file.

env_name: VectorMaConnector # Used for logging purposes.
env_name: VectorConnector # Used for logging purposes.

# Choose whether to aggregate individual rewards into a shared team reward or not.
aggregate_rewards: True

# Defines the metric that will be used to evaluate the performance of the agent.
# This metric is returned at the end of an experiment and can be used for hyperparameter tuning.
Expand Down
6 changes: 3 additions & 3 deletions mava/systems/mat/anakin/mat.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,11 +319,11 @@ def learner_setup(
# PRNG keys.
key, actor_net_key = keys

# Get mock inputs to initialise network.
init_x = env.observation_spec().generate_value()
# Initialise observation: Obs for all agents.
init_x = env.observation_spec.generate_value()
init_x = tree.map(lambda x: x[None, ...], init_x)

_, action_space_type = get_action_head(env.action_spec())
_, action_space_type = get_action_head(env.action_spec)

if action_space_type == "discrete":
init_action = jnp.zeros((1, config.system.num_agents), dtype=jnp.int32)
Expand Down
6 changes: 3 additions & 3 deletions mava/systems/ppo/anakin/ff_ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def learner_setup(

# Define network and optimiser.
actor_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso)
action_head, _ = get_action_head(env.action_spec())
action_head, _ = get_action_head(env.action_spec)
actor_action_head = hydra.utils.instantiate(action_head, action_dim=env.action_dim)
critic_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso)

Expand All @@ -351,8 +351,8 @@ def learner_setup(
optax.adam(critic_lr, eps=1e-5),
)

# Get mock inputs to initialise network.
obs = env.observation_spec().generate_value()
# Initialise observation with obs of all agents.
obs = env.observation_spec.generate_value()
init_x = tree.map(lambda x: x[jnp.newaxis, ...], obs)

# Initialise actor params and optimiser state.
Expand Down
6 changes: 3 additions & 3 deletions mava/systems/ppo/anakin/ff_mappo.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ def learner_setup(

# Define network and optimiser.
actor_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso)
action_head, _ = get_action_head(env.action_spec())
action_head, _ = get_action_head(env.action_spec)
actor_action_head = hydra.utils.instantiate(action_head, action_dim=env.action_dim)
critic_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso)

Expand All @@ -353,8 +353,8 @@ def learner_setup(
optax.adam(critic_lr, eps=1e-5),
)

# Get mock inputs to initialise network.
obs = env.observation_spec().generate_value()
# Initialise observation with obs of all agents.
obs = env.observation_spec.generate_value()
init_x = tree.map(lambda x: x[jnp.newaxis, ...], obs)

# Initialise actor params and optimiser state.
Expand Down
6 changes: 3 additions & 3 deletions mava/systems/ppo/anakin/rec_ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ def learner_setup(
# Define network and optimisers.
actor_pre_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso)
actor_post_torso = hydra.utils.instantiate(config.network.actor_network.post_torso)
action_head, _ = get_action_head(env.action_spec())
action_head, _ = get_action_head(env.action_spec)
actor_action_head = hydra.utils.instantiate(action_head, action_dim=env.action_dim)
critic_pre_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso)
critic_post_torso = hydra.utils.instantiate(config.network.critic_network.post_torso)
Expand Down Expand Up @@ -457,8 +457,8 @@ def learner_setup(
optax.adam(critic_lr, eps=1e-5),
)

# Get mock inputs to initialise network.
init_obs = env.observation_spec().generate_value()
# Initialise observation with obs of all agents.
init_obs = env.observation_spec.generate_value()
init_obs = tree.map(
lambda x: jnp.repeat(x[jnp.newaxis, ...], config.arch.num_envs, axis=0),
init_obs,
Expand Down
6 changes: 3 additions & 3 deletions mava/systems/ppo/anakin/rec_mappo.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ def learner_setup(
# Define network and optimiser.
actor_pre_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso)
actor_post_torso = hydra.utils.instantiate(config.network.actor_network.post_torso)
action_head, _ = get_action_head(env.action_spec())
action_head, _ = get_action_head(env.action_spec)
actor_action_head = hydra.utils.instantiate(action_head, action_dim=env.action_dim)
critic_pre_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso)
critic_post_torso = hydra.utils.instantiate(config.network.critic_network.post_torso)
Expand Down Expand Up @@ -460,8 +460,8 @@ def learner_setup(
optax.adam(critic_lr, eps=1e-5),
)

# Get mock inputs to initialise network.
init_obs = env.observation_spec().generate_value()
# Initialise observation with obs of all agents.
init_obs = env.observation_spec.generate_value()
init_obs = tree.map(
lambda x: jnp.repeat(x[jnp.newaxis, ...], config.arch.num_envs, axis=0),
init_obs,
Expand Down
4 changes: 2 additions & 2 deletions mava/systems/q_learning/anakin/rec_iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def replicate(x: Any) -> Any:
# N: Agent

# Make dummy inputs to init recurrent Q network -> need shape (T, B, N, ...)
init_obs = env.observation_spec().generate_value() # (N, ...)
init_obs = env.observation_spec.generate_value() # (N, ...)
# (B, T, N, ...)
init_obs_batched = tree.map(lambda x: x[jnp.newaxis, jnp.newaxis, ...], init_obs)
init_term_or_trunc = jnp.zeros((1, 1, 1), dtype=bool) # (T, B, 1)
Expand Down Expand Up @@ -130,7 +130,7 @@ def replicate(x: Any) -> Any:
init_hidden_state = replicate(init_hidden_state)

# Create dummy transition
init_acts = env.action_spec().generate_value() # (N,)
init_acts = env.action_spec.generate_value() # (N,)
init_transition = Transition(
obs=init_obs, # (N, ...)
action=init_acts,
Expand Down
6 changes: 3 additions & 3 deletions mava/systems/q_learning/anakin/rec_qmix.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def replicate(x: Any) -> Any:
# N: Agent

# Make dummy inputs to init recurrent Q network -> need shape (T, B, N, ...)
init_obs = env.observation_spec().generate_value() # (N, ...)
init_obs = env.observation_spec.generate_value() # (N, ...)
# (B, T, N, ...)
init_obs_batched = tree.map(lambda x: x[jnp.newaxis, jnp.newaxis, ...], init_obs)
init_term_or_trunc = jnp.zeros((1, 1, 1), dtype=bool) # (T, B, 1)
Expand Down Expand Up @@ -126,7 +126,7 @@ def replicate(x: Any) -> Any:
dtype=float,
)
global_env_state_shape = (
env.observation_spec().generate_value().global_state[0, :].shape
env.observation_spec.generate_value().global_state[0, :].shape
) # NOTE: Env wrapper currently duplicates env state for each agent
dummy_global_env_state = jnp.zeros(
(
Expand Down Expand Up @@ -159,7 +159,7 @@ def replicate(x: Any) -> Any:
opt_state = replicate(opt_state)
init_hidden_state = replicate(init_hidden_state)

init_acts = env.action_spec().generate_value()
init_acts = env.action_spec.generate_value()

# NOTE: term_or_trunc refers to the the joint done, ie. when all agents are done or when the
# episode horizon has been reached. We use this exclusively in QMIX.
Expand Down
Loading

0 comments on commit ae736ff

Please sign in to comment.