Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stephen yin01/spider analysis #11

Open
wants to merge 49 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
3e0e32a
Update sql2pandas with better error checking + cleanup
troyfeng116 Mar 3, 2022
89e8511
Sandbox scripts + testing
troyfeng116 Mar 9, 2022
39fb73e
Create json file with converted SQL queries
troyfeng116 Mar 15, 2022
60cdba2
Initial preprocessing: cleanup + nested SELECT framework
troyfeng116 Mar 15, 2022
26b46ac
Support for complex non-convertible queries: ProcessedSQLQueryNode fr…
troyfeng116 Mar 20, 2022
383078f
Cleanup
troyfeng116 Mar 20, 2022
d3cc6e3
Refine ProcessedSQLQueryTree structure
troyfeng116 Mar 20, 2022
5057c7e
gitignore
troyfeng116 Mar 20, 2022
7b7b8b7
Merge main (rebase)
troyfeng116 Mar 20, 2022
5b18a95
Fix l_to_r_keys
troyfeng116 Mar 20, 2022
cb1cc6f
Cleanup + internal/external symbol keys
troyfeng116 Mar 20, 2022
cdab109
Docstrings + cleanup
troyfeng116 Mar 20, 2022
e0fd20f
Update TODO
troyfeng116 Mar 20, 2022
6ea91e4
Remove extra commits for pre-untracked files
troyfeng116 Mar 20, 2022
895263a
Clean up gitignore conflicts
troyfeng116 Mar 20, 2022
e184279
Rm ast_sandbox from cache
troyfeng116 Mar 21, 2022
6e1ea8f
Rm sandbox_bad_queries.txt from git cache
troyfeng116 Mar 21, 2022
c54c725
Expand processed tree functionality to UNION/INTERSECT/EXCEPT
troyfeng116 Mar 21, 2022
cd40d80
Add output txt test file to gitignore
troyfeng116 Mar 21, 2022
473d315
Naive function to extract entire table from SELECT query
troyfeng116 Mar 21, 2022
82b9dcc
Clean up table extraction + separate table alias parsing
troyfeng116 Mar 21, 2022
4acfd59
Add table alias + substitute symbol to leaf nodes
troyfeng116 Mar 21, 2022
e607727
DFS to extract code snippets
troyfeng116 Mar 21, 2022
daba852
Cleanup some naming
troyfeng116 Mar 21, 2022
3ff0c81
Cleanup + indexing nit
troyfeng116 Mar 21, 2022
9e5a4ba
Add script to assert validity of SQL tree
troyfeng116 Mar 22, 2022
80cbfb4
Update gitignore for output files
troyfeng116 Mar 22, 2022
3ddc969
Clean up file structure + add table_expr class + debug symbol generation
troyfeng116 Mar 22, 2022
83d2c40
Docstrings + cleanup DFS
troyfeng116 Mar 22, 2022
17328bc
Refine pandas generation: initial setup to handle INTERSECT/UNION/EXCEPT
troyfeng116 Mar 22, 2022
af7570c
Move single node to pandas helpers to separate file, clean up node init
troyfeng116 Mar 22, 2022
9146a03
Docstrings + handle multiple JOIN ONs in table expression
troyfeng116 Mar 22, 2022
def0726
Update TODO
troyfeng116 Mar 22, 2022
b988ad6
Table alias removal + aliased_table_expr field
troyfeng116 Apr 1, 2022
624015f
Cleanup
troyfeng116 Apr 1, 2022
823340f
Gitignore
troyfeng116 Apr 1, 2022
51798d8
JOIN tables -> pandas
troyfeng116 Apr 2, 2022
3f79033
Cleanup + docstrings
troyfeng116 Apr 2, 2022
cdb359d
Add quick dirty fixes for spider execution
chenyx512 Apr 4, 2022
394c115
try to fix the intersect and union
niansong1996 Apr 6, 2022
f4739ad
Improve to 75% correct
chenyx512 Apr 11, 2022
98b5af6
unofficial files, playground progress for squall dataset
StephenYin01 Apr 11, 2022
4a295dd
Merge branch 'main' into yuxuan/sql2pandas
niansong1996 Apr 15, 2022
87c3760
Merge branch 'main' of github.com:Yale-LILY/NLP4Code into StephenYin0…
StephenYin01 Apr 18, 2022
3496afb
Merge branch 'yuxuan/sql2pandas' of github.com:Yale-LILY/NLP4Code int…
StephenYin01 Apr 18, 2022
b1d4cea
put WIP on remote github
StephenYin01 Jul 7, 2022
8d67d34
finished ten correction of failed conversions for codex
StephenYin01 Jul 11, 2022
2bcb5da
Merge branch 'main' of github.com:Yale-LILY/NLP4Code into StephenYin0…
StephenYin01 Jul 11, 2022
7ff994e
annotated 10 more examples
StephenYin01 Aug 3, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,16 @@ dmypy.json
# Pyre type checker
.pyre/

