forked from gitabcworld/MatchingNetworks
-
Notifications
You must be signed in to change notification settings - Fork 0
/
mainMiniImageNet.py
92 lines (78 loc) · 4.24 KB
/
mainMiniImageNet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Albert Berenguel
## Computer Vision Center (CVC). Universitat Autonoma de Barcelona
## Email: aberenguel@cvc.uab.es
## Copyright (c) 2017
##
## This source code is licensed under the MIT-style license found in the
## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
from datasets import miniImagenetOneShot
from option import Options
from experiments.OneShotMiniImageNetBuilder import miniImageNetBuilder
import tqdm
from logger import Logger
'''
:param batch_size: Experiment batch_size
:param classes_per_set: Integer indicating the number of classes per set
:param samples_per_class: Integer indicating samples per class
e.g. For a 20-way, 1-shot learning task, use classes_per_set=20 and samples_per_class=1
For a 5-way, 10-shot learning task, use classes_per_set=5 and samples_per_class=10
'''
# Experiment Setup
batch_size = 10
fce = True
classes_per_set = 5
samples_per_class = 5
channels = 3
# Training setup
total_epochs = 500
total_train_batches = 100
total_val_batches = 100
total_test_batches = 250
# Parse other options
args = Options().parse()
LOG_DIR = args.log_dir + '/miniImageNetOneShot_run-batchSize_{}-fce_{}-classes_per_set{}-samples_per_class{}-channels{}' \
.format(batch_size,fce,classes_per_set,samples_per_class,channels)
# create logger
logger = Logger(LOG_DIR)
#args.dataroot = '/home/aberenguel/Dataset/miniImagenet'
dataTrain = miniImagenetOneShot.miniImagenetOneShotDataset(dataroot=args.dataroot,
type = 'train',
nEpisodes = total_train_batches*batch_size,
classes_per_set=classes_per_set,
samples_per_class=samples_per_class)
dataVal = miniImagenetOneShot.miniImagenetOneShotDataset(dataroot=args.dataroot,
type = 'val',
nEpisodes = total_val_batches*batch_size,
classes_per_set=classes_per_set,
samples_per_class=samples_per_class)
dataTest = miniImagenetOneShot.miniImagenetOneShotDataset(dataroot=args.dataroot,
type = 'test',
nEpisodes = total_test_batches*batch_size,
classes_per_set=classes_per_set,
samples_per_class=samples_per_class)
obj_oneShotBuilder = miniImageNetBuilder(dataTrain,dataVal,dataTest)
obj_oneShotBuilder.build_experiment(batch_size, classes_per_set, samples_per_class, channels, fce)
best_val = 0.
with tqdm.tqdm(total=total_epochs) as pbar_e:
for e in range(0, total_epochs):
total_c_loss, total_accuracy = obj_oneShotBuilder.run_training_epoch()
print("Epoch {}: train_loss: {}, train_accuracy: {}".format(e, total_c_loss, total_accuracy))
total_val_c_loss, total_val_accuracy = obj_oneShotBuilder.run_validation_epoch()
print("Epoch {}: val_loss: {}, val_accuracy: {}".format(e, total_val_c_loss, total_val_accuracy))
logger.log_value('train_loss', total_c_loss)
logger.log_value('train_acc', total_accuracy)
logger.log_value('val_loss', total_val_c_loss)
logger.log_value('val_acc', total_val_accuracy)
if total_val_accuracy >= best_val: # if new best val accuracy -> produce test statistics
best_val = total_val_accuracy
total_test_c_loss, total_test_accuracy = obj_oneShotBuilder.run_testing_epoch()
print("Epoch {}: test_loss: {}, test_accuracy: {}".format(e, total_test_c_loss, total_test_accuracy))
logger.log_value('test_loss', total_test_c_loss)
logger.log_value('test_acc', total_test_accuracy)
else:
total_test_c_loss = -1
total_test_accuracy = -1
pbar_e.update(1)
logger.step()