forked from yg211/acl20-ref-free-eval
-
Notifications
You must be signed in to change notification settings - Fork 0
/
generate_summary.py
59 lines (49 loc) · 2.4 KB
/
generate_summary.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import sys
sys.path.append('../')
import numpy as np
import os
from ref_free_metrics.sbert_score_metrics import get_rewards
from summariser.ngram_vector.vector_generator import Vectoriser
from summariser.deep_td import DeepTDAgent as RLAgent
from utils.data_reader import CorpusReader
from utils.evaluator import evaluate_summary_rouge, add_result
class RLSummarizer():
def __init__(self,reward_type='top10-sbert-f1',reward_strict=5.,rl_strict=5.,train_episode=5000, base_length=200, sample_summ_num=5000):
self.reward_strict = reward_strict
self.rl_strict = rl_strict
self.reward_type = reward_type
self.train_episode = train_episode
self.base_length = base_length
self.sample_summ_num = sample_summ_num
def get_sample_summaries(self, docs, summ_max_len=100):
vec = Vectoriser(docs,summ_max_len)
summary_list = vec.sample_random_summaries(self.sample_summ_num)
rewards = get_rewards(docs, summary_list, self.reward_type.split('-')[0])
assert len(summary_list) == len(rewards)
return summary_list, rewards
def summarize(self, docs, summ_max_len=100):
# generate sample summaries for memory replay
summaries, rewards = self.get_sample_summaries(docs, summ_max_len)
vec = Vectoriser(docs,base=self.base_length)
rl_agent = RLAgent(vec, summaries, strict_para=self.rl_strict, train_round=self.train_episode)
summary = rl_agent(rewards)
return summary
if __name__ == '__main__':
# read source documents
reader = CorpusReader('data/topic_1')
source_docs = reader()
# generate summaries, with summary max length 100 tokens
rl_summarizer = RLSummarizer()
summary = rl_summarizer.summarize(source_docs, summ_max_len=100)
print('\n=====Generated Summary=====')
print(summary)
# (Optional) Evaluate the quality of the summary using ROUGE metrics
if os.path.isdir('./rouge/ROUGE-RELEASE-1.5.5'):
refs = reader.readReferences() # make sure you have put the references in data/topic_1/references
avg_rouge_score = {}
for ref in refs:
rouge_scores = evaluate_summary_rouge(summary, ref)
add_result(avg_rouge_score, rouge_scores)
print('\n=====ROUGE scores against {} references====='.format(len(refs)))
for metric in avg_rouge_score:
print('{}:\t{}'.format(metric, np.mean(rouge_scores[metric])))