Skip to content

Commit

Permalink
Merge branch 'main' into qlearning
Browse files Browse the repository at this point in the history
  • Loading branch information
mttga committed Dec 7, 2023
2 parents 703c043 + eb83d25 commit 90d5e69
Show file tree
Hide file tree
Showing 24 changed files with 990 additions and 92 deletions.
35 changes: 35 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
name: Tests
on: [push, pull_request]

jobs:
build:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: true
max-parallel: 15
matrix:
# os: [ubuntu-latest, macos-latest, windows-latest, macos-13-xlarge]
# For Apple Silicon: https://github.com/actions/runner-images/issues/8439
os: [ubuntu-latest, macos-latest, windows-latest]
python-version: ['3.9']
defaults:
run:
shell: bash
steps:
- name: Check out repository
uses: actions/checkout@v3

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
cache: 'pip'

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pytest
pip install -e '.[dev]'
- name: Run pytest
run: pytest tests
14 changes: 13 additions & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,17 @@ Please help build JaxMARL into the best possible tool for the MARL community.
We actively welcome your contributions!
- If adding an environment or algorithm, check with us that it is the right fit for the repo.
- Fork the repo and create your branch from main.
- Add tests, or show proof that the environment/algorithm works.
- Add tests, or show proof that the environment/algorithm works. The exact requirements are listed below.
- Add a README explaining your environment/algorithm.

