Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Develop #13

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
[flake8]
select = B,C,E,F,P,T4,W,B9
max-line-length = 120
# C408 ignored because we like the dict keyword argument syntax
# E501 is not flexible enough, we're using B950 instead
ignore =
E203,E305,E402,E501,E721,E741,F403,F405,F821,F999,W503,W504,C408,E302,W291,E303,W605,E722,E266,E265
# shebang has extra meaning in fbcode lints, so I think it's not worth trying
# to line this up with executable bit
EXE001,
# these ignores are from flake8-bugbear; please fix!
B007,B008,
# these ignores are from flake8-comprehensions; please fix!
C400,C401,C402,C403,C404,C405,C407,C411,C413,C414,C415
exclude =
.history,
.vscode,
tests/,
third_party,
build,
*.pyi,
.git,
.ipynb*,
__init__.py,
experiments_*,
data/


12 changes: 12 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# python
__pycache__

# IDE
.vscode
.history
experiments_*
data/

# project
*.pth
results
5 changes: 3 additions & 2 deletions codes/data/base_dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import cv2
import lmdb

import numpy as np
from torch.utils.data import Dataset

Expand Down Expand Up @@ -62,7 +62,8 @@ def parse_lmdb_key(key):
def read_lmdb_frame(env, key, size):
with env.begin(write=False) as txn:
buf = txn.get(key.encode('ascii'))
frm = np.frombuffer(buf, dtype=np.uint8).reshape(*size)
frm = np.frombuffer(buf, dtype=np.uint8)
frm = cv2.imdecode(frm, cv2.IMREAD_COLOR)
return frm

def crop_sequence(self, **kwargs):
Expand Down
3 changes: 1 addition & 2 deletions codes/data/paired_lmdb_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import pickle
import random

import cv2
import numpy as np
import torch

Expand All @@ -29,7 +28,7 @@ def __init__(self, data_opt, **kwargs):
# use partial videos
if hasattr(self, 'filter_file') and self.filter_file is not None:
with open(self.filter_file, 'r') as f:
sel_seqs = { line.strip() for line in f }
sel_seqs = {line.strip() for line in f}
self.gt_lr_keys = list(filter(
lambda x: self.parse_lmdb_key(x[0])[0] in sel_seqs, self.gt_lr_keys))

Expand Down
10 changes: 7 additions & 3 deletions codes/data/unpaired_lmdb_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@ def __init__(self, data_opt, **kwargs):
# use partial videos
if hasattr(self, 'filter_file') and self.filter_file is not None:
with open(self.filter_file, 'r') as f:
sel_seqs = { line.strip() for line in f }
sel_seqs = {line.strip() for line in f}
self.keys = list(filter(
lambda x: self.parse_lmdb_key(x)[0] in sel_seqs, self.keys))

# register parameters
self.env = None
self.env = self.init_lmdb(self.seq_dir)

