World Model implementation with PPO in PyTorch. This repository builds on world-models for the VAE and MDN-RNN implementations and firedup for the PPO optimization of the Controller network. Check the firedup setup file for requirements.
First save a number of the CarRacing-v0 Gym environment rollouts used for the train and test sets in the data_dir
folder:
python env/carracing.py --data_dir './env/data' ---n_fold_train 20 ---n_fold_test 1
Then train the Variational Autoencoder (VAE) using the stored rollouts:
from vae.train import run
run(data_dir='./env/data', vae_dir='./vae/model', epochs=5)
Using the pretrained VAE, we train the Recurrent Mixture Density Network (MDN-RNN) model to predict the future latent state:
from mdnrnn.train import run
run(data_dir='./env/data', vae_dir='./vae/model', mdnrnn_dir='./mdnrnn/model', epochs=5)
We can finally train the Controller network which steers the car with PPO:
from rl.algos.ppo.ppo import run
run(exp_name='carracing_ppo', epochs=100)