-
Notifications
You must be signed in to change notification settings - Fork 1
/
Test.py
80 lines (64 loc) · 2.89 KB
/
Test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import argparse
import os
import numpy as np
import random
import torch
import torch.nn as nn
from data_loader.msrs_data import MSRS_data
from models.Common import YCrCb2RGB, clamp
from models.Fusion import MBHFuse
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm
def init_seeds(seed=0):
import torch.backends.cudnn as cudnn
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if args.cuda:
torch.cuda.manual_seed(seed)
cudnn.benchmark, cudnn.deterministic = (False, True) if seed == 0 else (True, False)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='PyTorch MBHFuse')
parser.add_argument('--dataset_path', metavar='DIR', default=r'./test_data/TNO',
help='path to dataset (default: imagenet)')
parser.add_argument('-a', '--arch', metavar='ARCH', default='fusion_model',
choices=['fusion_model'])
parser.add_argument('--save_path', default='results/TNO')
parser.add_argument('-j', '--workers', default=1, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--fusion_pretrained', default='pretrained/fusion_model_epoch_59.pth',
help='use cls pre-trained model')
parser.add_argument('--seed', default=3407, type=int,
help='seed for initializing training. ')
parser.add_argument('--cuda', default=True, type=bool,
help='use GPU or not.')
args = parser.parse_args()
init_seeds(args.seed)
test_dataset = MSRS_data(args.dataset_path)
test_loader = DataLoader(
test_dataset, batch_size=1, shuffle=False,
num_workers=args.workers, pin_memory=True)
if not os.path.exists(args.save_path):
os.makedirs(args.save_path)
if args.arch == 'fusion_model':
model = MBHFuse()
model = model.cuda()
if args.cuda and torch.cuda.device_count() > 1:
model = nn.DataParallel(model)
model.load_state_dict(torch.load(args.fusion_pretrained))
model.eval()
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Number of parameters: {num_params}")
test_tqdm = tqdm(test_loader, total=len(test_loader))
with torch.no_grad():
for _, vis_y_image, cb, cr, inf_image, name in test_tqdm:
vis_y_image = vis_y_image.cuda()
cb = cb.cuda()
cr = cr.cuda()
inf_image = inf_image.cuda()
fused_image = model(vis_y_image, inf_image)
fused_image = clamp(fused_image)
rgb_fused_image = YCrCb2RGB(fused_image[0], cb[0], cr[0])
rgb_fused_image = transforms.ToPILImage()(rgb_fused_image)
rgb_fused_image.save(f'{args.save_path}/{name[0]}')