-
Notifications
You must be signed in to change notification settings - Fork 5
/
evaluate.py
68 lines (54 loc) · 2.16 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import json
from argparse import ArgumentParser
import tensorflow as tf
from train import (
SpacingModel,
string_to_example,
sparse_categorical_accuracy_with_ignore,
SparseCategoricalCrossentropyWithIgnore,
)
parser = ArgumentParser()
parser.add_argument("--char-file", type=str, required=True)
parser.add_argument("--model-file", type=str, required=True)
parser.add_argument("--training-config", type=str, required=True)
parser.add_argument("--test-file", type=str, required=True)
parser.add_argument("--add-prob", type=float, required=True)
parser.add_argument("--delete-prob", type=float, required=True)
def main():
args = parser.parse_args()
with open(args.training_config) as f:
config = json.load(f)
with open(args.char_file) as f:
content = f.read()
keys = ["<pad>", "<s>", "</s>", "<unk>"] + list(content)
values = list(range(len(keys)))
vocab_initializer = tf.lookup.KeyValueTensorInitializer(keys, values, key_dtype=tf.string, value_dtype=tf.int32)
vocab_table = tf.lookup.StaticHashTable(vocab_initializer, default_value=3)
test_dataset = (
tf.data.TextLineDataset(tf.constant([args.test_file]))
.shuffle(10000)
.map(
string_to_example(vocab_table, delete_prob=args.delete_prob, add_prob=args.add_prob),
num_parallel_calls=tf.data.experimental.AUTOTUNE,
)
.batch(config["val_batch_size"])
)
model = SpacingModel(
config["vocab_size"],
config["hidden_size"],
conv_activation=config["conv_activation"],
dense_activation=config["dense_activation"],
conv_kernel_and_filter_sizes=config["conv_kernel_and_filter_sizes"],
dropout_rate=config["dropout_rate"],
)
model.compile(
optimizer=tf.optimizers.Adam(learning_rate=config["learning_rate"]),
loss=SparseCategoricalCrossentropyWithIgnore(from_logits=True, ignore_id=-1),
metrics=[sparse_categorical_accuracy_with_ignore],
)
model.load_weights(args.model_file)
model(tf.keras.Input([None], dtype=tf.int32))
model.summary()
model.evaluate(test_dataset)
if __name__ == "__main__":
main()