Skip to content

Pytorch implementation of the paper "A Wavelet Diffusion GAN for Image Super-Resolution"

License

Notifications You must be signed in to change notification settings

aloilor/WaDiGAN-SR

Repository files navigation

A Wavelet Diffusion GAN for Image Super-Resolution

Accepted at WIRN 2024. Presentation.

Abstract

In recent years, Diffusion Models have emerged as a superior alternative to Generative Adversarial Networks (GANs) for high-fidelity image generation, with wide applications in text-to-image generation, image-to-image translation, and super-resolution. However, their real-time feasibility is hindered by slow training and inference speeds. This study addresses this challenge by proposing a wavelet-based conditional Diffusion GAN scheme for Single-Image Super-Resolution (SISR). Our approach utilizes the Diffusion GAN paradigm to reduce the number of timesteps required by the reverse diffusion process and the Discrete Wavelet Transform (DWT) to achieve dimensionality reduction, decreasing training and inference times significantly. The results of an experimental validation on the CelebA-HQ dataset confirm the effectiveness of our proposed scheme. Our approach outperforms other state-of-the-art methodologies successfully ensuring high-fidelity output while overcoming inherent drawbacks associated with diffusion models in time-sensitive applications.

Alt Text

Installation

Python 3.10.12 and Pytorch 1.11.0 are used in this implementation.

You can install neccessary libraries as follows:

pip install -r requirements.txt

Dataset preparation

We trained on CelebA HQ (16x16 -> 128x128).

If you don't have the data, you can prepare it in the following way:

Download CelebaHQ 256x256.

Use the following script to prepare the dataset in PNG or LMDB format:
IMPORTANT: be sure to have the images in the folder sequentially numbered, otherwise the conversion into LMDB won't work.

# Resize to get 16×16 LR_IMGS and 128×128 HR_IMGS, then prepare 128×128 Fake SR_IMGS by bicubic interpolation
# Specify -l for LMDB format
python datasets_prep/prepare_data.py  --path [dataset root]  --out [output root] --size 16,128 -l

Once a dataset is downloaded and prepared, please put it in data/ directory as follows:

data/
├── celebahq_16_128

How to run

We provide a bash script for our experiments. The syntax is following:

bash run.sh <DATASET> <MODE> <#GPUS>

where:

  • <DATASET>: celebahq_16_128.
  • <MODE>: train and test.
  • <#GPUS>: the number of gpus (e.g. 1, 2, 4, 8).

Note, please set argument --exp correspondingly for both train and test mode. All of detailed configurations are well set in run.sh.

GPU allocation: Our work is experimented on a single NVIDIA Tesla T4 GPU 15GBs.

Results

Comparisons between our model, SR3, DiWa and ESRGAN (all of them trained on 25k iteration steps) are below:

Metric ESRGAN SR3 DiWa Ours
PSNR ↑ 21.13 14.65 13.68 23.38
SSIM ↑ 0.59 0.42 0.13 0.68
LPIPS ↓ 0.082 0.365 0.336 0.061
FID ↓ 20.8 99.4 270 47.2

The checkpoint we used to compute these results is provided here.

Inference time is computed over 300 trials on a single NVIDIA Tesla T4 GPU for a batch size of 64.

Downloaded pre-trained models should be put in saved_info/srwavediff/<DATASET>/<EXP> directory where <DATASET> is defined in How to run section and <EXP> corresponds to the folder name of pre-trained checkpoints.

ESRGAN SR3 DiWa Ours
Runtime 0.04s 60.3s 34.7s 0.12s
Parameters 31M 98M 92M 57M

Evaluation

FID, PSNR, SSIM and LPIPS are computed on the whole test-set (6000 samples).

Inference

Samples can be generated by calling run.sh with test mode.

FID

To compute FID of pretrained models at a specific epoch, we can add additional arguments including --compute_fid and --real_img_dir /path/to/real/images of the corresponding experiments in run.sh.

PSNR, SSIM and LPIPS

A simple script is provided to compute PSNR, SSIM and LPIPS for the results. Please notice that you have to run inference without the --compute_fid and --measure_time options before executing the script.

python /benchmark/eval.py -p [result root]

Acknowledgments

About

Pytorch implementation of the paper "A Wavelet Diffusion GAN for Image Super-Resolution"

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages