Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
FanChiMao committed Feb 28, 2022
1 parent 4b33cd5 commit 67f8928
Show file tree
Hide file tree
Showing 3 changed files with 304 additions and 0 deletions.
64 changes: 64 additions & 0 deletions generate_patches.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from glob import glob
from tqdm import tqdm
import numpy as np
import os
from natsort import natsorted
import cv2
from joblib import Parallel, delayed
import argparse

parser = argparse.ArgumentParser(description='Generate patches from Full Resolution images')
parser.add_argument('--src_dir', default='D:/NCHU/Dataset/Deraindrop/train/', type=str, help='Directory for full resolution images')
parser.add_argument('--tar_dir', default='./datasets/train/RainDrop',type=str, help='Directory for image patches')
parser.add_argument('--ps', default=256, type=int, help='Image Patch Size')
parser.add_argument('--num_patches', default=10, type=int, help='Number of patches per image')
parser.add_argument('--num_cores', default=6, type=int, help='Number of CPU Cores')

args = parser.parse_args()

src = args.src_dir
tar = args.tar_dir
PS = args.ps
NUM_PATCHES = args.num_patches
NUM_CORES = args.num_cores

noisy_patchDir = os.path.join(tar, 'input')
clean_patchDir = os.path.join(tar, 'target')

if os.path.exists(tar):
os.system("rm -r {}".format(tar))

os.makedirs(noisy_patchDir)
os.makedirs(clean_patchDir)

#get sorted folders
files = natsorted(glob(os.path.join(src, '*', '*.PNG')))

noisy_files, clean_files = [], []
for file_ in files:
filename = os.path.split(file_)[-1]
if 'clean' in filename:
clean_files.append(file_)
if 'rain' in filename:
noisy_files.append(file_)
#if 'gt' in file_:
# clean_files.append(file_)
#if 'data' in file_:
# noisy_files.append(file_)
def save_files(i):
noisy_file, clean_file = noisy_files[i], clean_files[i]
noisy_img = cv2.imread(noisy_file)
clean_img = cv2.imread(clean_file)

H = noisy_img.shape[0]
W = noisy_img.shape[1]
for j in range(NUM_PATCHES):
rr = np.random.randint(0, H - PS)
cc = np.random.randint(0, W - PS)
noisy_patch = noisy_img[rr:rr + PS, cc:cc + PS, :]
clean_patch = clean_img[rr:rr + PS, cc:cc + PS, :]

cv2.imwrite(os.path.join(noisy_patchDir, '{}_{}.png'.format(i+1, j+1)), noisy_patch)
cv2.imwrite(os.path.join(clean_patchDir, '{}_{}.png'.format(i+1, j+1)), clean_patch)

Parallel(n_jobs=NUM_CORES)(delayed(save_files)(i) for i in tqdm(range(len(noisy_files))))
215 changes: 215 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
import os
import torch
import yaml
from utils import network_parameters
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
from tensorboardX import SummaryWriter
import time
import utils
import numpy as np
import random

from dataloader.data_RGB import get_training_data, get_validation_data
from utils.losses import PSNRLoss, SSIMLoss, CharbonnierLoss
from warmup_scheduler import GradualWarmupScheduler
from model_arch.SRMNet_SWFF import SRMNet_SWFF
from model_arch.SRMNet import SRMNet


## Set Seeds
torch.backends.cudnn.benchmark = True
random.seed(1234)
np.random.seed(1234)
torch.manual_seed(1234)
torch.cuda.manual_seed_all(1234)

## Load yaml configuration file
with open('training.yaml', 'r') as config:
opt = yaml.safe_load(config)
Train = opt['TRAINING']
OPT = opt['OPTIM']

## Build Model
print('==> Build the model')
model_restored = SRMNet_SWFF(in_chn=3, wf=96, depth=4)
p_number = network_parameters(model_restored)
model_restored.cuda()

## Training model path direction
mode = opt['MODEL']['MODE']

