-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path02-train.py
149 lines (116 loc) · 5.96 KB
/
02-train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
# Two mode of training available:
# - BCE: CNN training, NOT Adversarial Training here. Only learns the generator network.
# - SALGAN: Adversarial Training. Updates weights for both Generator and Discriminator.
# The training used data previously processed using "01-data_preocessing.py"
import sys
import cPickle as pickle
import random
from tqdm import tqdm
from constants import *
from models.model_salgan import ModelSALGAN
from models.model_bce import ModelBCE
from utils import *
flag = str(sys.argv[1])
def bce_batch_iterator(model, train_data, validation_sample):
num_epochs = 301
n_updates = 1
nr_batches_train = int(len(train_data) / model.batch_size)
for current_epoch in tqdm(range(num_epochs), ncols=20):
e_cost = 0.
random.shuffle(train_data)
for currChunk in chunks(train_data, model.batch_size):
if len(currChunk) != model.batch_size:
continue
batch_input = np.asarray([x.image.data.astype(theano.config.floatX).transpose(2, 0, 1) for x in currChunk],
dtype=theano.config.floatX)
batch_output = np.asarray([y.saliency.data.astype(theano.config.floatX) / 255. for y in currChunk],
dtype=theano.config.floatX)
batch_output = np.expand_dims(batch_output, axis=1)
# train generator with one batch and discriminator with next batch
G_cost = model.G_trainFunction(batch_input, batch_output)
e_cost += G_cost
n_updates += 1
e_cost /= nr_batches_train
print 'Epoch:', current_epoch, ' train_loss->', e_cost
if current_epoch % 5 == 0:
np.savez('./' + DIR_TO_SAVE + '/gen_modelWeights{:04d}.npz'.format(current_epoch),
*lasagne.layers.get_all_param_values(model.net['output']))
predict(model=model, image_stimuli=validation_sample, num_epoch=current_epoch, path_output_maps=DIR_TO_SAVE)
def salgan_batch_iterator(model, train_data, validation_sample):
num_epochs = 301
nr_batches_train = int(len(train_data) / model.batch_size)
n_updates = 1
for current_epoch in tqdm(range(num_epochs), ncols=20):
g_cost = 0.
d_cost = 0.
e_cost = 0.
random.shuffle(train_data)
for currChunk in chunks(train_data, model.batch_size):
if len(currChunk) != model.batch_size:
continue
batch_input = np.asarray([x.image.data.astype(theano.config.floatX).transpose(2, 0, 1) for x in currChunk],
dtype=theano.config.floatX)
batch_output = np.asarray([y.saliency.data.astype(theano.config.floatX) / 255. for y in currChunk],
dtype=theano.config.floatX)
batch_output = np.expand_dims(batch_output, axis=1)
# train generator with one batch and discriminator with next batch
if n_updates % 2 == 0:
G_obj, D_obj, G_cost = model.G_trainFunction(batch_input, batch_output)
d_cost += D_obj
g_cost += G_obj
e_cost += G_cost
else:
G_obj, D_obj, G_cost = model.D_trainFunction(batch_input, batch_output)
d_cost += D_obj
g_cost += G_obj
e_cost += G_cost
n_updates += 1
g_cost /= nr_batches_train
d_cost /= nr_batches_train
e_cost /= nr_batches_train
# Save weights every 3 epoch
if current_epoch % 3 == 0:
np.savez('./' + DIR_TO_SAVE + '/gen_modelWeights{:04d}.npz'.format(current_epoch),
*lasagne.layers.get_all_param_values(model.net['output']))
np.savez('./' + DIR_TO_SAVE + '/disrim_modelWeights{:04d}.npz'.format(current_epoch),
*lasagne.layers.get_all_param_values(model.discriminator['prob']))
predict(model=model, image_stimuli=validation_sample, numEpoch=current_epoch, pathOutputMaps=DIR_TO_SAVE)
print 'Epoch:', current_epoch, ' train_loss->', (g_cost, d_cost, e_cost)
def train():
"""
Train both generator and discriminator
:return:
"""
# Load data
print 'Loading training data...'
with open(pathToPickle + 'trainData.pickle', 'rb') as f:
# with open(TRAIN_DATA_DIR, 'rb') as f:
train_data = pickle.load(f)
print '-->done!'
print 'Loading validation data...'
with open(pathToPickle + 'validationData.pickle', 'rb') as f:
# with open(VALIDATION_DATA_DIR, 'rb') as f:
validation_data = pickle.load(f)
print '-->done!'
# Choose a random sample to monitor the training
num_random = random.choice(range(len(validation_data)))
validation_sample = validation_data[num_random]
cv2.imwrite('./' + DIR_TO_SAVE + '/validationRandomSaliencyGT.png', validation_sample.saliency.data)
cv2.imwrite('./' + DIR_TO_SAVE + '/validationRandomImage.png', cv2.cvtColor(validation_sample.image.data,
cv2.COLOR_RGB2BGR))
# Create network
if flag == 'salgan':
model = ModelSALGAN(INPUT_SIZE[0], INPUT_SIZE[1])
# Load a pre-trained model
# load_weights(net=model.net['output'], path="nss/gen_", epochtoload=15)
# load_weights(net=model.discriminator['prob'], path="test_dialted/disrim_", epochtoload=54)
salgan_batch_iterator(model, train_data, validation_sample.image.data)
elif flag == 'bce':
model = ModelBCE(INPUT_SIZE[0], INPUT_SIZE[1])
# Load a pre-trained model
# load_weights(net=model.net['output'], path='test/gen_', epochtoload=15)
bce_batch_iterator(model, train_data, validation_sample.image.data)
else:
print "Invalid input argument."
if __name__ == "__main__":
train()