-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtweet_trainer.py
111 lines (97 loc) · 3.69 KB
/
tweet_trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import tensorflow as tf
import os
import re
from textgenrnn import textgenrnn
import time
import click
import ujson as json
def process_tweet_text(text):
text = re.sub(r'http\S+', '', text) # Remove URLs
text = re.sub(r'@[a-zA-Z0-9_]+', '', text) # Remove @ mentions
text = text.strip(" ") # Remove whitespace resulting from above
text = re.sub(r' +', ' ', text) # Remove redundant spaces
# Handle common HTML entities
text = re.sub(r'<', '<', text)
text = re.sub(r'>', '>', text)
text = re.sub(r'&', '&', text)
return text
def train_model(infile, size, epoch):
cfg = {'num_epochs': epoch,
'gen_epochs': 1,
'batch_size': 128,
'train_size': 1.0,
'new_model': False,
'model_config': {'rnn_layers': 2,
'rnn_size': 128,
'rnn_bidirectional': False,
'max_length': 40,
'dim_embeddings': 100,
'word_level': False
}
}
texts = []
context_labels = []
print('Loading training sample from file...')
start_time = time.time()
with open(infile, 'r') as f:
for line in f:
try:
s = json.loads(line)
if 'text' in s.keys():
tweet_text = process_tweet_text(s['text'])
if tweet_text is not '':
texts.append(tweet_text)
context_labels.append(s['user']['screen_name'])
if len(texts) == size:
break
except ValueError:
print('Reached end of file!')
print("Load time: {} seconds".format(time.time() - start_time))
print('Actual sample size:', len(texts))
textgen = textgenrnn(name='./weights/twitter_general')
if cfg['new_model']:
textgen.train_new_model(
texts,
context_labels=context_labels,
num_epochs=cfg['num_epochs'],
gen_epochs=cfg['gen_epochs'],
batch_size=cfg['batch_size'],
train_size=cfg['train_size'],
rnn_layers=cfg['model_config']['rnn_layers'],
rnn_size=cfg['model_config']['rnn_size'],
rnn_bidirectional=cfg['model_config']['rnn_bidirectional'],
max_length=cfg['model_config']['max_length'],
dim_embeddings=cfg['model_config']['dim_embeddings'],
word_level=cfg['model_config']['word_level'])
else:
textgen.train_on_texts(
texts,
context_labels=context_labels,
num_epochs=cfg['num_epochs'],
gen_epochs=cfg['gen_epochs'],
train_size=cfg['train_size'],
batch_size=cfg['batch_size'])
@click.command()
@click.option('--infile', '-i',
required=True,
help='Enter the json file storing the original tweets (e.g. tweets.json).')
@click.option('--size', '-k', type=click.INT,
required=True,
help='Enter the training sample size.')
@click.option('--epoch', '-e', type=click.INT,
required=True,
help='Enter the training epoch.')
def main(infile, size, epoch):
# silence tensorflow
tf.logging.set_verbosity(tf.logging.ERROR)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
# training general tweets
print('Training general tweets with sample size k = {}...'.format(size))
start_time = time.time()
try:
train_model(infile, size, epoch)
except ValueError:
pass
print("Training time: {} seconds".format(time.time() - start_time))
if __name__ == '__main__':
main()