-
Notifications
You must be signed in to change notification settings - Fork 0
/
get_accuracy.py
92 lines (83 loc) · 3.49 KB
/
get_accuracy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import os
import json
import re
import argparse
import pandas as pd
from collections import Counter
def oreilly_choice_postprocess(text: str) -> str:
s = text
s += ' '
matched = re.search(r'answer.+?[^\w]([a-zA-Z][^\w])', s.lower())
if matched:
s = s[matched.span()[0]:]
if not re.match(r'[A-Z][^\w]', s):
matched = re.search(r'[^\w][A-Z][^\w]', s)
if matched:
s = s[matched.span()[0] + 1:]
s = s.strip() + ' '
pattern = r'([A-Z][\s,]+(and)?[\s,]*)*[A-Z][^\w]'
matched = re.match(pattern, s, re.S)
prefix = matched.group() if matched else ''
ans = list(
set([
c.upper() for c in re.split('[^A-Z]', prefix)
if len(c) == 1 and c.isalpha()
]))
ans.sort()
return ','.join(ans)
def gen_accuracy(input_path, out_path, sc=False, sample=False, sample_path=''):
res_list = []
for file in os.listdir(input_path):
with open(os.path.join(input_path, file), 'r') as f:
res = json.load(f)
res_list.extend(list(res.values()))
print(f"All {len(res_list)} results.")
if sample:
with open(sample_path, 'r') as f:
sample_ids = json.load(f)
book_all, book_correct, book_acc = {}, {}, {}
correct_num = 0
sample_num = 0
for res in res_list:
question_id = res['reference']['id']
if sample and question_id not in sample_ids:
continue
if sample:
sample_num += 1
if sc:
predictions = [oreilly_choice_postprocess(pred) for pred in res['prediction']]
prediction = Counter(predictions).most_common(1)[0][0]
else:
prediction = oreilly_choice_postprocess(res['prediction'])
answer = res['reference']['answer']
book = question_id.split('-')[0]
cor = 1 if prediction == answer else 0
correct_num += cor
if book in book_all:
book_all[book] += 1
book_correct[book] += cor
else:
book_all[book] = 1
book_correct[book] = cor
if sample:
print(f"Using sample, all {sample_num} samples.")
all_acc = correct_num / len(res_list) * 100 if not sample else correct_num / sample_num * 100
print(f"Accuracy: {all_acc}")
for key in book_all.keys():
book_acc[key] = book_correct[key] / book_all[key] * 100
if out_path:
sorted_book_acc = dict(sorted(book_acc.items(), key=lambda x: x[0]))
for key in sorted_book_acc.keys():
sorted_book_acc[key] = [round(sorted_book_acc[key], 4)]
df = pd.DataFrame(sorted_book_acc)
df.to_csv(out_path, index=False)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--input', help='directory of prediction files', type=str, required=True)
parser.add_argument("--out", help="output path", type=str, default=None)
parser.add_argument("--sc", help="use self-consistency or not", type=lambda x: x.lower() in ['true', '1'], default=False)
parser.add_argument("--sample", help="use sample or not", type=lambda x: x.lower() in ['true', '1'], default=False)
# parser.add_argument("--sample_path", help="sample ids path", type=str, default='/home/v-xll22/OpsGPT/OpenCompass-OpsQA/sample_id.json')
parser.add_argument("--sample_path", help="sample ids path", type=str, default='/home/v-xll22/OpsGPT/OpenCompass-OpsQA/chinese_id.json')
args = parser.parse_args()
gen_accuracy(args.input, args.out, args.sc, args.sample, args.sample_path)