-
Notifications
You must be signed in to change notification settings - Fork 37
/
test.py
66 lines (44 loc) · 2.03 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import argparse
import os
import torch
import torch.nn as nn
from model import VGG16
from vis_flux import vis_flux
from datasets import FluxSegmentationDataset
from torch.autograd import Variable
import scipy.io as sio
from torch.utils.data import Dataset, DataLoader
DATASET = 'PascalContext'
TEST_VIS_DIR = './test_pred_flux/'
SNAPSHOT_DIR = './snapshots/'
def get_arguments():
"""Parse all the arguments provided from the CLI.
Returns:
A list of parsed arguments.
"""
parser = argparse.ArgumentParser(description="Super-BPD Network")
parser.add_argument("--dataset", type=str, default=DATASET,
help="Dataset for training.")
parser.add_argument("--test-vis-dir", type=str, default=TEST_VIS_DIR,
help="Directory for saving vis results during testing.")
parser.add_argument("--snapshot-dir", type=str, default=SNAPSHOT_DIR,
help="Where to save snapshots of the model.")
return parser.parse_args()
args = get_arguments()
def main():
if not os.path.exists(args.test_vis_dir + args.dataset):
os.makedirs(args.test_vis_dir + args.dataset)
model = VGG16()
model.load_state_dict(torch.load(args.snapshot_dir + args.dataset + '_400000.pth'))
model.eval()
model.cuda()
dataloader = DataLoader(FluxSegmentationDataset(dataset=args.dataset, mode='test'), batch_size=1, shuffle=False, num_workers=4)
for i_iter, batch_data in enumerate(dataloader):
Input_image, vis_image, gt_mask, gt_flux, weight_matrix, dataset_lendth, image_name = batch_data
print(i_iter, dataset_lendth)
pred_flux = model(Input_image.cuda())
vis_flux(vis_image, pred_flux, gt_flux, gt_mask, image_name, args.test_vis_dir + args.dataset + '/')
pred_flux = pred_flux.data.cpu().numpy()[0, ...]
sio.savemat(args.test_vis_dir + args.dataset + '/' + image_name[0] + '.mat', {'flux': pred_flux})
if __name__ == '__main__':
main()