Skip to content

Latest commit

 

History

History
213 lines (162 loc) · 17.3 KB

README.md

File metadata and controls

213 lines (162 loc) · 17.3 KB

Mava logo

Distributed Multi-Agent Reinforcement Learning in JAX

Python Version Tests License Ruff MyPy ArXiv Collab

Welcome to Mava! 🦁

Mava allows researchers to experiment with multi-agent reinforcement learning (MARL) at lightning speed. The single-file JAX implementations are built for rapid research iteration - hack, modify, and test new ideas fast. Our state-of-the-art algorithms scale seamlessly across devices. Created for researchers, by The Research Team at InstaDeep.

Highlights 🦜

  • 🥑 Implementations of MARL algorithms: Implementations of current state-of-the-art MARL algorithms that are distributed and effectively make use of available accelerators.
  • 🍬 Environment Wrappers: We provide first class support to a few JAX based MARL environment suites through the use of wrappers, however new environments can be easily added by using existing wrappers as a guide.
  • 🧪 Statistically robust evaluation: Mava natively supports logging to json files which adhere to the standard suggested by Gorsane et al. (2022). This enables easy downstream experiment plotting and aggregation using the tools found in the MARL-eval library.
  • 🖥️ JAX Distrubution Architectures for Reinforcement Learning: Mava supports both Podracer architectures for scaling RL systems. The first of these is Anakin, which can be used when environments are written in JAX. This enables end-to-end JIT compilation of the full MARL training loop for fast experiment run times on hardware accelerators. The second is Sebulba, which can be used when environments are not written in JAX. Sebulba is particularly useful when running RL experiments where a hardware accelerator can interact with many CPU cores at a time.
  • Blazingly fast experiments: All of the above allow for very quick runtime for our experiments, especially when compared to other non-JAX based MARL libraries.

Installation 🎬

At the moment Mava is not meant to be installed as a library, but rather to be used as a research tool. We recommend cloning the Mava repo and pip installing as follows:

git clone https://github.com/instadeepai/mava.git
cd mava
pip install -e .

We have tested Mava on Python 3.11 and 3.12, but earlier versions may also work. Specifically, we use Python 3.10 for the Quickstart notebook on Google Colab since Colab uses Python 3.10 by default. Note that because the installation of JAX differs depending on your hardware accelerator, we advise users to explicitly install the correct JAX version (see the official installation guide). For more in-depth installation guides including Docker builds and virtual environments, please see our detailed installation guide.

Getting started ⚡

To get started with training your first Mava system, simply run one of the system files:

python mava/systems/ppo/anakin/ff_ippo.py

Mava makes use of Hydra for config management. In order to see our default system configs please see the mava/configs/ directory. A benefit of Hydra is that configs can either be set in config yaml files or overwritten from the terminal on the fly. For an example of running a system on the Level-based Foraging environment, the above code can simply be adapted as follows:

python mava/systems/ppo/anakin/ff_ippo.py env=lbf

Different scenarios can also be run by making the following config updates from the terminal:

python mava/systems/ff_ippo.py env=rware env/scenario=tiny-4ag

Additionally, we also have a Quickstart notebook that can be used to quickly create and train your first multi-agent system.

Algorithms

Mava has implementations of multiple on- and off-policy multi-agent algorithms that follow the independent learners (IL), centralised training with decentralised execution (CTDE) and heterogeneous agent learning paradigms. Aside from MARL learning paradigms, we also include implementations which follow the Anakin and Sebulba architectures to enable scalable training by default. The architecture that is relevant for a given problem depends on whether the environment being used in written in JAX or not. For more information on these paradigms, please see here.

Algorithm Variants Continuous Discrete Anakin Sebulba Paper Docs
PPO ff_ippo.py Link Link
ff_mappo.py Link Link
rec_ippo.py Link Link
rec_mappo.py Link Link
Q Learning rec_iql.py Link Link
rec_qmix.py Link Link
SAC ff_isac.py Link Link
ff_masac.py Link
ff_hasac.py Link Link
MAT mat.py Link Link
Sable ff_sable.py Link Link
rec_sable.py Link Link

Environments

These are the environments which Mava supports out of the box, to add a new environment, please use the existing wrapper implementations as an example. We also indicate whether the environment is implemented in JAX or not. JAX-based environments can be used with algorithms that follow the Anakin distribution architecture, while non-JAX environments can be used with algorithms following the Sebulba architecture.