model_dir = os.path.join(Train['SAVE_DIR'], mode, 'models')
utils.mkdir(model_dir)
train_dir = Train['TRAIN_DIR']
val_dir = Train['VAL_DIR']

## GPU
gpus = ','.join([str(i) for i in opt['GPU']])
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = gpus
device_ids = [i for i in range(torch.cuda.device_count())]
if torch.cuda.device_count() > 1:
print("\n\nLet's use", torch.cuda.device_count(), "GPUs!\n\n")
if len(device_ids) > 1:
model_restored = nn.DataParallel(model_restored, device_ids=device_ids)

## Optimizer
start_epoch = 1
new_lr = float(OPT['LR_INITIAL'])
optimizer = optim.Adam(model_restored.parameters(), lr=new_lr, betas=(0.9, 0.999), eps=1e-8)

## Scheduler (Strategy)
warmup_epochs = 3
scheduler_cosine = optim.lr_scheduler.CosineAnnealingLR(optimizer, OPT['EPOCHS'] - warmup_epochs,
eta_min=float(OPT['LR_MIN']))
scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_epochs, after_scheduler=scheduler_cosine)
scheduler.step()

## Resume (Continue training by a pretrained model)
if Train['RESUME']:
path_chk_rest = utils.get_last_path(model_dir, '_latest.pth')
utils.load_checkpoint(model_restored, path_chk_rest)
start_epoch = utils.load_start_epoch(path_chk_rest) + 1
utils.load_optim(optimizer, path_chk_rest)

for i in range(1, start_epoch):
scheduler.step()
new_lr = scheduler.get_lr()[0]
print('------------------------------------------------------------------')
print("==> Resuming Training with learning rate:", new_lr)
print('------------------------------------------------------------------')

## Loss
Charbonnier = CharbonnierLoss()
PSNR_loss = PSNRLoss()
SSIM_loss = SSIMLoss()

## DataLoaders
print('==> Loading datasets')
train_dataset = get_training_data(train_dir, {'patch_size': Train['TRAIN_PS']})
train_loader = DataLoader(dataset=train_dataset, batch_size=OPT['BATCH'],
shuffle=True, num_workers=0, drop_last=False)
val_dataset = get_validation_data(val_dir, {'patch_size': Train['VAL_PS']})
val_loader = DataLoader(dataset=val_dataset, batch_size=1, shuffle=False, num_workers=0,
drop_last=False)

# Show the training configuration
print(f'''==> Training details:
------------------------------------------------------------------
Restoration mode: {mode}
Train patches size: {str(Train['TRAIN_PS']) + 'x' + str(Train['TRAIN_PS'])}
Val patches size: {str(Train['VAL_PS']) + 'x' + str(Train['VAL_PS'])}
Model parameters: {p_number}
Start/End epochs: {str(start_epoch) + '~' + str(OPT['EPOCHS'])}
Batch sizes: {OPT['BATCH']}
Learning rate: {OPT['LR_INITIAL']}
GPU: {'GPU' + str(device_ids)}''')
print('------------------------------------------------------------------')

# Start training!
print('==> Training start: ')
best_psnr = 0
best_ssim = 0
best_epoch_psnr = 0
best_epoch_ssim = 0
total_start_time = time.time()

## Log
log_dir = os.path.join(Train['SAVE_DIR'], mode, 'log')
utils.mkdir(log_dir)
writer = SummaryWriter(log_dir=log_dir, filename_suffix=f'_{mode}')

for epoch in range(start_epoch, OPT['EPOCHS'] + 1):
epoch_start_time = time.time()
epoch_loss = 0
train_id = 1

model_restored.train()
for i, data in enumerate(tqdm(train_loader), 0):
# Forward propagation
for param in model_restored.parameters():
param.grad = None
target = data[0].cuda()
input_ = data[1].cuda()
restored = model_restored(input_)

# Compute loss
loss = SSIM_loss(restored, target) / (PSNR_loss(restored, target) + 0.005)

