From caff6c20610d123a6662e9e1da5682b0ce113490 Mon Sep 17 00:00:00 2001 From: Aryan Gupta <97878444+guptaaryan16@users.noreply.github.com> Date: Wed, 22 Mar 2023 01:45:15 +0530 Subject: [PATCH] Add the example of super_resolution (#2885) * Add the example for Super-Resolution * Made some changes * Made some changes * Add the time profiling features * Added torchvision dataset * Changed the dataset used in README to cifar10 * Used snake case in arguments * Made some changes * Make some formatting changes * Make the formatting changes * some changes * update the crop method * Made the suggested changes --- examples/super_resolution/README.md | 37 ++++++ examples/super_resolution/main.py | 148 +++++++++++++++++++++ examples/super_resolution/model.py | 29 ++++ examples/super_resolution/super_resolve.py | 41 ++++++ 4 files changed, 255 insertions(+) create mode 100644 examples/super_resolution/README.md create mode 100644 examples/super_resolution/main.py create mode 100644 examples/super_resolution/model.py create mode 100644 examples/super_resolution/super_resolve.py diff --git a/examples/super_resolution/README.md b/examples/super_resolution/README.md new file mode 100644 index 00000000000..f9be6c92f56 --- /dev/null +++ b/examples/super_resolution/README.md @@ -0,0 +1,37 @@ +# Super-Resolution using an efficient sub-pixel convolutional neural network + +ported from [pytorch-examples](https://github.com/pytorch/examples/tree/main/super_resolution) + +This example illustrates how to use the efficient sub-pixel convolution layer described in ["Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network" - Shi et al.](https://arxiv.org/abs/1609.05158) for increasing spatial resolution within your network for tasks such as superresolution. + +``` +usage: main.py [-h] --upscale_factor UPSCALE_FACTOR [--batch_size BATCHSIZE] + [--test_batch_size TESTBATCHSIZE] [--n_epochs NEPOCHS] [--lr LR] + [--cuda] [--threads THREADS] [--seed SEED] + +PyTorch Super Res Example + +optional arguments: + -h, --help show this help message and exit + --upscale_factor super resolution upscale factor + --batch_size training batch size + --test_batch_size testing batch size + --n_epochs number of epochs to train for + --lr Learning Rate. Default=0.01 + --cuda use cuda + --mps enable GPU on macOS + --threads number of threads for data loader to use Default=4 + --seed random seed to use. Default=123 +``` + +This example trains a super-resolution network on the [Caltech101 dataset](https://pytorch.org/vision/main/generated/torchvision.datasets.Caltech101.html). A snapshot of the model after every epoch with filename `model_epoch_.pth` + +## Example Usage: + +### Train + +`python main.py --upscale_factor 3 --batch_size 4 --test_batch_size 100 --n_epochs 30 --lr 0.001` + +### Super Resolve + +`python super_resolve.py --input_image .jpg --model model_epoch_500.pth --output_filename out.png` diff --git a/examples/super_resolution/main.py b/examples/super_resolution/main.py new file mode 100644 index 00000000000..d46deec1701 --- /dev/null +++ b/examples/super_resolution/main.py @@ -0,0 +1,148 @@ +import argparse + +import torch +import torch.nn as nn +import torch.optim as optim +import torchvision +from model import Net +from torch.utils.data import DataLoader +from torchvision.transforms.functional import center_crop, resize, to_tensor + +from ignite.engine import Engine, Events +from ignite.metrics import PSNR + +# Training settings +parser = argparse.ArgumentParser(description="PyTorch Super Res Example") +parser.add_argument("--upscale_factor", type=int, required=True, help="super resolution upscale factor") +parser.add_argument("--batch_size", type=int, default=64, help="training batch size") +parser.add_argument("--test_batch_size", type=int, default=10, help="testing batch size") +parser.add_argument("--n_epochs", type=int, default=2, help="number of epochs to train for") +parser.add_argument("--lr", type=float, default=0.01, help="Learning Rate. Default=0.01") +parser.add_argument("--cuda", action="store_true", help="use cuda?") +parser.add_argument("--mps", action="store_true", default=False, help="enables macOS GPU training") +parser.add_argument("--threads", type=int, default=4, help="number of threads for data loader to use") +parser.add_argument("--seed", type=int, default=123, help="random seed to use. Default=123") +opt = parser.parse_args() + +print(opt) + +if opt.cuda and not torch.cuda.is_available(): + raise Exception("No GPU found, please run without --cuda") +if not opt.mps and torch.backends.mps.is_available(): + raise Exception("Found mps device, please run with --mps to enable macOS GPU") + +torch.manual_seed(opt.seed) +use_mps = opt.mps and torch.backends.mps.is_available() + +if opt.cuda: + device = torch.device("cuda") +elif use_mps: + device = torch.device("mps") +else: + device = torch.device("cpu") + +print("===> Loading datasets") + + +class SRDataset(torch.utils.data.Dataset): + def __init__(self, dataset, scale_factor, crop_size=256): + self.dataset = dataset + self.scale_factor = scale_factor + self.crop_size = crop_size + + def __getitem__(self, index): + image, _ = self.dataset[index] + img = image.convert("YCbCr") + hr_image, _, _ = img.split() + hr_image = center_crop(hr_image, self.crop_size) + lr_image = hr_image.copy() + if self.scale_factor != 1: + size = self.crop_size // self.scale_factor + lr_image = resize(lr_image, [size, size]) + hr_image = to_tensor(hr_image) + lr_image = to_tensor(lr_image) + return lr_image, hr_image + + def __len__(self): + return len(self.dataset) + + +trainset = torchvision.datasets.Caltech101(root="./data", download=True) +testset = torchvision.datasets.Caltech101(root="./data", download=False) + +trainset_sr = SRDataset(trainset, scale_factor=opt.upscale_factor) +testset_sr = SRDataset(testset, scale_factor=opt.upscale_factor) + +training_data_loader = DataLoader(dataset=trainset_sr, num_workers=opt.threads, batch_size=opt.batch_size, shuffle=True) +testing_data_loader = DataLoader(dataset=testset_sr, num_workers=opt.threads, batch_size=opt.test_batch_size) + +print("===> Building model") +model = Net(upscale_factor=opt.upscale_factor).to(device) +criterion = nn.MSELoss() + +optimizer = optim.Adam(model.parameters(), lr=opt.lr) + + +def train_step(engine, batch): + model.train() + input, target = batch[0].to(device), batch[1].to(device) + + optimizer.zero_grad() + loss = criterion(model(input), target) + loss.backward() + optimizer.step() + + return loss.item() + + +def validation_step(engine, batch): + model.eval() + with torch.no_grad(): + x, y = batch[0].to(device), batch[1].to(device) + y_pred = model(x) + + return y_pred, y + + +trainer = Engine(train_step) +evaluator = Engine(validation_step) +psnr = PSNR(data_range=1) +psnr.attach(evaluator, "psnr") +validate_every = 1 +log_interval = 100 + + +@trainer.on(Events.ITERATION_COMPLETED(every=log_interval)) +def log_training_loss(engine): + print( + "===> Epoch[{}]({}/{}): Loss: {:.4f}".format( + engine.state.epoch, engine.state.iteration, len(training_data_loader), engine.state.output + ) + ) + + +@trainer.on(Events.EPOCH_COMPLETED(every=validate_every)) +def log_validation(): + evaluator.run(testing_data_loader) + metrics = evaluator.state.metrics + print(f"Epoch: {trainer.state.epoch}, Avg. PSNR: {metrics['psnr']} dB") + + +@trainer.on(Events.EPOCH_COMPLETED) +def log_epoch_time(): + print(f"Epoch {trainer.state.epoch}, Time Taken : {trainer.state.times['EPOCH_COMPLETED']}") + + +@trainer.on(Events.COMPLETED) +def log_total_time(): + print(f"Total Time: {trainer.state.times['COMPLETED']}") + + +@trainer.on(Events.EPOCH_COMPLETED) +def checkpoint(): + model_out_path = "model_epoch_{}.pth".format(trainer.state.epoch) + torch.save(model, model_out_path) + print("Checkpoint saved to {}".format(model_out_path)) + + +trainer.run(training_data_loader, opt.n_epochs) diff --git a/examples/super_resolution/model.py b/examples/super_resolution/model.py new file mode 100644 index 00000000000..1f80c95d064 --- /dev/null +++ b/examples/super_resolution/model.py @@ -0,0 +1,29 @@ +import torch.nn as nn +import torch.nn.init as init + + +class Net(nn.Module): + def __init__(self, upscale_factor): + super(Net, self).__init__() + + self.relu = nn.ReLU() + self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2)) + self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)) + self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1)) + self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1)) + self.pixel_shuffle = nn.PixelShuffle(upscale_factor) + + self._initialize_weights() + + def forward(self, x): + x = self.relu(self.conv1(x)) + x = self.relu(self.conv2(x)) + x = self.relu(self.conv3(x)) + x = self.pixel_shuffle(self.conv4(x)) + return x + + def _initialize_weights(self): + init.orthogonal_(self.conv1.weight, init.calculate_gain("relu")) + init.orthogonal_(self.conv2.weight, init.calculate_gain("relu")) + init.orthogonal_(self.conv3.weight, init.calculate_gain("relu")) + init.orthogonal_(self.conv4.weight) diff --git a/examples/super_resolution/super_resolve.py b/examples/super_resolution/super_resolve.py new file mode 100644 index 00000000000..05c84103769 --- /dev/null +++ b/examples/super_resolution/super_resolve.py @@ -0,0 +1,41 @@ +import argparse + +import numpy as np +import torch +from PIL import Image +from torchvision.transforms.functional import to_tensor + +# Training settings +parser = argparse.ArgumentParser(description="PyTorch Super Res Example") +parser.add_argument("--input_image", type=str, required=True, help="input image to use") +parser.add_argument("--model", type=str, required=True, help="model file to use") +parser.add_argument("--output_filename", type=str, help="where to save the output image") +parser.add_argument("--cuda", action="store_true", help="use cuda") +opt = parser.parse_args() + +print(opt) +img = Image.open(opt.input_image).convert("YCbCr") +y, cb, cr = img.split() + +model = torch.load(opt.model) +input = to_tensor(y).view(1, -1, y.size[1], y.size[0]) + +if opt.cuda: + model = model.cuda() + input = input.cuda() + +model.eval() +with torch.no_grad(): + out = model(input) +out = out.cpu() +out_img_y = out[0].detach().numpy() +out_img_y *= 255.0 +out_img_y = out_img_y.clip(0, 255) +out_img_y = Image.fromarray(np.uint8(out_img_y[0]), mode="L") + +out_img_cb = cb.resize(out_img_y.size, Image.BICUBIC) +out_img_cr = cr.resize(out_img_y.size, Image.BICUBIC) +out_img = Image.merge("YCbCr", [out_img_y, out_img_cb, out_img_cr]).convert("RGB") + +out_img.save(opt.output_filename) +print("output image saved to ", opt.output_filename)