PyTorch implementation of Patch-I2SB. Patch-I2SB is a combination of I2SB and Patch Diffusion for fast and data-efficient image-to-image translation.
- Data corruption process code for restoration tasks is removed. This code is only for general image-to-image translation tasks.
- Dataloader is changed from lmbd to customized dataloader. Use
dataset/dataloader.py
brought from stylegan2-ada-pytorch's dataloader. It also supports the CT dicom file extension, and other medical imaging modalities will be supported later. - Image512Net class is added in
network.py
for 512 x 512 resolution image training. - Evaluation metric for validation is changed from accuracy to (RMSE, PSNR, and SSIM).
- Patch diffusion technique is applied. You can use it by adding the flag
--run-patch
to the training options.
For general tasks, I2SB authors recommanded adding the flag --cond-x1
to the training options to overcome the large information loss in the new priors.
This code is developed with Python3, and we recommend PyTorch >=1.11.
Install the other packages in requirements.txt
following:
pip install -r requirements.txt
Use the flag --dataset-dir $DATA_DIR
to specify the dataset directory. Also, use the flags --src $SRC
and --trg $TRG
to specify the corrupt and clean data folder name. Images should be normalized to [-1,1]. All training and sampling results will be stored in results
. The overall file structures are:
$DATA_DIR/ # dataset directory
├── train/ # train folder
│ ├── $SRC/ # corrupt data folder name --src $SRC
│ │ ├── ... # sub folder
│ │ │ ├── 0001.png # image file
│ │ │ ├── 0002.png
│ │ │ └── 0003.png
│ │ ├── ...
│ │ └── ...
│ └── $TRG/ # clean data folder name --trg $TRG
│ ├── ...
│ ├── ...
│ └── ...
├── valid/ # valid folder
└── test/ # test folder
results/
├── $NAME/ # experiment ID set in train.py --name $NAME
│ ├── $NUM_ITR.pt # latest checkpoint: network, ema, optimizer
│ ├── options.pkl # full training options
│ └── samples_nfe$NFE_iter$NUM_ITR/ # images reconstructed from sample.py --nfe $NFE --num-itr $NUM_ITR
│ └── recon.pt
├── ...
To train an Patch-I2SB on a single node, run
python train.py --name $NAME --n-gpu-per-node $N_GPU \
--src $SRC --trg $TRG --dataset-dir $DATA_DIR \
--batch-size $BATCH --microbatch $MICRO_BATCH [--ot-ode] \
--beta-max $BMAX --log-dir $LOG_DIR [--log-writer $LOGGER] [--run-patch]
where NAME
is the experiment ID, N_GPU
is the number of GPUs on each node, DATA_DIR
is the path to the aligned dataset, BMAX
determines the noise scheduling. The default training on 32GB V100 GPU uses BATCH=256
and MICRO_BATCH=2
. If your GPUs have less than 32GB, consider lowering MICRO_BATCH
or using smaller network.
Add --ot-ode
for optionally training an OT-ODE model, i.e., the limit when the diffusion vanishes. By defualt, the model is discretized into 1000 steps; you can change it by adding --interval $INTERVAL
.
Note that we initialize the network with ADM (256x256_diffusion_uncond.pt and 512x512_diffusion_uncond.pt), which will be automatically downloaded to data/
at first call.
Images and losses can be logged with either tensorboard (LOGGER="tensorboard"
) or W&B (LOGGER="wandb"
) in the directory LOG_DIR
. To autonamtically login W&B, specify additionally the flags --wandb-api-key $WANDB_API_KEY --wandb-user $WANDB_USER
where WANDB_API_KEY
is the unique API key (about 40 characters) of your account and WANDB_USER
is your user name.
To resume previous training from the checkpoint, add the flag --ckpt $CKPT
.
To run patch-based training, add the flag --run-patch
.
@article{liu2023i2sb,
title={I{$^2$}SB: Image-to-Image Schr{\"o}dinger Bridge},
author={Liu, Guan-Horng and Vahdat, Arash and Huang, De-An and Theodorou, Evangelos A and Nie, Weili and Anandkumar, Anima},
journal={arXiv preprint arXiv:2302.05872},
year={2023},
}
@article{wang2023patch,
title={Patch Diffusion: Faster and More Data-Efficient Training of Diffusion Models},
author={Wang, Zhendong and Jiang, Yifan and Zheng, Huangjie and Wang, Peihao and He, Pengcheng and Wang, Zhangyang and Chen, Weizhu and Zhou, Mingyuan},
journal={arXiv preprint arXiv:2304.12526},
year={2023}
}
This code is heavily brought from I2SB and Patch Diffusion.
dataloader.py
is inspired by stylegan2-ada-pytorch's dataset.py
.
Copyright © 2023, NVIDIA Corporation. All rights reserved.
This work is made available under the Nvidia Source Code License-NC.
The model checkpoints are shared under CC-BY-NC-SA-4.0. If you remix, transform, or build upon the material, you must distribute your contributions under the same license as the original.
For business inquiries, please visit our website and submit the form: NVIDIA Research Licensing.