A Deep Reinforcement Learning codebase in TensorFlow 2.0 with an unified, flexible and highly customizable structure for fast prototyping.
Features | Unstable Baselines | Stable-Baselines3 | OpenAI Baselines |
---|---|---|---|
State of the art RL methods | ➖ (1) | ✔️ | ✔️ |
Documentation | ❌ | ✔️ | ❌ |
Custom callback (2) | ❌ | 🤮 | ➖ |
TensorFlow 2.0 support | ✔️ | ❌ | ❌ |
Clean, elegant code | ✔️ | ❌ | ❌ |
Easy to trace, customize | ✔️ | ❌ (3) | ❌ (3) |
Standalone implementations | ✔️ | ➖ | ❌ (4) |
(1) Currently only support DQN, C51, PPO, TD3, ...etc. We are still working on other algorithms.
(2) For example, in Stable-Baselines, you need to write this disgusting custom callback to save the best-performed model 🤮, while in Unstable Baselines, they are automatically saved.
(3) If you have traced Stable-baselines or OpenAI/baselines once, you'll never do that again.
(4) Many cross-dependencies across all algos make the code very hard to trace, for example baselines/common/policies.py, baselines/a2c/a2c.py.... Great job! OpenAI!:cat:
We don't have any documentation yet.
Basic requirements:
- Python >= 3.6
- TensorFlow (CPU/GPU) >= 2.3.0
You can install from PyPI
$ pip install unstable_baselines
Or you can also install the latest version from this repository
$ pip install git+https://github.com/Ending2015a/unstable_baselines.git@master
Done! Now, you can
- Go through the Quick Start section
- Or run the example codes in example folder.
Algorithm | Box |
Discrete |
MultiDiscrete |
MultiBinary |
---|---|---|---|---|
DQN | ❌ | ✔️ | ❌ | ❌ |
PPO | ✔️ | ✔️ | ❌ | ❌ |
TD3 | ✔️ | ❌ | ❌ | ❌ |
SD3 | ✔️ | ❌ | ❌ | ❌ |
- 2021.09.17: DQN supports
- Multi-step learning
- Prioritized experience replay: arXiv:1511.05952
- Dueling network: arXiv:1511.06581
- 2021.04.19: Implemented DQN
- From paper: arXiv:1509.06461
- 2021.03.27: PPO support continuous (Box) action space
- 2021.03.23: Implemented SD3
- From paper: arXiv:2010.09177
- 2021.03.20: Implemented TD3
- From paper: arXiv:1802.09477
- 2021.03.10: Implemented PPO
- From paper: arXiv:1707.06347
Algorithm | Box |
Discrete |
MultiDiscrete |
MultiBinary |
---|---|---|---|---|
C51 | ❌ | ✔️ | ❌ | ❌ |
QRDQN | ❌ | ✔️ | ❌ | ❌ |
IQN | ❌ | ✔️ | ❌ | ❌ |
- 2021.04.28: Implemented IQN
- From paper: arXiv:1806.06923
- 2021.04.21: Implemented QRDQN
- From paper: arXiv:1710.10044
- 2021.04.20: Implemented C51
- From paper: arXiv:1707.06887
This example shows how to train a PPO agent to play CartPole-v0
. You can find the full scripts in example/cartpole/train_ppo.py.
First, import dependencies
import gym
import unstable_baselines as ub
from unstable_baselines.algo.ppo import PPO
Create environments for training and evaluation
# create environments
env = ub.envs.VecEnv([gym.make('CartPole-v0') for _ in range(10)])
eval_env = gym.make('CartPole-v0')
Create a PPO model and train it
model = PPO(
env,
learning_rate=1e-3,
gamma=0.8,
batch_size=128,
n_steps=500
).learn( # train for 20000 steps
20000,
verbose=1
)
Save and load the trained model
model.save('./my_ppo_model')
model = PPO.load('./my_ppo_model')
Evaluate the training results
model.eval(eval_env, 20, 200, render=True)
# don't forget to close the environments!
env.close()
eval_env.close()
More examples:
- 2021.05.22: Add benchmarks
- 2021.04.27: Update to framework v2: supports saving/loading the best performed checkpoints.