forked from cbfinn/maml
-
Notifications
You must be signed in to change notification settings - Fork 2
/
maml.py
231 lines (202 loc) · 12.6 KB
/
maml.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
""" Code for the MAML algorithm and network definitions. """
from __future__ import print_function
import numpy as np
import sys
import tensorflow as tf
try:
import special_grads
except KeyError as e:
print('WARN: Cannot define MaxPoolGrad, likely already defined for this version of tensorflow: %s' % e,
file=sys.stderr)
from tensorflow.python.platform import flags
from utils import mse, xent, conv_block, normalize
FLAGS = flags.FLAGS
class MAML:
def __init__(self, dim_input=1, dim_output=1, test_num_updates=5):
""" must call construct_model() after initializing MAML! """
self.dim_input = dim_input
self.dim_output = dim_output
self.update_lr = FLAGS.update_lr
self.meta_lr = tf.placeholder_with_default(FLAGS.meta_lr, ())
self.classification = False
self.test_num_updates = test_num_updates
if FLAGS.datasource == 'sinusoid':
self.dim_hidden = [40, 40]
self.loss_func = mse
self.forward = self.forward_fc
self.construct_weights = self.construct_fc_weights
elif FLAGS.datasource == 'omniglot' or FLAGS.datasource == 'miniimagenet':
self.loss_func = xent
self.classification = True
if FLAGS.conv:
self.dim_hidden = FLAGS.num_filters
self.forward = self.forward_conv
self.construct_weights = self.construct_conv_weights
else:
self.dim_hidden = [256, 128, 64, 64]
self.forward=self.forward_fc
self.construct_weights = self.construct_fc_weights
if FLAGS.datasource == 'miniimagenet':
self.channels = 3
else:
self.channels = 1
self.img_size = int(np.sqrt(self.dim_input/self.channels))
else:
raise ValueError('Unrecognized data source.')
def construct_model(self, input_tensors=None, prefix='metatrain_'):
# a: training data for inner gradient, b: test data for meta gradient
if input_tensors is None:
self.inputa = tf.placeholder(tf.float32)
self.inputb = tf.placeholder(tf.float32)
self.labela = tf.placeholder(tf.float32)
self.labelb = tf.placeholder(tf.float32)
else:
self.inputa = input_tensors['inputa']
self.inputb = input_tensors['inputb']
self.labela = input_tensors['labela']
self.labelb = input_tensors['labelb']
with tf.variable_scope('model', reuse=None) as training_scope:
if 'weights' in dir(self):
training_scope.reuse_variables()
weights = self.weights
else:
# Define the weights
self.weights = weights = self.construct_weights()
# outputbs[i] and lossesb[i] is the output and loss after i+1 gradient updates
lossesa, outputas, lossesb, outputbs = [], [], [], []
accuraciesa, accuraciesb = [], []
num_updates = max(self.test_num_updates, FLAGS.num_updates)
outputbs = [[]]*num_updates
lossesb = [[]]*num_updates
accuraciesb = [[]]*num_updates
def task_metalearn(inp, reuse=True):
""" Perform gradient descent for one task in the meta-batch. """
inputa, inputb, labela, labelb = inp
task_outputbs, task_lossesb = [], []
if self.classification:
task_accuraciesb = []
task_outputa = self.forward(inputa, weights, reuse=reuse) # only reuse on the first iter
task_lossa = self.loss_func(task_outputa, labela)
grads = tf.gradients(task_lossa, list(weights.values()))
if FLAGS.stop_grad:
grads = [tf.stop_gradient(grad) for grad in grads]
gradients = dict(zip(weights.keys(), grads))
fast_weights = dict(zip(weights.keys(), [weights[key] - self.update_lr*gradients[key] for key in weights.keys()]))
output = self.forward(inputb, fast_weights, reuse=True)
task_outputbs.append(output)
task_lossesb.append(self.loss_func(output, labelb))
for j in range(num_updates - 1):
loss = self.loss_func(self.forward(inputa, fast_weights, reuse=True), labela)
grads = tf.gradients(loss, list(fast_weights.values()))
if FLAGS.stop_grad:
grads = [tf.stop_gradient(grad) for grad in grads]
gradients = dict(zip(fast_weights.keys(), grads))
fast_weights = dict(zip(fast_weights.keys(), [fast_weights[key] - self.update_lr*gradients[key] for key in fast_weights.keys()]))
output = self.forward(inputb, fast_weights, reuse=True)
task_outputbs.append(output)
task_lossesb.append(self.loss_func(output, labelb))
task_output = [task_outputa, task_outputbs, task_lossa, task_lossesb]
if self.classification:
task_accuracya = tf.contrib.metrics.accuracy(tf.argmax(tf.nn.softmax(task_outputa), 1), tf.argmax(labela, 1))
for j in range(num_updates):
task_accuraciesb.append(tf.contrib.metrics.accuracy(tf.argmax(tf.nn.softmax(task_outputbs[j]), 1), tf.argmax(labelb, 1)))
task_output.extend([task_accuracya, task_accuraciesb])
return task_output
if FLAGS.norm is not 'None':
# to initialize the batch norm vars, might want to combine this, and not run idx 0 twice.
unused = task_metalearn((self.inputa[0], self.inputb[0], self.labela[0], self.labelb[0]), False)
out_dtype = [tf.float32, [tf.float32]*num_updates, tf.float32, [tf.float32]*num_updates]
if self.classification:
out_dtype.extend([tf.float32, [tf.float32]*num_updates])
result = tf.map_fn(task_metalearn, elems=(self.inputa, self.inputb, self.labela, self.labelb), dtype=out_dtype, parallel_iterations=FLAGS.meta_batch_size)
if self.classification:
outputas, outputbs, lossesa, lossesb, accuraciesa, accuraciesb = result
else:
outputas, outputbs, lossesa, lossesb = result
## Performance & Optimization
if 'train' in prefix:
self.total_loss1 = total_loss1 = tf.reduce_sum(lossesa) / tf.to_float(FLAGS.meta_batch_size)
self.total_losses2 = total_losses2 = [tf.reduce_sum(lossesb[j]) / tf.to_float(FLAGS.meta_batch_size) for j in range(num_updates)]
# after the map_fn
self.outputas, self.outputbs = outputas, outputbs
if self.classification:
self.total_accuracy1 = total_accuracy1 = tf.reduce_sum(accuraciesa) / tf.to_float(FLAGS.meta_batch_size)
self.total_accuracies2 = total_accuracies2 = [tf.reduce_sum(accuraciesb[j]) / tf.to_float(FLAGS.meta_batch_size) for j in range(num_updates)]
self.pretrain_op = tf.train.AdamOptimizer(self.meta_lr).minimize(total_loss1)
if FLAGS.metatrain_iterations > 0:
optimizer = tf.train.AdamOptimizer(self.meta_lr)
self.gvs = gvs = optimizer.compute_gradients(self.total_losses2[FLAGS.num_updates-1])
if FLAGS.datasource == 'miniimagenet':
gvs = [(tf.clip_by_value(grad, -10, 10), var) for grad, var in gvs]
self.metatrain_op = optimizer.apply_gradients(gvs)
else:
self.metaval_total_loss1 = total_loss1 = tf.reduce_sum(lossesa) / tf.to_float(FLAGS.meta_batch_size)
self.metaval_total_losses2 = total_losses2 = [tf.reduce_sum(lossesb[j]) / tf.to_float(FLAGS.meta_batch_size) for j in range(num_updates)]
if self.classification:
self.metaval_total_accuracy1 = total_accuracy1 = tf.reduce_sum(accuraciesa) / tf.to_float(FLAGS.meta_batch_size)
self.metaval_total_accuracies2 = total_accuracies2 =[tf.reduce_sum(accuraciesb[j]) / tf.to_float(FLAGS.meta_batch_size) for j in range(num_updates)]
## Summaries
tf.summary.scalar(prefix+'Pre-update loss', total_loss1)
if self.classification:
tf.summary.scalar(prefix+'Pre-update accuracy', total_accuracy1)
for j in range(num_updates):
tf.summary.scalar(prefix+'Post-update loss, step ' + str(j+1), total_losses2[j])
if self.classification:
tf.summary.scalar(prefix+'Post-update accuracy, step ' + str(j+1), total_accuracies2[j])
### Network construction functions (fc networks and conv networks)
def construct_fc_weights(self):
weights = {}
weights['w1'] = tf.Variable(tf.truncated_normal([self.dim_input, self.dim_hidden[0]], stddev=0.01))
weights['b1'] = tf.Variable(tf.zeros([self.dim_hidden[0]]))
for i in range(1,len(self.dim_hidden)):
weights['w'+str(i+1)] = tf.Variable(tf.truncated_normal([self.dim_hidden[i-1], self.dim_hidden[i]], stddev=0.01))
weights['b'+str(i+1)] = tf.Variable(tf.zeros([self.dim_hidden[i]]))
weights['w'+str(len(self.dim_hidden)+1)] = tf.Variable(tf.truncated_normal([self.dim_hidden[-1], self.dim_output], stddev=0.01))
weights['b'+str(len(self.dim_hidden)+1)] = tf.Variable(tf.zeros([self.dim_output]))
return weights
def forward_fc(self, inp, weights, reuse=False):
hidden = normalize(tf.matmul(inp, weights['w1']) + weights['b1'], activation=tf.nn.relu, reuse=reuse, scope='0')
for i in range(1,len(self.dim_hidden)):
hidden = normalize(tf.matmul(hidden, weights['w'+str(i+1)]) + weights['b'+str(i+1)], activation=tf.nn.relu, reuse=reuse, scope=str(i+1))
return tf.matmul(hidden, weights['w'+str(len(self.dim_hidden)+1)]) + weights['b'+str(len(self.dim_hidden)+1)]
def construct_conv_weights(self):
weights = {}
dtype = tf.float32
conv_initializer = tf.contrib.layers.xavier_initializer_conv2d(dtype=dtype)
fc_initializer = tf.contrib.layers.xavier_initializer(dtype=dtype)
k = 3
weights['conv1'] = tf.get_variable('conv1', [k, k, self.channels, self.dim_hidden], initializer=conv_initializer, dtype=dtype)
weights['b1'] = tf.Variable(tf.zeros([self.dim_hidden]))
weights['conv2'] = tf.get_variable('conv2', [k, k, self.dim_hidden, self.dim_hidden], initializer=conv_initializer, dtype=dtype)
weights['b2'] = tf.Variable(tf.zeros([self.dim_hidden]))
weights['conv3'] = tf.get_variable('conv3', [k, k, self.dim_hidden, self.dim_hidden], initializer=conv_initializer, dtype=dtype)
weights['b3'] = tf.Variable(tf.zeros([self.dim_hidden]))
weights['conv4'] = tf.get_variable('conv4', [k, k, self.dim_hidden, self.dim_hidden], initializer=conv_initializer, dtype=dtype)
weights['b4'] = tf.Variable(tf.zeros([self.dim_hidden]))
if FLAGS.datasource == 'miniimagenet':
# assumes max pooling
weights['w5'] = tf.get_variable('w5', [self.dim_hidden * 5 * 5, 32], initializer=fc_initializer)
weights['b5'] = tf.Variable(tf.zeros([32]), name='b5')
#
weights['w6'] = tf.get_variable('w6', [32, self.dim_output], initializer=fc_initializer)
weights['b6'] = tf.Variable(tf.zeros([self.dim_output]), name='b6')
else:
weights['w5'] = tf.Variable(tf.random_normal([self.dim_hidden, self.dim_output]), name='w5')
weights['b5'] = tf.Variable(tf.zeros([self.dim_output]), name='b5')
return weights
def forward_conv(self, inp, weights, reuse=False, scope=''):
# reuse is for the normalization parameters.
channels = self.channels
inp = tf.reshape(inp, [-1, self.img_size, self.img_size, channels])
hidden1 = conv_block(inp, weights['conv1'], weights['b1'], reuse, scope+'0')
hidden2 = conv_block(hidden1, weights['conv2'], weights['b2'], reuse, scope+'1')
hidden3 = conv_block(hidden2, weights['conv3'], weights['b3'], reuse, scope+'2')
hidden4 = conv_block(hidden3, weights['conv4'], weights['b4'], reuse, scope+'3')
if FLAGS.datasource == 'miniimagenet':
# last hidden layer is 6x6x64-ish, reshape to a vector
hidden4 = tf.reshape(hidden4, [-1, np.prod([int(dim) for dim in hidden4.get_shape()[1:]])])
else:
hidden4 = tf.reduce_mean(hidden4, [1, 2])
hidden5 = tf.matmul(hidden4, weights['w5']) + weights['b5']
hidden5 = normalize(hidden5, tf.nn.relu, reuse, scope + '4')
return tf.matmul(hidden5, weights['w6']) + weights['b6']