Skip to content

Official implementation of MADFormer, including training, sampling, and evaluation scripts as described in "MADFormer: Mixed Autoregressive and Diffusion Transformers for Continuous Image Generation".

License

Notifications You must be signed in to change notification settings

Junhaoo-Chen/MADFormer

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

7 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

MADFormer: Mixed Autoregressive and Diffusion Transformers for Continuous Image Generation

This repository contains the official implementation of MADFormer, a unified generative model that fuses the global modeling of autoregressive transformers with the fine-grained refinement capabilities of diffusion models. MADFormer introduces a flexible, two-axis hybrid framework—mixing AR and diffusion across spatial blocks and model layers—delivering strong performance under compute constraints while maintaining high visual fidelity across image generation tasks.

Overview

MADFormer (Mixed Autoregressive and Diffusion Transformer) bridges the strengths of autoregressive (AR) and diffusion modeling in continuous image generation. It predicts M image patches at a time and uses the predicted patches as gold context for predicting the next patches. During each multi-patch prediction, it uses the first N transformer layers as AR layers for a one-pass prediction, and uses the rest layers as diffusion layers that are called recursively to refine the prediction. Users can choose how "AR" or "diffusion" the model is on two axes: the horizontal token axis (by adjusting M) and the vertical model layer axis (by adjusting N). One interesting finding from our paper is that models with a stronger AR presence performs better under constrained inference budget.

High-level overview of the MADFormer architecture.

The generation process of MADFormer: each image block is autoregressively predicted, then refined through a conditioned diffusion process.

MADFormer acts not only as a performant generator for high-resolution data like FFHQ-1024 and regular images like ImageNet-256, but also as a testbed for exploring hybrid design choices. Notably, we show that increasing AR layer allocation can improve FID by up to 60–75% under constrained inference budgets. Our modular design supports controlled experiments on inference cost, block granularity, loss objectives, and layer allocation—offering actionable insights for hybrid model design in multimodal generation.

Setup

To set up the runtime environment for this project, install the required dependencies using the provided requirements.txt file:

pip install -r requirements.txt

Training

To train MADFormer on FFHQ-1024, first download the dataset locally using Hugging Face datasets:

python -c "from datasets import load_dataset; load_dataset('gaunernst/ffhq-1024-wds', num_proc=24).save_to_disk('./datasets/ffhq-1024')"

Our training configurations are provided in the configs directory, complete with model and training hyperparameters. You can use the following command to start training (arguments are set to reproduce FFHQ-1024 baseline results by default):

torchrun \
    --rdzv_backend c10d \
    --rdzv_id=456 \
    --nproc-per-node=8 \
    --nnodes=1 \
    --node_rank=0 \
    --rdzv-endpoint=<rdvz_endpoint>  \
    src/train.py --id=<experiment_id>

Sampling

We provide the pretrained model weights for MADFormer trained on FFHQ-1024. You can download the checkpoint using this link or the CLI command below:

mkdir -p ./ckpts/madformer_ffhq_baseline/

huggingface-cli download JunhaoC/MADFormer-FFHQ \
    --include ckpts.pt \
    --local-dir ./ckpts/madformer_ffhq_baseline/ \
    --local-dir-use-symlinks False

Once downloaded (or after training your own checkpoint), you can sample images with:

python src/sample.py \
    --ckpt ./ckpts/madformer_ffhq_baseline/ckpts.pt \
    --range_start 0 --range_end 7

Evaluation

We adopt Fréchet Inception Distance (FID) as our primary evaluation metric for image quality. For FFHQ-1024, FID is computed over 8,000 generated samples. Image generation is performed with the DDIM sampler , using 250 sampling steps for FFHQ. To ensure stability, final FID scores are averaged across the last five checkpoints (saved every 10,000 steps).

FID scores are computed using the pytorch-fid library.

Acknowledgements

This code is mainly built upon the ACDiT repository.

License

This project is liscenced under the Apache-2.0 liscence.

Citation

If you find MADFormer useful in your research, please consider citing our paper:

@article{MADFormer,
    title={MADFormer: Mixed Autoregressive and Diffusion Transformers for Continuous Image Generation}, 
    author={Junhao Chen and Yulia Tsvetkov and Xiaochuang Han},
    journal={arXiv preprint arXiv:2506.07999},
    year={2025}
}

About

Official implementation of MADFormer, including training, sampling, and evaluation scripts as described in "MADFormer: Mixed Autoregressive and Diffusion Transformers for Continuous Image Generation".

Resources

License

Stars

Watchers

Forks

Languages