# Back propagation
loss.backward()
optimizer.step()
epoch_loss += loss.item()

## Evaluation (Validation)
if epoch % Train['VAL_AFTER_EVERY'] == 0:
model_restored.eval()
psnr_val_rgb = []
ssim_val_rgb = []
for ii, data_val in enumerate(val_loader, 0):
target = data_val[0].cuda()
input_ = data_val[1].cuda()
with torch.no_grad():
restored = model_restored(input_)

for res, tar in zip(restored, target):
psnr_val_rgb.append(utils.torchPSNR(res, tar))
ssim_val_rgb.append(utils.torchSSIM(restored, target))

psnr_val_rgb = torch.stack(psnr_val_rgb).mean().item()
ssim_val_rgb = torch.stack(ssim_val_rgb).mean().item()

# Save the best PSNR model of validation
if psnr_val_rgb > best_psnr:
best_psnr = psnr_val_rgb
best_epoch_psnr = epoch
torch.save({'epoch': epoch,
'state_dict': model_restored.state_dict(),
'optimizer': optimizer.state_dict()
}, os.path.join(model_dir, "model_bestPSNR.pth"))
print("[epoch %d PSNR: %.4f --- best_epoch %d Best_PSNR %.4f]" % (
epoch, psnr_val_rgb, best_epoch_psnr, best_psnr))

# Save the best SSIM model of validation
if ssim_val_rgb > best_ssim:
best_ssim = ssim_val_rgb
best_epoch_ssim = epoch
torch.save({'epoch': epoch,
'state_dict': model_restored.state_dict(),
'optimizer': optimizer.state_dict()
}, os.path.join(model_dir, "model_bestSSIM.pth"))
print("[epoch %d SSIM: %.4f --- best_epoch %d Best_SSIM %.4f]" % (
epoch, ssim_val_rgb, best_epoch_ssim, best_ssim))


# Save evey epochs of model
torch.save({'epoch': epoch,
'state_dict': model_restored.state_dict(),
'optimizer': optimizer.state_dict()
}, os.path.join(model_dir, f"model_epoch_{epoch}.pth"))


writer.add_scalar('val/PSNR', psnr_val_rgb, epoch)
writer.add_scalar('val/SSIM', ssim_val_rgb, epoch)
scheduler.step()

print("------------------------------------------------------------------")
print("Epoch: {}\tTime: {:.4f}\tLoss: {:.4f}\tLearningRate {:.6f}".format(epoch, time.time() - epoch_start_time,
epoch_loss, scheduler.get_lr()[0]))
print("------------------------------------------------------------------")

# Save the last model
torch.save({'epoch': epoch,
'state_dict': model_restored.state_dict(),
'optimizer': optimizer.state_dict()
}, os.path.join(model_dir, "model_latest.pth"))

writer.add_scalar('train/loss', epoch_loss, epoch)
writer.add_scalar('train/lr', scheduler.get_lr()[0], epoch)
writer.close()

total_finish_time = (time.time() - total_start_time) # seconds
print('Total training time: {:.1f} hours'.format((total_finish_time / 60 / 60)))
25 changes: 25 additions & 0 deletions training.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Training configuration
GPU: [0,1,2,3]

VERBOSE: False

MODEL:
MODE: 'Enhancement'

# Optimization arguments.
OPTIM:
BATCH: 2
EPOCHS: 200
# EPOCH_DECAY: [10]
LR_INITIAL: 2e-4
LR_MIN: 1e-5
# BETA1: 0.9

TRAINING:
VAL_AFTER_EVERY: 1
RESUME: False
TRAIN_PS: 256
VAL_PS: 256
TRAIN_DIR: './datasets/train/LOL/train' # path to training data
VAL_DIR: './datasets/train/LOL/test' # path to validation data
SAVE_DIR: './checkpoints' # path to save models and images

0 comments on commit 67f8928

Please sign in to comment.