-
Notifications
You must be signed in to change notification settings - Fork 6
/
evaluate.py
114 lines (92 loc) · 3.97 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
__author__ = 'Daan van Stigt'
import os
import subprocess
import re
import pickle
from tqdm import tqdm
from parser import DependencyParser
from features import get_feature_opts
from utils import UD_LANG, UD_SPLIT
def predict(model, lines):
pred = []
for i, line in enumerate(tqdm(lines)):
final = i == len(lines) - 1
pred_line = []
pred_heads, _, _ = model.parse(line)
for i, token in enumerate(line[1:], 1): # Discard Root.
token.head = pred_heads[i]
pred_line.append(str(token))
if not final:
pred_line.append('') # Empty line.
pred.append('\n'.join(pred_line))
return pred
def call_conllx_eval_script(gold_path, pred_path):
s = subprocess.check_output(
['perl', 'scripts/eval.pl', '-g', gold_path, '-s', pred_path, '-q']
).decode('utf-8')
scores = re.findall('([0-9]?[0-9]\.[0-9][0-9]) %', s)
las, uas, lab_acc = [float(score) for score in scores]
return las, uas, lab_acc
def call_conllu_eval_script(gold_path, pred_path):
s = subprocess.check_output(
['scripts/conll18_ud_eval.py', '-v', gold_path, pred_path]
).decode('utf-8')
return s
def evaluate(args):
from main import get_data # TODO: cannot import at top... Some circular dependency?
print(f'Loading data from `{args.data}`...')
_, dev_dataset, test_dataset = get_data(args)
print(f'Loading model from `{args.model}`...')
feature_opts = get_feature_opts(args.features)
model = DependencyParser(feature_opts, args.decoder)
model.load(args.model)
print(f'Parsing development set...')
dev_pred = predict(model, dev_dataset.tokens)
print(f'Parsing test set...')
test_pred = predict(model, test_dataset.tokens)
ext = 'conll' if args.use_ptb else 'conllu'
print(f'Writing out predictions in {ext} format to `{args.out}`...')
dev_pred_path = os.path.join(args.out, f'dev.pred.{ext}')
test_pred_path = os.path.join(args.out, f'test.pred.{ext}')
with open(dev_pred_path, 'w') as f:
print('\n'.join(dev_pred), file=f)
with open(test_pred_path, 'w') as f:
print('\n'.join(test_pred), file=f)
print('Evaluating results...')
data_dir = os.path.expanduser(args.data)
if args.use_ptb:
dev_gold_path = os.path.join(data_dir, f'dev.conll')
test_gold_path = os.path.join(data_dir, f'test.conll')
dev_las, dev_uas, dev_lab_acc = call_conllx_eval_script(dev_gold_path, dev_pred_path)
test_las, test_uas, test_lab_acc = call_conllx_eval_script(test_gold_path, test_pred_path)
# TODO: formatting can be done a little cleaner than this...
dev_results = '\n'.join(
(f'{"LAS":<10} {dev_las:3.2f}', f'{"UAS":<10} {dev_uas:3.2f}', f'{"Lab-acc":<10} {dev_lab_acc:3.2f}'))
test_results = '\n'.join(
(f'{"LAS":<10} {test_las:3.2f}', f'{"UAS":<10} {test_uas:3.2f}', f'{"Lab-acc":<10} {test_lab_acc:3.2f}'))
else:
data_path = os.path.join(data_dir, UD_LANG[args.lang])
dev_gold_path = data_path + UD_SPLIT['dev']
test_gold_path = data_path + UD_SPLIT['test']
dev_results = call_conllu_eval_script(dev_gold_path, dev_pred_path)
test_results = call_conllu_eval_script(test_gold_path, test_pred_path)
# Print results to terminal.
print('Development results:')
print(dev_results)
print()
print('Test results:')
print(dev_results)
print()
# Print results to file.
dev_result_path = os.path.join(args.out, f'dev.{ext}.result')
with open(dev_result_path, 'w') as f:
print(f'Gold file: {dev_gold_path}', file=f)
print(f'Predicted file: {dev_pred_path}', file=f)
print(file=f)
print(dev_results, file=f)
test_result_path = os.path.join(args.out, f'test.{ext}.result')
with open(dev_result_path, 'w') as f:
print(f'Gold file: {dev_gold_path}', file=f)
print(f'Predicted file: {dev_pred_path}', file=f)
print(file=f)
print(test_results, file=f)