This repository contains the code for the paper "Model-based Diffusion for Trajectory Optimization".
Model-based diffusion (MBD) is a novel diffusion-based trajectory optimization framework that employs a dynamics model to approximate the score function. MBD outperforms existing methods (including RL) in terms of sample efficiency and generalization.
To install the required packages, run the following command:
git clone --depth 1 git@github.com:LeCAR-Lab/model-based-diffusion.git
pip install -e .
To run model-based diffusion to optimize a trajectory, run the following command:
cd mbd/planners
python mbd_planner.py --env_name $ENV_NAME
where $ENV_NAME
is the name of the environment, you can choose from hopper
, halfcheetah
, walker2d
, ant
, humanoidrun
, humanoidstandup
, humanoidtrack
, car2d
, pushT
.
To run model-based diffusion combined with demonstrations, run the following command:
cd mbd/planners
python mbd_planner.py --env_name $ENV_NAME --enable_demos
Currently, only the humanoidtrack
, car2d
support demonstrations.
To run multiple seeds, run the following command:
cd mbd/scripts
python run_mbd.py --env_name $ENV_NAME
To visualize the diffusion process, run the following command:
cd mbd/scripts
python vis_diffusion.py --env_name $ENV_NAME
Please make sure you have run the planner first to generate the data.
To run model-based diffusion for black-box optimization, run the following command:
cd mbd/blackbox
python mbd_opt.py
To run RL-based baselines, run the following command:
cd mbd/rl
python train_brax.py --env_name $ENV_NAME
To run other zeroth order trajectory optimization baselines, run the following command:
cd mbd/planners
python path_integral.py --env_name $ENV_NAME --mode $MODE
where $MODE
is the mode of the planner, you can choose from mppi
, cem
, cma-es
.
- This codebase's environment and RL implementation is built on top of Brax.
@misc{pan2024modelbaseddiffusiontrajectoryoptimization,
title={Model-Based Diffusion for Trajectory Optimization},
author={Chaoyi Pan and Zeji Yi and Guanya Shi and Guannan Qu},
year={2024},
eprint={2407.01573},
archivePrefix={arXiv},
primaryClass={cs.RO},
url={https://arxiv.org/abs/2407.01573},
}