-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathtest.py
More file actions
executable file
·67 lines (55 loc) · 2.11 KB
/
test.py
File metadata and controls
executable file
·67 lines (55 loc) · 2.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# Author: Mark Gee
# Platform: keras
# Testing script for gaze tracker
from utils import ITrackerData
from utils.random_eraser import get_random_eraser
from models import ITrackerModel, ITrackerImprove, mobileIFT, SEITracker
from keras.preprocessing.image import ImageDataGenerator
from keras.optimizers import Adam
import numpy as np
import argparse
parser = argparse.ArgumentParser(description='Testing the gaze tracker')
parser.add_argument('--model', help="Model to use (baseline, improved, seresnet, semobile, mobileift, semobileift)", required=True)
parser.add_argument('--weights', default=None, help="Path to weights to be loaded to start training from (optional).", required=True)
args = parser.parse_args()
# Define training parameters
batch_size = 20 # Change if out of cuda memory
# Define the model
if args.model == 'baseline':
model = ITrackerModel.ITrackerModel()
mode = 'baseline'
elif args.model == 'improve':
model = ITrackerImprove.ITrackerModel()
mode = 'baseline'
elif args.model == 'seresnet':
model = SEITracker.SEITracker()
mode = 'face'
elif args.model == 'semobile':
model = SEITracker.SEITracker(type='mobile')
mode = 'face'
elif args.model == 'mobileift':
model = mobileIFT.MobileIFTracker()
mode = 'landmarks'
elif args.model == 'semobileift':
model = mobileIFT.MobileIFTracker(use_se=True)
mode = 'landmarks'
model.summary()
# Load data generators from ITrackerData
# Change to test split as necessary
val_data_generator = ITrackerData.ITrackerData(batch_size, imSize=(224,224), split='val', mode=mode)
# Define the optimizer
optimizer = Adam(lr=lr)
# Compile with loss weights if using MobileIFTracker since multiple outputs
if args.model == 'mobileift' or args.model == 'semobileift':
losses = {
"gaze": "mse",
"lms": "mse",
}
lossWeights = {"gaze": 1.0, "lms": 1.0}
model.compile(optimizer=optimizer, loss=losses, loss_weights=lossWeights)
else:
model.compile(optimizer=optimizer, loss='mse')
model.load_weights(args.weights)
# Evaluating
results = model.evaluate_generator(val_data_generator, verbose=1)
print('Results: \n: ', results)