def __len__(self):
return len(self.keys)
Expand Down Expand Up @@ -64,7 +64,7 @@ def __getitem__(self, item):
frms.append(frm[:, top: top + c_h, left: left + c_w].copy())
else:
# read frames
for i in range(cur_frm, cur_frm + self.tempo_extent):
def get_frames(i):
if i >= tot_frm:
# reflect temporal paddding, e.g., (0,1,2) -> (0,1,2,1,0)
key = '{}_{}x{}x{}_{:04d}'.format(
Expand All @@ -75,6 +75,10 @@ def __getitem__(self, item):

frm = self.read_lmdb_frame(self.env, key, size=(h, w, c))
frm = frm.transpose(2, 0, 1) # chw|rgb|uint8
return frm

for i in range(cur_frm, cur_frm + self.tempo_extent):
frm = get_frames(i)
frms.append(frm)

frms = np.stack(frms) # tchw|rgb|uint8
Expand Down
37 changes: 24 additions & 13 deletions codes/main.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import os
import os.path as osp
import math
import os.path as osp
import time
from datetime import datetime

import torch

from data import create_dataloader
from metrics.metric_calculator import MetricCalculator
from models import define_model
from models.networks import define_generator
from metrics.metric_calculator import MetricCalculator
from utils import dist_utils, base_utils, data_utils
from utils import base_utils, data_utils, dist_utils


def train(opt):
Expand All @@ -36,6 +36,7 @@ def train(opt):
base_utils.log_info(f'Total epochs needed: {total_epoch} for {total_iter} iterations')

# train
start_time = time.time()
for epoch in range(total_epoch):
if opt['dist']:
train_loader.sampler.set_epoch(epoch)
Expand All @@ -44,30 +45,39 @@ def train(opt):
# update iter
iter += 1
curr_iter = start_iter + iter
if iter > total_iter: break

# update learning rate
model.update_learning_rate()
if iter > total_iter:
break

# prepare data
model.prepare_data(data)

# train a mini-batch
model.train()

# update learning rate
model.update_learning_rate()

# update running log
model.reduce_log() # for distributed training
model.update_running_log()

# print messages
if log_freq > 0 and curr_iter % log_freq == 0:
msg = model.get_format_msg(epoch, curr_iter)
base_utils.log_info(msg)
ckpt_time = time.time() - start_time
eta = int(ckpt_time * (total_iter - curr_iter) / curr_iter)
eta = datetime.fromtimestamp(eta) - datetime.fromtimestamp(0)

msg = model.get_format_msg(epoch, curr_iter, total_epoch, total_iter)
base_utils.log_info(f"{msg} | eta: {str(eta)}")

# save model
if ckpt_freq > 0 and curr_iter % ckpt_freq == 0:
model.save(curr_iter)

filename = f'G_iter{curr_iter}.pth'
save_path = osp.join(model.ckpt_dir, filename)
base_utils.log_info(f"save model in {save_path}")

# evaluate model
if test_freq > 0 and curr_iter % test_freq == 0:
# set model index
Expand All @@ -76,7 +86,8 @@ def train(opt):
# for each testset
for dataset_idx in sorted(opt['dataset'].keys()):
# select test dataset
if 'test' not in dataset_idx: continue
if 'test' not in dataset_idx:
continue

ds_name = opt['dataset'][dataset_idx]['name']
base_utils.log_info(f'Testing on {ds_name} dataset')
Expand Down Expand Up @@ -135,7 +146,7 @@ def test(opt):
for load_path in opt['model']['generator']['load_path_lst']:
# set model index
model_idx = osp.splitext(osp.split(load_path)[-1])[0]

# log
base_utils.log_info(f'\n{"="*40}')
base_utils.log_info(f'Testing model: {model_idx}')
Expand Down Expand Up @@ -175,7 +186,7 @@ def test(opt):
data_utils.save_sequence(
res_seq_dir, hr_seq, data['frm_idx'], to_bgr=True)

base_utils.log_info('-'*40)
base_utils.log_info('-' * 40)

# logging
base_utils.log_info(f'Finish testing\n{"="*40}')
Expand Down
18 changes: 7 additions & 11 deletions codes/metrics/LPIPS/models/base_model.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import os

import numpy as np
import torch
from torch.autograd import Variable
from pdb import set_trace as st


class BaseModel():
def __init__(self):
pass;
pass

def name(self):
return 'BaseModel'

Expand All @@ -18,9 +18,6 @@ def initialize(self, use_gpu=True, gpu_ids=[0]):
def forward(self):
pass

def get_image_paths(self):
pass

def optimize_parameters(self):
pass

Expand All @@ -43,7 +40,7 @@ def save_network(self, network, path, network_label, epoch_label):
def load_network(self, network, network_label, epoch_label):
save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
save_path = os.path.join(self.save_dir, save_filename)
print('Loading network from %s'%save_path)
print('Loading network from %s' % save_path)
network.load_state_dict(torch.load(save_path))

def update_learning_rate():
Expand All @@ -53,6 +50,5 @@ def get_image_paths(self):
return self.image_paths

def save_done(self, flag=False):
np.save(os.path.join(self.save_dir, 'done_flag'),flag)
np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i')

np.save(os.path.join(self.save_dir, 'done_flag'), flag)
np.savetxt(os.path.join(self.save_dir, 'done_flag'), [flag, ], fmt='%i')
Loading