-
Notifications
You must be signed in to change notification settings - Fork 0
/
meta_main.py
48 lines (33 loc) · 1.3 KB
/
meta_main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import tensorflow as tf
from cnn_model import load_config
from data_utils import load_rubtsova_datasets
from sentiment.cnn import SentimentCNN
def main():
config = load_config('meta_config.yaml')
cnn_base_config = config['cnn']
ds_config = config['datasets']
datasets = load_rubtsova_datasets(ds_config['positive'],
ds_config['negative'],
ds_config['size'])
AVG_RUNS_COUNT = 10
for n_filters in [20, 50, 100, 150, 200]:
cnn_config = cnn_base_config.copy()
cnn_config['n_filters'] = n_filters
avg_accuracy = 0
for i in range(AVG_RUNS_COUNT):
loss, accuracy = evaluate(cnn_config, datasets)
avg_accuracy += accuracy
print('n_filters: {}, accuracy: {}'.format(n_filters, accuracy))
avg_accuracy /= AVG_RUNS_COUNT
print('n_filters: {}, avg_accuracy: {}'.format(n_filters, avg_accuracy))
def evaluate(config, datasets):
with tf.Graph().as_default() as graph:
with tf.Session(graph=graph) as session:
cnn = SentimentCNN(
session=session,
**config
)
loss, accuracy = cnn.train(*datasets)
return loss, accuracy
if __name__ == '__main__':
main()