Environment Action space JAX Non-JAX Paper JAX Source Non-JAX Source
Mulit-Robot Warehouse Discrete Link Link Link
Level-based Foraging Discrete Link Link Link
StarCraft Multi-Agent Challenge Discrete Link Link Link
Multi-Agent Brax Continuous Link Link
Matrax Discrete Link Link
Multi Particle Environments Discrete/Continuous Link Link

Performance and Speed 🚀

We have performed a rigorous benchmark across 45 different scenarios and 6 different environment suites to validate the performance of Mava's algorithm implementations. For more detailed results please see our Sable paper and for all hyperparameters, please see the following website.

Mava performance across 15 Robot Warehouse environments Mava performance across 7 Level Based Foraging environments Mava performance across 11 Smax environments Mava performance across 4 Conneector environments Mava performance across 5 MaBrax environments Mava performance across 3 Multi-Particle environments
Legend

Mava's algorithm performance: Each algorithm was tuned for 40 trials with the TPE optimizer and benchmarked over 10 seeds for each scenario. Environments from top left Multi-Robot Warehouse (aggregated over 15 scenarios) Level-based Foraging (aggregated over 7 scenarios) StarCraft Multi-Agent Challenge in JAX (aggregated over 11 scenarios) Connector (aggregated over 4 scenarios) Multi-Agent Brax (aggregated over 5 scenarios) Multi Particle Environments (aggregated over 3 scenarios)

Code Philosophy 🧘

The original code in Mava was adapted from PureJaxRL which provides high-quality single-file implementations with research-friendly features. In turn, PureJaxRL is inspired by the code philosophy from CleanRL. Along this vein of easy-to-use and understandable RL codebases, Mava is not designed to be a modular library and is not meant to be imported. Our repository focuses on simplicity and clarity in its implementations while utilising the advantages offered by JAX such as pmap and vmap, making it an excellent resource for researchers and practitioners to build upon. A notable difference between Mava and CleanRL is that Mava creates small utilities for heavily re-used elements, such as networks and logging, we've found that this, in addition to Hydra configs, greatly improves the readability of the algorithms.

Contributing 🤝

Please read our contributing docs for details on how to submit pull requests, our Contributor License Agreement and community guidelines.

Roadmap 🛤️

We plan to iteratively expand Mava in the following increments:

  • Support for more environments.
  • More robust recurrent systems.
  • Support for non JAX-based environments.
  • Add Sebulba versions of more algorithms.
  • Support for off-policy algorithms.
  • Continuous action space environments and algorithms.
  • Allow systems to easily scale across multiple TPUs/GPUs.

Please do follow along as we develop this next phase!

See Also 🔎

InstaDeep's MARL ecosystem in JAX. In particular, we suggest users check out the following sister repositories:

  • 🔌 OG-MARL: datasets with baselines for offline MARL in JAX.
  • 🌴 Jumanji: a diverse suite of scalable reinforcement learning environments in JAX.
  • 😎 Matrax: a collection of matrix games in JAX.
  • Flashbax: accelerated replay buffers in JAX.
  • 📈 MARL-eval: standardised experiment data aggregation and visualisation for MARL.

Related. Other libraries related to accelerated MARL in JAX.

  • 🦊 JaxMARL: accelerated MARL environments with baselines in JAX.
  • 🌀 DeepMind Anakin for the Anakin podracer architecture to train RL agents at scale.
  • ♟️ Pgx: JAX implementations of classic board games, such as Chess, Go and Shogi.
  • 🔼 Minimax: JAX implementations of autocurricula baselines for RL.

Citing Mava 📚

If you use Mava in your work, please cite the accompanying technical report:

@article{dekock2023mava,
    title={Mava: a research library for distributed multi-agent reinforcement learning in JAX},
    author={Ruan de Kock and Omayma Mahjoub and Sasha Abramowitz and Wiem Khlifi and Callum Rhys Tilbury
    and Claude Formanek and Andries P. Smit and Arnu Pretorius},
    year={2023},
    journal={arXiv preprint arXiv:2107.01460},
    url={https://arxiv.org/pdf/2107.01460.pdf},
}

Acknowledgements 🙏

We would like to thank all the authors who contributed to the previous TF version of Mava: Kale-ab Tessera, St John Grimbly, Kevin Eloff, Siphelele Danisa, Lawrence Francis, Jonathan Shock, Herman Kamper, Willie Brink, Herman Engelbrecht, Alexandre Laterre, Karim Beguir. Their contributions can be found in our TF technical report.

The development of Mava was supported with Cloud TPUs from Google's TPU Research Cloud (TRC) 🌤.