Skip to content

Commit 83dd5b4

Browse files
Bug fixes related to UI flow
2 parents 88973fc + 2513b40 commit 83dd5b4

File tree

10 files changed

+185
-132
lines changed

10 files changed

+185
-132
lines changed

Makefile

+4
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,9 @@ download_demo_data:
1919
run:
2020
./.sidekickvenv/bin/python3 start.py
2121

22+
clean:
23+
rm -rf ./db
24+
rm -rf ./var
25+
2226
cloud_bundle:
2327
h2o bundle -L debug 2>&1 | tee -a h2o-bundle.log

about.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
**Target Audience:** Data (Machine Learning) Scientists, Citizen Data Scientists, Data Engineers Managers and Business Analysts
44

5-
**Actively Being Maintained:** Yes (Demo release: _In active RnD_)
5+
**Actively Being Maintained:** Yes (Demo release)
66

77
**Last Updated:** January, 2024
88

app.toml

+3-2
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@ name = "ai.h2o.wave.sql-sidekick"
33
title = "SQL-Sidekick"
44
description = "QnA with tabular data using NLQ"
55
LongDescription = "about.md"
6-
Tags = ["DATA_SCIENCE", "MACHINE_LEARNING", "NLP"]
7-
Version = "0.2.0"
6+
InstanceLifecycle = "MANAGED"
7+
Tags = ["DATA_SCIENCE", "MACHINE_LEARNING", "NLP", "GENERATIVE_AI"]
8+
Version = "0.2.1"
89

