-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
304 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
from glob import glob | ||
from tqdm import tqdm | ||
import numpy as np | ||
import os | ||
from natsort import natsorted | ||
import cv2 | ||
from joblib import Parallel, delayed | ||
import argparse | ||
|
||
parser = argparse.ArgumentParser(description='Generate patches from Full Resolution images') | ||
parser.add_argument('--src_dir', default='D:/NCHU/Dataset/Deraindrop/train/', type=str, help='Directory for full resolution images') | ||
parser.add_argument('--tar_dir', default='./datasets/train/RainDrop',type=str, help='Directory for image patches') | ||
parser.add_argument('--ps', default=256, type=int, help='Image Patch Size') | ||
parser.add_argument('--num_patches', default=10, type=int, help='Number of patches per image') | ||
parser.add_argument('--num_cores', default=6, type=int, help='Number of CPU Cores') | ||
|
||
args = parser.parse_args() | ||
|
||
src = args.src_dir | ||
tar = args.tar_dir | ||
PS = args.ps | ||
NUM_PATCHES = args.num_patches | ||
NUM_CORES = args.num_cores | ||
|
||
noisy_patchDir = os.path.join(tar, 'input') | ||
clean_patchDir = os.path.join(tar, 'target') | ||
|
||
if os.path.exists(tar): | ||
os.system("rm -r {}".format(tar)) | ||
|
||
os.makedirs(noisy_patchDir) | ||
os.makedirs(clean_patchDir) | ||
|
||
#get sorted folders | ||
files = natsorted(glob(os.path.join(src, '*', '*.PNG'))) | ||
|
||
noisy_files, clean_files = [], [] | ||
for file_ in files: | ||
filename = os.path.split(file_)[-1] | ||
if 'clean' in filename: | ||
clean_files.append(file_) | ||
if 'rain' in filename: | ||
noisy_files.append(file_) | ||
#if 'gt' in file_: | ||
# clean_files.append(file_) | ||
#if 'data' in file_: | ||
# noisy_files.append(file_) | ||
def save_files(i): | ||
noisy_file, clean_file = noisy_files[i], clean_files[i] | ||
noisy_img = cv2.imread(noisy_file) | ||
clean_img = cv2.imread(clean_file) | ||
|
||
H = noisy_img.shape[0] | ||
W = noisy_img.shape[1] | ||
for j in range(NUM_PATCHES): | ||
rr = np.random.randint(0, H - PS) | ||
cc = np.random.randint(0, W - PS) | ||
noisy_patch = noisy_img[rr:rr + PS, cc:cc + PS, :] | ||
clean_patch = clean_img[rr:rr + PS, cc:cc + PS, :] | ||
|
||
cv2.imwrite(os.path.join(noisy_patchDir, '{}_{}.png'.format(i+1, j+1)), noisy_patch) | ||
cv2.imwrite(os.path.join(clean_patchDir, '{}_{}.png'.format(i+1, j+1)), clean_patch) | ||
|
||
Parallel(n_jobs=NUM_CORES)(delayed(save_files)(i) for i in tqdm(range(len(noisy_files)))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,215 @@ | ||
import os | ||
import torch | ||
import yaml | ||
from utils import network_parameters | ||
import torch.nn as nn | ||
import torch.optim as optim | ||
from torch.utils.data import DataLoader | ||
from tqdm import tqdm | ||
from tensorboardX import SummaryWriter | ||
import time | ||
import utils | ||
import numpy as np | ||
import random | ||
|
||
from dataloader.data_RGB import get_training_data, get_validation_data | ||
from utils.losses import PSNRLoss, SSIMLoss, CharbonnierLoss | ||
from warmup_scheduler import GradualWarmupScheduler | ||
from model_arch.SRMNet_SWFF import SRMNet_SWFF | ||
from model_arch.SRMNet import SRMNet | ||
|
||
|
||
## Set Seeds | ||
torch.backends.cudnn.benchmark = True | ||
random.seed(1234) | ||
np.random.seed(1234) | ||
torch.manual_seed(1234) | ||
torch.cuda.manual_seed_all(1234) | ||
|
||
## Load yaml configuration file | ||
with open('training.yaml', 'r') as config: | ||
opt = yaml.safe_load(config) | ||
Train = opt['TRAINING'] | ||
OPT = opt['OPTIM'] | ||
|
||
## Build Model | ||
print('==> Build the model') | ||
model_restored = SRMNet_SWFF(in_chn=3, wf=96, depth=4) | ||
p_number = network_parameters(model_restored) | ||
model_restored.cuda() | ||
|
||
## Training model path direction | ||
mode = opt['MODEL']['MODE'] | ||
|
||
model_dir = os.path.join(Train['SAVE_DIR'], mode, 'models') | ||
utils.mkdir(model_dir) | ||
train_dir = Train['TRAIN_DIR'] | ||
val_dir = Train['VAL_DIR'] | ||
|
||
## GPU | ||
gpus = ','.join([str(i) for i in opt['GPU']]) | ||
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" | ||
os.environ["CUDA_VISIBLE_DEVICES"] = gpus | ||
device_ids = [i for i in range(torch.cuda.device_count())] | ||
if torch.cuda.device_count() > 1: | ||
print("\n\nLet's use", torch.cuda.device_count(), "GPUs!\n\n") | ||
if len(device_ids) > 1: | ||
model_restored = nn.DataParallel(model_restored, device_ids=device_ids) | ||
|
||
## Optimizer | ||
start_epoch = 1 | ||
new_lr = float(OPT['LR_INITIAL']) | ||
optimizer = optim.Adam(model_restored.parameters(), lr=new_lr, betas=(0.9, 0.999), eps=1e-8) | ||
|
||
## Scheduler (Strategy) | ||
warmup_epochs = 3 | ||
scheduler_cosine = optim.lr_scheduler.CosineAnnealingLR(optimizer, OPT['EPOCHS'] - warmup_epochs, | ||
eta_min=float(OPT['LR_MIN'])) | ||
scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_epochs, after_scheduler=scheduler_cosine) | ||
scheduler.step() | ||
|
||
## Resume (Continue training by a pretrained model) | ||
if Train['RESUME']: | ||
path_chk_rest = utils.get_last_path(model_dir, '_latest.pth') | ||
utils.load_checkpoint(model_restored, path_chk_rest) | ||
start_epoch = utils.load_start_epoch(path_chk_rest) + 1 | ||
utils.load_optim(optimizer, path_chk_rest) | ||
|
||
for i in range(1, start_epoch): | ||
scheduler.step() | ||
new_lr = scheduler.get_lr()[0] | ||
print('------------------------------------------------------------------') | ||
print("==> Resuming Training with learning rate:", new_lr) | ||
print('------------------------------------------------------------------') | ||
|
||
## Loss | ||
Charbonnier = CharbonnierLoss() | ||
PSNR_loss = PSNRLoss() | ||
SSIM_loss = SSIMLoss() | ||
|
||
## DataLoaders | ||
print('==> Loading datasets') | ||
train_dataset = get_training_data(train_dir, {'patch_size': Train['TRAIN_PS']}) | ||
train_loader = DataLoader(dataset=train_dataset, batch_size=OPT['BATCH'], | ||
shuffle=True, num_workers=0, drop_last=False) | ||
val_dataset = get_validation_data(val_dir, {'patch_size': Train['VAL_PS']}) | ||
val_loader = DataLoader(dataset=val_dataset, batch_size=1, shuffle=False, num_workers=0, | ||
drop_last=False) | ||
|
||
# Show the training configuration | ||
print(f'''==> Training details: | ||
------------------------------------------------------------------ | ||
Restoration mode: {mode} | ||
Train patches size: {str(Train['TRAIN_PS']) + 'x' + str(Train['TRAIN_PS'])} | ||
Val patches size: {str(Train['VAL_PS']) + 'x' + str(Train['VAL_PS'])} | ||
Model parameters: {p_number} | ||
Start/End epochs: {str(start_epoch) + '~' + str(OPT['EPOCHS'])} | ||
Batch sizes: {OPT['BATCH']} | ||
Learning rate: {OPT['LR_INITIAL']} | ||
GPU: {'GPU' + str(device_ids)}''') | ||
print('------------------------------------------------------------------') | ||
|
||
# Start training! | ||
print('==> Training start: ') | ||
best_psnr = 0 | ||
best_ssim = 0 | ||
best_epoch_psnr = 0 | ||
best_epoch_ssim = 0 | ||
total_start_time = time.time() | ||
|
||
## Log | ||
log_dir = os.path.join(Train['SAVE_DIR'], mode, 'log') | ||
utils.mkdir(log_dir) | ||
writer = SummaryWriter(log_dir=log_dir, filename_suffix=f'_{mode}') | ||
|
||
for epoch in range(start_epoch, OPT['EPOCHS'] + 1): | ||
epoch_start_time = time.time() | ||
epoch_loss = 0 | ||
train_id = 1 | ||
|
||
model_restored.train() | ||
for i, data in enumerate(tqdm(train_loader), 0): | ||
# Forward propagation | ||
for param in model_restored.parameters(): | ||
param.grad = None | ||
target = data[0].cuda() | ||
input_ = data[1].cuda() | ||
restored = model_restored(input_) | ||
|
||
# Compute loss | ||
loss = SSIM_loss(restored, target) / (PSNR_loss(restored, target) + 0.005) | ||
|
||
# Back propagation | ||
loss.backward() | ||
optimizer.step() | ||
epoch_loss += loss.item() | ||
|
||
## Evaluation (Validation) | ||
if epoch % Train['VAL_AFTER_EVERY'] == 0: | ||
model_restored.eval() | ||
psnr_val_rgb = [] | ||
ssim_val_rgb = [] | ||
for ii, data_val in enumerate(val_loader, 0): | ||
target = data_val[0].cuda() | ||
input_ = data_val[1].cuda() | ||
with torch.no_grad(): | ||
restored = model_restored(input_) | ||
|
||
for res, tar in zip(restored, target): | ||
psnr_val_rgb.append(utils.torchPSNR(res, tar)) | ||
ssim_val_rgb.append(utils.torchSSIM(restored, target)) | ||
|
||
psnr_val_rgb = torch.stack(psnr_val_rgb).mean().item() | ||
ssim_val_rgb = torch.stack(ssim_val_rgb).mean().item() | ||
|
||
# Save the best PSNR model of validation | ||
if psnr_val_rgb > best_psnr: | ||
best_psnr = psnr_val_rgb | ||
best_epoch_psnr = epoch | ||
torch.save({'epoch': epoch, | ||
'state_dict': model_restored.state_dict(), | ||
'optimizer': optimizer.state_dict() | ||
}, os.path.join(model_dir, "model_bestPSNR.pth")) | ||
print("[epoch %d PSNR: %.4f --- best_epoch %d Best_PSNR %.4f]" % ( | ||
epoch, psnr_val_rgb, best_epoch_psnr, best_psnr)) | ||
|
||
# Save the best SSIM model of validation | ||
if ssim_val_rgb > best_ssim: | ||
best_ssim = ssim_val_rgb | ||
best_epoch_ssim = epoch | ||
torch.save({'epoch': epoch, | ||
'state_dict': model_restored.state_dict(), | ||
'optimizer': optimizer.state_dict() | ||
}, os.path.join(model_dir, "model_bestSSIM.pth")) | ||
print("[epoch %d SSIM: %.4f --- best_epoch %d Best_SSIM %.4f]" % ( | ||
epoch, ssim_val_rgb, best_epoch_ssim, best_ssim)) | ||
|
||
|
||
# Save evey epochs of model | ||
torch.save({'epoch': epoch, | ||
'state_dict': model_restored.state_dict(), | ||
'optimizer': optimizer.state_dict() | ||
}, os.path.join(model_dir, f"model_epoch_{epoch}.pth")) | ||
|
||
|
||
writer.add_scalar('val/PSNR', psnr_val_rgb, epoch) | ||
writer.add_scalar('val/SSIM', ssim_val_rgb, epoch) | ||
scheduler.step() | ||
|
||
print("------------------------------------------------------------------") | ||
print("Epoch: {}\tTime: {:.4f}\tLoss: {:.4f}\tLearningRate {:.6f}".format(epoch, time.time() - epoch_start_time, | ||
epoch_loss, scheduler.get_lr()[0])) | ||
print("------------------------------------------------------------------") | ||
|
||
# Save the last model | ||
torch.save({'epoch': epoch, | ||
'state_dict': model_restored.state_dict(), | ||
'optimizer': optimizer.state_dict() | ||
}, os.path.join(model_dir, "model_latest.pth")) | ||
|
||
writer.add_scalar('train/loss', epoch_loss, epoch) | ||
writer.add_scalar('train/lr', scheduler.get_lr()[0], epoch) | ||
writer.close() | ||
|
||
total_finish_time = (time.time() - total_start_time) # seconds | ||
print('Total training time: {:.1f} hours'.format((total_finish_time / 60 / 60))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
# Training configuration | ||
GPU: [0,1,2,3] | ||
|
||
VERBOSE: False | ||
|
||
MODEL: | ||
MODE: 'Enhancement' | ||
|
||
# Optimization arguments. | ||
OPTIM: | ||
BATCH: 2 | ||
EPOCHS: 200 | ||
# EPOCH_DECAY: [10] | ||
LR_INITIAL: 2e-4 | ||
LR_MIN: 1e-5 | ||
# BETA1: 0.9 | ||
|
||
TRAINING: | ||
VAL_AFTER_EVERY: 1 | ||
RESUME: False | ||
TRAIN_PS: 256 | ||
VAL_PS: 256 | ||
TRAIN_DIR: './datasets/train/LOL/train' # path to training data | ||
VAL_DIR: './datasets/train/LOL/test' # path to validation data | ||
SAVE_DIR: './checkpoints' # path to save models and images |