**Environment Requirements**
- Unit tests (in `pytest` format) demonstrating correctness. If applicable, show correspondence to existing implementations. If transitions match, write a unit test to demonstrate this ([example](https://github.com/FLAIROx/JaxMARL/blob/be9fe46e52a736f8dd766acf98b4e0803f199dd2/tests/mpe/test_mpe.py)).
- Training results for IPPO and MAPPO over 20 seeds, with configuration files saved to `baselines`.

**Algorithm Requirements**
- Performance results on at least 3 environments (e.g. SMAX, MABrax & Overcooked) with at least 20 seeds per result.
- If applicable, compare performance results to existing implementations to demonstrate correctness.

## Bug reports

We use Github's issues to track bugs, just open a new issue! Great Bug Reports tend to have:
Expand All @@ -24,3 +32,7 @@ We use Github's issues to track bugs, just open a new issue! Great Bug Reports t

All contributions will fall under the project's original license.

## Roadmap

Some improvements we would like to see implemented:
- [ ] improved RNN implementations. In the current implementation, the hidden size is dependent on "NUM_STEPS", it should be made independent. Speed could also be improved with an S5 architecture.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ We follow CleanRL's philosophy of providing single file implementations which ca
| IQL | [Paper](https://arxiv.org/abs/1312.5602v1) | [Source](https://github.com/FLAIROx/JaxMARL/tree/main/baselines/QLearning) |
| VDN | [Paper](https://arxiv.org/abs/1706.05296) | [Source](https://github.com/FLAIROx/JaxMARL/tree/main/baselines/QLearning) |
| QMIX | [Paper](https://arxiv.org/abs/1803.11485) | [Source](https://github.com/FLAIROx/JaxMARL/tree/main/baselines/QLearning) |
| SHAQ | [Paper](https://arxiv.org/abs/2105.15013) | [Source](https://github.com/FLAIROx/JaxMARL/tree/main/baselines/QLearning) |

<h2 name="install" id="install">Installation 🧗 </h2>

Expand Down Expand Up @@ -124,7 +125,7 @@ obs, state, reward, done, infos = env.step(key_step, state, actions)
```

## Contributing 🔨
Please contribute! Please take a look at our [contributing guide](https://github.com/FLAIROx/JaxMARL/blob/main/CONTRIBUTING.md) for how to add an environment/algorithm or submit a bug report.
Please contribute! Please take a look at our [contributing guide](https://github.com/FLAIROx/JaxMARL/blob/main/CONTRIBUTING.md) for how to add an environment/algorithm or submit a bug report. Our roadmap also lives there.

<h2 name="cite" id="cite">Citing JaxMARL 📜 </h2>
If you use JaxMARL in your work, please cite us as follows:
Expand All @@ -151,3 +152,4 @@ JAX-native environments:
- [Jumanji](https://github.com/instadeepai/jumanji): A diverse set of environments ranging from simple games to NP-hard combinatorial problems.
- [Pgx](https://github.com/sotetsuk/pgx): JAX implementations of classic board games, such as Chess, Go and Shogi.
- [Brax](https://github.com/google/brax): A fully differentiable physics engine written in JAX, features continuous control tasks.
- [XLand-MiniGrid](https://github.com/corl-team/xland-minigrid): Meta-RL gridworld environments inspired by XLand and MiniGrid.
20 changes: 10 additions & 10 deletions baselines/IPPO/config/ippo_rnn_smax.yaml
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
"LR": 4e-3
"NUM_ENVS": 64
"LR": 0.004
"NUM_ENVS": 128
"NUM_STEPS": 128
"TOTAL_TIMESTEPS": 5e5
"UPDATE_EPOCHS": 2
"NUM_MINIBATCHES": 2
"TOTAL_TIMESTEPS": 1e7
"UPDATE_EPOCHS": 4
"NUM_MINIBATCHES": 4
"GAMMA": 0.99
"GAE_LAMBDA": 0.95
"CLIP_EPS": 0.04
"CLIP_EPS": 0.05
"SCALE_CLIP_EPS": False
"ENT_COEF": 0.00
"ENT_COEF": 0.01
"VF_COEF": 0.5
"MAX_GRAD_NORM": 0.5
"MAX_GRAD_NORM": 0.25
"ACTIVATION": "relu"
"ENV_NAME": "HeuristicEnemySMAX"
"MAP_NAME": "2s3z"
"SEED": 30
"SEED": 0
"ENV_KWARGS":
"see_enemy_actions": True
"walls_cause_death": True
Expand All @@ -24,4 +24,4 @@
# WandB Params
"ENTITY": ""
"PROJECT": "jaxmarl-smax"
"WANDB_MODE" : "disabled"
"WANDB_MODE" : "disabled"
2 changes: 1 addition & 1 deletion baselines/IPPO/ippo_ff_hanabi.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def make_train(config):
env = jaxmarl.make(config["ENV_NAME"], **config["ENV_KWARGS"])
config["NUM_ACTORS"] = env.num_agents * config["NUM_ENVS"]
config["NUM_UPDATES"] = (
config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ACTORS"]
config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"]
)
config["MINIBATCH_SIZE"] = (
config["NUM_ACTORS"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"]
Expand Down
18 changes: 9 additions & 9 deletions baselines/IPPO/ippo_ff_mabrax.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def make_train(config):
env = jaxmarl.make(config["ENV_NAME"], **config["ENV_KWARGS"])
config["NUM_ACTORS"] = env.num_agents * config["NUM_ENVS"]
config["NUM_UPDATES"] = (
config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ACTORS"] # Q: NUM_ACTORS CORRECT?
config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"]
)
config["MINIBATCH_SIZE"] = (
config["NUM_ACTORS"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"]
Expand Down Expand Up @@ -301,22 +301,22 @@ def main(config):
train_jit = jax.jit(make_train(config), device=jax.devices()[config["DEVICE"]])
out = train_jit(rng)

updates_x = jnp.arange(out["metrics"]["returned_episode_returns"].squeeze().shape[0])
'''updates_x = jnp.arange(out["metrics"]["returned_episode_returns"].squeeze().shape[0])
print('updates x', updates_x.shape)
print('metrics shape', out["metrics"]["returned_episode_returns"].shape)
returns_table = jnp.stack([updates_x, out["metrics"]["returned_episode_returns"].mean(-1).squeeze()], axis=1)
returns_table = wandb.Table(data=returns_table.tolist(), columns=["updates", "returns"])
wandb.log({
"returns_plot": wandb.plot.line(returns_table, "updates", "returns", title="returns_vs_updates"),
"returns": out["metrics"]["returned_episode_returns"].mean()
})
})'''

# mean_returns = out["metrics"]["returned_episode_returns"].mean(-1).reshape(-1)
# x = np.arange(len(mean_returns)) * config["NUM_ACTORS"]
# plt.plot(x, mean_returns)
# plt.xlabel("Timestep")
# plt.ylabel("Return")
# plt.savefig(f'mabrax_ippo_ret.png')'''
mean_returns = out["metrics"]["returned_episode_returns"].mean(-1).reshape(-1)
x = np.arange(len(mean_returns)) * config["NUM_ACTORS"]
plt.plot(x, mean_returns)
plt.xlabel("Timestep")
plt.ylabel("Return")
plt.savefig(f'mabrax_ippo_ret.png')

# import pdb; pdb.set_trace()

Expand Down
2 changes: 1 addition & 1 deletion baselines/IPPO/ippo_ff_mpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def make_train(config):
env = jaxmarl.make(config["ENV_NAME"], **config["ENV_KWARGS"])
config["NUM_ACTORS"] = env.num_agents * config["NUM_ENVS"]
config["NUM_UPDATES"] = (
config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ACTORS"] # Q: NUM_ACTORS CORRECT?
config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"]
)
config["MINIBATCH_SIZE"] = (
config["NUM_ACTORS"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"]
Expand Down
2 changes: 1 addition & 1 deletion baselines/IPPO/ippo_ff_mpe_facmac.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def make_train(config):
env = jaxmarl.make(config["ENV_NAME"], **config["ENV_KWARGS"])
config["NUM_ACTORS"] = env.num_agents * config["NUM_ENVS"]
config["NUM_UPDATES"] = (
config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ACTORS"] # Q: NUM_ACTORS CORRECT?
config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"]
)
config["MINIBATCH_SIZE"] = (
config["NUM_ACTORS"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"]
Expand Down
2 changes: 1 addition & 1 deletion baselines/IPPO/ippo_rnn_hanabi.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def make_train(config):
# env = HanabiGame()
config["NUM_ACTORS"] = env.num_agents * config["NUM_ENVS"]
config["NUM_UPDATES"] = (
config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ACTORS"]
config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"]
)
config["MINIBATCH_SIZE"] = (
config["NUM_ACTORS"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"]
Expand Down
4 changes: 2 additions & 2 deletions baselines/MAPPO/config/mappo_homogenous_rnn_smax.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"ACTIVATION": "relu"
"OBS_WITH_AGENT_ID": True
"ENV_NAME": "HeuristicEnemySMAX"
"MAP_NAME": "27m_vs_30m"
"MAP_NAME": "2s3z"
"SEED": 0
"ENV_KWARGS":
"see_enemy_actions": True
Expand All @@ -25,4 +25,4 @@
# WandB Params
"WANDB_MODE": "disabled"
"ENTITY": ""
"PROJECT": "jaxmarl-smax"
"PROJECT": "jaxmarl-smax"
9 changes: 8 additions & 1 deletion baselines/QLearning/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
# QLearning Baselines

*Pure Jax implementation of **IQL** (Independent Q-Learners), **VDN** (Value Decomposition Network), and **QMix**. These implementations follow the original [Pymarl](https://github.com/oxwhirl/pymarl/blob/master/src/learners/q_learner.py) codebase.*

Pure JAX implementations of:
* IQL (Independent Q-Learners)
* VDN (Value Decomposition Network)
* QMIX
* SHAQ (Incorporating Shapley Value Theory into Multi-Agent Q-Learning)

The first three are follow the original [Pymarl](https://github.com/oxwhirl/pymarl/blob/master/src/learners/q_learner.py) codebase while SHAQ follows the [paper code](https://github.com/hsvgbkhgbv/shapley-q-learning)

```
⚠️ The implementations were tested with Python 3.9 and Jax 0.4.11.
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
30 changes: 30 additions & 0 deletions baselines/QLearning/config/alg/shaq_mpe.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"NUM_ENVS": 8
"BUFFER_SIZE": 5000
"BUFFER_BATCH_SIZE": 32
"TOTAL_TIMESTEPS": 2050000
"AGENT_HIDDEN_DIM": 64
"AGENT_INIT_SCALE": 2.
"PARAMETERS_SHARING": True
"EPSILON_START": 1.0
"EPSILON_FINISH": 0.05
"EPSILON_ANNEAL_TIME": 100000
"MIXER_EMBEDDING_DIM": 32
"MIXER_HYPERNET_HIDDEN_DIM": 64
"MIXER_INIT_SCALE": 0.00001
"MAX_GRAD_NORM": 25
"TARGET_UPDATE_INTERVAL": 200
"LR": 0.005
"LR_LINEAR_DECAY": True
"EPS_ADAM": 0.001
"WEIGHT_DECAY_ADAM": 0.00001
"TD_LAMBDA_LOSS": True
"TD_LAMBDA": 0.6
"GAMMA": 0.9
"VERBOSE": False
"WANDB_ONLINE_REPORT": True
"NUM_TEST_EPISODES": 32
"TEST_INTERVAL": 50000
"SAMPLE_SIZE": 5
"MANUAL_ALPHA_ESTIMATES": null
"LR_ALPHA": 0.001
"ALG_NAME": shaq
30 changes: 30 additions & 0 deletions baselines/QLearning/config/alg/shaq_smax.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"NUM_ENVS": 8
"BUFFER_SIZE": 3000
"BUFFER_BATCH_SIZE": 32
"TOTAL_TIMESTEPS": 20000000
"AGENT_HIDDEN_DIM": 256
"AGENT_INIT_SCALE": 1.
"PARAMETERS_SHARING": True
"EPSILON_START": 1.0
"EPSILON_FINISH": 0.05
"EPSILON_ANNEAL_TIME": 100000
"MIXER_EMBEDDING_DIM": 64
"MIXER_HYPERNET_HIDDEN_DIM": 256
"MIXER_INIT_SCALE": 0.001
"MAX_GRAD_NORM": 10
"TARGET_UPDATE_INTERVAL": 200
"LR": 0.001
"LR_LINEAR_DECAY": False
"EPS_ADAM": 0.00001
"WEIGHT_DECAY_ADAM": 0.000001
"TD_LAMBDA_LOSS": False
"TD_LAMBDA": 0.6
"GAMMA": 0.99
"VERBOSE": False
"WANDB_ONLINE_REPORT": True
"NUM_TEST_EPISODES": 32
"TEST_INTERVAL": 100000
"SAMPLE_SIZE": 1 # suggest choosing from 1/5/10
"MANUAL_ALPHA_ESTIMATES": null
"LR_ALPHA": 0.0005
"ALG_NAME": shaq
6 changes: 3 additions & 3 deletions baselines/QLearning/config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
"SEED": 30

# wandb params
"ENTITY": "mttga"
"PROJECT": "smax"
"WANDB_MODE": "online"
"ENTITY": ""
"PROJECT": "jaxMARL"
"WANDB_MODE": "disabled"

# where to save the params (if None, will not save)
"SAVE_PATH": "baselines/QLearning/checkpoints"
Loading

0 comments on commit 90d5e69

Please sign in to comment.