910
[Runtime]
1011
MemoryLimit = "64Gi"

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "sql-sidekick"
3-
version = "0.2.0"
3+
version = "0.2.1"
44
license = "Apache-2.0 license"
55
description = "An AI assistant for SQL generation"
66
authors = [

requirements.txt

+33-19
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ aiosignal==1.3.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.
44
ansicon==1.89.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" and platform_system == "Windows"
55
anyio==4.2.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
66
async-timeout==4.0.3 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
7-
attrs==23.1.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
7+
attrs==23.2.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
88
beautifulsoup4==4.12.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
99
bitsandbytes==0.41.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
1010
blessed==1.20.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
@@ -13,73 +13,86 @@ certifi==2023.11.17 ; python_full_version >= "3.8.1" and python_full_version <=
1313
charset-normalizer==3.3.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
1414
click==8.1.7 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
1515
colorama==0.4.6 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
16+
databricks-sql-connector==3.0.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
1617
dataclasses-json==0.6.3 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
1718
deprecated==1.2.14 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
18-
distro==1.8.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
19+
distro==1.9.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
20+
editor==1.6.5 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
21+
et-xmlfile==1.1.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
1922
exceptiongroup==1.2.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
2023
filelock==3.13.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
2124
frozenlist==1.4.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
2225
fsspec==2023.12.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
2326
greenlet==3.0.3 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
2427
h11==0.14.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
2528
h2o-wave==0.26.3 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
26-
h2ogpte==1.2.6 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
29+
h2ogpte==1.2.7 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
2730
httpcore==0.17.3 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
2831
httpx==0.24.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
29-
huggingface-hub==0.20.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
32+
huggingface-hub==0.20.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
3033
idna==3.6 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
31-
inquirer==3.1.4 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
34+
inquirer==3.2.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
3235
instructorembedding==1.0.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
33-
jinja2==3.1.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
36+
jinja2==3.1.3 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
3437
jinxed==1.2.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" and platform_system == "Windows"
3538
joblib==1.3.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
36-
llama-index==0.9.20 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
39+
jsonpatch==1.33 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
40+
jsonpointer==2.4 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
41+
langchain-community==0.0.13 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
42+
langchain-core==0.1.11 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
43+
langsmith==0.0.81 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
44+
llama-index==0.9.32 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
3745
loguru==0.7.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
3846
lxml==4.9.3 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
47+
lz4==4.3.3 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
3948
markupsafe==2.1.3 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
40-
marshmallow==3.20.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
49+
marshmallow==3.20.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
4150
mpmath==1.3.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
4251
multidict==6.0.4 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
4352
mypy-extensions==1.0.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
44-
nest-asyncio==1.5.8 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
53+
nest-asyncio==1.5.9 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
4554
networkx==3.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
4655
nltk==3.8.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
4756
numpy==1.24.4 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
48-
openai==1.6.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
57+
oauthlib==3.2.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
58+
openai==1.8.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
59+
openpyxl==3.1.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
4960
packaging==23.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
5061
pandas==1.5.3 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
5162
pandasql==0.7.3 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
52-
pillow==10.1.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
63+
pillow==10.2.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
5364
psutil==5.9.7 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
5465
psycopg2-binary==2.9.9 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
66+
pyarrow==14.0.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
5567
pydantic==1.10.13 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
5668
pydantic[dotenv]==1.10.13 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
5769
python-dateutil==2.8.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
5870
python-dotenv==1.0.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
59-
python-editor==1.0.4 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
6071
pytz==2023.3.post1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
6172
pyyaml==6.0.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
6273
readchar==4.0.5 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
63-
regex==2023.10.3 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
74+
regex==2023.12.25 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
6475
requests==2.31.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
76+
runs==1.2.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
6577
safetensors==0.4.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
6678
scikit-learn==1.3.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
6779
scipy==1.10.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
6880
sentence-transformers==2.2.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
6981
sentencepiece==0.1.99 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
70-
setuptools==69.0.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
82+
setuptools==69.0.3 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
7183
six==1.16.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
7284
sniffio==1.3.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
7385
soupsieve==2.5 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
7486
sqlalchemy-utils==0.41.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
75-
sqlalchemy==1.4.50 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
76-
sqlalchemy[asyncio]==1.4.50 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
87+
sqlalchemy==2.0.25 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
88+
sqlalchemy[asyncio]==2.0.25 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
7789
sqlglot==12.4.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
7890
sqlparse==0.4.4 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
79-
starlette==0.34.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
91+
starlette==0.35.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
8092
sympy==1.12 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
8193
tenacity==8.2.3 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
8294
threadpoolctl==3.2.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
95+
thrift==0.16.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
8396
tiktoken==0.5.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
8497
tokenizers==0.15.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
8598
toml==0.10.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
@@ -90,9 +103,10 @@ transformers==4.36.2 ; python_full_version >= "3.8.1" and python_full_version <=
90103
typing-extensions==4.9.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
91104
typing-inspect==0.9.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
92105
urllib3==2.1.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
93-
uvicorn==0.25.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
94-
wcwidth==0.2.12 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
106+
uvicorn==0.26.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
107+
wcwidth==0.2.13 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
95108
websockets==11.0.3 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
96109
win32-setctime==1.1.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" and sys_platform == "win32"
97110
wrapt==1.16.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
111+
xmod==1.8.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"
98112
yarl==1.9.4 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0"

sidekick/configs/prompt_template.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
### *History*:\n{_sample_queries}
3636
### *Question*: For table {_table_name}, {_question}
3737
# SELECT 1
38-
### *Tasks for table {_table_name}*:\n{_tasks}
38+
### *Plan for table {_table_name}*:\n{_tasks}
3939
### *Policies for SQL generation*:
4040
# Avoid overly complex SQL queries, favor concise human readable SQL queries which are easy to understand and debug
4141
# Avoid patterns that might be vulnerable to SQL injection
@@ -118,7 +118,7 @@
118118
- Only use supplied table names: **{table_name}** for generation
119119
- Only use column names from the CREATE TABLE statement: **{column_info}** for generation. DO NOT USE any other column names outside of this.
120120
- Avoid overly complex SQL queries, favor concise human readable SQL queries which are easy to understand and debug
121-
- Avoid patterns that might be vulnerable to SQL injection, e.g. sanitize inputs
121+
- Avoid patterns that might be vulnerable to SQL injection, e.g. use proper sanitization and escaping for raw user input
122122
- Always cast the numerator as float when computing ratios
123123
- Always use COUNT(1) instead of COUNT(*)
124124
- If the question is asking for a rate, use COUNT to compute percentage

sidekick/prompter.py

+12-10
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
execute_query_pd, extract_table_names,
2424
generate_suggestions, save_query, setup_dir)
2525

26-
__version__ = "0.2.0"
26+
__version__ = "0.2.1"
2727

2828
# Load the config file and initialize required paths
2929
app_base_path = (Path(__file__).parent / "../").resolve()
@@ -41,6 +41,8 @@
4141
os.environ["TOKENIZERS_PARALLELISM"] = "False"
4242
os.environ["H2O_BASE_MODEL_URL"] = h2ogpt_base_model_url
4343
os.environ["H2O_BASE_MODEL_API_KEY"] = h2ogpt_base_model_key
44+
os.environ["RECOMMENDATION_MODEL_REMOTE_URL"] = h2o_remote_url
45+
os.environ["RECOMMENDATION_MODEL_API_KEY"] = h2o_key
4446