# sandboxes
**sandbox.py
parsing/sandbox*

# defined by Ansong
.venv
data/
debug-tmp/
wandb/
results/

# defined by Troy
.DS_Store
94 changes: 74 additions & 20 deletions execution/spider_execution.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import sqlite3
import pandas as pd
import numpy as np
import re
import keyword
import math

from typing import List, Dict, Any, Union, Tuple

Expand Down Expand Up @@ -29,8 +32,11 @@ def spider_execution_sql(sql: str, conn: sqlite3.Connection, return_error_msg: b
def db_to_df_dict(conn: sqlite3.Connection) -> Dict[str, pd.DataFrame]:
df_dict = {}
for table_name in conn.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall():
df_dict[table_name[0]] = pd.read_sql_query(f"SELECT * FROM {table_name[0]}", conn)
df_dict[table_name[0]].rename(columns=lambda x: x.lower(), inplace=True)
# modify to change everything including labels lower case
df = pd.read_sql_query(f"SELECT * FROM {table_name[0]}", conn)
df = df.applymap(lambda s: s.lower() if type(s) == str else s)
df_dict[table_name[0].lower()] = df
df_dict[table_name[0].lower()].rename(columns=lambda x: x.lower(), inplace=True)
return df_dict

def spider_execution_py(code: str, df_dict: Dict[str, pd.DataFrame], return_error_msg: bool = False) -> Any:
Expand All @@ -39,8 +45,22 @@ def spider_execution_py(code: str, df_dict: Dict[str, pd.DataFrame], return_erro
# use the tables as part of the code context
table_vars_code = "import pandas as pd\n"
for table_name in df_dict.keys():
table_vars_code += f"# {' '.join(list(df_dict[table_name].columns))}\n{table_name} = df_dict['{table_name}']\n"
code = table_vars_code + "\n" + code
# table names may be reserved words like "class"
if table_name in keyword.kwlist:
table_vars_code += f"_{table_name} = df_dict['{table_name}']\n"
# but we have to make sure that table columns are not changed
# code = code.replace(table_name, f"_{table_name}")
code = re.sub("((?<!_)class(?!_))", "_class", code)
else:
table_vars_code += f"{table_name} = df_dict['{table_name}']\n"

# lower everything in quotes
code = re.sub(r"'(.*?)'", lambda p: f"'{p.group(1).lower()}'", code)
# move select statements after sorting or drop_dup
# TODO further processing needed, case 784, 1721,
# and select, drop_duplicate, followed by sorting
code = re.sub(r"(.*(?<!\[))(\[\[?.*?\]?\])(\.sort_values.*)", r"\1\3\2", code)
code = table_vars_code + "\n" + f"answer = {code}"

# execute the code
try:
Expand All @@ -58,43 +78,77 @@ def spider_execution_py(code: str, df_dict: Dict[str, pd.DataFrame], return_erro
else:
return None

def flatten_list_of_list(l: List[List[Any]]) -> List[Any]:
def flatten_list_of_list(l: List[List[Any]], sort: bool = False) -> List[Any]:
result = []
for sublist in l:
if isinstance(sublist, list) or isinstance(sublist, tuple):
result.extend(sublist)
else:
result.append(sublist)

return result
if sort:
result.sort(key = str)
return result
else:
return result

def spider_answer_eq(prediction: Union[pd.DataFrame, pd.Series, List[Tuple[Any]]],
gold_answer: Union[List[Tuple[Any]], int]) -> bool:
def list_to_lower_case(l: List[Any]):
result = []
for object in l:
if isinstance(object, str):
result.append(object.lower())
else:
result.append(object)
return result

if isinstance(prediction, int) or isinstance(prediction, float):
def compare_lists(l1: List[Any], l2: List[Any]) -> bool:
if len(l1) != len(l2):
return False
else:
for i in range(len(l1)):
if type(l1[i]) == float:
if not math.isclose(l1[i], l2[i]):
return False
else:
continue
elif l1[i] != l2[i]:
return False
return True

def spider_answer_eq(prediction: Union[pd.DataFrame, pd.Series, List[Tuple[Any]]],
gold_answer: Union[List [Tuple[Any]], int],
sort: bool = False) -> bool:

if isinstance(prediction, int) or isinstance(prediction, float) or (not isinstance(prediction, list) and not isinstance(prediction, pd.DataFrame) and not isinstance(prediction, np.ndarray) and not isinstance(prediction, tuple) and np.issubdtype(prediction, np.integer)):
prediction = [prediction]

if isinstance(prediction, list) or isinstance(prediction, np.ndarray):
if isinstance(gold_answer, list):
gold_flattened = flatten_list_of_list(gold_answer)
pred_flattened = flatten_list_of_list(prediction)
result = pred_flattened == gold_flattened
gold_flattened = list_to_lower_case(
flatten_list_of_list(gold_answer, sort))
pred_flattened = flatten_list_of_list(prediction, sort)
result = compare_lists(pred_flattened, gold_flattened)
else:
result = False
elif isinstance(prediction, pd.DataFrame):
if isinstance(gold_answer, list):
# convert the dataframe to a list of tuples and check
pred_list = flatten_list_of_list(list(prediction.itertuples(index=False, name=None)))
gold_list = flatten_list_of_list(gold_answer)
result = pred_list == gold_list
# we include the index only when it exists
pred_list = flatten_list_of_list(list(prediction.itertuples(
index=bool(prediction.index.name), name=None)), sort)
gold_list = list_to_lower_case(flatten_list_of_list(gold_answer, sort))
result = compare_lists(pred_list, gold_list)
else:
result = False
elif isinstance(prediction, pd.Series):
if isinstance(gold_answer, list):
# convert the series to a list of tuples and check
pred_list = flatten_list_of_list(prediction.tolist())
gold_list = flatten_list_of_list(gold_answer)
result = pred_list == gold_list
# we include the index only when it exists
if prediction.index.name:
pred_list = flatten_list_of_list(list(prediction.items()), sort)
else:
pred_list = flatten_list_of_list(prediction.tolist(), sort)
gold_list = list_to_lower_case(flatten_list_of_list(gold_answer, sort))
result = compare_lists(pred_list, gold_list)
else:
result = False
else:
Expand Down
36 changes: 36 additions & 0 deletions parsing/clean_query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import re
from helpers import trim_front_and_back


# TODO:
# - more robust parenthesis handling
# - more robust extra space removal?


# sql2pandas requires single quotes in SQL queries
def replace_quotes(sql_query):
return sql_query.replace("\"", "\'")


# Remove extra spaces
def remove_consecutive_spaces(sql_query):
sql_query = sql_query.strip()
sql_query = re.sub(r"\s+", " ", sql_query)
sql_query = re.sub(r"\( ", "(", sql_query)
return sql_query


# Add semi-colon at end of SQL query for consistency
def add_semicolon(sql_query):
return sql_query if sql_query[-1:] == ";" else sql_query + ";"


# Basic string preprocessing/cleanup for SQL queries
def basic_clean_query(sql_query):
sql_query = replace_quotes(sql_query)
sql_query = remove_consecutive_spaces(sql_query)
# TODO: ensure balance for front/back parentheses
sql_query = trim_front_and_back(sql_query, "(", ")")

sql_query = add_semicolon(sql_query)
return sql_query
Empty file added parsing/data.json
Empty file.
190 changes: 190 additions & 0 deletions parsing/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
from typing import Tuple


# Removes any characters in `chars_to_remove` from the front of `s`
def trim_front(s, chars_to_remove):
while s[0] in chars_to_remove:
s = s[1:]
return s


# Removes any characters in `chars_to_remove` from the back of `s`
def trim_back(s, chars_to_remove):
while s[-1:] in chars_to_remove:
s = s[:-1]
return s


# Removes characters like parentheses from front/end of `s`
def trim_front_and_back(s, char_front, char_back):
while s[0] == char_front and s[-1:] == char_back:
s = s[1:-1]
return s


# Find corresponding balanced closing parenthesis for opening parenthesis at index `open_idx-1`
def find_closing_parenthesis(s, open_idx):
if s[open_idx-1] != '(':
print('[find_closing_parenthesis] input open_idx error')
return -1

idx = open_idx
ct_open = 0
while idx < len(s):
if s[idx] == '(':
ct_open += 1
elif s[idx] == ')':
if ct_open == 0:
return idx
ct_open -= 1

idx += 1

return -1


# Determines if first non-whitespace char in `partial_sql_query` is "SELECT"
def is_next_token_select(partial_sql_query):
return partial_sql_query.strip().find("SELECT") == 0


def is_idx_at_token_start(sql_query: str, idx: int):
if idx >= len(sql_query):
return False

if sql_query[idx] == " ":
print("[is_idx_at_token_start] idx not in word")
return False

if idx > 0 and not sql_query[idx-1] == " ":
print("[is_idx_at_token_start] idx not at start of token")
return False

return True


def get_next_token_idx(sql_query: str, idx: int):
while idx < len(sql_query) and sql_query[idx] != " ":
idx += 1

while idx < len(sql_query) and sql_query[idx] == " ":
idx += 1

return idx


def get_prev_token(sql_query: str, idx: int):
if idx == 0:
print("[get_prev_token] no prev token")
return None

if not is_idx_at_token_start(sql_query, idx):
return None

finish_idx = idx - 1
while finish_idx - 1 >= 0 and sql_query[finish_idx-1] == " ":
finish_idx -= 1

start_idx = finish_idx - 1
while start_idx - 1 >= 0 and sql_query[start_idx - 1] != " ":
start_idx -= 1

return sql_query[start_idx:finish_idx]

def get_second_last_token(sql_query: str):
length = len(sql_query)
if length < 2:
print("[get_second_last_token] no second last token")
return None

finish_idx = length - 1
while finish_idx > 0 and sql_query[finish_idx] != " ":
finish_idx -= 1

start_idx = finish_idx - 1
while start_idx > 0 and sql_query[start_idx] != " ":
start_idx -= 1

return sql_query[start_idx:finish_idx].strip()

def get_cur_token(sql_query: str, idx: int):
if not is_idx_at_token_start(sql_query, idx):
return None

finish_idx = idx
while finish_idx < len(sql_query) and sql_query[finish_idx] != " ":
finish_idx += 1

return sql_query[idx:finish_idx]


def get_next_token(sql_query: str, idx: int):
if idx >= len(sql_query) - 1:
print("[get_prev_token] no next token")
return None

if not is_idx_at_token_start(sql_query, idx):
return None

start_idx = get_next_token_idx(sql_query, idx)
return get_cur_token(sql_query, start_idx)


def remove_prev_token(s: str, idx: int) -> Tuple[str, int]:
"""Removes previous token from idx, where idx is at start of token.

Args:
s (str): String from which to remove previous token
idx (int): Index of start of token, where previous token from idx is removed.

Returns:
Tuple[str, int]: Redacted string, and new position of idx
"""
if idx == 0:
print("[get_prev_token] no prev token")
return None

if not is_idx_at_token_start(s, idx):
return None

finish_idx = idx - 1
while finish_idx - 1 >= 0 and s[finish_idx-1] == " ":
finish_idx -= 1

start_idx = finish_idx - 1
while start_idx - 1 >= 0 and s[start_idx - 1] != " ":
start_idx -= 1

return s[:start_idx] + s[finish_idx:], idx - (finish_idx - start_idx)


def extract_table_column(join_on_col: str) -> str:
"""For a table column of the form TABLE.COLUMN (as in JOIN), extract COLUMN.

Args:
join_on_col (str): Full name of column, potentially with table specified.

Returns:
str: Extracted column (without specified table, if specified).
"""
dot_idx = join_on_col.find(".")
return join_on_col if dot_idx < 0 else join_on_col[dot_idx+1:]


def get_first_token(s: str) -> str:
idx = s.find(" ")
if idx < 0:
idx = len(s)
return s[:idx]

def subtract_sql_to_pandas(sql: str, simple: bool) -> str:
"""If simple subtract, removes the SELECT and parenthesis and ; from the sql for a subtract sql
Otherwise, replaces subtraction with pandas and returns the new sql with subtraction replaced

TODO fill args
"""
if simple:
ret = sql.replace("SELECT ", "").replace(";", "").replace("(", "").replace(")", "")
else:
ret = None
return ret
Loading