Skip to content

MediaTek-NeuroPilot/mai21-learned-smartphone-isp

Repository files navigation

Deep Learning for Smartphone ISP

PUNET_results_full

Overview

[Challenge Report Paper] [Challenge Website] [Workshop Website]

This repository provides the implementation of the baseline model, PUNET, for the Learned Smartphone ISP Challenge in Mobile AI (MAI) Workshop @ CVPR 2021. The model is trained to convert RAW Bayer data obtained directly from mobile camera sensor into photos captured with a professional Fujifilm DSLR camera, thus replacing the entire hand-crafted ISP camera pipeline. The provided pre-trained PUNET model can be used to generate full-resolution 12MP photos from RAW image files captured using the Sony IMX586 camera sensor. PUNET is a UNet-like architecture modified from PyNET and serve as an extension to the PyNET project.

Contents:


Prerequisites

[back]


Dataset and model preparation

  • Download Mediatek's pre-trained PUNET model and put it into models/original/ folder.

  • Download training data and extract it into raw_images/train/ folder.

  • Download validation data and extract it into raw_images/val/ folder.

  • Download testing data and extract it into raw_images/test/ folder.
    The dataset folder (default name: raw_images/) should contain three subfolders: train/, val/ and test/. Please find the download links to above files in MAI'21 Learned Smartphone ISP Challenge website (registration needed).

  • [Optional] Download pre-trained VGG-19 model Mirror and put it into vgg_pretrained/ folder.
    The VGG model is used for one of the loss functions loss_content in the baseline, which takes the output of PUNET as the input. You are free to remove that loss (line 65-72 in train_model.py). This may affect the result PSNR, but won't affect the whole pipeline.

[back]


Learned ISP Pipeline

The whole pipeline of Learned Smartphone ISP has two main steps (assume the input resolution is H x W):

  1. deBayer pre-processing (in load_dataset.py):
    • Input: RAW data [H x W x 1]
    • Output: deBayer RAW data [(H/2) x (W/2) x 4]
    • You are free to modify the pre-processing method as long as the input & output shapes are kept.
  2. PUNET model (in model.py): PUNET is a UNet-like architecture modified from PyNET.

PUNET

[back]


Training

Start training

To train the model, use the following command:

python train_model.py

Optional parameters (and default values):

dataset_dir: raw_images/   -   path to the folder with the dataset
model_dir: models/   -   path to the folder with the model to be restored or saved
vgg_dir: vgg_pretrained/imagenet-vgg-verydeep-19.mat   -   path to the pre-trained VGG-19 network
dslr_dir: fujifilm/   -   path to the folder with the RGB data
phone_dir: mediatek_raw/   -   path to the folder with the Raw data
arch: punet   -   architecture name
num_maps_base: 16   -   base channel number (e.g. 8, 16, 32, etc.)
restore_iter: None   -   iteration to restore
patch_w: 256   -   width of the training images
patch_h: 256   -   height of the training images
batch_size: 32   -   batch size [small values can lead to unstable training]
train_size: 5000   -   the number of training patches randomly loaded each 1000 iterations
learning_rate: 5e-5   -   learning rate
eval_step: 1000   -   each eval_step iterations the accuracy is computed and the model is saved
num_train_iters: 100000   -   the number of training iterations


Below we provide an example command used for training the PUNET model on the Nvidia GeForce GTX 1080 GPU with 8GB of RAM.

CUDA_VISIBLE_DEVICES=0 python train_model.py \
  model_dir=models/punet_MAI/ arch=punet num_maps_base=16 \
  patch_w=256 patch_h=256 batch_size=32 \
  eval_step=1000 num_train_iters=100000

After training, the following files will be produced under model_dir:

checkpoint   -   contain all the checkpoint names
logs_[restore_iter]-[num_train_iters].txt   -   training log (including loss, PSNR, etc.)
[arch]_iteration_[iter].ckpt.data   -   part of checkpoint data for the model [arch]_iteration_[iter]
[arch]_iteration_[iter].ckpt.index   -   part of checkpoint data for the model [arch]_iteration_[iter]

