forked from datastaxdevs/workshop-ai-as-api
-
Notifications
You must be signed in to change notification settings - Fork 0
/
loadTestModel.py
55 lines (46 loc) · 1.94 KB
/
loadTestModel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
""" loadTestModel.py
Check that one can start predicting with just the files
found in the "trained model" directory.
"""
import sys
import json
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.preprocessing.text import tokenizer_from_json
from tensorflow.keras import models
# in
trainedModelFile = 'training/trained_model_v1/spam_model.h5'
trainedMetadataFile = 'training/trained_model_v1/spam_metadata.json'
trainedTokenizerFile = 'training/trained_model_v1/spam_tokenizer.json'
if __name__ == '__main__':
# Load tokenizer and metadata:
# (in metadata, we'll need keys 'label_legend_inverted' and 'max_seq_length')
tokenizer = tokenizer_from_json(open(trainedTokenizerFile).read())
metadata = json.load(open(trainedMetadataFile))
# Load the model:
model = models.load_model(trainedModelFile)
# a function for testing:
def predictSpamStatus(text, spamModel, pMaxSequence, pLabelLegendInverted, pTokenizer):
sequences = pTokenizer.texts_to_sequences([text])
xInput = pad_sequences(sequences, maxlen=pMaxSequence)
yOutput = spamModel.predict(xInput)
preds = yOutput[0]
labeledPredictions = {pLabelLegendInverted[str(i)]: x for i, x in enumerate(preds)}
return labeledPredictions
if sys.argv[1:] == []:
# texts for the test
sampleTexts = [
'This is a nice touch, adding a sense of belonging and coziness. Thank you so much.',
'Click here to WIN A FREE IPHONE and this and that.',
]
else:
sampleTexts = [
' '.join(sys.argv[1:])
]
# simple test:
print('\n\tMODEL TEST:')
print('=' * 20)
for st in sampleTexts:
preds = predictSpamStatus(st, model, metadata['max_seq_length'], metadata['label_legend_inverted'], tokenizer)
print('TEXT = %s' % st)
print('PREDICTION = %s' % str(preds))
print('*' * 20)