This repo contains the image diffusion models and training/sampling code for our paper exploring the Combinatorial Stochasticity for Diffusion Generative Models. [Project] [Arxiv]
@misc{xu2024combostoccombinatorialstochasticitydiffusion,
title={ComboStoc: Combinatorial Stochasticity for Diffusion Generative Models},
author={Rui Xu and Jiepeng Wang and Hao Pan and Yang Liu and Xin Tong and Shiqing Xin and Changhe Tu and Taku Komura and Wenping Wang},
year={2024},
eprint={2405.13729},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2405.13729},
}
In this paper, we study an under-explored but important factor of diffusion generative models, i.e., the combinatorial complexity. Data samples are generally high-dimensional, and for various structured generation tasks, additional attributes are combined to associate with data samples. We show that the space spanned by the combination of dimensions and attributes is insufficiently sampled by existing training scheme of diffusion generative models, causing degraded test time performance. We present a simple fix to this problem by constructing stochastic processes that fully exploit the combinatorial structures, hence the name ComboStoc. Using this simple strategy, we show that network training is significantly accelerated across diverse data modalities, including images and 3D structured shapes. Moreover, ComboStoc enables a new way of test time generation which uses asynchronous time steps for different dimensions and attributes, thus allowing for varying degrees of control over them.
This repository contains:
- 🪐 A simple PyTorch implementation of ComboStoc following SiT
- 🛸 A ComboStoc training script using PyTorch DDP
- 📫 A pre-trained ComboStoc-XL-2 model with 'UNSYNC_ALL' setting at 800K. [Google Drive]
We provide an environment.yml
file that can be used to create a Conda environment. If you only want
to run pre-trained models locally on CPU, you can remove the cudatoolkit
and pytorch-cuda
requirements from the file.
conda env create -f environment.yml
conda activate ComboStoc
You can sample from ComboStoc checkpoints with sample.py
. The script has various arguments to adjust sampler configurations (ODE & SDE), sampling steps, change the classifier-free guidance scale, etc. For example, to sample from
our 256x256 ComboStoc-XL model with default ODE setting, you can use:
python sample.py SDE --model ComboStoc-XL/2 --image-size 256 --ckpt /path/to/model.pt
ODE | --atol |
float |
Absolute error tolerance |
--rtol |
float |
Relative error tolenrace | |
--sampling-method |
str |
Sampling methods (refer to [torchdiffeq ] ) |
SDE | --diffusion-form |
str |
Form of SDE's diffusion coefficient (refer to Tab. 2 in paper) |
--diffusion-norm |
float |
Magnitude of SDE's diffusion coefficient | |
--last-step |
str |
Form of SDE's last step | |
None - Single SDE integration step | |||
"Mean" - SDE integration step without diffusion coefficient | |||
"Tweedie" - [Tweedie's denoising] step | |||
"Euler" - Single ODE integration step | |||
--sampling-method |
str |
Sampling methods | |
"Euler" - First order integration | |||
"Heun" - Second order integration |
There are some more options; refer to train_utils.py
for details.
We provide a training script for ComboStoc in train.py
. To launch ComboStoc-XL/2 (256x256) training with N
GPUs on
one node:
torchrun --nnodes=1 --nproc_per_node=N train.py --model ComboStoc-XL/2 --data-path /path/to/imagenet/train
Logging. To enable wandb
, firstly set WANDB_KEY
, ENTITY
, and PROJECT
as environment variables:
export WANDB_KEY="key"
export ENTITY="entity name"
export PROJECT="project name"
Then in training command add the --wandb
flag:
torchrun --nnodes=1 --nproc_per_node=N train.py --model ComboStoc-XL/2 --data-path /path/to/imagenet/train --wandb
Resume training. To resume training from custom checkpoint:
torchrun --nnodes=1 --nproc_per_node=N train.py --model ComboStoc-L/2 --data-path /path/to/imagenet/train --ckpt /path/to/model.pt
Caution. Resuming training will automatically restore both model, EMA, and optimizer states and training configs to be the same as in the checkpoint.
We include a sample_ddp.py
script which samples a large number of images from a ComboStoc model in parallel. This script
generates a folder of samples as well as a .npz
file which can be directly used with [PyTorch-FID] to compute FID, Inception Score and
other metrics. For example, to sample 50K images from our pre-trained ComboStoc-XL/2 model over N
GPUs under default ODE sampler settings, run:
torchrun --nnodes=1 --nproc_per_node=N sample_ddp.py SDE --model ComboStoc-XL/2 --num-fid-samples 50000