This directory contains code to reproduce the main results from the paper titled "The Benefits of Model-Based Generalization in Reinforcement Learning" by Kenny Young, Aditya Ramesh, Louis Kirsch and Jürgen Schmidhuber.
Model-Based Reinforcement Learning (RL) is widely believed to have the potential to improve sample efficiency by allowing an agent to synthesize large amounts of imagined experience. Experience Replay (ER) can be considered a simple kind of model, which has proved extremely effective at improving the stability and efficiency of deep RL. In principle, a learned parametric model could improve on ER by generalizing from real experience to augment the dataset with additional plausible experience. However, owing to the many design choices involved in empirically successful algorithms, it can be very hard to establish where the benefits are actually coming from. Here, we provide theoretical and empirical insight into when, and how, we can expect data generated by a learned model to be useful. First, we provide a general theorem motivating how learning a model as an intermediate step can narrow down the set of possible value functions more than learning a value function directly from data using the Bellman equation. Second, we provide an illustrative example showing empirically how a similar effect occurs in a more concrete setting with neural network function approximation. Finally, we provide extensive experiments showing the benefit of model-based learning for online RL in environments with combinatorial complexity, but factored structure that allows a learned model to generalize. In these experiments, we take care to control for other factors in order to isolate, insofar as possible, the benefit of using experience generated by a learned model relative to ER alone.
This directory contains scripts to train each of the reinforcement learning agents tested in Figure 4 of the paper. The is written in python and requires JAX and haiku. The training scripts can be run with commands like the following:
python train_ER_DQN.py --seed=0 --config=configs/ER.json --output=ER
python train_latent_model_DQN.py --seed=0 --config=configs/categorical_latent_model.json --output=categorical_latent_model
python train_latent_model_DQN.py --seed=0 --config=configs/gaussian_latent_model.json --output=gaussian_latent_model
python train_perfect_model_DQN.py --seed=0 --config=configs/perfect_model.json --output=perfect_model
python train_simple_model_DQN.py --seed=0 --config=configs/simple_model.json --output=simple_model
The included config files will reproduce the High Data ProcMaze results with GridSize 4 and with tuned step-size and softmax exploration parameters. Results for other experiments can be reproduced by modifying the environment and the other hyperparameters in the config file appropriately. For the ButtonGrid and PanFlute experiments set "episodic_env":false
in the config file to indicate that the environment is continuing. For the low data regime change the following config options:
"target_update_frequency":10
"updates_per_step":10
"num_steps":100000
"eval_frequency":50
Upon completion, each training script will produce a file <output>.out, where <output> is the string passed with the --output flag containing evaluation metrics for processing. If "save_params":true
is set, the file <output>.params will contain the final trained agent parameters. The data in the resulting <output>.out files can be plotted using the included script plot_metrics.py. For example, to plot the returns resulting from each of the above training scripts, run the following command:
python3 plot_metric.py --data ER.out categorical_latent_model.out gaussian_latent_model.out perfect_model.out simple_model.out
Each script is written to allow multiple random repeats of an experiment to be run in parallel on a single GPU using automated batching in JAX. In each file, the main loop is implemented by the function agent_environment_interaction_loop_function
which takes a set of arguments representing the current state of an experiment and returns their updated values along with some metrics computed along the way. We apply the JAX function vmap to produce a version of this interaction loop which processes multiple runs in parallel. The number of parallel seeds to run is specified by "num_seeds" in the config file and is set to 30 in each of the included config files.
The following table provides the optimal hyperparameters for each Agent, Environment, Data Regime combination, which were used in Figure 4.
Agent | Environment | Data Regime | Step-Size | Temperature |
---|---|---|---|---|
Simple Model | ProcMaze | High Data | 0.0002 | 0.8 |
Low Data | 2.5e-05 | 0.4 | ||
ButtonGrid | High Data | 0.0002 | 1.6 | |
Low Data | 5e-05 | 0.4 | ||
PanFlute | High Data | 0.0001 | 0.4 | |
Low Data | 0.0002 | 0.2 | ||
Perfect Model | ProcMaze | High Data | 0.0001 | 0.4 |
Low Data | 0.0004 | 0.4 | ||
ButtonGrid | High Data | 0.0002 | 0.0125 | |
Low Data | 0.0002 | 0.0125 | ||
PanFlute | High Data | 0.0002 | 0.1 | |
Low Data | 0.0001 | 0.1 | ||
ER | ProcMaze | High Data | 5e-05 | 0.1 |
Low Data | 2.5e-05 | 1.6 | ||
ButtonGrid | High Data | 5e-05 | 0.0125 | |
Low Data | 2.5e-05 | 0.1 | ||
PanFlute | High Data | 0.0001 | 0.1 | |
Low Data | 3.125e-06 | 0.8 | ||
Categorical Latent Model | ProcMaze | High Data | 0.0002 | 1.6 |
Low Data | 0.0001 | 0.8 | ||
ButtonGrid | High Data | 0.0001 | 0.05 | |
Low Data | 0.0002 | 0.1 | ||
PanFlute | High Data | 0.0008 | 0.05 | |
Low Data | 0.0016 | 0.05 | ||
Gaussian Latent Model | ProcMaze | High Data | 3.125e-06 | 12.8 |
Low Data | 3.125e-06 | 0.8 | ||
ButtonGrid | High Data | 1.25e-05 | 0.1 | |
Low Data | 2.5e-05 | 0.8 | ||
PanFlute | High Data | 0.0002 | 0.1 | |
Low Data | 0.0002 | 0.1 |
These hyperparameters resulted from the grid search described in the Experiment Design portion of Section 4 in the paper.