-
Notifications
You must be signed in to change notification settings - Fork 179
/
eval.py
98 lines (82 loc) · 3.27 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
import numpy as np
import torch
from torch.utils.data import DataLoader
from semilearn.core.utils import get_net_builder, get_dataset
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--load_path', type=str, required=True)
'''
Backbone Net Configurations
'''
parser.add_argument('--net', type=str, default='wrn_28_2')
parser.add_argument('--net_from_name', type=bool, default=False)
'''
Data Configurations
'''
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--data_dir', type=str, default='./data')
parser.add_argument('--dataset', type=str, default='cifar10')
parser.add_argument('--num_classes', type=int, default=10)
parser.add_argument('--img_size', type=int, default=32)
parser.add_argument('--crop_ratio', type=int, default=0.875)
parser.add_argument('--max_length', type=int, default=512)
parser.add_argument('--max_length_seconds', type=float, default=4.0)
parser.add_argument('--sample_rate', type=int, default=16000)
args = parser.parse_args()
checkpoint_path = os.path.join(args.load_path)
checkpoint = torch.load(checkpoint_path)
load_model = checkpoint['ema_model']
load_state_dict = {}
for key, item in load_model.items():
if key.startswith('module'):
new_key = '.'.join(key.split('.')[1:])
load_state_dict[new_key] = item
else:
load_state_dict[key] = item
save_dir = '/'.join(checkpoint_path.split('/')[:-1])
args.save_dir = save_dir
args.save_name = ''
net = get_net_builder(args.net, args.net_from_name)(num_classes=args.num_classes)
keys = net.load_state_dict(load_state_dict)
if torch.cuda.is_available():
net.cuda()
net.eval()
# specify these arguments manually
args.num_labels = 40
args.ulb_num_labels = 49600
args.lb_imb_ratio = 1
args.ulb_imb_ratio = 1
args.seed = 0
args.epoch = 1
args.num_train_iter = 1024
dataset_dict = get_dataset(args, 'fixmatch', args.dataset, args.num_labels, args.num_classes, args.data_dir, False)
eval_dset = dataset_dict['eval']
eval_loader = DataLoader(eval_dset, batch_size=args.batch_size, drop_last=False, shuffle=False, num_workers=4)
acc = 0.0
test_feats = []
test_preds = []
test_probs = []
test_labels = []
with torch.no_grad():
for data in eval_loader:
image = data['x_lb']
target = data['y_lb']
image = image.type(torch.FloatTensor).cuda()
feat = net(image, only_feat=True)
logit = net(feat, only_fc=True)
prob = logit.softmax(dim=-1)
pred = prob.argmax(1)
acc += pred.cpu().eq(target).numpy().sum()
test_feats.append(feat.cpu().numpy())
test_preds.append(pred.cpu().numpy())
test_probs.append(prob.cpu().numpy())
test_labels.append(target.cpu().numpy())
test_feats = np.concatenate(test_feats)
test_preds = np.concatenate(test_preds)
test_probs = np.concatenate(test_probs)
test_labels = np.concatenate(test_labels)
print(f"Test Accuracy: {acc/len(eval_dset)}")