Skip to content

Code for "The Benefits of Model-Based Generalization in Reinforcement Learning"

License

Notifications You must be signed in to change notification settings

nemo1120/Model_Generalization_Code_supplement

 
 

Repository files navigation

Code Supplement for "The Benefits of Model-Based Generalization in Reinforcement Learning"

This directory contains code to reproduce the main results from the paper titled "The Benefits of Model-Based Generalization in Reinforcement Learning" by Kenny Young, Aditya Ramesh, Louis Kirsch and Jürgen Schmidhuber.

Model-Based Reinforcement Learning (RL) is widely believed to have the potential to improve sample efficiency by allowing an agent to synthesize large amounts of imagined experience. Experience Replay (ER) can be considered a simple kind of model, which has proved extremely effective at improving the stability and efficiency of deep RL. In principle, a learned parametric model could improve on ER by generalizing from real experience to augment the dataset with additional plausible experience. However, owing to the many design choices involved in empirically successful algorithms, it can be very hard to establish where the benefits are actually coming from. Here, we provide theoretical and empirical insight into when, and how, we can expect data generated by a learned model to be useful. First, we provide a general theorem motivating how learning a model as an intermediate step can narrow down the set of possible value functions more than learning a value function directly from data using the Bellman equation. Second, we provide an illustrative example showing empirically how a similar effect occurs in a more concrete setting with neural network function approximation. Finally, we provide extensive experiments showing the benefit of model-based learning for online RL in environments with combinatorial complexity, but factored structure that allows a learned model to generalize. In these experiments, we take care to control for other factors in order to isolate, insofar as possible, the benefit of using experience generated by a learned model relative to ER alone.

Reproducing the main results

This directory contains scripts to train each of the reinforcement learning agents tested in Figure 4 of the paper. The is written in python and requires JAX and haiku. The training scripts can be run with commands like the following:

python train_ER_DQN.py --seed=0 --config=configs/ER.json --output=ER
python train_latent_model_DQN.py --seed=0 --config=configs/categorical_latent_model.json --output=categorical_latent_model
python train_latent_model_DQN.py --seed=0 --config=configs/gaussian_latent_model.json --output=gaussian_latent_model
python train_perfect_model_DQN.py --seed=0 --config=configs/perfect_model.json --output=perfect_model
python train_simple_model_DQN.py --seed=0 --config=configs/simple_model.json --output=simple_model

The included config files will reproduce the High Data ProcMaze results with GridSize 4 and with tuned step-size and softmax exploration parameters. Results for other experiments can be reproduced by modifying the environment and the other hyperparameters in the config file appropriately. For the ButtonGrid and PanFlute experiments set "episodic_env":false in the config file to indicate that the environment is continuing. For the low data regime change the following config options:

"target_update_frequency":10
"updates_per_step":10
"num_steps":100000
"eval_frequency":50

Upon completion, each training script will produce a file <output>.out, where <output> is the string passed with the --output flag containing evaluation metrics for processing. If "save_params":true is set, the file <output>.params will contain the final trained agent parameters. The data in the resulting <output>.out files can be plotted using the included script plot_metrics.py. For example, to plot the returns resulting from each of the above training scripts, run the following command:

python3 plot_metric.py --data ER.out categorical_latent_model.out gaussian_latent_model.out perfect_model.out simple_model.out

Running Multiple Random Seeds in Parallel on a Single GPU

Each script is written to allow multiple random repeats of an experiment to be run in parallel on a single GPU using automated batching in JAX. In each file, the main loop is implemented by the function agent_environment_interaction_loop_function which takes a set of arguments representing the current state of an experiment and returns their updated values along with some metrics computed along the way. We apply the JAX function vmap to produce a version of this interaction loop which processes multiple runs in parallel. The number of parallel seeds to run is specified by "num_seeds" in the config file and is set to 30 in each of the included config files.

Optimal Hyperparameters for Figure 4

The following table provides the optimal hyperparameters for each Agent, Environment, Data Regime combination, which were used in Figure 4.

Agent Environment Data Regime Step-Size Temperature
Simple Model ProcMaze High Data 0.0002 0.8
Low Data 2.5e-05 0.4
ButtonGrid High Data 0.0002 1.6
Low Data 5e-05 0.4
PanFlute High Data 0.0001 0.4
Low Data 0.0002 0.2
Perfect Model ProcMaze High Data 0.0001 0.4
Low Data 0.0004 0.4
ButtonGrid High Data 0.0002 0.0125
Low Data 0.0002 0.0125
PanFlute High Data 0.0002 0.1
Low Data 0.0001 0.1
ER ProcMaze High Data 5e-05 0.1
Low Data 2.5e-05 1.6
ButtonGrid High Data 5e-05 0.0125
Low Data 2.5e-05 0.1
PanFlute High Data 0.0001 0.1
Low Data 3.125e-06 0.8
Categorical Latent Model ProcMaze High Data 0.0002 1.6
Low Data 0.0001 0.8
ButtonGrid High Data 0.0001 0.05
Low Data 0.0002 0.1
PanFlute High Data 0.0008 0.05
Low Data 0.0016 0.05
Gaussian Latent Model ProcMaze High Data 3.125e-06 12.8
Low Data 3.125e-06 0.8
ButtonGrid High Data 1.25e-05 0.1
Low Data 2.5e-05 0.8
PanFlute High Data 0.0002 0.1
Low Data 0.0002 0.1

These hyperparameters resulted from the grid search described in the Experiment Design portion of Section 4 in the paper.

About

Code for "The Benefits of Model-Based Generalization in Reinforcement Learning"

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%