-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
151 lines (108 loc) · 4.35 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import tensorflow as tf
import numpy as np
import pickle
import os
import time
from sklearn.utils import shuffle
from model import siamese_network
model = siamese_network((105,105, 1))
print(model.summary())
save_path = './data/'
with open(os.path.join(save_path, "train.pickle"), "rb") as f:
(Xtrain, train_classes) = pickle.load(f)
print("Training alphabets: \n")
print(list(train_classes.keys()))
with open(os.path.join(save_path, "val.pickle"), "rb") as f:
(Xval, val_classes) = pickle.load(f)
print("Validation alphabets:", end="\n\n")
print(list(val_classes.keys()))
def get_batch(batch_size,s="train"):
if s == 'train':
X = Xtrain
categories = train_classes
else:
X = Xval
categories = val_classes
n_classes, n_examples, w, h = X.shape
categories = np.random.choice(n_classes,size=(batch_size,),replace=False)
pairs=[np.zeros((batch_size, h, w,1)) for i in range(2)]
targets=np.zeros((batch_size,))
targets[batch_size//2:] = 1
for i in range(batch_size):
category = categories[i]
idx_1 = np.random.randint(0, n_examples)
pairs[0][i,:,:,:] = X[category, idx_1].reshape(w, h, 1)
idx_2 = np.random.randint(0, n_examples)
if i >= batch_size // 2:
category_2 = category
else:
category_2 = (category + np.random.randint(1,n_classes)) % n_classes
pairs[1][i,:,:,:] = X[category_2,idx_2].reshape(w, h,1)
return pairs, targets
def generate(batch_size, s="train"):
while True:
pairs, targets = get_batch(batch_size,s)
yield (pairs, targets)
def make_oneshot_task(N, s="val", language=None):
if s == 'train':
X = Xtrain
categories = train_classes
else:
X = Xval
categories = val_classes
n_classes, n_examples, w, h = X.shape
indices = np.random.randint(0, n_examples,size=(N,))
if language is not None: # if language is specified, select characters for that language
low, high = categories[language]
if N > high - low:
raise ValueError("This language ({}) has less than {} letters".format(language, N))
categories = np.random.choice(range(low,high),size=(N,),replace=False)
else:
categories = np.random.choice(range(n_classes),size=(N,),replace=False)
true_category = categories[0]
ex1, ex2 = np.random.choice(n_examples,replace=False,size=(2,))
test_image = np.asarray([X[true_category,ex1,:,:]]*N).reshape(N, w, h,1)
support_set = X[categories,indices,:,:]
support_set[0,:,:] = X[true_category,ex2]
support_set = support_set.reshape(N, w, h,1)
targets = np.zeros((N,))
targets[0] = 1
targets, test_image, support_set = shuffle(targets, test_image, support_set)
pairs = [test_image,support_set]
return pairs, targets
def test_oneshot(model, N, k, s = "val", verbose = 0):
"""Test average N way oneshot learning accuracy of a siamese neural net over k one-shot tasks"""
n_correct = 0
if verbose:
print("Evaluating model on {} random {} way one-shot learning tasks ... \n".format(k,N))
for i in range(k):
inputs, targets = make_oneshot_task(N,s)
probs = model.predict(inputs)
if np.argmax(probs) == np.argmax(targets):
n_correct+=1
percent_correct = (100.0 * n_correct / k)
if verbose:
print("Got an average of {}% {} way one-shot learning accuracy \n".format(percent_correct,N))
return percent_correct
evaluate_every = 200
batch_size = 32
n_iter = 20000
N_way = 20
n_val = 250
best = -1
model_path = './weights/'
print("Starting training process!")
print("-------------------------------------")
t_start = time.time()
for i in range(1, n_iter+1):
(inputs,targets) = get_batch(batch_size)
loss = model.train_on_batch(inputs, targets)
if i % evaluate_every == 0:
print("\n ------------- \n")
print("Time for {0} iterations: {1} mins".format(i, (time.time()-t_start)/60.0))
print("Train Loss: {0}".format(loss))
val_acc = test_oneshot(model, N_way, n_val, verbose=True)
model.save_weights(os.path.join(model_path, 'weights.{}.h5'.format(i)))
if val_acc >= best:
print("Current best: {0}, previous best: {1}".format(val_acc, best))
best = val_acc