Taehwan Lee*, Kyeongkook Seo∗ , Jaejun Yoo**, Sung Whan Yoon**
(*: Equal contribution, **: Co-corresponding author)
Flat minima, known to enhance generalization and robustness in supervised learning, remain largely unexplored in generative models. In this work, we systematically investigate the role of loss surface flatness in generative models, both theoretically and empirically, with a particular focus on diffusion models. We establish a theoretical claim that flatter minima improve robustness against perturbations in target prior distributions, leading to benefits such as reduced exposure bias -- where errors in noise estimation accumulate over iterations -- and significantly improved resilience to model quantization, preserving generative performance even under strong quantization constraints. We further observe that Sharpness-Aware Minimization (SAM), which explicitly controls the degree of flatness, effectively enhances flatness in diffusion models, whereas other well-known methods such as Stochastic Weight Averaging (SWA) and Exponential Moving Average (EMA), which promote flatness indirectly via ensembling, are less effective. Through extensive experiments on CIFAR-10, LSUN Tower, and FFHQ, we demonstrate that flat minima in diffusion models indeed improves not only generative performance but also robustness.
This codebase is based on ADM-ES and openai/guided-diffusion.
The installation is the same with guided-diffusion
git clone https://github.com/forever208/DDPM-IP.git
cd DDPM-IP
conda create -n ADM python=3.8
conda activate ADM
pip install -e .
(note that, pytorch 1.10~1.13 is recommended as our experiments in paper were done with pytorch 1.10 and pytorch 2.0 has not been tested by us in this repo)
# install the missing packages
conda install mpi4py
conda install numpy
pip install Pillow
pip install opencv-python
Please refer to README.md for the detailed data preparation.
CIFAR10 32x32 (uncond)
MODEL_FLAGS="--image_size 32 --num_channels 128 --num_res_blocks 3 --learn_sigma True --dropout 0.3"
DIFFUSION_FLAGS="--diffusion_steps 1000 --noise_schedule cosine"
TRAIN_FLAGS="--lr 1e-4 --batch_size 128 --end_step 200000"
DATA_DIR="/DATA-PATH"
LSUN Tower 32x32 (uncond)
MODEL_FLAGS="--image_size 64 --num_channels 192 --num_head_channels 64 --num_res_blocks 3 --attention_resolutions 32,16,8 --resblock_updown True --use_new_attention_order True --learn_sigma True --dropout 0.1"
DIFFUSION_FLAGS="--diffusion_steps 1000 --noise_schedule cosine --use_scale_shift_norm True --rescale_learned_sigmas True"
TRAIN_FLAGS="--lr 1e-4 --batch_size 32 --end_step 200000"
DATA_DIR="/DATA-PATH"
FFHQ 64x64 (uncond)
MODEL_FLAGS="--image_size 64 --class_cond False --num_channels 128 --num_res_blocks 3 --attention_resolutions 32,16,8 --resblock_updown True --use_new_attention_order True --learn_sigma True --dropout 0.1 --use_fp16 False"
DIFFUSION_FLAGS="--diffusion_steps 1000 --noise_schedule cosine --use_scale_shift_norm True --rescale_learned_sigmas True"
TRAIN_FLAGS="--lr 1e-4 --batch_size 64 --end_step 200000"
DATA_DIR="/DATA-PATH"
w/ SAM
OPENAI_LOGDIR='./Logs/EXP-NAME' PYTHONPATH='.' CUDA_VISIBLE_DEVICES=0,1 \
mpirun -n 2 python scripts/image_train.py --data_dir $DATA_DIR $MODEL_FLAGS $DIFFUSION_FLAGS $TRAIN_FLAGS \
--optimizer adam-sam --rho 0.1
w/ SWA (require intermediate chkpts e.g. model180000.pt)
OPENAI_LOGDIR='./Logs/EXP-NAME' PYTHONPATH='.' CUDA_VISIBLE_DEVICES=0,1\
mpirun -n 2 python scripts/image_train.py --optimizer adam --data_dir $DATA_DIR \
$MODEL_FLAGS $DIFFUSION_FLAGS $TRAIN_FLAGS \
--end_step 200000 \
--lr_anneal_steps 1 --use_swa True --swa_window 100 \
--resume_checkpoint $CHKPT_PATH --swa_lr_min 1e-6
SAMPLE_FLAGS="--batch_size 100 --num_samples 50000 --timestep_respacing 100 --eps_scaler 1"
CHKPT_PATH="MODEL-PATH"
You can adjust timestep_respacing. More shorter timesteps resulting more harder estimations.
OPENAI_LOGDIR='./Logs/EXP-NAME/SAMPLE' PYTHONPATH='.' CUDA_VISIBLE_DEVICES=0,1
mpirun -n 2 python scripts/image_sample.py
--model_path $CHKPT_PATH $MODEL_FLAGS $DIFFUSION_FLAGS $SAMPLE_FLAGS
Optionally, you can run image_sample_quant.py for quantized sampling.
Run code is same with origin sampling case, except for --quant 4 (or 8).
OPENAI_LOGDIR='./Logs/EXP-NAME/SAMPLE' PYTHONPATH='.' CUDA_VISIBLE_DEVICES=0,1
mpirun -n 2 python scripts/image_sample.py --quant 8
--model_path $CHKPT_PATH $MODEL_FLAGS $DIFFUSION_FLAGS $SAMPLE_FLAGS
Please refer to README.md for the full instructions
CHKPT_PATH="MODEL-PATH"
DATA_DIR="DATA-PATH"
PYTHONPATH='.' CUDA_VISIBLE_DEVICES=0 python scripts/flatness.py \
--data_dir $DATA_DIR $MODEL_FLAGS $DIFFUSION_FLAGS $TRAIN_FLAGS \
--flatness {MC, LPF} --model_path $CHKPT_PATH --t_range 1000
@article{lee2025understanding,
title={Understanding Flatness in Generative Models: Its Role and Benefits},
author={Lee, Taehwan and Seo, Kyeongkook and Yoo, Jaejun and Yoon, Sung Whan},
journal={arXiv preprint arXiv:2503.11078},
year={2025}
}