-
Notifications
You must be signed in to change notification settings - Fork 0
/
cnn_model.py
117 lines (82 loc) · 3.27 KB
/
cnn_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import argparse
import yaml
import tensorflow as tf
from data_utils import load_dataset, text2words, load_rubtsova_datasets
from sentiment.cnn import SentimentCNN
def train(interactive=False, config_file=None):
return _initialize('train', interactive, config_file)
def load_model(interactive=False, config_file=None):
return _initialize('load', interactive, config_file)
def _initialize(mode, interactive, config_file):
if mode not in ['train', 'load']:
raise Exception('mode should be one of \'train\' or \'load\'')
config = load_config(config_file)
cnn_config = config['cnn']
ds_config = config['datasets']
train_dataset, train_labels, valid_dataset, valid_labels = None, None, None, None
if mode == 'train':
train_dataset, train_labels, valid_dataset, valid_labels = \
load_rubtsova_datasets(ds_config['positive'],
ds_config['negative'],
ds_config['size'])
max_len = max(map(len, train_dataset))
print('Maximum sentence length: {}'.format(max_len))
with tf.Graph().as_default() as graph:
session = tf.Session(graph=graph)
cnn = SentimentCNN(
session=session,
**cnn_config
)
if mode == 'train':
cnn.train(train_dataset=train_dataset, train_labels=train_labels,
valid_dataset=valid_dataset, valid_labels=valid_labels)
else:
cnn.restore()
if interactive is True:
run_interactive(cnn)
return cnn
def load_config(file_name):
with open(file_name, 'r') as config_file:
return yaml.load(config_file.read())
def predict(cnn, sentence):
words = text2words(sentence)
prediction = cnn.predict(words)
n = prediction[0]
p = prediction[1]
return n, p
def run_interactive(cnn):
while True:
try:
text = input('Text: ')
n, p = predict(cnn, text)
print('Negative: {:g}. Positive: {:g}'.format(n, p))
print()
except KeyboardInterrupt:
return
except:
pass
def _str_to_bool(s):
"""Convert string to bool (in argparse context)."""
if s.lower() not in ['true', 'false']:
raise ValueError('Need bool; got %r' % s)
return {'true': True, 'false': False}[s.lower()]
def add_boolean_argument(parser, name, default=False):
"""Add a boolean argument to an ArgumentParser instance."""
group = parser.add_mutually_exclusive_group()
group.add_argument(
'--' + name, nargs='?', default=default, const=True, type=_str_to_bool)
group.add_argument('--no' + name, dest=name, action='store_false')
def cli_main():
parser = argparse.ArgumentParser()
parser.add_argument('mode', metavar='M',
help='CSV files to process',
type=str)
add_boolean_argument(parser, 'interactive', True)
parser.add_argument('-c', '--meta_config_file',
help='path to meta_config.yml',
type=str,
required=True)
args = parser.parse_args()
_initialize(args.mode, interactive=args.interactive, config_file=args.meta_config_file)
if __name__ == '__main__':
cli_main()