-
Notifications
You must be signed in to change notification settings - Fork 0
/
eval_multilabel_semantic.py
82 lines (65 loc) · 3.5 KB
/
eval_multilabel_semantic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import argparse
import json
import os
import sys
from constants import DATA_DIR, REYMULTICLASSIFIER
from config_train import config as train_config
from config_eval import config as cfg_eval
import hyperparameters_multilabel
from src.training.train_utils import Logger
from src.models import get_classifier
from src.dataloaders.semantic_transforms_dataset import TF_BRIGHTNESS, TF_PERSPECTIVE, TF_CONTRAST, TF_ROTATION
from src.evaluate import SemanticMultilabelEvaluator
# setup arg parser
parser = argparse.ArgumentParser()
parser.add_argument('--results-dir', type=str, default=None)
parser.add_argument('--batch-size', default=100, type=int)
parser.add_argument('--workers', default=8, type=int)
parser.add_argument('--tta', action='store_true')
# transformations
parser.add_argument('--transform', type=str, default=TF_ROTATION,
choices=[TF_BRIGHTNESS, TF_PERSPECTIVE, TF_CONTRAST, TF_ROTATION])
parser.add_argument('--angles', nargs='+', type=float, default=[0, 5], help='absolute value (min, max) rotation angles')
parser.add_argument('--distortion', type=float, help='amount of distortion; ranges from 0 to 1')
parser.add_argument('--brightness', type=float, help='0 = black image, 1 = original image, 2 increases the brightness')
parser.add_argument('--contrast', type=float, help='0 = gray image, 1 = original image, 2 increases the contrast')
args = parser.parse_args()
def main():
# load args from .json
with open(os.path.join(args.results_dir, 'args.json'), 'r') as f:
train_args = json.load(f)
num_classes = train_args['n_classes']
image_size_str = " ".join(str(s) for s in train_args['image_size'])
data_dir = os.path.join(DATA_DIR, train_config['data_root'][image_size_str])
print(f'--> evaluating model from {args.results_dir}')
print(f'--> using data from {data_dir}')
# Read parameters from hyperparameters_multilabel.py
hyperparams = hyperparameters_multilabel.train_params[image_size_str]
# save terminal output to file
if args.transform == TF_ROTATION:
prefix = f'rotation_{args.angles}'
elif args.transform == TF_CONTRAST:
prefix = f'contrast_{args.contrast}'
elif args.transform == TF_BRIGHTNESS:
prefix = f'brightness_{args.brightness}'
elif args.transform == TF_PERSPECTIVE:
prefix = f'perspective_{args.distortion}'
else:
raise ValueError
log_file = "semantic_eval_out_" + prefix + ".txt"
sys.stdout = Logger(print_fp=os.path.join(args.results_dir, log_file))
model = get_classifier(arch=REYMULTICLASSIFIER, num_classes=num_classes)
evaluator = SemanticMultilabelEvaluator(model=model, image_size=hyperparams['image_size'],
results_dir=args.results_dir, data_dir=data_dir, batch_size=args.batch_size,
workers=hyperparams['workers'],
transform=args.transform,
rotation_angles=args.angles,
distortion_scale=args.distortion,
brightness_factor=args.brightness,
contrast_factor=args.contrast,
num_classes=num_classes,
tta=args.tta,
angles=cfg_eval[REYMULTICLASSIFIER]['angles'])
evaluator.run_eval(save=True, prefix=prefix)
if __name__ == '__main__':
main()