-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathudc_model.py
97 lines (79 loc) · 2.95 KB
/
udc_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import tensorflow as tf
import sys
def get_id_feature(features, key, len_key, max_len):
ids = features[key]
ids_len = tf.squeeze(features[len_key], [1])
ids_len = tf.minimum(ids_len, tf.constant(max_len, dtype=tf.int64))
return ids, ids_len
def create_train_op(loss, hparams):
train_op = tf.contrib.layers.optimize_loss(
loss=loss,
global_step=tf.contrib.framework.get_global_step(),
learning_rate=hparams.learning_rate,
clip_gradients=10.0,
optimizer=hparams.optimizer)
return train_op
def create_model_fn(hparams, model_impl):
def model_fn(features, targets, mode):
context, context_len = get_id_feature(
features, "context", "context_len", hparams.max_context_len)
utterance, utterance_len = get_id_feature(
features, "utterance", "utterance_len", hparams.max_utterance_len)
if mode == tf.contrib.learn.ModeKeys.TRAIN:
probs, loss = model_impl(
hparams,
mode,
context,
context_len,
utterance,
utterance_len,
targets)
train_op = create_train_op(loss, hparams)
return probs, loss, train_op
if mode == tf.contrib.learn.ModeKeys.INFER:
probs, loss = model_impl(
hparams,
mode,
context,
context_len,
utterance,
utterance_len,
None)
return probs, 0.0, None
if mode == tf.contrib.learn.ModeKeys.EVAL:
batch_size = targets.get_shape().as_list()[0]
# We have 10 exampels per record, so we accumulate them
all_contexts = [context]
all_context_lens = [context_len]
all_utterances = [utterance]
all_utterance_lens = [utterance_len]
all_targets = [tf.ones([batch_size, 1], dtype=tf.int64)]
for i in range(9):
distractor, distractor_len = get_id_feature(features,
"distractor_{}".format(i),
"distractor_{}_len".format(i),
hparams.max_utterance_len)
all_contexts.append(context)
all_context_lens.append(context_len)
all_utterances.append(distractor)
all_utterance_lens.append(distractor_len)
all_targets.append(
tf.zeros([batch_size, 1], dtype=tf.int64)
)
probs, loss = model_impl(
hparams,
mode,
tf.concat(all_contexts, 0),
tf.concat(all_context_lens, 0),
tf.concat(all_utterances, 0),
tf.concat(all_utterance_lens, 0),
tf.concat(all_targets, 0))
split_probs = tf.split(probs, 10, 0)
shaped_probs = tf.concat(split_probs, 1)
# Add summaries
tf.summary.histogram("eval_correct_probs_hist", split_probs[0])
tf.summary.scalar("eval_correct_probs_average", tf.reduce_mean(split_probs[0]))
tf.summary.histogram("eval_incorrect_probs_hist", split_probs[1])
tf.summary.scalar("eval_incorrect_probs_average", tf.reduce_mean(split_probs[1]))
return shaped_probs, loss, None
return model_fn