From a7829a9cd6e8c288c66878f8bbd318c9a4d3e3cc Mon Sep 17 00:00:00 2001 From: Aryan Gupta Date: Sat, 4 Mar 2023 00:15:52 +0530 Subject: [PATCH 01/13] Add the example for Super-Resolution --- examples/super_resolution/README.md | 37 +++++++ examples/super_resolution/data.py | 78 ++++++++++++++ examples/super_resolution/dataset.py | 37 +++++++ examples/super_resolution/main.py | 115 +++++++++++++++++++++ examples/super_resolution/model.py | 30 ++++++ examples/super_resolution/super_resolve.py | 42 ++++++++ 6 files changed, 339 insertions(+) create mode 100644 examples/super_resolution/README.md create mode 100644 examples/super_resolution/data.py create mode 100644 examples/super_resolution/dataset.py create mode 100644 examples/super_resolution/main.py create mode 100644 examples/super_resolution/model.py create mode 100644 examples/super_resolution/super_resolve.py diff --git a/examples/super_resolution/README.md b/examples/super_resolution/README.md new file mode 100644 index 00000000000..d86594ab0e0 --- /dev/null +++ b/examples/super_resolution/README.md @@ -0,0 +1,37 @@ +# Superresolution using an efficient sub-pixel convolutional neural network + +ported from [pytorch-examples](https://github.com/pytorch/examples/tree/main/super_resolution) + +This example illustrates how to use the efficient sub-pixel convolution layer described in ["Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network" - Shi et al.](https://arxiv.org/abs/1609.05158) for increasing spatial resolution within your network for tasks such as superresolution. + +``` +usage: main.py [-h] --upscale_factor UPSCALE_FACTOR [--batchSize BATCHSIZE] + [--testBatchSize TESTBATCHSIZE] [--nEpochs NEPOCHS] [--lr LR] + [--cuda] [--threads THREADS] [--seed SEED] + +PyTorch Super Res Example + +optional arguments: + -h, --help show this help message and exit + --upscale_factor super resolution upscale factor + --batchSize training batch size + --testBatchSize testing batch size + --nEpochs number of epochs to train for + --lr Learning Rate. Default=0.01 + --cuda use cuda + --mps enable GPU on macOS + --threads number of threads for data loader to use Default=4 + --seed random seed to use. Default=123 +``` + +This example trains a super-resolution network on the [BSD300 dataset](https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/), using crops from the 200 training images, and evaluating on crops of the 100 test images. A snapshot of the model after every epoch with filename model*epoch*.pth + +## Example Usage: + +### Train + +`python main.py --upscale_factor 3 --batchSize 4 --testBatchSize 100 --nEpochs 30 --lr 0.001` + +### Super Resolve + +`python super_resolve.py --input_image dataset/BSDS300/images/test/16077.jpg --model model_epoch_500.pth --output_filename out.png` diff --git a/examples/super_resolution/data.py b/examples/super_resolution/data.py new file mode 100644 index 00000000000..e199a86c87c --- /dev/null +++ b/examples/super_resolution/data.py @@ -0,0 +1,78 @@ +import tarfile +from os import makedirs, remove +from os.path import basename, exists, join + +from dataset import DatasetFromFolder +from six.moves import urllib +from torchvision.transforms import CenterCrop, Compose, Resize, ToTensor + + +def download_bsd300(dest="dataset"): + output_image_dir = join(dest, "BSDS300/images") + + if not exists(output_image_dir): + makedirs(dest) + url = "http://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/BSDS300-images.tgz" + print("downloading url ", url) + + data = urllib.request.urlopen(url) + + file_path = join(dest, basename(url)) + with open(file_path, "wb") as f: + f.write(data.read()) + + print("Extracting data") + with tarfile.open(file_path) as tar: + for item in tar: + tar.extract(item, dest) + + remove(file_path) + + return output_image_dir + + +def calculate_valid_crop_size(crop_size, upscale_factor): + return crop_size - (crop_size % upscale_factor) + + +def input_transform(crop_size, upscale_factor): + return Compose( + [ + CenterCrop(crop_size), + Resize(crop_size // upscale_factor), + ToTensor(), + ] + ) + + +def target_transform(crop_size): + return Compose( + [ + CenterCrop(crop_size), + ToTensor(), + ] + ) + + +def get_training_set(upscale_factor): + root_dir = download_bsd300() + train_dir = join(root_dir, "train") + crop_size = calculate_valid_crop_size(256, upscale_factor) + + return DatasetFromFolder( + train_dir, + input_transform=input_transform(crop_size, upscale_factor), + target_transform=target_transform(crop_size), + ) + + +def get_test_set(upscale_factor): + root_dir = download_bsd300() + test_dir = join(root_dir, "test") + crop_size = calculate_valid_crop_size(256, upscale_factor) + + return DatasetFromFolder( + test_dir, + input_transform=input_transform(crop_size, upscale_factor), + target_transform=target_transform(crop_size), + ) diff --git a/examples/super_resolution/dataset.py b/examples/super_resolution/dataset.py new file mode 100644 index 00000000000..a02ce2172df --- /dev/null +++ b/examples/super_resolution/dataset.py @@ -0,0 +1,37 @@ +from os import listdir +from os.path import join + +import torch.utils.data as data +from PIL import Image + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg"]) + + +def load_img(filepath): + img = Image.open(filepath).convert("YCbCr") + y, _, _ = img.split() + return y + + +class DatasetFromFolder(data.Dataset): + def __init__(self, image_dir, input_transform=None, target_transform=None): + super(DatasetFromFolder, self).__init__() + self.image_filenames = [join(image_dir, x) for x in listdir(image_dir) if is_image_file(x)] + + self.input_transform = input_transform + self.target_transform = target_transform + + def __getitem__(self, index): + input = load_img(self.image_filenames[index]) + target = input.copy() + if self.input_transform: + input = self.input_transform(input) + if self.target_transform: + target = self.target_transform(target) + + return input, target + + def __len__(self): + return len(self.image_filenames) diff --git a/examples/super_resolution/main.py b/examples/super_resolution/main.py new file mode 100644 index 00000000000..28115e2da74 --- /dev/null +++ b/examples/super_resolution/main.py @@ -0,0 +1,115 @@ +from __future__ import print_function + +import argparse +from math import log10 + +import torch +import torch.nn as nn +import torch.optim as optim +from data import get_test_set, get_training_set +from model import Net +from torch.utils.data import DataLoader + +from ignite.engine import Engine, Events +from ignite.handlers import Checkpoint +from ignite.metrics import PSNR + +# Training settings +parser = argparse.ArgumentParser(description="PyTorch Super Res Example") +parser.add_argument("--upscale_factor", type=int, required=True, help="super resolution upscale factor") +parser.add_argument("--batchSize", type=int, default=64, help="training batch size") +parser.add_argument("--testBatchSize", type=int, default=10, help="testing batch size") +parser.add_argument("--nEpochs", type=int, default=2, help="number of epochs to train for") +parser.add_argument("--lr", type=float, default=0.01, help="Learning Rate. Default=0.01") +parser.add_argument("--cuda", action="store_true", help="use cuda?") +parser.add_argument("--mps", action="store_true", default=False, help="enables macOS GPU training") +parser.add_argument("--threads", type=int, default=4, help="number of threads for data loader to use") +parser.add_argument("--seed", type=int, default=123, help="random seed to use. Default=123") +opt = parser.parse_args() + +print(opt) + +if opt.cuda and not torch.cuda.is_available(): + raise Exception("No GPU found, please run without --cuda") +if not opt.mps and torch.backends.mps.is_available(): + raise Exception("Found mps device, please run with --mps to enable macOS GPU") + +torch.manual_seed(opt.seed) +use_mps = opt.mps and torch.backends.mps.is_available() + +if opt.cuda: + device = torch.device("cuda") +elif use_mps: + device = torch.device("mps") +else: + device = torch.device("cpu") + +print("===> Loading datasets") +train_set = get_training_set(opt.upscale_factor) +test_set = get_test_set(opt.upscale_factor) +training_data_loader = DataLoader(dataset=train_set, num_workers=opt.threads, batch_size=opt.batchSize, shuffle=True) +testing_data_loader = DataLoader(dataset=test_set, num_workers=opt.threads, batch_size=opt.testBatchSize, shuffle=False) + +print("===> Building model") +model = Net(upscale_factor=opt.upscale_factor).to(device) +criterion = nn.MSELoss() + +optimizer = optim.Adam(model.parameters(), lr=opt.lr) + + +def train_step(engine, batch): + input, target = batch[0].to(device), batch[1].to(device) + + optimizer.zero_grad() + loss = criterion(model(input), target) + loss.backward() + optimizer.step() + + return loss.item() + + +def validation_step(engine, batch): + model.eval() + with torch.no_grad(): + x, y = batch[0].to(device), batch[1].to(device) + y_pred = model(x) + + return y_pred, y + + +trainer = Engine(train_step) +evaluator = Engine(validation_step) +psnr = PSNR(data_range=1) +psnr.attach(evaluator, "psnr") +validate_every = 1 +log_interval = 10 + + +@trainer.on(Events.ITERATION_COMPLETED(every=log_interval)) +def log_training_loss(engine): + print( + "===> Epoch[{}]({}/{}): Loss: {:.4f}".format( + engine.state.epoch, engine.state.iteration, len(training_data_loader), engine.state.output + ) + ) + + +@trainer.on(Events.EPOCH_COMPLETED(every=validate_every)) +def run_validation(): + evaluator.run(testing_data_loader) + + +@trainer.on(Events.EPOCH_COMPLETED(every=validate_every)) +def log_validation(): + metrics = evaluator.state.metrics + print(f"Epoch: {trainer.state.epoch}, Avg. PSNR: {metrics['psnr']} dB") + + +@trainer.on(Events.EPOCH_COMPLETED) +def checkpoint(): + model_out_path = "model_epoch_{}.pth".format(trainer.state.epoch) + torch.save(model, model_out_path) + print("Checkpoint saved to {}".format(model_out_path)) + + +trainer.run(training_data_loader, opt.nEpochs) diff --git a/examples/super_resolution/model.py b/examples/super_resolution/model.py new file mode 100644 index 00000000000..5ad5418c16f --- /dev/null +++ b/examples/super_resolution/model.py @@ -0,0 +1,30 @@ +import torch +import torch.nn as nn +import torch.nn.init as init + + +class Net(nn.Module): + def __init__(self, upscale_factor): + super(Net, self).__init__() + + self.relu = nn.ReLU() + self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2)) + self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)) + self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1)) + self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1)) + self.pixel_shuffle = nn.PixelShuffle(upscale_factor) + + self._initialize_weights() + + def forward(self, x): + x = self.relu(self.conv1(x)) + x = self.relu(self.conv2(x)) + x = self.relu(self.conv3(x)) + x = self.pixel_shuffle(self.conv4(x)) + return x + + def _initialize_weights(self): + init.orthogonal_(self.conv1.weight, init.calculate_gain("relu")) + init.orthogonal_(self.conv2.weight, init.calculate_gain("relu")) + init.orthogonal_(self.conv3.weight, init.calculate_gain("relu")) + init.orthogonal_(self.conv4.weight) diff --git a/examples/super_resolution/super_resolve.py b/examples/super_resolution/super_resolve.py new file mode 100644 index 00000000000..8b9e5ea1a1d --- /dev/null +++ b/examples/super_resolution/super_resolve.py @@ -0,0 +1,42 @@ +from __future__ import print_function + +import argparse + +import numpy as np +import torch +from PIL import Image +from torchvision.transforms import ToTensor + +# Training settings +parser = argparse.ArgumentParser(description="PyTorch Super Res Example") +parser.add_argument("--input_image", type=str, required=True, help="input image to use") +parser.add_argument("--model", type=str, required=True, help="model file to use") +parser.add_argument("--output_filename", type=str, help="where to save the output image") +parser.add_argument("--cuda", action="store_true", help="use cuda") +opt = parser.parse_args() + +print(opt) +img = Image.open(opt.input_image).convert("YCbCr") +y, cb, cr = img.split() + +model = torch.load(opt.model) +img_to_tensor = ToTensor() +input = img_to_tensor(y).view(1, -1, y.size[1], y.size[0]) + +if opt.cuda: + model = model.cuda() + input = input.cuda() + +out = model(input) +out = out.cpu() +out_img_y = out[0].detach().numpy() +out_img_y *= 255.0 +out_img_y = out_img_y.clip(0, 255) +out_img_y = Image.fromarray(np.uint8(out_img_y[0]), mode="L") + +out_img_cb = cb.resize(out_img_y.size, Image.BICUBIC) +out_img_cr = cr.resize(out_img_y.size, Image.BICUBIC) +out_img = Image.merge("YCbCr", [out_img_y, out_img_cb, out_img_cr]).convert("RGB") + +out_img.save(opt.output_filename) +print("output image saved to ", opt.output_filename) From 1b0baf3522a2a7b43ea1d662d8f1d5a7bf08662e Mon Sep 17 00:00:00 2001 From: Aryan Gupta Date: Sat, 4 Mar 2023 00:29:31 +0530 Subject: [PATCH 02/13] Made some changes --- examples/super_resolution/main.py | 2 -- examples/super_resolution/model.py | 1 - 2 files changed, 3 deletions(-) diff --git a/examples/super_resolution/main.py b/examples/super_resolution/main.py index 28115e2da74..0e12a5bbad9 100644 --- a/examples/super_resolution/main.py +++ b/examples/super_resolution/main.py @@ -1,7 +1,6 @@ from __future__ import print_function import argparse -from math import log10 import torch import torch.nn as nn @@ -11,7 +10,6 @@ from torch.utils.data import DataLoader from ignite.engine import Engine, Events -from ignite.handlers import Checkpoint from ignite.metrics import PSNR # Training settings diff --git a/examples/super_resolution/model.py b/examples/super_resolution/model.py index 5ad5418c16f..1f80c95d064 100644 --- a/examples/super_resolution/model.py +++ b/examples/super_resolution/model.py @@ -1,4 +1,3 @@ -import torch import torch.nn as nn import torch.nn.init as init From 7ebee4908c42db43c374e890cf4922e47452cd29 Mon Sep 17 00:00:00 2001 From: Aryan Gupta Date: Tue, 7 Mar 2023 01:34:50 +0530 Subject: [PATCH 03/13] Made some changes --- examples/super_resolution/README.md | 4 ++-- examples/super_resolution/data.py | 2 +- examples/super_resolution/main.py | 2 -- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/examples/super_resolution/README.md b/examples/super_resolution/README.md index d86594ab0e0..39d462f993f 100644 --- a/examples/super_resolution/README.md +++ b/examples/super_resolution/README.md @@ -1,4 +1,4 @@ -# Superresolution using an efficient sub-pixel convolutional neural network +# Super-Resolution using an efficient sub-pixel convolutional neural network ported from [pytorch-examples](https://github.com/pytorch/examples/tree/main/super_resolution) @@ -24,7 +24,7 @@ optional arguments: --seed random seed to use. Default=123 ``` -This example trains a super-resolution network on the [BSD300 dataset](https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/), using crops from the 200 training images, and evaluating on crops of the 100 test images. A snapshot of the model after every epoch with filename model*epoch*.pth +This example trains a super-resolution network on the [BSD300 dataset](https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/), using crops from the 200 training images, and evaluating on crops of the 100 test images. A snapshot of the model after every epoch with filename `model_epoch_.pth` ## Example Usage: diff --git a/examples/super_resolution/data.py b/examples/super_resolution/data.py index e199a86c87c..38372973680 100644 --- a/examples/super_resolution/data.py +++ b/examples/super_resolution/data.py @@ -1,9 +1,9 @@ import tarfile +import urllib from os import makedirs, remove from os.path import basename, exists, join from dataset import DatasetFromFolder -from six.moves import urllib from torchvision.transforms import CenterCrop, Compose, Resize, ToTensor diff --git a/examples/super_resolution/main.py b/examples/super_resolution/main.py index 0e12a5bbad9..211eaf336b9 100644 --- a/examples/super_resolution/main.py +++ b/examples/super_resolution/main.py @@ -1,5 +1,3 @@ -from __future__ import print_function - import argparse import torch From 3982d7b7cdd083dbfd270f1821e9bd82ed463d11 Mon Sep 17 00:00:00 2001 From: Aryan Gupta Date: Thu, 16 Mar 2023 00:18:14 +0530 Subject: [PATCH 04/13] Add the time profiling features --- examples/super_resolution/main.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/examples/super_resolution/main.py b/examples/super_resolution/main.py index 211eaf336b9..b9794334ef4 100644 --- a/examples/super_resolution/main.py +++ b/examples/super_resolution/main.py @@ -101,6 +101,16 @@ def log_validation(): print(f"Epoch: {trainer.state.epoch}, Avg. PSNR: {metrics['psnr']} dB") +@trainer.on(Events.EPOCH_COMPLETED) +def log_epoch_time(): + print(f"Epoch {trainer.state.epoch}, Time Taken : {trainer.state.times['EPOCH_COMPLETED']}") + + +@trainer.on(Events.COMPLETED) +def log_total_time(): + print(f"Total Time: {trainer.state.times['COMPLETED']}") + + @trainer.on(Events.EPOCH_COMPLETED) def checkpoint(): model_out_path = "model_epoch_{}.pth".format(trainer.state.epoch) From 982a0ebd4b8772efa002b5dc30efc86b4b706483 Mon Sep 17 00:00:00 2001 From: Aryan Gupta Date: Sat, 18 Mar 2023 00:53:24 +0530 Subject: [PATCH 05/13] Added torchvision dataset --- examples/super_resolution/data.py | 78 ---------------------------- examples/super_resolution/dataset.py | 37 ------------- examples/super_resolution/main.py | 37 +++++++++++-- 3 files changed, 32 insertions(+), 120 deletions(-) delete mode 100644 examples/super_resolution/data.py delete mode 100644 examples/super_resolution/dataset.py diff --git a/examples/super_resolution/data.py b/examples/super_resolution/data.py deleted file mode 100644 index 38372973680..00000000000 --- a/examples/super_resolution/data.py +++ /dev/null @@ -1,78 +0,0 @@ -import tarfile -import urllib -from os import makedirs, remove -from os.path import basename, exists, join - -from dataset import DatasetFromFolder -from torchvision.transforms import CenterCrop, Compose, Resize, ToTensor - - -def download_bsd300(dest="dataset"): - output_image_dir = join(dest, "BSDS300/images") - - if not exists(output_image_dir): - makedirs(dest) - url = "http://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/BSDS300-images.tgz" - print("downloading url ", url) - - data = urllib.request.urlopen(url) - - file_path = join(dest, basename(url)) - with open(file_path, "wb") as f: - f.write(data.read()) - - print("Extracting data") - with tarfile.open(file_path) as tar: - for item in tar: - tar.extract(item, dest) - - remove(file_path) - - return output_image_dir - - -def calculate_valid_crop_size(crop_size, upscale_factor): - return crop_size - (crop_size % upscale_factor) - - -def input_transform(crop_size, upscale_factor): - return Compose( - [ - CenterCrop(crop_size), - Resize(crop_size // upscale_factor), - ToTensor(), - ] - ) - - -def target_transform(crop_size): - return Compose( - [ - CenterCrop(crop_size), - ToTensor(), - ] - ) - - -def get_training_set(upscale_factor): - root_dir = download_bsd300() - train_dir = join(root_dir, "train") - crop_size = calculate_valid_crop_size(256, upscale_factor) - - return DatasetFromFolder( - train_dir, - input_transform=input_transform(crop_size, upscale_factor), - target_transform=target_transform(crop_size), - ) - - -def get_test_set(upscale_factor): - root_dir = download_bsd300() - test_dir = join(root_dir, "test") - crop_size = calculate_valid_crop_size(256, upscale_factor) - - return DatasetFromFolder( - test_dir, - input_transform=input_transform(crop_size, upscale_factor), - target_transform=target_transform(crop_size), - ) diff --git a/examples/super_resolution/dataset.py b/examples/super_resolution/dataset.py deleted file mode 100644 index a02ce2172df..00000000000 --- a/examples/super_resolution/dataset.py +++ /dev/null @@ -1,37 +0,0 @@ -from os import listdir -from os.path import join - -import torch.utils.data as data -from PIL import Image - - -def is_image_file(filename): - return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg"]) - - -def load_img(filepath): - img = Image.open(filepath).convert("YCbCr") - y, _, _ = img.split() - return y - - -class DatasetFromFolder(data.Dataset): - def __init__(self, image_dir, input_transform=None, target_transform=None): - super(DatasetFromFolder, self).__init__() - self.image_filenames = [join(image_dir, x) for x in listdir(image_dir) if is_image_file(x)] - - self.input_transform = input_transform - self.target_transform = target_transform - - def __getitem__(self, index): - input = load_img(self.image_filenames[index]) - target = input.copy() - if self.input_transform: - input = self.input_transform(input) - if self.target_transform: - target = self.target_transform(target) - - return input, target - - def __len__(self): - return len(self.image_filenames) diff --git a/examples/super_resolution/main.py b/examples/super_resolution/main.py index b9794334ef4..8839006a7b7 100644 --- a/examples/super_resolution/main.py +++ b/examples/super_resolution/main.py @@ -3,7 +3,8 @@ import torch import torch.nn as nn import torch.optim as optim -from data import get_test_set, get_training_set +import torchvision +import torchvision.transforms as transforms from model import Net from torch.utils.data import DataLoader @@ -41,10 +42,36 @@ device = torch.device("cpu") print("===> Loading datasets") -train_set = get_training_set(opt.upscale_factor) -test_set = get_test_set(opt.upscale_factor) -training_data_loader = DataLoader(dataset=train_set, num_workers=opt.threads, batch_size=opt.batchSize, shuffle=True) -testing_data_loader = DataLoader(dataset=test_set, num_workers=opt.threads, batch_size=opt.testBatchSize, shuffle=False) + + +class SRDataset(torch.utils.data.Dataset): + def __init__(self, dataset, scale_factor): + self.dataset = dataset + self.transform = transforms.Resize( + (len(dataset[0][0][0]) * scale_factor, len(dataset[0][0][0][0]) * scale_factor) + ) + + def __getitem__(self, index): + lr_image, _ = self.dataset[index] + hr_image = self.transform(lr_image) + return lr_image, hr_image + + def __len__(self): + return len(self.dataset) + + +transform = transforms.Compose([transforms.ToTensor()]) + +trainset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transform) +testset = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=transform) + +trainset_sr = SRDataset(trainset, scale_factor=opt.upscale_factor) +testset_sr = SRDataset(testset, scale_factor=opt.upscale_factor) + +training_data_loader = DataLoader(dataset=trainset_sr, num_workers=opt.threads, batch_size=opt.batchSize, shuffle=True) +testing_data_loader = DataLoader( + dataset=testset_sr, num_workers=opt.threads, batch_size=opt.testBatchSize, shuffle=False +) print("===> Building model") model = Net(upscale_factor=opt.upscale_factor).to(device) From 0cd5c59d973739940ac77bd906e39b2dabcb0621 Mon Sep 17 00:00:00 2001 From: Aryan Gupta Date: Sat, 18 Mar 2023 01:31:25 +0530 Subject: [PATCH 06/13] Changed the dataset used in README to cifar10 --- examples/super_resolution/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/super_resolution/README.md b/examples/super_resolution/README.md index 39d462f993f..9674940aaf9 100644 --- a/examples/super_resolution/README.md +++ b/examples/super_resolution/README.md @@ -24,7 +24,7 @@ optional arguments: --seed random seed to use. Default=123 ``` -This example trains a super-resolution network on the [BSD300 dataset](https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/), using crops from the 200 training images, and evaluating on crops of the 100 test images. A snapshot of the model after every epoch with filename `model_epoch_.pth` +This example trains a super-resolution network on the [Cifar10 dataset](https://www.cs.toronto.edu/~kriz/cifar.html). A snapshot of the model after every epoch with filename `model_epoch_.pth` ## Example Usage: From 7bcea2f5c2dbc705864a6bbc8e39df625a77b032 Mon Sep 17 00:00:00 2001 From: Aryan Gupta Date: Tue, 21 Mar 2023 00:11:22 +0530 Subject: [PATCH 07/13] Used snake case in arguments --- examples/super_resolution/README.md | 2 +- examples/super_resolution/main.py | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/examples/super_resolution/README.md b/examples/super_resolution/README.md index 9674940aaf9..c80ab42d653 100644 --- a/examples/super_resolution/README.md +++ b/examples/super_resolution/README.md @@ -30,7 +30,7 @@ This example trains a super-resolution network on the [Cifar10 dataset](https:// ### Train -`python main.py --upscale_factor 3 --batchSize 4 --testBatchSize 100 --nEpochs 30 --lr 0.001` +`python main.py --upscale_factor 3 --batch_size 4 --test_batch_size 100 --n_epochs 30 --lr 0.001` ### Super Resolve diff --git a/examples/super_resolution/main.py b/examples/super_resolution/main.py index 8839006a7b7..dcba9334345 100644 --- a/examples/super_resolution/main.py +++ b/examples/super_resolution/main.py @@ -14,9 +14,9 @@ # Training settings parser = argparse.ArgumentParser(description="PyTorch Super Res Example") parser.add_argument("--upscale_factor", type=int, required=True, help="super resolution upscale factor") -parser.add_argument("--batchSize", type=int, default=64, help="training batch size") -parser.add_argument("--testBatchSize", type=int, default=10, help="testing batch size") -parser.add_argument("--nEpochs", type=int, default=2, help="number of epochs to train for") +parser.add_argument("--batch_size", type=int, default=64, help="training batch size") +parser.add_argument("--test_batch_size", type=int, default=10, help="testing batch size") +parser.add_argument("--n_epochs", type=int, default=2, help="number of epochs to train for") parser.add_argument("--lr", type=float, default=0.01, help="Learning Rate. Default=0.01") parser.add_argument("--cuda", action="store_true", help="use cuda?") parser.add_argument("--mps", action="store_true", default=False, help="enables macOS GPU training") @@ -68,9 +68,9 @@ def __len__(self): trainset_sr = SRDataset(trainset, scale_factor=opt.upscale_factor) testset_sr = SRDataset(testset, scale_factor=opt.upscale_factor) -training_data_loader = DataLoader(dataset=trainset_sr, num_workers=opt.threads, batch_size=opt.batchSize, shuffle=True) +training_data_loader = DataLoader(dataset=trainset_sr, num_workers=opt.threads, batch_size=opt.batch_size, shuffle=True) testing_data_loader = DataLoader( - dataset=testset_sr, num_workers=opt.threads, batch_size=opt.testBatchSize, shuffle=False + dataset=testset_sr, num_workers=opt.threads, batch_size=opt.test_batch_size, shuffle=False ) print("===> Building model") @@ -105,7 +105,7 @@ def validation_step(engine, batch): psnr = PSNR(data_range=1) psnr.attach(evaluator, "psnr") validate_every = 1 -log_interval = 10 +log_interval = 100 @trainer.on(Events.ITERATION_COMPLETED(every=log_interval)) @@ -145,4 +145,4 @@ def checkpoint(): print("Checkpoint saved to {}".format(model_out_path)) -trainer.run(training_data_loader, opt.nEpochs) +trainer.run(training_data_loader, opt.n_epochs) From 698d76f43a0c0cea94e6869e1291a0df53446f5e Mon Sep 17 00:00:00 2001 From: Aryan Gupta Date: Tue, 21 Mar 2023 01:58:37 +0530 Subject: [PATCH 08/13] Made some changes --- examples/super_resolution/main.py | 69 +++++++++++++++++++++++-------- 1 file changed, 52 insertions(+), 17 deletions(-) diff --git a/examples/super_resolution/main.py b/examples/super_resolution/main.py index dcba9334345..22addc08fbf 100644 --- a/examples/super_resolution/main.py +++ b/examples/super_resolution/main.py @@ -4,9 +4,9 @@ import torch.nn as nn import torch.optim as optim import torchvision -import torchvision.transforms as transforms from model import Net from torch.utils.data import DataLoader +from torchvision.transforms import CenterCrop, Compose, Resize, ToTensor from ignite.engine import Engine, Events from ignite.metrics import PSNR @@ -45,28 +45,67 @@ class SRDataset(torch.utils.data.Dataset): - def __init__(self, dataset, scale_factor): + def __init__(self, dataset, scale_factor, input_transform=None, target_transform=None): self.dataset = dataset - self.transform = transforms.Resize( - (len(dataset[0][0][0]) * scale_factor, len(dataset[0][0][0][0]) * scale_factor) - ) + self.input_transform = input_transform + self.target_transform = target_transform def __getitem__(self, index): - lr_image, _ = self.dataset[index] - hr_image = self.transform(lr_image) + image, _ = self.dataset[index] + img = image.convert("YCbCr") + lr_image, _, _ = img.split() + + hr_image = lr_image.copy() + if self.input_transform: + lr_image = self.input_transform(lr_image) + if self.target_transform: + hr_image = self.target_transform(hr_image) return lr_image, hr_image def __len__(self): return len(self.dataset) -transform = transforms.Compose([transforms.ToTensor()]) +def calculate_valid_crop_size(crop_size, upscale_factor): + return crop_size - (crop_size % upscale_factor) + + +def input_transform(crop_size, upscale_factor): + return Compose( + [ + CenterCrop(crop_size), + Resize(crop_size // upscale_factor), + ToTensor(), + ] + ) + + +def target_transform(crop_size): + return Compose( + [ + CenterCrop(crop_size), + ToTensor(), + ] + ) + + +crop_size = calculate_valid_crop_size(256, opt.upscale_factor) -trainset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transform) -testset = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=transform) +trainset = torchvision.datasets.Caltech101(root="./data", download=True) +testset = torchvision.datasets.Caltech101(root="./data", download=False) -trainset_sr = SRDataset(trainset, scale_factor=opt.upscale_factor) -testset_sr = SRDataset(testset, scale_factor=opt.upscale_factor) +trainset_sr = SRDataset( + trainset, + scale_factor=opt.upscale_factor, + input_transform=input_transform(crop_size, opt.upscale_factor), + target_transform=target_transform(crop_size), +) +testset_sr = SRDataset( + testset, + scale_factor=opt.upscale_factor, + input_transform=input_transform(crop_size, opt.upscale_factor), + target_transform=target_transform(crop_size), +) training_data_loader = DataLoader(dataset=trainset_sr, num_workers=opt.threads, batch_size=opt.batch_size, shuffle=True) testing_data_loader = DataLoader( @@ -117,13 +156,9 @@ def log_training_loss(engine): ) -@trainer.on(Events.EPOCH_COMPLETED(every=validate_every)) -def run_validation(): - evaluator.run(testing_data_loader) - - @trainer.on(Events.EPOCH_COMPLETED(every=validate_every)) def log_validation(): + evaluator.run(testing_data_loader) metrics = evaluator.state.metrics print(f"Epoch: {trainer.state.epoch}, Avg. PSNR: {metrics['psnr']} dB") From 51f47b4c6d6f13f05ffbf7937292680b109b0a6e Mon Sep 17 00:00:00 2001 From: Aryan Gupta Date: Tue, 21 Mar 2023 02:34:43 +0530 Subject: [PATCH 09/13] Make some formatting changes --- examples/super_resolution/README.md | 6 +++--- examples/super_resolution/super_resolve.py | 2 -- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/examples/super_resolution/README.md b/examples/super_resolution/README.md index c80ab42d653..1d4003b8260 100644 --- a/examples/super_resolution/README.md +++ b/examples/super_resolution/README.md @@ -14,9 +14,9 @@ PyTorch Super Res Example optional arguments: -h, --help show this help message and exit --upscale_factor super resolution upscale factor - --batchSize training batch size - --testBatchSize testing batch size - --nEpochs number of epochs to train for + --batch_size training batch size + --test_batch_size testing batch size + --n_epochs number of epochs to train for --lr Learning Rate. Default=0.01 --cuda use cuda --mps enable GPU on macOS diff --git a/examples/super_resolution/super_resolve.py b/examples/super_resolution/super_resolve.py index 8b9e5ea1a1d..964d7a1344d 100644 --- a/examples/super_resolution/super_resolve.py +++ b/examples/super_resolution/super_resolve.py @@ -1,5 +1,3 @@ -from __future__ import print_function - import argparse import numpy as np From 235c908a2056c1a1840c3cd73588ab6b3c0ad5b0 Mon Sep 17 00:00:00 2001 From: Aryan Gupta Date: Tue, 21 Mar 2023 02:43:21 +0530 Subject: [PATCH 10/13] Make the formatting changes --- examples/super_resolution/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/super_resolution/README.md b/examples/super_resolution/README.md index 1d4003b8260..8292e5b9450 100644 --- a/examples/super_resolution/README.md +++ b/examples/super_resolution/README.md @@ -5,8 +5,8 @@ ported from [pytorch-examples](https://github.com/pytorch/examples/tree/main/sup This example illustrates how to use the efficient sub-pixel convolution layer described in ["Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network" - Shi et al.](https://arxiv.org/abs/1609.05158) for increasing spatial resolution within your network for tasks such as superresolution. ``` -usage: main.py [-h] --upscale_factor UPSCALE_FACTOR [--batchSize BATCHSIZE] - [--testBatchSize TESTBATCHSIZE] [--nEpochs NEPOCHS] [--lr LR] +usage: main.py [-h] --upscale_factor UPSCALE_FACTOR [--batch_size BATCHSIZE] + [--test_batch_size TESTBATCHSIZE] [--n_epochs NEPOCHS] [--lr LR] [--cuda] [--threads THREADS] [--seed SEED] PyTorch Super Res Example From 3b2fde9af74c8e21b4bc58340ab2aa9a0c5c0589 Mon Sep 17 00:00:00 2001 From: Aryan Gupta Date: Tue, 21 Mar 2023 03:31:06 +0530 Subject: [PATCH 11/13] some changes --- examples/super_resolution/README.md | 4 +- examples/super_resolution/main.py | 56 ++++++---------------- examples/super_resolution/super_resolve.py | 5 +- 3 files changed, 18 insertions(+), 47 deletions(-) diff --git a/examples/super_resolution/README.md b/examples/super_resolution/README.md index 8292e5b9450..f9be6c92f56 100644 --- a/examples/super_resolution/README.md +++ b/examples/super_resolution/README.md @@ -24,7 +24,7 @@ optional arguments: --seed random seed to use. Default=123 ``` -This example trains a super-resolution network on the [Cifar10 dataset](https://www.cs.toronto.edu/~kriz/cifar.html). A snapshot of the model after every epoch with filename `model_epoch_.pth` +This example trains a super-resolution network on the [Caltech101 dataset](https://pytorch.org/vision/main/generated/torchvision.datasets.Caltech101.html). A snapshot of the model after every epoch with filename `model_epoch_.pth` ## Example Usage: @@ -34,4 +34,4 @@ This example trains a super-resolution network on the [Cifar10 dataset](https:// ### Super Resolve -`python super_resolve.py --input_image dataset/BSDS300/images/test/16077.jpg --model model_epoch_500.pth --output_filename out.png` +`python super_resolve.py --input_image .jpg --model model_epoch_500.pth --output_filename out.png` diff --git a/examples/super_resolution/main.py b/examples/super_resolution/main.py index 22addc08fbf..f39b4629b22 100644 --- a/examples/super_resolution/main.py +++ b/examples/super_resolution/main.py @@ -6,7 +6,7 @@ import torchvision from model import Net from torch.utils.data import DataLoader -from torchvision.transforms import CenterCrop, Compose, Resize, ToTensor +from torchvision.transforms.functional import center_crop, resize, to_tensor from ignite.engine import Engine, Events from ignite.metrics import PSNR @@ -45,21 +45,22 @@ class SRDataset(torch.utils.data.Dataset): - def __init__(self, dataset, scale_factor, input_transform=None, target_transform=None): + def __init__(self, dataset, scale_factor, crop_size=256): self.dataset = dataset - self.input_transform = input_transform - self.target_transform = target_transform + self.scale_factor = scale_factor + self.crop_size = crop_size def __getitem__(self, index): image, _ = self.dataset[index] img = image.convert("YCbCr") - lr_image, _, _ = img.split() - - hr_image = lr_image.copy() - if self.input_transform: - lr_image = self.input_transform(lr_image) - if self.target_transform: - hr_image = self.target_transform(hr_image) + hr_image, _, _ = img.split() + hr_image = center_crop(hr_image, self.crop_size) + lr_image = hr_image.copy() + if self.scale_factor != 1: + dim = self.crop_size // self.scale_factor + lr_image = resize(lr_image, [dim, dim]) + hr_image = to_tensor(hr_image) + lr_image = to_tensor(lr_image) return lr_image, hr_image def __len__(self): @@ -70,42 +71,13 @@ def calculate_valid_crop_size(crop_size, upscale_factor): return crop_size - (crop_size % upscale_factor) -def input_transform(crop_size, upscale_factor): - return Compose( - [ - CenterCrop(crop_size), - Resize(crop_size // upscale_factor), - ToTensor(), - ] - ) - - -def target_transform(crop_size): - return Compose( - [ - CenterCrop(crop_size), - ToTensor(), - ] - ) - - crop_size = calculate_valid_crop_size(256, opt.upscale_factor) trainset = torchvision.datasets.Caltech101(root="./data", download=True) testset = torchvision.datasets.Caltech101(root="./data", download=False) -trainset_sr = SRDataset( - trainset, - scale_factor=opt.upscale_factor, - input_transform=input_transform(crop_size, opt.upscale_factor), - target_transform=target_transform(crop_size), -) -testset_sr = SRDataset( - testset, - scale_factor=opt.upscale_factor, - input_transform=input_transform(crop_size, opt.upscale_factor), - target_transform=target_transform(crop_size), -) +trainset_sr = SRDataset(trainset, scale_factor=opt.upscale_factor, crop_size=crop_size) +testset_sr = SRDataset(testset, scale_factor=opt.upscale_factor, crop_size=crop_size) training_data_loader = DataLoader(dataset=trainset_sr, num_workers=opt.threads, batch_size=opt.batch_size, shuffle=True) testing_data_loader = DataLoader( diff --git a/examples/super_resolution/super_resolve.py b/examples/super_resolution/super_resolve.py index 964d7a1344d..5c5f3c87acc 100644 --- a/examples/super_resolution/super_resolve.py +++ b/examples/super_resolution/super_resolve.py @@ -3,7 +3,7 @@ import numpy as np import torch from PIL import Image -from torchvision.transforms import ToTensor +from torchvision.transforms.functional import to_tensor # Training settings parser = argparse.ArgumentParser(description="PyTorch Super Res Example") @@ -18,8 +18,7 @@ y, cb, cr = img.split() model = torch.load(opt.model) -img_to_tensor = ToTensor() -input = img_to_tensor(y).view(1, -1, y.size[1], y.size[0]) +input = to_tensor(y).view(1, -1, y.size[1], y.size[0]) if opt.cuda: model = model.cuda() From 0e2f9a3a979edacb24ab4c55ac86b469d59c176c Mon Sep 17 00:00:00 2001 From: Aryan Gupta Date: Tue, 21 Mar 2023 18:16:04 +0530 Subject: [PATCH 12/13] update the crop method --- examples/super_resolution/main.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/examples/super_resolution/main.py b/examples/super_resolution/main.py index f39b4629b22..db9c1d94148 100644 --- a/examples/super_resolution/main.py +++ b/examples/super_resolution/main.py @@ -67,17 +67,11 @@ def __len__(self): return len(self.dataset) -def calculate_valid_crop_size(crop_size, upscale_factor): - return crop_size - (crop_size % upscale_factor) - - -crop_size = calculate_valid_crop_size(256, opt.upscale_factor) - trainset = torchvision.datasets.Caltech101(root="./data", download=True) testset = torchvision.datasets.Caltech101(root="./data", download=False) -trainset_sr = SRDataset(trainset, scale_factor=opt.upscale_factor, crop_size=crop_size) -testset_sr = SRDataset(testset, scale_factor=opt.upscale_factor, crop_size=crop_size) +trainset_sr = SRDataset(trainset, scale_factor=opt.upscale_factor) +testset_sr = SRDataset(testset, scale_factor=opt.upscale_factor) training_data_loader = DataLoader(dataset=trainset_sr, num_workers=opt.threads, batch_size=opt.batch_size, shuffle=True) testing_data_loader = DataLoader( From 3d9dda7ebe7ce0c93fa656b1a193200666d81d6a Mon Sep 17 00:00:00 2001 From: Aryan Gupta Date: Tue, 21 Mar 2023 21:39:37 +0530 Subject: [PATCH 13/13] Made the suggested changes --- examples/super_resolution/main.py | 9 ++++----- examples/super_resolution/super_resolve.py | 4 +++- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/examples/super_resolution/main.py b/examples/super_resolution/main.py index db9c1d94148..d46deec1701 100644 --- a/examples/super_resolution/main.py +++ b/examples/super_resolution/main.py @@ -57,8 +57,8 @@ def __getitem__(self, index): hr_image = center_crop(hr_image, self.crop_size) lr_image = hr_image.copy() if self.scale_factor != 1: - dim = self.crop_size // self.scale_factor - lr_image = resize(lr_image, [dim, dim]) + size = self.crop_size // self.scale_factor + lr_image = resize(lr_image, [size, size]) hr_image = to_tensor(hr_image) lr_image = to_tensor(lr_image) return lr_image, hr_image @@ -74,9 +74,7 @@ def __len__(self): testset_sr = SRDataset(testset, scale_factor=opt.upscale_factor) training_data_loader = DataLoader(dataset=trainset_sr, num_workers=opt.threads, batch_size=opt.batch_size, shuffle=True) -testing_data_loader = DataLoader( - dataset=testset_sr, num_workers=opt.threads, batch_size=opt.test_batch_size, shuffle=False -) +testing_data_loader = DataLoader(dataset=testset_sr, num_workers=opt.threads, batch_size=opt.test_batch_size) print("===> Building model") model = Net(upscale_factor=opt.upscale_factor).to(device) @@ -86,6 +84,7 @@ def __len__(self): def train_step(engine, batch): + model.train() input, target = batch[0].to(device), batch[1].to(device) optimizer.zero_grad() diff --git a/examples/super_resolution/super_resolve.py b/examples/super_resolution/super_resolve.py index 5c5f3c87acc..05c84103769 100644 --- a/examples/super_resolution/super_resolve.py +++ b/examples/super_resolution/super_resolve.py @@ -24,7 +24,9 @@ model = model.cuda() input = input.cuda() -out = model(input) +model.eval() +with torch.no_grad(): + out = model(input) out = out.cpu() out_img_y = out[0].detach().numpy() out_img_y *= 255.0