Resume training

To resume training from restore_iter, use the command like follows:

CUDA_VISIBLE_DEVICES=0 python train_model.py \
  model_dir=models/punet_MAI/ arch=punet num_maps_base=16 \
  patch_w=256 patch_h=256 batch_size=32 \
  eval_step=1000 num_train_iters=110000 restore_iter=100000 

[back]


Test/Inference

test_model.py runs a model on testing images with the height=img_h and width=img_w. Here we use img_h=1088 and img_w=1920 as the example. If save=True, the protobuf (frozen graph) that corresponds to the testing image resolution will also be produced.


Use the provided pre-trained model

To produce output images and protobuf using the pre-trained model, use the following command:

python test_model.py orig=True

Use the self-obtained model

To produce output images and protobuf using the self-trained model, use the following command:

python test_model.py

Optional parameters (and default values):

dataset_dir: raw_images/   -   path to the folder with the dataset
test_dir: fujifilm_full_resolution/   -   path to the folder with the test data
model_dir: models/   -   path to the folder with the models to be restored/loaded
result_dir: results/   -   path to the folder with the produced outputs from the loaded model
arch: punet   -   architecture name
num_maps_base: 16   -   base channel number (e.g. 8, 16, 32, etc.)
orig: True, False   -   use the pre-trained model or not
restore_iter: None   -   iteration to restore (when not specified with self-train model, the last saved model will be loaded)
img_h: 1088   -   width of the testing images
img_w: 1920   -   height of the testing images
use_gpu: True,False   -   run the model on GPU or CPU
save: True   -   save the loaded check point and protobuf (frozed graph) again
test_image: True   -   run the loaded model on the test images. Can set as False if you only want to save models.


Below we provide an example command used for testing the model:

CUDA_VISIBLE_DEVICES=0 python test_model.py \
  test_dir=fujifilm_full_resolution/ model_dir=models/punet_MAI/ result_dir=results/full-resolution/ \
  arch=punet num_maps_base=16 orig=False restore_iter=98000 \
  img_h=1088 img_w=1920 use_gpu=True save=True test_image=True

After inference, the output images will be produced under result_dir.

[Optional] If save=True, the following files will be produced under model_dir:

[model_name].ckpt.meta   -   graph data for the model model_name
[model_name].pb   -   protobuf (frozen graph) for the model model_name
[model_name]/   -   a folder containing Tensorboard data for the model model_name