4547
def color(fore="", back="", text=None):
4648
return f"{fore}{back}{text}{Style.RESET_ALL}"
@@ -103,7 +105,7 @@ def _get_table_info(cache_path: str, table_name: str = None):
103105
if table_info_path is None:
104106
# if table_info_path is None, generate default schema n set path
105107
data_path = current_meta["samples_path"]
106-
_, table_info_path = generate_schema(data_path, f"{cache_path}/{table_name}_table_info.jsonl")
108+
_, table_info_path = generate_schema(data_path=data_path, output_path=f"{cache_path}/{table_name}_table_info.jsonl")
107109
table_metadata = {"schema_info_path": table_info_path}
108110
with open(f"{cache_path}/table_context.json", "w") as outfile:
109111
json.dump(table_metadata, outfile, indent=4, sort_keys=False)
@@ -178,7 +180,7 @@ def recommend_suggestions(cache_path: str, table_name: str, n_qs: int=10):
178180
@click.option("--data_path", default="data.csv", help="Enter the path of csv", type=str)
179181
@click.option("--output_path", default="table_info.jsonl", help="Enter the path of generated schema in jsonl", type=str)
180182
def generate_input_schema(data_path, output_path):
181-
_, o_path = generate_schema(data_path, output_path)
183+
_, o_path = generate_schema(data_path=data_path, output_path=output_path)
182184
click.echo(f"Schema generated for the input data at {o_path}")
183185

184186

@@ -463,7 +465,7 @@ def ask(
463465
"""
464466

465467
results = []
466-
err = None # TODO - Need to handle errors if occurred
468+
res = err = alt_res = None # TODO - Need to handle errors if occurred
467469
# Book-keeping
468470
base_path = local_base_path if local_base_path else default_base_path
469471
setup_dir(base_path)
@@ -575,7 +577,7 @@ def ask(
575577
click.echo("Skipping edit...")
576578
if updated_tasks is not None:
577579
sql_g._tasks = updated_tasks
578-
alt_res = None
580+
579581
# The interface could also be used to simply execute user provided SQL
580582
# Keyword: "Execute SQL: <SQL query>"
581583
if (
@@ -650,12 +652,12 @@ def ask(
650652
attempt = 0
651653
error_condition = lambda e: ('OperationalError'.lower() in e.lower() or 'OperationError'.lower() in e.lower() or 'Syntax error'.lower() in e.lower()) if e else False
652654
if self_correction and error_condition(err):
653-
logger.info("Attempting to auto-correct the query...")
655+
logger.info("Attempting to auto-correct the query during runtime...")
654656
while attempt !=3 and error_condition(err):
655657
try:
656658
logger.debug(f"Attempt: {attempt+1}")
657659
_tmp = err.split("\n")
658-
_err = _tmp[0].split("Error occurred :")[1] if len(_tmp) > 0 else None
660+
_err = _tmp[0].split("Error occurred:")[1] if len(_tmp) > 0 else None
659661
env_url = os.environ["RECOMMENDATION_MODEL_REMOTE_URL"]
660662
env_key = os.environ["RECOMMENDATION_MODEL_API_KEY"]
661663
corr_sql = sql_g.self_correction(input_prompt=_val, error_msg=_err, remote_url=env_url, client_key=env_key)
@@ -667,7 +669,7 @@ def ask(
667669
logger.error(f"Something went wrong:\n{e}")
668670
attempt += 1
669671
if m:
670-
_t = "\nWarning:\n".join([str(q_res), m])
672+
_t = "\n\n**Warning:**\n".join([str(q_res), m])
671673
q_res = _t
672674
elif option == "pandas":
673675
tables = extract_table_names(_val)
@@ -697,7 +699,7 @@ def ask(
697699
click.echo("Error in executing the query. Validate generated SQL and try again.")
698700
click.echo("No result to display.")
699701

700-
results.append("**Result:** \n")
702+
results.append("**Result:**\n")
701703
if q_res:
702704
# Check shape of the final result to avoid blowing up memory
703705
# Logging a quick preview of the result
@@ -718,7 +720,7 @@ def ask(
718720
else:
719721
click.echo("Exiting...")
720722
else:
721-
results = ["I was not able to generate a response for the question. Please try re-phrasing."]
723+
results = ["I was not able to generate a response for the question. Please try re-phrasing or try again."]
722724
alt_res, err = None, None
723725
except (MemoryError, RuntimeError, AttributeError) as e:
724726
logger.error(f"Something went wrong while generating response: {e}")

0 commit comments

Comments
 (0)