-
Notifications
You must be signed in to change notification settings - Fork 0
/
eval.py
135 lines (108 loc) · 6.99 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import os
import argparse
import torch
import torch.nn as nn
from torch.utils import data
from model import Net
from data_load import ACE2005Dataset, pad, all_triggers, all_entities, all_postags, idx2trigger, all_arguments, word_x_2d,all_words
from utils import calc_metric, find_triggers
def eval(model, iterator, fname, epoch):
model.eval()
"trigger和argument可以分开,"
words_all, triggers_all, triggers_hat_all, arguments_all, arguments_hat_all = [], [], [], [], []
with torch.no_grad():
for i, batch in enumerate(iterator):
tokens_x_2d, entities_x_3d, postags_x_2d, word_lstm_x_2d, triggers_y_2d, arguments_2d, seqlens_1d, head_indexes_2d, words_2d, triggers_2d = batch
"要改造predict_triggers的输入输出"
trigger_logits, triggers_y_2d, trigger_hat_2d, argument_hidden, argument_keys = model.module.predict_triggers_LSTM(tokens_x_2d=tokens_x_2d, entities_x_3d=entities_x_3d,
postags_x_2d=postags_x_2d, head_indexes_2d=head_indexes_2d,
triggers_y_2d=triggers_y_2d, arguments_2d=arguments_2d,
words_lstm_x_2d=word_lstm_x_2d)
words_all.extend(words_2d)
triggers_all.extend(triggers_2d)
triggers_hat_all.extend(trigger_hat_2d.cpu().numpy().tolist())
arguments_all.extend(arguments_2d)
if len(argument_keys) > 0:
argument_logits, arguments_y_1d, argument_hat_1d, argument_hat_2d = model.module.predict_arguments(argument_hidden, argument_keys, arguments_2d)
arguments_hat_all.extend(argument_hat_2d)
else:
batch_size = len(arguments_2d)
argument_hat_2d = [{'events': {}} for _ in range(batch_size)]
arguments_hat_all.extend(argument_hat_2d)
triggers_true, triggers_pred, arguments_true, arguments_pred = [], [], [], []
with open('temp', 'w') as fout:
for i, (words, triggers, triggers_hat, arguments, arguments_hat) in enumerate(zip(words_all, triggers_all, triggers_hat_all, arguments_all, arguments_hat_all)):
triggers_hat = triggers_hat[:len(words)]
triggers_hat = [idx2trigger[hat] for hat in triggers_hat]
# [(ith sentence, t_start, t_end, t_type_str)]
triggers_true.extend([(i, *item) for item in find_triggers(triggers)])
triggers_pred.extend([(i, *item) for item in find_triggers(triggers_hat)])
"这里argument都和trigger有联系"
# [(ith sentence, t_start, t_end, t_type_str, a_start, a_end, a_type_idx)]
for trigger in arguments['events']:
t_start, t_end, t_type_str = trigger
for argument in arguments['events'][trigger]:
a_start, a_end, a_type_idx = argument
arguments_true.append((i, t_start, t_end, t_type_str, a_start, a_end, a_type_idx))
for trigger in arguments_hat['events']:
t_start, t_end, t_type_str = trigger
for argument in arguments_hat['events'][trigger]:
a_start, a_end, a_type_idx = argument
arguments_pred.append((i, t_start, t_end, t_type_str, a_start, a_end, a_type_idx))
for w, t, t_h in zip(words[1:-1], triggers, triggers_hat):
fout.write('{}\t{}\t{}\n'.format(w, t, t_h))
fout.write('#arguments#{}\n'.format(arguments['events']))
fout.write('#arguments_hat#{}\n'.format(arguments_hat['events']))
fout.write("\n")
# print(classification_report([idx2trigger[idx] for idx in y_true], [idx2trigger[idx] for idx in y_pred]))
print('[trigger classification]')
trigger_p, trigger_r, trigger_f1 = calc_metric(triggers_true, triggers_pred,epoch ,True)
print('P={:.3f}\tR={:.3f}\tF1={:.3f}'.format(trigger_p, trigger_r, trigger_f1))#
print('[argument classification]')
argument_p, argument_r, argument_f1 = calc_metric((arguments_true), arguments_pred,epoch,False)
print('P={:.3f}\tR={:.3f}\tF1={:.3f}'.format(argument_p, argument_r, argument_f1))
print('[trigger identification]')
triggers_true = [(item[0], item[1], item[2]) for item in triggers_true]
triggers_pred = [(item[0], item[1], item[2]) for item in triggers_pred]
trigger_p_, trigger_r_, trigger_f1_ = calc_metric(triggers_true, triggers_pred,epoch, True)
print('P={:.3f}\tR={:.3f}\tF1={:.3f}'.format(trigger_p_, trigger_r_, trigger_f1_))
print('[argument identification]')
arguments_true = [(item[0], item[1], item[2], item[3], item[4], item[5]) for item in arguments_true]
arguments_pred = [(item[0], item[1], item[2], item[3], item[4], item[5]) for item in arguments_pred]
argument_p_, argument_r_, argument_f1_ = calc_metric(arguments_true, arguments_pred,epoch, False)
print('P={:.3f}\tR={:.3f}\tF1={:.3f}'.format(argument_p_, argument_r_, argument_f1_))
metric = '[trigger classification]\tP={:.3f}\tR={:.3f}\tF1={:.3f}\n'.format(trigger_p, trigger_r, trigger_f1)
metric += '[argument classification]\tP={:.3f}\tR={:.3f}\tF1={:.3f}\n'.format(argument_p, argument_r, argument_f1)
metric += '[trigger identification]\tP={:.3f}\tR={:.3f}\tF1={:.3f}\n'.format(trigger_p_, trigger_r_, trigger_f1_)
metric += '[argument identification]\tP={:.3f}\tR={:.3f}\tF1={:.3f}\n'.format(argument_p_, argument_r_, argument_f1_)
final = fname + ".P%.2f_R%.2f_F%.2f" % (trigger_p, trigger_r, trigger_f1)
with open(final, 'w') as fout:
result = open("temp", "r").read()
fout.write("{}\n".format(result))
fout.write(metric)
os.remove("temp")
"val"
return metric, trigger_f1
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--logdir", type=str, default="logdir")
parser.add_argument("--batch_size", type=int, default=24)
parser.add_argument("--testset", type=str, default="data/test.json")
parser.add_argument("--model_path", type=str, default="best_model_LSTM.pt")
hp = parser.parse_args()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if not os.path.exists(hp.model_path):
print('Warning: There is no model on the path:', hp.model_path, 'Please check the model_path parameter')
model = torch.load(hp.model_path)
if device == 'cuda':
model = model.cuda()
test_dataset = ACE2005Dataset(hp.testset)#
test_iter = data.DataLoader(dataset=test_dataset,
batch_size=hp.batch_size,
shuffle=False,
num_workers=4,
collate_fn=pad)
if not os.path.exists(hp.logdir):
os.makedirs(hp.logdir)
print(f"=========eval test=========")
eval(model, test_iter, 'eval_test')