-
-
Notifications
You must be signed in to change notification settings - Fork 650
Add the example of super_resolution #2885
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 11 commits
Commits
Show all changes
19 commits
Select commit
Hold shift + click to select a range
a7829a9
Add the example for Super-Resolution
guptaaryan16 74602d4
Merge branch 'master' of github.com:guptaaryan16/ignite
guptaaryan16 1b0baf3
Made some changes
guptaaryan16 7ebee49
Made some changes
guptaaryan16 f6b5b41
Merge branch 'pytorch:master' into master
guptaaryan16 d810510
Merge branch 'pytorch:master' into master
guptaaryan16 3982d7b
Add the time profiling features
guptaaryan16 bc219c7
Merge branch 'pytorch:master' into master
guptaaryan16 982a0eb
Added torchvision dataset
guptaaryan16 51fe3df
Merge branch 'master' of github.com:guptaaryan16/ignite
guptaaryan16 0cd5c59
Changed the dataset used in README to cifar10
guptaaryan16 83f10e2
Merge branch 'pytorch:master' into master
guptaaryan16 7bcea2f
Used snake case in arguments
guptaaryan16 698d76f
Made some changes
guptaaryan16 51f47b4
Make some formatting changes
guptaaryan16 235c908
Make the formatting changes
guptaaryan16 3b2fde9
some changes
guptaaryan16 0e2f9a3
update the crop method
guptaaryan16 3d9dda7
Made the suggested changes
guptaaryan16 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# Super-Resolution using an efficient sub-pixel convolutional neural network | ||
|
||
ported from [pytorch-examples](https://github.com/pytorch/examples/tree/main/super_resolution) | ||
|
||
This example illustrates how to use the efficient sub-pixel convolution layer described in ["Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network" - Shi et al.](https://arxiv.org/abs/1609.05158) for increasing spatial resolution within your network for tasks such as superresolution. | ||
|
||
``` | ||
usage: main.py [-h] --upscale_factor UPSCALE_FACTOR [--batchSize BATCHSIZE] | ||
[--testBatchSize TESTBATCHSIZE] [--nEpochs NEPOCHS] [--lr LR] | ||
[--cuda] [--threads THREADS] [--seed SEED] | ||
|
||
PyTorch Super Res Example | ||
|
||
optional arguments: | ||
-h, --help show this help message and exit | ||
--upscale_factor super resolution upscale factor | ||
--batchSize training batch size | ||
--testBatchSize testing batch size | ||
--nEpochs number of epochs to train for | ||
--lr Learning Rate. Default=0.01 | ||
--cuda use cuda | ||
--mps enable GPU on macOS | ||
--threads number of threads for data loader to use Default=4 | ||
--seed random seed to use. Default=123 | ||
``` | ||
vfdev-5 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
This example trains a super-resolution network on the [Cifar10 dataset](https://www.cs.toronto.edu/~kriz/cifar.html). A snapshot of the model after every epoch with filename `model_epoch_<epoch_number>.pth` | ||
|
||
## Example Usage: | ||
|
||
### Train | ||
|
||
`python main.py --upscale_factor 3 --batchSize 4 --testBatchSize 100 --nEpochs 30 --lr 0.001` | ||
|
||
### Super Resolve | ||
|
||
`python super_resolve.py --input_image dataset/BSDS300/images/test/16077.jpg --model model_epoch_500.pth --output_filename out.png` |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
import argparse | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.optim as optim | ||
import torchvision | ||
import torchvision.transforms as transforms | ||
from model import Net | ||
from torch.utils.data import DataLoader | ||
|
||
from ignite.engine import Engine, Events | ||
from ignite.metrics import PSNR | ||
|
||
# Training settings | ||
parser = argparse.ArgumentParser(description="PyTorch Super Res Example") | ||
parser.add_argument("--upscale_factor", type=int, required=True, help="super resolution upscale factor") | ||
parser.add_argument("--batchSize", type=int, default=64, help="training batch size") | ||
parser.add_argument("--testBatchSize", type=int, default=10, help="testing batch size") | ||
parser.add_argument("--nEpochs", type=int, default=2, help="number of epochs to train for") | ||
vfdev-5 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
parser.add_argument("--lr", type=float, default=0.01, help="Learning Rate. Default=0.01") | ||
parser.add_argument("--cuda", action="store_true", help="use cuda?") | ||
parser.add_argument("--mps", action="store_true", default=False, help="enables macOS GPU training") | ||
parser.add_argument("--threads", type=int, default=4, help="number of threads for data loader to use") | ||
parser.add_argument("--seed", type=int, default=123, help="random seed to use. Default=123") | ||
opt = parser.parse_args() | ||
|
||
print(opt) | ||
|
||
if opt.cuda and not torch.cuda.is_available(): | ||
raise Exception("No GPU found, please run without --cuda") | ||
if not opt.mps and torch.backends.mps.is_available(): | ||
raise Exception("Found mps device, please run with --mps to enable macOS GPU") | ||
|
||
torch.manual_seed(opt.seed) | ||
use_mps = opt.mps and torch.backends.mps.is_available() | ||
|
||
if opt.cuda: | ||
device = torch.device("cuda") | ||
elif use_mps: | ||
device = torch.device("mps") | ||
else: | ||
device = torch.device("cpu") | ||
|
||
print("===> Loading datasets") | ||
|
||
|
||
class SRDataset(torch.utils.data.Dataset): | ||
def __init__(self, dataset, scale_factor): | ||
self.dataset = dataset | ||
self.transform = transforms.Resize( | ||
(len(dataset[0][0][0]) * scale_factor, len(dataset[0][0][0][0]) * scale_factor) | ||
) | ||
vfdev-5 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def __getitem__(self, index): | ||
lr_image, _ = self.dataset[index] | ||
hr_image = self.transform(lr_image) | ||
return lr_image, hr_image | ||
|
||
def __len__(self): | ||
return len(self.dataset) | ||
|
||
|
||
transform = transforms.Compose([transforms.ToTensor()]) | ||
|
||
trainset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transform) | ||
testset = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=transform) | ||
|
||
trainset_sr = SRDataset(trainset, scale_factor=opt.upscale_factor) | ||
testset_sr = SRDataset(testset, scale_factor=opt.upscale_factor) | ||
|
||
training_data_loader = DataLoader(dataset=trainset_sr, num_workers=opt.threads, batch_size=opt.batchSize, shuffle=True) | ||
testing_data_loader = DataLoader( | ||
dataset=testset_sr, num_workers=opt.threads, batch_size=opt.testBatchSize, shuffle=False | ||
) | ||
|
||
print("===> Building model") | ||
model = Net(upscale_factor=opt.upscale_factor).to(device) | ||
criterion = nn.MSELoss() | ||
|
||
optimizer = optim.Adam(model.parameters(), lr=opt.lr) | ||
|
||
|
||
def train_step(engine, batch): | ||
input, target = batch[0].to(device), batch[1].to(device) | ||
vfdev-5 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
optimizer.zero_grad() | ||
loss = criterion(model(input), target) | ||
loss.backward() | ||
optimizer.step() | ||
|
||
return loss.item() | ||
|
||
|
||
def validation_step(engine, batch): | ||
model.eval() | ||
with torch.no_grad(): | ||
x, y = batch[0].to(device), batch[1].to(device) | ||
y_pred = model(x) | ||
|
||
return y_pred, y | ||
|
||
|
||
trainer = Engine(train_step) | ||
evaluator = Engine(validation_step) | ||
psnr = PSNR(data_range=1) | ||
psnr.attach(evaluator, "psnr") | ||
validate_every = 1 | ||
log_interval = 10 | ||
|
||
|
||
@trainer.on(Events.ITERATION_COMPLETED(every=log_interval)) | ||
def log_training_loss(engine): | ||
print( | ||
"===> Epoch[{}]({}/{}): Loss: {:.4f}".format( | ||
engine.state.epoch, engine.state.iteration, len(training_data_loader), engine.state.output | ||
) | ||
) | ||
|
||
|
||
@trainer.on(Events.EPOCH_COMPLETED(every=validate_every)) | ||
def run_validation(): | ||
evaluator.run(testing_data_loader) | ||
|
||
|
||
@trainer.on(Events.EPOCH_COMPLETED(every=validate_every)) | ||
def log_validation(): | ||
metrics = evaluator.state.metrics | ||
print(f"Epoch: {trainer.state.epoch}, Avg. PSNR: {metrics['psnr']} dB") | ||
vfdev-5 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
@trainer.on(Events.EPOCH_COMPLETED) | ||
def log_epoch_time(): | ||
print(f"Epoch {trainer.state.epoch}, Time Taken : {trainer.state.times['EPOCH_COMPLETED']}") | ||
|
||
|
||
@trainer.on(Events.COMPLETED) | ||
def log_total_time(): | ||
print(f"Total Time: {trainer.state.times['COMPLETED']}") | ||
|
||
|
||
@trainer.on(Events.EPOCH_COMPLETED) | ||
def checkpoint(): | ||
model_out_path = "model_epoch_{}.pth".format(trainer.state.epoch) | ||
torch.save(model, model_out_path) | ||
print("Checkpoint saved to {}".format(model_out_path)) | ||
|
||
|
||
trainer.run(training_data_loader, opt.nEpochs) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import torch.nn as nn | ||
import torch.nn.init as init | ||
|
||
|
||
class Net(nn.Module): | ||
def __init__(self, upscale_factor): | ||
super(Net, self).__init__() | ||
|
||
self.relu = nn.ReLU() | ||
self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2)) | ||
self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)) | ||
self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1)) | ||
self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1)) | ||
self.pixel_shuffle = nn.PixelShuffle(upscale_factor) | ||
|
||
self._initialize_weights() | ||
|
||
def forward(self, x): | ||
x = self.relu(self.conv1(x)) | ||
x = self.relu(self.conv2(x)) | ||
x = self.relu(self.conv3(x)) | ||
x = self.pixel_shuffle(self.conv4(x)) | ||
return x | ||
|
||
def _initialize_weights(self): | ||
init.orthogonal_(self.conv1.weight, init.calculate_gain("relu")) | ||
init.orthogonal_(self.conv2.weight, init.calculate_gain("relu")) | ||
init.orthogonal_(self.conv3.weight, init.calculate_gain("relu")) | ||
init.orthogonal_(self.conv4.weight) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,42 @@ | ||||||||||
from __future__ import print_function | ||||||||||
vfdev-5 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||
|
||||||||||
import argparse | ||||||||||
|
||||||||||
import numpy as np | ||||||||||
import torch | ||||||||||
from PIL import Image | ||||||||||
from torchvision.transforms import ToTensor | ||||||||||
|
||||||||||
# Training settings | ||||||||||
parser = argparse.ArgumentParser(description="PyTorch Super Res Example") | ||||||||||
parser.add_argument("--input_image", type=str, required=True, help="input image to use") | ||||||||||
parser.add_argument("--model", type=str, required=True, help="model file to use") | ||||||||||
parser.add_argument("--output_filename", type=str, help="where to save the output image") | ||||||||||
parser.add_argument("--cuda", action="store_true", help="use cuda") | ||||||||||
opt = parser.parse_args() | ||||||||||
|
||||||||||
print(opt) | ||||||||||
img = Image.open(opt.input_image).convert("YCbCr") | ||||||||||
y, cb, cr = img.split() | ||||||||||
|
||||||||||
model = torch.load(opt.model) | ||||||||||
img_to_tensor = ToTensor() | ||||||||||
input = img_to_tensor(y).view(1, -1, y.size[1], y.size[0]) | ||||||||||
vfdev-5 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||
|
||||||||||
if opt.cuda: | ||||||||||
model = model.cuda() | ||||||||||
input = input.cuda() | ||||||||||
|
||||||||||
out = model(input) | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
out = out.cpu() | ||||||||||
out_img_y = out[0].detach().numpy() | ||||||||||
out_img_y *= 255.0 | ||||||||||
out_img_y = out_img_y.clip(0, 255) | ||||||||||
out_img_y = Image.fromarray(np.uint8(out_img_y[0]), mode="L") | ||||||||||
|
||||||||||
out_img_cb = cb.resize(out_img_y.size, Image.BICUBIC) | ||||||||||
out_img_cr = cr.resize(out_img_y.size, Image.BICUBIC) | ||||||||||
out_img = Image.merge("YCbCr", [out_img_y, out_img_cb, out_img_cr]).convert("RGB") | ||||||||||
|
||||||||||
out_img.save(opt.output_filename) | ||||||||||
print("output image saved to ", opt.output_filename) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.