-
Notifications
You must be signed in to change notification settings - Fork 0
/
app.py
83 lines (60 loc) · 2.38 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import os
import torch
from flask import Flask, request
from datetime import datetime
from transformers import BartTokenizer, BartForConditionalGeneration
try:
print('Loading model...')
tokenizer = BartTokenizer.from_pretrained('./model')
model = BartForConditionalGeneration.from_pretrained('./model')
except:
print('No model found, downloading...')
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn')
tokenizer.save_pretrained("./model")
model = BartForConditionalGeneration.from_pretrained('facebook/bart-large-cnn')
model.save_pretrained("./model")
print('Done!\nLoading model...')
def save_backup(text):
date = datetime.today().strftime('%Y-%m-%d-%H:%M:%S')
file_name = f"summarize_backup_{date}.txt"
with open(os.path.join('data/summary_text_backups', file_name), "w") as fp:
fp.write(text)
def make_summary(text):
inputs_no_trunc = tokenizer(text, max_length=None, return_tensors='pt', truncation=False)
chunk_start = 0
chunk_end = tokenizer.model_max_length
inputs_batch_lst = []
while chunk_start <= len(inputs_no_trunc['input_ids'][0]):
inputs_batch = inputs_no_trunc['input_ids'][0][chunk_start:chunk_end]
inputs_batch = torch.unsqueeze(inputs_batch, 0)
inputs_batch_lst.append(inputs_batch)
chunk_start += tokenizer.model_max_length
chunk_end += tokenizer.model_max_length
summary_ids_lst = [model.generate(inputs, num_beams=4, max_length=100, early_stopping=True) for inputs in
inputs_batch_lst]
summary_batch_lst = []
for summary_id in summary_ids_lst:
summary_batch = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in
summary_id]
summary_batch_lst.append(summary_batch[0])
return '\n'.join(summary_batch_lst)
"""FLASK APP"""
app = Flask(__name__)
@app.route("/")
def index():
return "Request (200)\n"
@app.route('/summarize', methods=['POST', 'GET'])
def summarize():
query = request.args.get('text')
if isinstance(query, list):
list_sums = []
for q in query:
list_sums.append(make_summary(q))
query = '\n\n'.join(list_sums)
sum_text = make_summary(query)
save_backup(sum_text)
summary = {}
summary['summary'] = sum_text
return summary
if __name__ == "__main__":
app.run()