Skip to content

Commit e24eccb

Browse files
oktieperlitz
andauthored
Text2sql execution accuracy metric updates (#1604)
* revised text2sql execution accuracy metric Signed-off-by: Oktie Hassanzadeh <hassanzadeh@us.ibm.com> * sql execution accuracy catalog card Signed-off-by: Oktie Hassanzadeh <hassanzadeh@us.ibm.com> * get_sql processor card update Signed-off-by: Oktie Hassanzadeh <hassanzadeh@us.ibm.com> * adding text2sql non-execution accuracy metric Signed-off-by: Oktie Hassanzadeh <hassanzadeh@us.ibm.com> * ruff fix Signed-off-by: Oktie Hassanzadeh <hassanzadeh@us.ibm.com> * adding sqlparse dependency for tests Signed-off-by: Oktie Hassanzadeh <hassanzadeh@us.ibm.com> * moving text2sql metrics functions to sql_utils Signed-off-by: Oktie Hassanzadeh <hassanzadeh@us.ibm.com> * fixing sql_utils ruff issues Signed-off-by: Oktie Hassanzadeh <hassanzadeh@us.ibm.com> * text2sql metrics: empty string is not SQL Signed-off-by: Oktie Hassanzadeh <hassanzadeh@us.ibm.com> * text2sql metric: check if dfs are non-empty before comparison Signed-off-by: Oktie Hassanzadeh <hassanzadeh@us.ibm.com> * text2sql metrics: fixing is_subset for output dfs Signed-off-by: Oktie Hassanzadeh <hassanzadeh@us.ibm.com> * sql_utils SQL API fixes Signed-off-by: Oktie Hassanzadeh <hassanzadeh@us.ibm.com> * text2sql metric: returning gold sql exec error Signed-off-by: Oktie Hassanzadeh <hassanzadeh@us.ibm.com> * cleanup, change info message Signed-off-by: Yotam Perlitz <y.perlitz@ibm.com> --------- Signed-off-by: Oktie Hassanzadeh <hassanzadeh@us.ibm.com> Signed-off-by: Yotam Perlitz <y.perlitz@ibm.com> Co-authored-by: Yotam Perlitz <y.perlitz@ibm.com>
1 parent b47f4fe commit e24eccb

File tree

20 files changed

+1035
-254
lines changed

20 files changed

+1035
-254
lines changed

examples/evaluate_text2sql.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,18 +23,14 @@
2323

2424
print_dict(
2525
evaluated_dataset[0],
26-
keys_to_print=[
27-
"source",
28-
"prediction",
29-
"subset",
30-
],
26+
keys_to_print=["source", "prediction", "subset"],
3127
)
3228
print_dict(
3329
evaluated_dataset[0]["score"]["global"],
3430
)
3531

3632
assert (
37-
evaluated_dataset[0]["score"]["global"]["score"] >= 0.44
33+
evaluated_dataset[0]["score"]["global"]["score"] >= 0.43
3834
), "results have been degraded, something is wrong with the metric"
3935

4036
# with llama-3-3-70b-instruct
Lines changed: 265 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,265 @@
1+
from unitxt.catalog import add_to_catalog
2+
from unitxt.metrics import SQLExecutionAccuracy, SQLNonExecutionAccuracy
3+
from unitxt.test_utils.metrics import test_metric
4+
5+
sql_execution_accuracy_metric = SQLExecutionAccuracy()
6+
7+
predictions = [
8+
"SELECT nme FROM employees WHERE department = 'Sales'",
9+
"SELECT name FROM employees WHERE department = 'Sales'",
10+
"SELECT name FROM employees WHERE department = 'Engineering'",
11+
"SELECT id, name FROM employees WHERE department = 'Sales'",
12+
"SELECT name FROM employees WHERE department = 'Non-Existent'",
13+
"Garbage SELECT *",
14+
] # Incorrect column name 'nme'
15+
references = [
16+
["SELECT name FROM employees WHERE department = 'Sales';"],
17+
["SELECT name FROM employees WHERE department = 'Sales';"],
18+
["SELECT name FROM employees WHERE department = 'Sales';"],
19+
["SELECT name FROM employees WHERE department = 'Sales';"],
20+
["SELECT name FROM employees WHERE department = 'Non-Existent';"],
21+
["SELECT name FROM employees WHERE department = 'Sales';"],
22+
]
23+
task_data = [
24+
{
25+
"db": {
26+
"db_id": "mock_db",
27+
"db_type": "in_memory",
28+
"data": {
29+
"employees": {
30+
"columns": ["id", "name", "department", "salary"],
31+
"rows": [
32+
(1, "Alice", "Sales", 50000),
33+
(2, "Bob", "Engineering", 60000),
34+
(3, "Charlie", "Sales", 55000),
35+
],
36+
}
37+
},
38+
}
39+
}
40+
] * 6
41+
42+
instance_targets = [
43+
{
44+
"error_message": "Error executing SQL: no such column: nme",
45+
"execution_accuracy": 0.0,
46+
"gold_df_json": "",
47+
"gold_error": 0.0,
48+
"non_empty_execution_accuracy": 0.0,
49+
"non_empty_gold_df": 0.0,
50+
"predicted_df_json": "",
51+
"predicted_error": 1.0,
52+
"score": 0.0,
53+
"score_name": "non_empty_execution_accuracy",
54+
"subset_non_empty_execution_result": 0.0,
55+
},
56+
{
57+
"error_message": "",
58+
"execution_accuracy": 1.0,
59+
"gold_df_json": '{"0":{"0":"Alice","1":"Charlie"}}',
60+
"gold_error": 1.0,
61+
"non_empty_execution_accuracy": 1.0,
62+
"non_empty_gold_df": 1.0,
63+
"predicted_df_json": '{"0":{"0":"Alice","1":"Charlie"}}',
64+
"predicted_error": 0.0,
65+
"score": 1.0,
66+
"score_name": "non_empty_execution_accuracy",
67+
"subset_non_empty_execution_result": 1.0,
68+
},
69+
{
70+
"error_message": "None",
71+
"execution_accuracy": 0.0,
72+
"gold_df_json": '{"0":{"0":"Alice","1":"Charlie"}}',
73+
"gold_error": 0.0,
74+
"non_empty_execution_accuracy": 0.0,
75+
"non_empty_gold_df": 1.0,
76+
"predicted_df_json": '{"0":{"0":"Bob"}}',
77+
"predicted_error": 0.0,
78+
"score": 0.0,
79+
"score_name": "non_empty_execution_accuracy",
80+
"subset_non_empty_execution_result": 0.0,
81+
},
82+
{
83+
"error_message": "None",
84+
"execution_accuracy": 0.0,
85+
"gold_df_json": '{"0":{"0":"Alice","1":"Charlie"}}',
86+
"gold_error": 0.0,
87+
"non_empty_execution_accuracy": 0.0,
88+
"non_empty_gold_df": 1.0,
89+
"predicted_df_json": '{"0":{"0":1,"1":3},"1":{"0":"Alice","1":"Charlie"}}',
90+
"predicted_error": 0.0,
91+
"score": 0.0,
92+
"score_name": "non_empty_execution_accuracy",
93+
"subset_non_empty_execution_result": 1.0,
94+
},
95+
{
96+
"error_message": "",
97+
"execution_accuracy": 1.0,
98+
"gold_df_json": "{}",
99+
"gold_error": 1.0,
100+
"non_empty_execution_accuracy": 0.0,
101+
"non_empty_gold_df": 0.0,
102+
"predicted_df_json": "{}",
103+
"predicted_error": 0.0,
104+
"score": 0.0,
105+
"score_name": "non_empty_execution_accuracy",
106+
"subset_non_empty_execution_result": 0.0,
107+
},
108+
{
109+
"error_message": "Error executing SQL: no tables specified",
110+
"execution_accuracy": 0.0,
111+
"gold_df_json": "",
112+
"gold_error": 0.0,
113+
"non_empty_execution_accuracy": 0.0,
114+
"non_empty_gold_df": 0.0,
115+
"predicted_df_json": "",
116+
"predicted_error": 1.0,
117+
"score": 0.0,
118+
"score_name": "non_empty_execution_accuracy",
119+
"subset_non_empty_execution_result": 0.0,
120+
},
121+
]
122+
123+
124+
global_target = {
125+
"execution_accuracy": 0.33,
126+
"execution_accuracy_ci_high": 0.83,
127+
"execution_accuracy_ci_low": 0.0,
128+
"gold_error": 0.33,
129+
"gold_sql_runtime_ci_high": 0.0,
130+
"gold_sql_runtime_ci_low": 0.0,
131+
"non_empty_execution_accuracy": 0.17,
132+
"non_empty_execution_accuracy_ci_high": 0.67,
133+
"non_empty_execution_accuracy_ci_low": 0.0,
134+
"non_empty_gold_df": 0.5,
135+
"num_of_instances": 6,
136+
"predicted_error": 0.33,
137+
"predicted_sql_runtime_ci_high": 0.0,
138+
"predicted_sql_runtime_ci_low": 0.0,
139+
"score": 0.17,
140+
"score_ci_high": 0.67,
141+
"score_ci_low": 0.0,
142+
"score_name": "non_empty_execution_accuracy",
143+
"subset_non_empty_execution_result": 0.33,
144+
"subset_non_empty_execution_result_ci_high": 0.83,
145+
"subset_non_empty_execution_result_ci_low": 0.0,
146+
}
147+
148+
outputs = test_metric(
149+
metric=sql_execution_accuracy_metric,
150+
predictions=predictions,
151+
references=references,
152+
instance_targets=instance_targets,
153+
global_target=global_target,
154+
task_data=task_data,
155+
score_keys_to_ignore=[
156+
"predicted_sql_runtime",
157+
"gold_sql_runtime",
158+
"pred_to_gold_runtime_ratio",
159+
],
160+
)
161+
162+
add_to_catalog(
163+
sql_execution_accuracy_metric, "metrics.text2sql.execution_accuracy", overwrite=True
164+
)
165+
166+
sql_non_execution_accuracy_metric = SQLNonExecutionAccuracy()
167+
168+
instance_targets = [
169+
{
170+
"score": 0.0,
171+
"score_name": "sqlglot_equivalence",
172+
"sql_exact_match": 0.0,
173+
"sqlglot_equivalence": 0.0,
174+
"sqlglot_optimized_equivalence": 0.0,
175+
"sqlglot_validity": 1.0,
176+
"sqlparse_equivalence": 0.0,
177+
"sqlparse_validity": 1.0,
178+
},
179+
{
180+
"score": 1.0,
181+
"score_name": "sqlglot_equivalence",
182+
"sql_exact_match": 1.0,
183+
"sqlglot_equivalence": 1.0,
184+
"sqlglot_optimized_equivalence": 1.0,
185+
"sqlglot_validity": 1.0,
186+
"sqlparse_equivalence": 0.0,
187+
"sqlparse_validity": 1.0,
188+
},
189+
{
190+
"score": 0.0,
191+
"score_name": "sqlglot_equivalence",
192+
"sql_exact_match": 0.0,
193+
"sqlglot_equivalence": 0.0,
194+
"sqlglot_optimized_equivalence": 0.0,
195+
"sqlglot_validity": 1.0,
196+
"sqlparse_equivalence": 0.0,
197+
"sqlparse_validity": 1.0,
198+
},
199+
{
200+
"score": 0.0,
201+
"score_name": "sqlglot_equivalence",
202+
"sql_exact_match": 0.0,
203+
"sqlglot_equivalence": 0.0,
204+
"sqlglot_optimized_equivalence": 0.0,
205+
"sqlglot_validity": 1.0,
206+
"sqlparse_equivalence": 0.0,
207+
"sqlparse_validity": 1.0,
208+
},
209+
{
210+
"score": 1.0,
211+
"score_name": "sqlglot_equivalence",
212+
"sql_exact_match": 1.0,
213+
"sqlglot_equivalence": 1.0,
214+
"sqlglot_optimized_equivalence": 1.0,
215+
"sqlglot_validity": 1.0,
216+
"sqlparse_equivalence": 0.0,
217+
"sqlparse_validity": 1.0,
218+
},
219+
{
220+
"score": 0.0,
221+
"score_name": "sqlglot_equivalence",
222+
"sql_exact_match": 0.0,
223+
"sqlglot_equivalence": 0.0,
224+
"sqlglot_optimized_equivalence": 0.0,
225+
"sqlglot_validity": 1.0,
226+
"sqlparse_equivalence": 0.0,
227+
"sqlparse_validity": 1.0,
228+
},
229+
]
230+
231+
232+
global_target = {
233+
"num_of_instances": 6,
234+
"score": 0.33,
235+
"score_ci_high": 0.83,
236+
"score_ci_low": 0.0,
237+
"score_name": "sqlglot_equivalence",
238+
"sql_exact_match": 0.33,
239+
"sql_exact_match_ci_high": 0.83,
240+
"sql_exact_match_ci_low": 0.0,
241+
"sqlglot_equivalence": 0.33,
242+
"sqlglot_equivalence_ci_high": 0.83,
243+
"sqlglot_equivalence_ci_low": 0.0,
244+
"sqlglot_optimized_equivalence": 0.33,
245+
"sqlglot_optimized_equivalence_ci_high": 0.83,
246+
"sqlglot_optimized_equivalence_ci_low": 0.0,
247+
"sqlglot_validity": 1.0,
248+
"sqlparse_equivalence": 0.0,
249+
"sqlparse_validity": 1.0,
250+
}
251+
252+
outputs = test_metric(
253+
metric=sql_non_execution_accuracy_metric,
254+
predictions=predictions,
255+
references=references,
256+
instance_targets=instance_targets,
257+
global_target=global_target,
258+
task_data=task_data,
259+
)
260+
261+
add_to_catalog(
262+
sql_non_execution_accuracy_metric,
263+
"metrics.text2sql.non_execution_accuracy",
264+
overwrite=True,
265+
)

prepare/metrics/text2sql_execution_accuracy.py

Lines changed: 0 additions & 65 deletions
This file was deleted.

prepare/processors/text2sql.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from unitxt import add_to_catalog
22
from unitxt.operator import SequentialOperator
3-
from unitxt.processors import GetSQL
3+
from unitxt.processors import AddPrefix, GetSQL
44

55
add_to_catalog(
66
SequentialOperator(
77
steps=[
8+
AddPrefix(field="prediction", prefix="SELECT "),
89
GetSQL(field="prediction"),
910
]
1011
),

prepare/tasks/text2sql.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,11 @@
1414
},
1515
reference_fields={"query": str},
1616
prediction_type=str,
17-
metrics=["metrics.text2sql.execution_accuracy", "metrics.anls"],
17+
metrics=[
18+
"metrics.text2sql.execution_accuracy",
19+
"metrics.text2sql.non_execution_accuracy",
20+
"metrics.anls",
21+
],
1822
),
1923
"tasks.text2sql",
2024
overwrite=True,

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ tests = [
108108
"func_timeout==4.3.5",
109109
"Wikipedia-API",
110110
"sqlglot",
111+
"sqlparse",
111112
]
112113
ui = [
113114
"gradio",
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
{
2-
"__type__": "execution_accuracy"
2+
"__type__": "sql_execution_accuracy"
33
}
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
{
2+
"__type__": "sql_non_execution_accuracy"
3+
}

0 commit comments

Comments
 (0)