-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy paththroughput.py
116 lines (86 loc) · 3.48 KB
/
throughput.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
"""
(Testing FPS)
Pixel Difference Networks for Efficient Edge Detection
"""
from __future__ import absolute_import
from __future__ import unicode_literals
from __future__ import print_function
from __future__ import division
import os
import argparse
import models
from utils import *
from edge_dataloader import BSDS_VOCLoader, Dataloader_BSDS500
from torch.utils.data import DataLoader
import torch
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
parser = argparse.ArgumentParser(description='PyTorch Diff Convolutional Networks (Train)')
parser.add_argument('--datadir', type=str, default='../data',
help='dir to the dataset')
parser.add_argument('--dataset', type=str, default='BSDS',
help='data settings for BSDS, Multicue and NYUD datasets')
parser.add_argument('--model', type=str, default='baseline',
help='model to train the dataset')
parser.add_argument('--sa', action='store_true',
help='use attention in diffnet')
parser.add_argument('--dil', action='store_true',
help='use dilation in diffnet')
parser.add_argument('--config', type=str, default='nas-all',
help='model configurations, please refer to models/config.py for possible configurations')
parser.add_argument('--seed', type=int, default=None,
help='random seed (default: None)')
parser.add_argument('--gpu', type=str, default='',
help='gpus available')
parser.add_argument('--epochs', type=int, default=150,
help='number of total epochs to run')
parser.add_argument('-j', '--workers', type=int, default=4,
help='number of data loading workers')
parser.add_argument('--eta', type=float, default=0.3,
help='threshold to determine the ground truth')
args = parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
def main():
global args
### Refine args
if args.seed is None:
args.seed = int(time.time())
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
args.use_cuda = torch.cuda.is_available()
dataset_setting_choices = ['BSDS', 'Custom']
if not isinstance(args.dataset, list):
assert args.dataset in dataset_setting_choices, 'unrecognized data setting %s, please choose from %s' % (str(args.dataset), str(dataset_setting_choices))
args.dataset = list(args.dataset.strip().split('-'))
print(args)
### Create model
model = getattr(models, args.model)(args)
### Transfer to cuda devices
if args.use_cuda:
model = torch.nn.DataParallel(model).cuda()
print('cuda is used, with %d gpu devices' % torch.cuda.device_count())
else:
print('cuda is not used, the running might be slow')
### Load Data
if 'BSDS' == args.dataset[0]:
test_dataset = BSDS_VOCLoader(root=args.datadir, split="test", threshold=args.eta)
else:
raise ValueError("unrecognized dataset setting")
test_loader = DataLoader(
test_dataset, batch_size=1, num_workers=args.workers, shuffle=False)
test(test_loader, model, args)
return
def test(test_loader, model, args):
model.eval()
end = time.perf_counter()
torch.cuda.synchronize()
for idx, (image, img_name) in enumerate(test_loader):
with torch.no_grad():
image = image.cuda() if args.use_cuda else image
_, _, H, W = image.shape
results = model(image)
torch.cuda.synchronize()
end = time.perf_counter() - end
print('fps: %f' % (len(test_loader) / end))
if __name__ == '__main__':
main()
print('done')