Notes:

  1. to export protobuf (frozen graph), the output node name needs to be specified. In this sample code, we use output_l0 for PUNET. If you use a different name, please modify the argument for the function utils.export_pb (Line #111 in test_model.py). You can also use Tensorboard to check the output node name.
  2. [Important] In the Learned Smartphone ISP Challenge in Mobile AI (MAI) Workshop @ CVPR 2021, you may need to use different models for different evaluation goals (i.e. quality and latency). Therefore, please specify different img_h and img_w for different evaluation goals.
    • Example 1: For evaluating PSNR for validation data (resolution: 256x256) and generating processed images:
    CUDA_VISIBLE_DEVICES=0 python test_model.py \
      test_dir=mediatek_raw/ model_dir=models/punet_MAI/ result_dir=results/ \
      arch=punet num_maps_base=16 orig=False restore_iter=98000 \
      img_h=256 img_w=256 use_gpu=True save=True test_image=True
    • Example 2: For evaluating the inference latency (resolution: 1088x1920) without generating any output images:
    CUDA_VISIBLE_DEVICES=0 python test_model.py \
      model_dir=models/punet_MAI/ \
      arch=punet num_maps_base=16 orig=False restore_iter=98000 \
      img_h=1088 img_w=1920 use_gpu=True save=True test_image=False

[back]


Convert checkpoint to pb

test_model.py can produce protobuf automatically if save=True. If you want to directly convert the checkpoint model (including .meta, .data, and .index) to protobuf, we also provide ckpt2pb.py to do so. The main arguments (and default values) are as follows:

--in_path: models/punet_MAI/punet_iteration_100000.ckpt   -   input checkpoint file (including .meta, .data, and .index)
--out_path: models/punet_MAI/punet_iteration_100000.pb   -   output protobuf file
--out_nodes: output_l0   -   output node name

Notes:

  1. As mentioned earlier, the output node name needs to be specified. There are two ways to check the output node name:
    • check the graph in Tensorboard.
    • directly specify the node name in the source code (e.g. use tf.identity).
  2. .meta is necessary to convert a checkpoint to protobuf since it contains the important model information (e.g. architecture, input size, etc.).

Below we provide an example command:

python ckpt2pb.py \
  --in_path models/punet_MAI/punet_iteration_100000.ckpt \
  --out_path models/punet_MAI/punet_iteration_100000.pb \
  --out_nodes output_l0

[back]


Convert pb to tflite

The last step is converting the frozen graph to TFLite so that the evaluation server can evaluate the performance on MediaTek devices. Please use the official Tensorflow function tflite_convert. The main arguments (and default values) are as follows:

graph_def_file: models/original/punet_pretrained.pb   -   input protobuf file
output_file: models/original/punet_pretrained.tflite   -   output tflite file
input_shape: 1,544,960,4   -   the network input, which is after debayering/demosaicing. If the raw image shape is (img_h, img_w, 1), input_shape should be (img_h/2, img_w/2, 4).
input_arrays: Placeholder   -   input node name (can be found in Tensorboard, or specified in source codes)
output_arrays: output_l0   -   output node name (can be found in Tensorboard, or specified in source codes)

Below we provide an example command:

tflite_convert \
  --graph_def_file=models/punet_MAI/punet_iteration_100000.pb \
  --output_file=models/punet_MAI/punet_iteration_100000.tflite \
  --input_shape=1,544,960,4 \
  --input_arrays=Placeholder \
  --output_arrays=output_l0

[Important] In the Learned Smartphone ISP Challenge in Mobile AI (MAI) Workshop @ CVPR 2021, participants are required to submit TWO TFLite models:

  1. model_none.tflite (for evaluating the image quality): input shape [1, None, None, 4] and output shape [1, None, None, 3].
  2. model.tflite (for evaluating the inference latency): input shape [1, 544, 960, 4] and output shape [1, 1088, 1920, 3].

Feel free to use our provided bash script as well:

bash pb2tflite.sh

Note: pb2tflite.sh converts our provided pretrained model, not exactly the same as the above example commands.

[back]


TFLite inference

Desktop/Laptop

inference_tflite.py can load the TFLite model and process a folder of images. The main argument (and default values) are as follows:

dataset_dir: raw_images   -   main folder for input images
dir_type: test,val   -   select validation or testing data
phone_dir: mediatek_raw   -   folder for input RAW images
dslr_dir: None   -   folder for corresponding ground truth
model_file: models/original/punet_pretrained_small.tflite   -   TFLite model
save_results:   -   save the processed images
save_dir: results   -   main folder for output images

Below we provide an example command:

python inference_tflite.py \
  --dir_type=test --phone_dir=mediatek_raw \
  --model_file=models/punet_MAI/punet_iteration_100000_input-256.tflite \
  --save_results --save_dir=results

If ground truth is available (e.g. validation data), inference_tflite.py can also compute PSNR. Please see the following example (assume we don't want to save output images):

python inference_tflite.py \
  --dir_type=val --phone_dir=mediatek_raw --dslr_dir=fujifilm \
  --model_file=models/punet_MAI/punet_iteration_100000_input-256.tflite

[Important] The main goal of inference_tflite.py is to introduce an additional sanity check for the submitted solution. However, it will NOT be used for the final evaluation of the Learned Smartphone ISP Challenge in Mobile AI (MAI) Workshop @ CVPR 2021. Therefore, it would be better to add inference_tflite.py to the final submission, but NOT mandatory.
inference_tflite.py contains some pre-processing techniques, which will NOT be included in the final evaluation script for the Challenge. Therefore, all the processing should be integrated into the required TFLite models.

Mobile

We provide two ways to evaluate the mobile performance of your TFLite models:

  • AI benchmark: An app allowing you to load your model and run it locally on your own Android devices with various acceleration options (e.g. CPU, GPU, APU, etc.).
  • TFLite Neuron Delegate: You can build MediaTek's neuron delegate runner by yourself.

[back]


[Optional] Some useful tools

We also provide some useful tools. Feel free to try them.

[back]


Results

We evaluate the pre-trained PUNET on the validation data (resolution: 256x256), and obtain the following results:

  • PSNR: 23.03
  • Some visualized comparison with the ground truth:

PUNET_results_patch

[back]


Folder structure (default)

models/   -   logs and models that are saved during the training process
models/original/   -   the folder with the provided pre-trained PUNET model
raw_images/   -   the folder with the dataset
results/   -   visual results for the produced images
vgg-pretrained/   -   the folder with the pre-trained VGG-19 network
tools/   -   [optional] some useful tools

load_dataset.py   -   python script that loads training data
model.py   -   PUNET implementation (TensorFlow)
train_model.py   -   implementation of the training procedure
test_model.py   -   applying the trained model to testing images
utils.py   -   auxiliary functions
vgg.py   -   loading the pre-trained vgg-19 network
ckpt2pb.py   -   convert checkpoint to protobuf (frozen graph)
pb2tflite.sh   -   bash script that converts protobuf to tflite
inference_tflite.py   -   load TFLite model and process images

[back]


Model Optimization

To make your model run faster on device, please fullfill the preference of network operations as much as possible to leverage the great power of AI accelerator. You may also find some optimization hint from our paper: Deploying Image Deblurring across Mobile Devices: A Perspective of Quality and Latency

Please find the download links to optmization guide in MAI'21 Learned Smartphone ISP Challenge website (registration needed).

[back]


Common FAQ

[back]


Acknowledge

This project is an extension of the PyNET project.

[back]


Citation

If you find this repository useful, please cite our MAI'21 work:

@inproceedings{ignatov2021learned,
  title={Learned Smartphone ISP on Mobile NPUs with Deep Learning, Mobile AI 2021 Challenge: Report},
  author={Ignatov, Andrey and Chiang, Cheng-Ming and Kuo, Hsien-Kai and Sycheva, Anastasia and Timofte, Radu and Chen, Min-Hung and Lee, Man-Yu and Xu, Yu-Syuan and Tseng, Yu and Xu, Shusong and Guo, Jin and Chen, Chao-Hung and Hsyu, Ming-Chun and Tsai, Wen-Chia and Chen, Chao-Wei and Malivenko, Grigory and Kwon, Minsu and Lee, Myungje and Yoo, Jaeyoon and Kang, Changbeom and Wang, Shinjo and Shaolong, Zheng and Dejun, Hao and Fen, Xie and Zhuang, Feng and Ma, Yipeng and Peng, Jingyang and Wang, Tao and Song, Fenglong and Hsu, Chih-Chung and Chen, Kwan-Lin and Wu, Mei-Hsuang and Chudasama, Vishal and Prajapati, Kalpesh and Patel, Heena and Sarvaiya, Anjali and Upla, Kishor and Raja, Kiran and Ramachandra, Raghavendra and Busch, Christoph and de Stoutz, Etienne},
  booktitle={IEEE Conference on Computer Vision and Pattern Recognition (CVPR) Workshops (Mobile AI)},
  year={2021},
  url={https://arxiv.org/abs/2105.07809}
}

@misc{mtk2021mai,
  title={Mobile AI Workshop: Learned Smartphone ISP Challenge},
  year={2021},
  url={https://github.com/MediaTek-NeuroPilot/mai21-learned-smartphone-isp}
}

[back]


License

PyNet License: PyNet License
Mediatek License: Mediatek Apache License 2.0

[back]