-
-
Notifications
You must be signed in to change notification settings - Fork 634
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add the example of super_resolution (#2885)
* Add the example for Super-Resolution * Made some changes * Made some changes * Add the time profiling features * Added torchvision dataset * Changed the dataset used in README to cifar10 * Used snake case in arguments * Made some changes * Make some formatting changes * Make the formatting changes * some changes * update the crop method * Made the suggested changes
- Loading branch information
1 parent
107a282
commit caff6c2
Showing
4 changed files
with
255 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# Super-Resolution using an efficient sub-pixel convolutional neural network | ||
|
||
ported from [pytorch-examples](https://github.com/pytorch/examples/tree/main/super_resolution) | ||
|
||
This example illustrates how to use the efficient sub-pixel convolution layer described in ["Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network" - Shi et al.](https://arxiv.org/abs/1609.05158) for increasing spatial resolution within your network for tasks such as superresolution. | ||
|
||
``` | ||
usage: main.py [-h] --upscale_factor UPSCALE_FACTOR [--batch_size BATCHSIZE] | ||
[--test_batch_size TESTBATCHSIZE] [--n_epochs NEPOCHS] [--lr LR] | ||
[--cuda] [--threads THREADS] [--seed SEED] | ||
PyTorch Super Res Example | ||
optional arguments: | ||
-h, --help show this help message and exit | ||
--upscale_factor super resolution upscale factor | ||
--batch_size training batch size | ||
--test_batch_size testing batch size | ||
--n_epochs number of epochs to train for | ||
--lr Learning Rate. Default=0.01 | ||
--cuda use cuda | ||
--mps enable GPU on macOS | ||
--threads number of threads for data loader to use Default=4 | ||
--seed random seed to use. Default=123 | ||
``` | ||
|
||
This example trains a super-resolution network on the [Caltech101 dataset](https://pytorch.org/vision/main/generated/torchvision.datasets.Caltech101.html). A snapshot of the model after every epoch with filename `model_epoch_<epoch_number>.pth` | ||
|
||
## Example Usage: | ||
|
||
### Train | ||
|
||
`python main.py --upscale_factor 3 --batch_size 4 --test_batch_size 100 --n_epochs 30 --lr 0.001` | ||
|
||
### Super Resolve | ||
|
||
`python super_resolve.py --input_image <in>.jpg --model model_epoch_500.pth --output_filename out.png` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
import argparse | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.optim as optim | ||
import torchvision | ||
from model import Net | ||
from torch.utils.data import DataLoader | ||
from torchvision.transforms.functional import center_crop, resize, to_tensor | ||
|
||
from ignite.engine import Engine, Events | ||
from ignite.metrics import PSNR | ||
|
||
# Training settings | ||
parser = argparse.ArgumentParser(description="PyTorch Super Res Example") | ||
parser.add_argument("--upscale_factor", type=int, required=True, help="super resolution upscale factor") | ||
parser.add_argument("--batch_size", type=int, default=64, help="training batch size") | ||
parser.add_argument("--test_batch_size", type=int, default=10, help="testing batch size") | ||
parser.add_argument("--n_epochs", type=int, default=2, help="number of epochs to train for") | ||
parser.add_argument("--lr", type=float, default=0.01, help="Learning Rate. Default=0.01") | ||
parser.add_argument("--cuda", action="store_true", help="use cuda?") | ||
parser.add_argument("--mps", action="store_true", default=False, help="enables macOS GPU training") | ||
parser.add_argument("--threads", type=int, default=4, help="number of threads for data loader to use") | ||
parser.add_argument("--seed", type=int, default=123, help="random seed to use. Default=123") | ||
opt = parser.parse_args() | ||
|
||
print(opt) | ||
|
||
if opt.cuda and not torch.cuda.is_available(): | ||
raise Exception("No GPU found, please run without --cuda") | ||
if not opt.mps and torch.backends.mps.is_available(): | ||
raise Exception("Found mps device, please run with --mps to enable macOS GPU") | ||
|
||
torch.manual_seed(opt.seed) | ||
use_mps = opt.mps and torch.backends.mps.is_available() | ||
|
||
if opt.cuda: | ||
device = torch.device("cuda") | ||
elif use_mps: | ||
device = torch.device("mps") | ||
else: | ||
device = torch.device("cpu") | ||
|
||
print("===> Loading datasets") | ||
|
||
|
||
class SRDataset(torch.utils.data.Dataset): | ||
def __init__(self, dataset, scale_factor, crop_size=256): | ||
self.dataset = dataset | ||
self.scale_factor = scale_factor | ||
self.crop_size = crop_size | ||
|
||
def __getitem__(self, index): | ||
image, _ = self.dataset[index] | ||
img = image.convert("YCbCr") | ||
hr_image, _, _ = img.split() | ||
hr_image = center_crop(hr_image, self.crop_size) | ||
lr_image = hr_image.copy() | ||
if self.scale_factor != 1: | ||
size = self.crop_size // self.scale_factor | ||
lr_image = resize(lr_image, [size, size]) | ||
hr_image = to_tensor(hr_image) | ||
lr_image = to_tensor(lr_image) | ||
return lr_image, hr_image | ||
|
||
def __len__(self): | ||
return len(self.dataset) | ||
|
||
|
||
trainset = torchvision.datasets.Caltech101(root="./data", download=True) | ||
testset = torchvision.datasets.Caltech101(root="./data", download=False) | ||
|
||
trainset_sr = SRDataset(trainset, scale_factor=opt.upscale_factor) | ||
testset_sr = SRDataset(testset, scale_factor=opt.upscale_factor) | ||
|
||
training_data_loader = DataLoader(dataset=trainset_sr, num_workers=opt.threads, batch_size=opt.batch_size, shuffle=True) | ||
testing_data_loader = DataLoader(dataset=testset_sr, num_workers=opt.threads, batch_size=opt.test_batch_size) | ||
|
||
print("===> Building model") | ||
model = Net(upscale_factor=opt.upscale_factor).to(device) | ||
criterion = nn.MSELoss() | ||
|
||
optimizer = optim.Adam(model.parameters(), lr=opt.lr) | ||
|
||
|
||
def train_step(engine, batch): | ||
model.train() | ||
input, target = batch[0].to(device), batch[1].to(device) | ||
|
||
optimizer.zero_grad() | ||
loss = criterion(model(input), target) | ||
loss.backward() | ||
optimizer.step() | ||
|
||
return loss.item() | ||
|
||
|
||
def validation_step(engine, batch): | ||
model.eval() | ||
with torch.no_grad(): | ||
x, y = batch[0].to(device), batch[1].to(device) | ||
y_pred = model(x) | ||
|
||
return y_pred, y | ||
|
||
|
||
trainer = Engine(train_step) | ||
evaluator = Engine(validation_step) | ||
psnr = PSNR(data_range=1) | ||
psnr.attach(evaluator, "psnr") | ||
validate_every = 1 | ||
log_interval = 100 | ||
|
||
|
||
@trainer.on(Events.ITERATION_COMPLETED(every=log_interval)) | ||
def log_training_loss(engine): | ||
print( | ||
"===> Epoch[{}]({}/{}): Loss: {:.4f}".format( | ||
engine.state.epoch, engine.state.iteration, len(training_data_loader), engine.state.output | ||
) | ||
) | ||
|
||
|
||
@trainer.on(Events.EPOCH_COMPLETED(every=validate_every)) | ||
def log_validation(): | ||
evaluator.run(testing_data_loader) | ||
metrics = evaluator.state.metrics | ||
print(f"Epoch: {trainer.state.epoch}, Avg. PSNR: {metrics['psnr']} dB") | ||
|
||
|
||
@trainer.on(Events.EPOCH_COMPLETED) | ||
def log_epoch_time(): | ||
print(f"Epoch {trainer.state.epoch}, Time Taken : {trainer.state.times['EPOCH_COMPLETED']}") | ||
|
||
|
||
@trainer.on(Events.COMPLETED) | ||
def log_total_time(): | ||
print(f"Total Time: {trainer.state.times['COMPLETED']}") | ||
|
||
|
||
@trainer.on(Events.EPOCH_COMPLETED) | ||
def checkpoint(): | ||
model_out_path = "model_epoch_{}.pth".format(trainer.state.epoch) | ||
torch.save(model, model_out_path) | ||
print("Checkpoint saved to {}".format(model_out_path)) | ||
|
||
|
||
trainer.run(training_data_loader, opt.n_epochs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import torch.nn as nn | ||
import torch.nn.init as init | ||
|
||
|
||
class Net(nn.Module): | ||
def __init__(self, upscale_factor): | ||
super(Net, self).__init__() | ||
|
||
self.relu = nn.ReLU() | ||
self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2)) | ||
self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)) | ||
self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1)) | ||
self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1)) | ||
self.pixel_shuffle = nn.PixelShuffle(upscale_factor) | ||
|
||
self._initialize_weights() | ||
|
||
def forward(self, x): | ||
x = self.relu(self.conv1(x)) | ||
x = self.relu(self.conv2(x)) | ||
x = self.relu(self.conv3(x)) | ||
x = self.pixel_shuffle(self.conv4(x)) | ||
return x | ||
|
||
def _initialize_weights(self): | ||
init.orthogonal_(self.conv1.weight, init.calculate_gain("relu")) | ||
init.orthogonal_(self.conv2.weight, init.calculate_gain("relu")) | ||
init.orthogonal_(self.conv3.weight, init.calculate_gain("relu")) | ||
init.orthogonal_(self.conv4.weight) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import argparse | ||
|
||
import numpy as np | ||
import torch | ||
from PIL import Image | ||
from torchvision.transforms.functional import to_tensor | ||
|
||
# Training settings | ||
parser = argparse.ArgumentParser(description="PyTorch Super Res Example") | ||
parser.add_argument("--input_image", type=str, required=True, help="input image to use") | ||
parser.add_argument("--model", type=str, required=True, help="model file to use") | ||
parser.add_argument("--output_filename", type=str, help="where to save the output image") | ||
parser.add_argument("--cuda", action="store_true", help="use cuda") | ||
opt = parser.parse_args() | ||
|
||
print(opt) | ||
img = Image.open(opt.input_image).convert("YCbCr") | ||
y, cb, cr = img.split() | ||
|
||
model = torch.load(opt.model) | ||
input = to_tensor(y).view(1, -1, y.size[1], y.size[0]) | ||
|
||
if opt.cuda: | ||
model = model.cuda() | ||
input = input.cuda() | ||
|
||
model.eval() | ||
with torch.no_grad(): | ||
out = model(input) | ||
out = out.cpu() | ||
out_img_y = out[0].detach().numpy() | ||
out_img_y *= 255.0 | ||
out_img_y = out_img_y.clip(0, 255) | ||
out_img_y = Image.fromarray(np.uint8(out_img_y[0]), mode="L") | ||
|
||
out_img_cb = cb.resize(out_img_y.size, Image.BICUBIC) | ||
out_img_cr = cr.resize(out_img_y.size, Image.BICUBIC) | ||
out_img = Image.merge("YCbCr", [out_img_y, out_img_cb, out_img_cr]).convert("RGB") | ||
|
||
out_img.save(opt.output_filename) | ||
print("output image saved to ", opt.output_filename) |