Skip to content

Commit

Permalink
Add the example of super_resolution (#2885)
Browse files Browse the repository at this point in the history
* Add the example for Super-Resolution

* Made some changes

* Made some changes

* Add the time profiling features

* Added torchvision dataset

* Changed the dataset used in README to cifar10

* Used snake case in arguments

* Made some changes

* Make some formatting changes

* Make the formatting changes

* some changes

* update the crop method

* Made the suggested changes
  • Loading branch information
guptaaryan16 authored Mar 21, 2023
1 parent 107a282 commit caff6c2
Show file tree
Hide file tree
Showing 4 changed files with 255 additions and 0 deletions.
37 changes: 37 additions & 0 deletions examples/super_resolution/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Super-Resolution using an efficient sub-pixel convolutional neural network

ported from [pytorch-examples](https://github.com/pytorch/examples/tree/main/super_resolution)

This example illustrates how to use the efficient sub-pixel convolution layer described in ["Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network" - Shi et al.](https://arxiv.org/abs/1609.05158) for increasing spatial resolution within your network for tasks such as superresolution.

```
usage: main.py [-h] --upscale_factor UPSCALE_FACTOR [--batch_size BATCHSIZE]
[--test_batch_size TESTBATCHSIZE] [--n_epochs NEPOCHS] [--lr LR]
[--cuda] [--threads THREADS] [--seed SEED]
PyTorch Super Res Example
optional arguments:
-h, --help show this help message and exit
--upscale_factor super resolution upscale factor
--batch_size training batch size
--test_batch_size testing batch size
--n_epochs number of epochs to train for
--lr Learning Rate. Default=0.01
--cuda use cuda
--mps enable GPU on macOS
--threads number of threads for data loader to use Default=4
--seed random seed to use. Default=123
```

This example trains a super-resolution network on the [Caltech101 dataset](https://pytorch.org/vision/main/generated/torchvision.datasets.Caltech101.html). A snapshot of the model after every epoch with filename `model_epoch_<epoch_number>.pth`

## Example Usage:

### Train

`python main.py --upscale_factor 3 --batch_size 4 --test_batch_size 100 --n_epochs 30 --lr 0.001`

### Super Resolve

`python super_resolve.py --input_image <in>.jpg --model model_epoch_500.pth --output_filename out.png`
148 changes: 148 additions & 0 deletions examples/super_resolution/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import argparse

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from model import Net
from torch.utils.data import DataLoader
from torchvision.transforms.functional import center_crop, resize, to_tensor

from ignite.engine import Engine, Events
from ignite.metrics import PSNR

# Training settings
parser = argparse.ArgumentParser(description="PyTorch Super Res Example")
parser.add_argument("--upscale_factor", type=int, required=True, help="super resolution upscale factor")
parser.add_argument("--batch_size", type=int, default=64, help="training batch size")
parser.add_argument("--test_batch_size", type=int, default=10, help="testing batch size")
parser.add_argument("--n_epochs", type=int, default=2, help="number of epochs to train for")
parser.add_argument("--lr", type=float, default=0.01, help="Learning Rate. Default=0.01")
parser.add_argument("--cuda", action="store_true", help="use cuda?")
parser.add_argument("--mps", action="store_true", default=False, help="enables macOS GPU training")
parser.add_argument("--threads", type=int, default=4, help="number of threads for data loader to use")
parser.add_argument("--seed", type=int, default=123, help="random seed to use. Default=123")
opt = parser.parse_args()

print(opt)

if opt.cuda and not torch.cuda.is_available():
raise Exception("No GPU found, please run without --cuda")
if not opt.mps and torch.backends.mps.is_available():
raise Exception("Found mps device, please run with --mps to enable macOS GPU")

torch.manual_seed(opt.seed)
use_mps = opt.mps and torch.backends.mps.is_available()

if opt.cuda:
device = torch.device("cuda")
elif use_mps:
device = torch.device("mps")
else:
device = torch.device("cpu")

print("===> Loading datasets")


class SRDataset(torch.utils.data.Dataset):
def __init__(self, dataset, scale_factor, crop_size=256):
self.dataset = dataset
self.scale_factor = scale_factor
self.crop_size = crop_size

def __getitem__(self, index):
image, _ = self.dataset[index]
img = image.convert("YCbCr")
hr_image, _, _ = img.split()
hr_image = center_crop(hr_image, self.crop_size)
lr_image = hr_image.copy()
if self.scale_factor != 1:
size = self.crop_size // self.scale_factor
lr_image = resize(lr_image, [size, size])
hr_image = to_tensor(hr_image)
lr_image = to_tensor(lr_image)
return lr_image, hr_image

def __len__(self):
return len(self.dataset)


trainset = torchvision.datasets.Caltech101(root="./data", download=True)
testset = torchvision.datasets.Caltech101(root="./data", download=False)

trainset_sr = SRDataset(trainset, scale_factor=opt.upscale_factor)
testset_sr = SRDataset(testset, scale_factor=opt.upscale_factor)

training_data_loader = DataLoader(dataset=trainset_sr, num_workers=opt.threads, batch_size=opt.batch_size, shuffle=True)
testing_data_loader = DataLoader(dataset=testset_sr, num_workers=opt.threads, batch_size=opt.test_batch_size)

print("===> Building model")
model = Net(upscale_factor=opt.upscale_factor).to(device)
criterion = nn.MSELoss()

optimizer = optim.Adam(model.parameters(), lr=opt.lr)


def train_step(engine, batch):
model.train()
input, target = batch[0].to(device), batch[1].to(device)

optimizer.zero_grad()
loss = criterion(model(input), target)
loss.backward()
optimizer.step()

return loss.item()


def validation_step(engine, batch):
model.eval()
with torch.no_grad():
x, y = batch[0].to(device), batch[1].to(device)
y_pred = model(x)

return y_pred, y


trainer = Engine(train_step)
evaluator = Engine(validation_step)
psnr = PSNR(data_range=1)
psnr.attach(evaluator, "psnr")
validate_every = 1
log_interval = 100


@trainer.on(Events.ITERATION_COMPLETED(every=log_interval))
def log_training_loss(engine):
print(
"===> Epoch[{}]({}/{}): Loss: {:.4f}".format(
engine.state.epoch, engine.state.iteration, len(training_data_loader), engine.state.output
)
)


@trainer.on(Events.EPOCH_COMPLETED(every=validate_every))
def log_validation():
evaluator.run(testing_data_loader)
metrics = evaluator.state.metrics
print(f"Epoch: {trainer.state.epoch}, Avg. PSNR: {metrics['psnr']} dB")


@trainer.on(Events.EPOCH_COMPLETED)
def log_epoch_time():
print(f"Epoch {trainer.state.epoch}, Time Taken : {trainer.state.times['EPOCH_COMPLETED']}")


@trainer.on(Events.COMPLETED)
def log_total_time():
print(f"Total Time: {trainer.state.times['COMPLETED']}")


@trainer.on(Events.EPOCH_COMPLETED)
def checkpoint():
model_out_path = "model_epoch_{}.pth".format(trainer.state.epoch)
torch.save(model, model_out_path)
print("Checkpoint saved to {}".format(model_out_path))


trainer.run(training_data_loader, opt.n_epochs)
29 changes: 29 additions & 0 deletions examples/super_resolution/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import torch.nn as nn
import torch.nn.init as init


class Net(nn.Module):
def __init__(self, upscale_factor):
super(Net, self).__init__()

self.relu = nn.ReLU()
self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
self.pixel_shuffle = nn.PixelShuffle(upscale_factor)

self._initialize_weights()

def forward(self, x):
x = self.relu(self.conv1(x))
x = self.relu(self.conv2(x))
x = self.relu(self.conv3(x))
x = self.pixel_shuffle(self.conv4(x))
return x

def _initialize_weights(self):
init.orthogonal_(self.conv1.weight, init.calculate_gain("relu"))
init.orthogonal_(self.conv2.weight, init.calculate_gain("relu"))
init.orthogonal_(self.conv3.weight, init.calculate_gain("relu"))
init.orthogonal_(self.conv4.weight)
41 changes: 41 additions & 0 deletions examples/super_resolution/super_resolve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import argparse

import numpy as np
import torch
from PIL import Image
from torchvision.transforms.functional import to_tensor

# Training settings
parser = argparse.ArgumentParser(description="PyTorch Super Res Example")
parser.add_argument("--input_image", type=str, required=True, help="input image to use")
parser.add_argument("--model", type=str, required=True, help="model file to use")
parser.add_argument("--output_filename", type=str, help="where to save the output image")
parser.add_argument("--cuda", action="store_true", help="use cuda")
opt = parser.parse_args()

print(opt)
img = Image.open(opt.input_image).convert("YCbCr")
y, cb, cr = img.split()

model = torch.load(opt.model)
input = to_tensor(y).view(1, -1, y.size[1], y.size[0])

if opt.cuda:
model = model.cuda()
input = input.cuda()

model.eval()
with torch.no_grad():
out = model(input)
out = out.cpu()
out_img_y = out[0].detach().numpy()
out_img_y *= 255.0
out_img_y = out_img_y.clip(0, 255)
out_img_y = Image.fromarray(np.uint8(out_img_y[0]), mode="L")

out_img_cb = cb.resize(out_img_y.size, Image.BICUBIC)
out_img_cr = cr.resize(out_img_y.size, Image.BICUBIC)
out_img = Image.merge("YCbCr", [out_img_y, out_img_cb, out_img_cr]).convert("RGB")

out_img.save(opt.output_filename)
print("output image saved to ", opt.output_filename)

0 comments on commit caff6c2

Please sign in to comment.