-
Notifications
You must be signed in to change notification settings - Fork 5
/
caption_testing.py
62 lines (50 loc) · 1.82 KB
/
caption_testing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import numpy as np
import pandas as pd
import h5py
from scipy.sparse import save_npz, load_npz
from sklearn.model_selection import train_test_split
from keras.models import load_model
import tools.generic as tg
import tools.keras as tk
import tools.text as tt
from models.NRC import NRC
# Setting the seed for any train-test splitting
seed = 1234
# Importing and splitting the sparsified syndromic records
records = load_npz(SPARSE_RECORDS_NPZ_FILE)
train_indices, test_indices = train_test_split(range(records.shape[0]),
random_state=seed)
train_recs = records[train_indices]
test_recs = records[test_indices]
# Importing the pretrained autoencoder
ae = load_model(AE_ENCODER_HDF5_FILE)
ae_encoder = ae.layers[1]
# Importing the text files
sents = h5py.File(SENTENCE_INTEGERS_HDF5_FILE, mode='r')
train_sents = sents['X_train'].__array__()
y_train = sents['y_train'].__array__()
test_sents = sents['X_test'].__array__()
y_test = sents['y_test'].__array__()
# Importing the character lookup dictionary
vocab_df = pd.read_csv(WORD_DICTIONARY)
vocab = dict(zip(vocab_df['word'], vocab_df['value']))
eos_val = vocab['end_string']
# Setting some parameters
n = records.shape[0]
sparse_size = records.shape[1]
hidden_size = 128
embedding_size = 128
vocab_size = len(vocab.keys()) + 1
max_length = train_sents.shape[1]
# Setting up the NRC model
nrc = NRC(embedding_size=embedding_size,
sparse_size=sparse_size,
hidden_size=hidden_size,
vocab_size=vocab_size,
max_length=max_length)
# Loading the training model and building the inference model
nrc.load_training_model('nrc_training.hdf5')
# Picking random records to caption
to_caption = np.random.chocie(0, test_recs.shape[0], 100)
# Running some test captions
nrc.caption(test_recs[to_caption], vocab, method='beam')