Skip to content

TransfQMix release #58

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Feb 14, 2024
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ We follow CleanRL's philosophy of providing single file implementations which ca
| IQL | [Paper](https://arxiv.org/abs/1312.5602v1) | [Source](https://github.com/FLAIROx/JaxMARL/tree/main/baselines/QLearning) |
| VDN | [Paper](https://arxiv.org/abs/1706.05296) | [Source](https://github.com/FLAIROx/JaxMARL/tree/main/baselines/QLearning) |
| QMIX | [Paper](https://arxiv.org/abs/1803.11485) | [Source](https://github.com/FLAIROx/JaxMARL/tree/main/baselines/QLearning) |
| TransfQMIX | [Peper](https://www.southampton.ac.uk/~eg/AAMAS2023/pdfs/p1679.pdf) | [Source](https://github.com/FLAIROx/JaxMARL/tree/main/baselines/QLearning) |
| SHAQ | [Paper](https://arxiv.org/abs/2105.15013) | [Source](https://github.com/FLAIROx/JaxMARL/tree/main/baselines/QLearning) |

<h2 name="install" id="install">Installation 🧗 </h2>
Expand Down
55 changes: 10 additions & 45 deletions baselines/QLearning/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Pure JAX implementations of:
* IQL (Independent Q-Learners)
* VDN (Value Decomposition Network)
* QMIX
* TransfQMix (Transformers for Leveraging the Graph Structure of MARL Problems)
* SHAQ (Incorporating Shapley Value Theory into Multi-Agent Q-Learning)

The first three are follow the original [Pymarl](https://github.com/oxwhirl/pymarl/blob/master/src/learners/q_learner.py) codebase while SHAQ follows the [paper code](https://github.com/hsvgbkhgbv/shapley-q-learning)
Expand All @@ -26,12 +27,12 @@ pip install -r requirements/requirements-qlearning.txt
- Hanabi
```

## 🔎 Implementation Details
## ⚙️ Implementation Details

General features:

- Agents are controlled by a single RNN architecture.
- You can choose whether to share parameters between agents or not.
- You can choose whether to share parameters between agents or not (not available on TransfQMix).
- Works also with non-homogeneous agents (different observation/action spaces).
- Experience replay is a simple buffer with uniform sampling.
- Uses Double Q-Learning with a target agent network (hard-updated).
Expand Down Expand Up @@ -60,10 +61,12 @@ python baselines/QLearning/iql.py +alg=iql_mpe +env=mpe_speaker_listener
python baselines/QLearning/vdn.py +alg=vdn_mpe +env=mpe_spread
# QMix with SMAX
python baselines/QLearning/qmix.py +alg=qmix_smax +env=smax
# QMix with hanabi
python baselines/QLearning/qmix.py +alg=qmix_hanabi +env=hanabi
# VDN with hanabi
python baselines/QLearning/vdn.py +alg=qlearn_hanabi +env=hanabi
# QMix against pretrained agents
python baselines/QLearning/qmix_pretrained.py +alg=qmix_mpe +env=mpe_tag_pretrained
# TransfQMix
python baselines/QLearning/transf_qmix.py +alg=transf_qmix_smax +env=smax
```

Notice that with Hydra, you can modify parameters on the go in this way:
Expand All @@ -73,44 +76,6 @@ Notice that with Hydra, you can modify parameters on the go in this way:
python baselines/QLearning/iql.py +alg=iql_mpe +env=mpe_spread alg.PARAMETERS_SHARING=False
```

It is often useful to run these scripts manually in a notebook or in another script.

```python
from jaxmarl import make
from baselines.QLearning.qmix import make_train

env = make("MPE_simple_spread_v3")

config = {
"NUM_ENVS": 8,
"BUFFER_SIZE": 5000,
"BUFFER_BATCH_SIZE": 32,
"TOTAL_TIMESTEPS": 2050000,
"AGENT_HIDDEN_DIM": 64,
"AGENT_INIT_SCALE": 2.0,
"PARAMETERS_SHARING": True,
"EPSILON_START": 1.0,
"EPSILON_FINISH": 0.05,
"EPSILON_ANNEAL_TIME": 100000,
"MIXER_EMBEDDING_DIM": 32,
"MIXER_HYPERNET_HIDDEN_DIM": 64,
"MIXER_INIT_SCALE": 0.00001,
"MAX_GRAD_NORM": 25,
"TARGET_UPDATE_INTERVAL": 200,
"LR": 0.005,
"LR_LINEAR_DECAY": True,
"EPS_ADAM": 0.001,
"WEIGHT_DECAY_ADAM": 0.00001,
"TD_LAMBDA_LOSS": True,
"TD_LAMBDA": 0.6,
"GAMMA": 0.9,
"VERBOSE": False,
"WANDB_ONLINE_REPORT": False,
"NUM_TEST_EPISODES": 32,
"TEST_INTERVAL": 50000,
}

rng = jax.random.PRNGKey(42)
train_vjit = jax.jit(make_train(config, env))
outs = train_vjit(rng)
```
## 🎯 Hyperparameter tuning

Please refer to the ```tune``` function in the [transf_qmix.py](transf_qmix.py) script for an example of hyperparameter tuning using WANDB.
Binary file not shown.
Binary file not shown.
33 changes: 33 additions & 0 deletions baselines/QLearning/config/alg/transf_qmix_mpe.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"NUM_ENVS": 8
"N_MINI_UPDATES": 1
"NUM_STEPS": 25
"BUFFER_SIZE": 5000
"BUFFER_BATCH_SIZE": 32
"TOTAL_TIMESTEPS": 2050000
"AGENT_INIT_SCALE": 1.
"AGENT_HIDDEN_DIM": 32
"AGENT_TRANSF_NUM_LAYERS": 2
"AGENT_TRANSF_NUM_HEADS": 8
"AGENT_TRANSF_DIM_FF": 128
"MIXER_INIT_SCALE": 0.01
"MIXER_TRANSF_NUM_LAYERS": 2
"MIXER_TRANSF_NUM_HEADS": 8
"MIXER_TRANSF_DIM_FF": 128
"USE_FAST_ATTENTION": True
"SCALE_INPUTS": True
"EMBEDDER_USE_RELU": False
"EPSILON_START": 1.0
"EPSILON_FINISH": 0.05
"EPSILON_ANNEAL_TIME": 100000
"MAX_GRAD_NORM": 10.
"TARGET_UPDATE_INTERVAL": 200
"LR": 0.001
"LR_DECAY_TYPE":
"EPS_ADAM": 0.000001
"TD_LAMBDA_LOSS": True
"TD_LAMBDA": 0.6
"GAMMA": 0.9
"VERBOSE": False
"WANDB_ONLINE_REPORT": True
"NUM_TEST_EPISODES": 32
"TEST_INTERVAL": 50000
39 changes: 39 additions & 0 deletions baselines/QLearning/config/alg/transf_qmix_smax.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Since it's more expensive to step a transformer than a rnn in an env,
# more parallel environments, together with more network updates per episode to balance
# the training. The total number of timesteps is decreased so that the total number of network updates
# is the same of qmix.
"NUM_ENVS": 16
"N_MINI_UPDATES": 4
"NUM_STEPS": 128
"BUFFER_SIZE": 3000
"BUFFER_BATCH_SIZE": 32
"TOTAL_TIMESTEPS": 1.e+7
"AGENT_INIT_SCALE": 1.
"AGENT_HIDDEN_DIM": 32
"AGENT_TRANSF_NUM_LAYERS": 2
"AGENT_TRANSF_NUM_HEADS": 8
"AGENT_TRANSF_DIM_FF": 128
"MIXER_INIT_SCALE": 1.
"MIXER_TRANSF_NUM_LAYERS": 2
"MIXER_TRANSF_NUM_HEADS": 8
"MIXER_TRANSF_DIM_FF": 128
"USE_FAST_ATTENTION": True # assumes you have a fast_attention.py file accesible by the training script
"SCALE_INPUTS": True # applies batch normalization to the obs vectors
"EMBEDDER_USE_RELU": True # applies relu on the embeddings
"EPSILON_START": 1.0
"EPSILON_FINISH": 0.05
"EPSILON_ANNEAL_TIME": 100000
"MAX_GRAD_NORM": 1.
"TARGET_UPDATE_INTERVAL": 10
"LR": 0.005
"LR_DECAY_TYPE": 'exp' # can be exp (exponential), cos (cosine), linear (linear) or None (static)
"LR_EXP_DECAY_RATE": 0.00002 # applies only to exponential decay
"LR_WARMUP": 10 # applies only to cosine decay
"EPS_ADAM": 0.0000000001
"TD_LAMBDA_LOSS": False
"TD_LAMBDA": 0.6
"GAMMA": 0.99
"VERBOSE": False
"WANDB_ONLINE_REPORT": True
"NUM_TEST_EPISODES": 32
"TEST_INTERVAL": 100000
4 changes: 2 additions & 2 deletions baselines/QLearning/config/config.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# experiment params
"NUM_SEEDS": 2
"SEED": 30
"SEED": 0

# wandb params
"ENTITY": ""
"PROJECT": "jaxMARL"
"PROJECT": ""
"WANDB_MODE": "disabled"

# where to save the params (if None, will not save)
Expand Down
Loading