-
Notifications
You must be signed in to change notification settings - Fork 62
/
Copy patheval_pretrained.py
executable file
·232 lines (184 loc) · 8.18 KB
/
eval_pretrained.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
import argparse
import json
import logging
from typing import Tuple
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from eval_tools import Metrics, time_sync, write_results
from mivolo.data.dataset import build as build_data
from mivolo.model.mi_volo import MiVOLO
from timm.utils import setup_default_logging
_logger = logging.getLogger("inference")
LOG_FREQUENCY = 10
def get_parser():
parser = argparse.ArgumentParser(description="PyTorch MiVOLO Validation")
parser.add_argument("--dataset_images", default="", type=str, required=True, help="path to images")
parser.add_argument("--dataset_annotations", default="", type=str, required=True, help="path to annotations")
parser.add_argument(
"--dataset_name",
default=None,
type=str,
required=True,
choices=["utk", "imdb", "lagenda", "fairface", "adience", "agedb", "cacd"],
help="dataset name",
)
parser.add_argument("--split", default="validation", help="dataset splits separated by comma (default: validation)")
parser.add_argument("--checkpoint", default="", type=str, required=True, help="path to mivolo checkpoint")
parser.add_argument("--batch-size", default=64, type=int, help="batch size")
parser.add_argument(
"--workers", default=4, type=int, metavar="N", help="number of data loading workers (default: 4)"
)
parser.add_argument("--device", default="cuda", type=str, help="Device (accelerator) to use.")
parser.add_argument("--l-for-cs", type=int, default=5, help="L for CS (cumulative score)")
parser.add_argument("--half", action="store_true", default=False, help="use half-precision model")
parser.add_argument(
"--with-persons", action="store_true", default=False, help="If the model will run with persons, if available"
)
parser.add_argument(
"--disable-faces", action="store_true", default=False, help="If the model will use only persons if available"
)
parser.add_argument("--draw-hist", action="store_true", help="Draws the hist of error by age")
parser.add_argument(
"--results-file",
default="",
type=str,
metavar="FILENAME",
help="Output csv file for validation results (summary)",
)
parser.add_argument(
"--results-format", default="csv", type=str, help="Format for results file one of (csv, json) (default: csv)."
)
return parser
def process_batch(
mivolo_model: MiVOLO,
input: torch.tensor,
target: torch.tensor,
num_classes_gender: int = 2,
):
start = time_sync()
output = mivolo_model.inference(input)
# target with age == -1 and gender == -1 marks that sample is not valid
assert not (all(target[:, 0] == -1) and all(target[:, 1] == -1))
if not mivolo_model.meta.only_age:
gender_out = output[:, :num_classes_gender]
gender_target = target[:, 1]
age_out = output[:, num_classes_gender:]
else:
age_out = output
gender_out, gender_target = None, None
# measure elapsed time
process_time = time_sync() - start
age_target = target[:, 0].unsqueeze(1)
return age_out, age_target, gender_out, gender_target, process_time
def _filter_invalid_target(out: torch.tensor, target: torch.tensor):
# exclude samples where target gt == -1, that marks sample is not valid
mask = target != -1
return out[mask], target[mask]
def postprocess_gender(gender_out: torch.tensor, gender_target: torch.tensor) -> Tuple[torch.tensor, torch.tensor]:
if gender_target is None:
return gender_out, gender_target
return _filter_invalid_target(gender_out, gender_target)
def postprocess_age(age_out: torch.tensor, age_target: torch.tensor, dataset) -> Tuple[torch.tensor, torch.tensor]:
# Revert _norm_age() operation. Output is 2 float tensors
age_out, age_target = _filter_invalid_target(age_out, age_target)
age_out = age_out * (dataset.max_age - dataset.min_age) + dataset.avg_age
# clamp to 0 because age can be below zero
age_out = torch.clamp(age_out, min=0)
if dataset.age_classes is not None:
# classification case
age_out = torch.round(age_out)
if dataset._intervals.device != age_out.device:
dataset._intervals = dataset._intervals.to(age_out.device)
age_inds = torch.searchsorted(dataset._intervals, age_out, side="right") - 1
age_out = age_inds
else:
age_target = age_target * (dataset.max_age - dataset.min_age) + dataset.avg_age
return age_out, age_target
def validate(args):
if torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True
mivolo_model = MiVOLO(
args.checkpoint,
args.device,
half=args.half,
use_persons=args.with_persons,
disable_faces=args.disable_faces,
verbose=True,
)
dataset, loader = build_data(
name=args.dataset_name,
images_path=args.dataset_images,
annotations_path=args.dataset_annotations,
split=args.split,
mivolo_model=mivolo_model, # to get meta information from model
workers=args.workers,
batch_size=args.batch_size,
)
d_stat = Metrics(args.l_for_cs, args.draw_hist, dataset.age_classes)
# warmup, reduce variability of first batch time, especially for comparing torchscript vs non
mivolo_model.warmup(args.batch_size)
preproc_end = time_sync()
for batch_idx, (input, target) in enumerate(loader):
preprocess_time = time_sync() - preproc_end
# get output and calculate loss
age_out, age_target, gender_out, gender_target, process_time = process_batch(
mivolo_model, input, target, dataset.num_classes_gender
)
gender_out, gender_target = postprocess_gender(gender_out, gender_target)
age_out, age_target = postprocess_age(age_out, age_target, dataset)
d_stat.update_gender_accuracy(gender_out, gender_target)
if d_stat.is_regression:
d_stat.update_regression_age_metrics(age_out, age_target)
else:
d_stat.update_age_accuracy(age_out, age_target)
d_stat.update_time(process_time, preprocess_time, input.shape[0])
if batch_idx % LOG_FREQUENCY == 0:
_logger.info(
"Test: [{0:>4d}/{1}] " "{2}".format(batch_idx, len(loader), d_stat.get_info_str(input.size(0)))
)
preproc_end = time_sync()
# model info
results = dict(
model=args.checkpoint,
dataset_name=args.dataset_name,
param_count=round(mivolo_model.param_count / 1e6, 2),
img_size=mivolo_model.input_size,
use_faces=mivolo_model.meta.use_face_crops,
use_persons=mivolo_model.meta.use_persons,
in_chans=mivolo_model.meta.in_chans,
batch=args.batch_size,
)
# metrics info
results.update(d_stat.get_result())
return results
def main():
parser = get_parser()
setup_default_logging()
args = parser.parse_args()
if torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True
results = validate(args)
result_str = " * Age Acc@1 {:.3f} ({:.3f})".format(results["agetop1"], results["agetop1_err"])
if "gendertop1" in results:
result_str += " Gender Acc@1 1 {:.3f} ({:.3f})".format(results["gendertop1"], results["gendertop1_err"])
result_str += " Mean inference time {:.3f} ms Mean preprocessing time {:.3f}".format(
results["mean_inference_time"], results["mean_preprocessing_time"]
)
_logger.info(result_str)
if args.draw_hist and "per_age_error" in results:
err = [sum(v) / len(v) for k, v in results["per_age_error"].items()]
ages = list(results["per_age_error"].keys())
sns.scatterplot(x=ages, y=err, hue=err)
plt.legend([], [], frameon=False)
plt.xlabel("Age")
plt.ylabel("MAE")
plt.savefig("age_error.png", dpi=300)
if args.results_file:
write_results(args.results_file, results, format=args.results_format)
# output results in JSON to stdout w/ delimiter for runner script
print(f"--result\n{json.dumps(results, indent=4)}")
if __name__ == "__main__":
main()