-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy patheval.py
119 lines (100 loc) · 4.31 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import os
import sys
import argparse
import math
from tqdm import tqdm
import torch
from torch.backends import cudnn
from torch.utils import data
cur_path = os.path.dirname(__file__)
sys.path.insert(0, os.path.join(cur_path, '../..'))
from model.get_model import get_model
from data import transforms_cv
from data.helper import make_data_sampler
from data.imagenet import ImageNet
from utils.helper import synchronize, accumulate_metric, is_main_process
from utils.metric_cls import Accuracy, TopKAccuracy
def parse_args():
parser = argparse.ArgumentParser(description='Eval ImageNet networks.')
parser.add_argument('--model', type=str, default='darknet53',
help="Base network name")
parser.add_argument('--input-size', type=int, default=224,
help='size of the input image size. default is 224')
parser.add_argument('--crop-ratio', type=float, default=0.875,
help='Crop ratio during validation. default is 0.875')
parser.add_argument('--batch-size', type=int, default=16,
help='Testing batch size')
parser.add_argument('--num-workers', '-j', dest='num_workers', type=int,
default=4, help='Number of data workers')
parser.add_argument('--data-dir', type=str, default=os.path.expanduser('~/.torch/datasets/imagenet'),
help='default data root')
parser.add_argument('--pretrained', type=str, default=None,
help='Load weights from previously saved parameters.')
# device
parser.add_argument('--cuda', action='store_true', default=True,
help='Evaluate with GPUs.')
parser.add_argument('--local_rank', type=int, default=0)
parser.add_argument('--init-method', type=str, default="env://")
args = parser.parse_args()
return args
def get_dataloader(opt, distributed):
input_size = opt.input_size
crop_ratio = opt.crop_ratio if opt.crop_ratio > 0 else 0.875
resize = int(math.ceil(input_size / crop_ratio))
transform_test = transforms_cv.Compose([
transforms_cv.Resize((resize, resize)),
transforms_cv.CenterCrop(input_size),
transforms_cv.ToTensor(),
transforms_cv.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
val_dataset = ImageNet(opt.data_dir, train=False, transform=transform_test)
sampler = make_data_sampler(val_dataset, False, distributed)
batch_sampler = data.BatchSampler(sampler=sampler, batch_size=opt.batch_size, drop_last=False)
val_loader = data.DataLoader(val_dataset, batch_sampler=batch_sampler, num_workers=opt.num_workers)
return val_loader
def validate(net, val_data, device, acc_top1, acc_top5):
net.eval()
acc_top1.reset()
acc_top5.reset()
cpu_device = torch.device("cpu")
tbar = tqdm(val_data)
for i, (data, label) in enumerate(tbar):
data = data.to(device)
with torch.no_grad():
outputs = net(data)
outputs = outputs.to(cpu_device)
acc_top1.update(label, outputs)
acc_top5.update(label, outputs)
return acc_top1, acc_top5
if __name__ == '__main__':
args = parse_args()
if args.pretrained is None:
args.pretrained = os.path.join(os.path.expanduser('~/.torch/models'), args.model.lower() + '.pth')
# device
device = torch.device('cpu')
num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
distributed = num_gpus > 1
if args.cuda and torch.cuda.is_available():
cudnn.benchmark = True
device = torch.device('cuda')
else:
distributed = False
if distributed:
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend="nccl", init_method=args.init_method)
# Load Model
model_name = args.model
kwargs = {'classes': 1000, 'pretrained': args.pretrained, }
net = get_model(model_name, **kwargs)
net.to(device)
# testing data
acc_top1 = Accuracy()
acc_top5 = TopKAccuracy(5)
val_data = get_dataloader(args, distributed)
# testing
acc_top1, acc_top5 = validate(net, val_data, device, acc_top1, acc_top5)
synchronize()
name1, top1 = accumulate_metric(acc_top1)
name5, top5 = accumulate_metric(acc_top5)
if is_main_process():
print('%s: %f, %s: %f' % (name1, top1, name5, top5))