-
Notifications
You must be signed in to change notification settings - Fork 59
/
Copy pathevaluate_text2sql_ckpts.py
110 lines (86 loc) · 4.4 KB
/
evaluate_text2sql_ckpts.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import argparse
import os
import json
from text2sql import _test
def parse_option():
parser = argparse.ArgumentParser("command line arguments for selecting the best ckpt.")
parser.add_argument('--batch_size', type = int, default = 8,
help = 'input batch size.')
parser.add_argument('--device', type = str, default = "2",
help = 'the id of used GPU device.')
parser.add_argument('--seed', type = int, default = 42,
help = 'random seed.')
parser.add_argument('--save_path', type = str, default = "./models/text2sql",
help = 'save path of fine-tuned text2sql models.')
parser.add_argument('--eval_results_path', type = str, default = "./eval_results/text2sql",
help = 'the evaluation results of fine-tuned text2sql models.')
parser.add_argument('--mode', type = str, default = "eval",
help='eval.')
parser.add_argument('--dev_filepath', type = str, default = "./data/pre-processing/resdsql_test.json",
help = 'file path of test2sql dev set.')
parser.add_argument('--original_dev_filepath', type = str, default = "./data/spider/dev.json",
help = 'file path of the original dev set (for registing evaluator).')
parser.add_argument('--db_path', type = str, default = "./data/spider/database",
help = 'file path of database.')
parser.add_argument('--tables_for_natsql', type = str, default = "NatSQL/NatSQLv1_6/tables_for_natsql.json",
help = 'file path of tables_for_natsql.json.')
parser.add_argument('--num_beams', type = int, default = 8,
help = 'beam size in model.generate() function.')
parser.add_argument('--num_return_sequences', type = int, default = 8,
help = 'the number of returned sequences in model.generate() function (num_return_sequences <= num_beams).')
parser.add_argument("--target_type", type = str, default = "sql",
help = "sql or natsql.")
parser.add_argument("--output", type = str, default = "predicted_sql.txt")
opt = parser.parse_args()
return opt
if __name__ == "__main__":
opt = parse_option()
ckpt_names = os.listdir(opt.save_path)
ckpt_names = sorted(ckpt_names, key = lambda x:eval(x.split("-")[1]))
print("ckpt_names:", ckpt_names)
save_path = opt.save_path
os.makedirs(opt.eval_results_path, exist_ok = True)
eval_results = []
for ckpt_name in ckpt_names:
print("Start evaluating ckpt: {}".format(ckpt_name))
opt.save_path = save_path + "/{}".format(ckpt_name)
em, exec = _test(opt)
eval_result = dict()
eval_result["ckpt"] = opt.save_path
eval_result["EM"] = em
eval_result["EXEC"] = exec
with open(opt.eval_results_path+"/{}.txt".format(ckpt_name), "w") as f:
f.write(json.dumps(eval_result, indent = 2, ensure_ascii = False))
eval_results.append(eval_result)
for eval_result in eval_results:
print("ckpt name:", eval_result["ckpt"])
print("EM:", eval_result["EM"])
print("EXEC:", eval_result["EXEC"])
print("-----------")
em_list = [er["EM"] for er in eval_results]
exec_list = [er["EXEC"] for er in eval_results]
em_and_exec_list = [em + exec for em, exec in zip(em_list, exec_list)]
# find best EM ckpt
best_em, exec_in_best_em = 0.00, 0.00
best_em_idx = 0
# find best EXEC ckpt
best_exec, em_in_best_exec = 0.00, 0.00
best_exec_idx = 0
# find best EM + EXEC ckpt
best_em_plus_exec = 0.00
best_em_plus_exec_idx = 0
for idx, (em, exec) in enumerate(zip(em_list, exec_list)):
if em > best_em or (em == best_em and exec > exec_in_best_em):
best_em = em
exec_in_best_em = exec
best_em_idx = idx
if exec > best_exec or (exec == best_exec and em > em_in_best_exec):
best_exec = exec
em_in_best_exec = em
best_exec_idx = idx
if em+exec > best_em_plus_exec:
best_em_plus_exec = em+exec
best_em_plus_exec_idx = idx
print("Best EM ckpt:", eval_results[best_em_idx])
print("Best EXEC ckpt:", eval_results[best_exec_idx])
print("Best EM+EXEC ckpt:", eval_results[best_em_plus_exec_idx])