From 97533ebc4fda7a53eba0d8c788d53b6b34d040b7 Mon Sep 17 00:00:00 2001 From: ChihYu Yeh Date: Mon, 15 Jul 2024 09:06:27 +0800 Subject: [PATCH] update --- wren-ai-service/eval/data_curation/app.py | 4 +- wren-ai-service/eval/data_curation/utils.py | 54 +++------------------ 2 files changed, 8 insertions(+), 50 deletions(-) diff --git a/wren-ai-service/eval/data_curation/app.py b/wren-ai-service/eval/data_curation/app.py index 9d723ec7b5..b9412ec27f 100644 --- a/wren-ai-service/eval/data_curation/app.py +++ b/wren-ai-service/eval/data_curation/app.py @@ -11,7 +11,7 @@ from streamlit_tags import st_tags from utils import ( DATA_SOURCES, - get_contexts_from_sqls_v2, + get_contexts_from_sqls, get_eval_dataset_in_toml_string, get_llm_client, get_question_sql_pairs, @@ -94,7 +94,7 @@ def on_change_sql(i: int, key: str): valid, error = asyncio.run(is_sql_valid(sql)) if valid: - new_context = asyncio.run(get_contexts_from_sqls_v2([sql])) + new_context = asyncio.run(get_contexts_from_sqls([sql])) if i != -1: st.session_state["llm_question_sql_pairs"][i]["sql"] = sql st.session_state["llm_question_sql_pairs"][i]["is_valid"] = valid diff --git a/wren-ai-service/eval/data_curation/utils.py b/wren-ai-service/eval/data_curation/utils.py index c30a45b2be..4ef2bfe2b5 100644 --- a/wren-ai-service/eval/data_curation/utils.py +++ b/wren-ai-service/eval/data_curation/utils.py @@ -229,7 +229,7 @@ async def get_sql_analysis( return await response.json() -async def get_contexts_from_sqls_v2( +async def get_contexts_from_sqls( sqls: list[str], ) -> list[str]: def _compose_contexts_of_select_type(select_items: list[dict]): @@ -281,6 +281,9 @@ def _get_contexts_from_sql_analysis_results(sql_analysis_results: list[dict]): if "sortings" in result: contexts += _compose_contexts_of_sorting_type(result["sortings"]) + print( + f'SQL ANALYSIS RESULTS: {orjson.dumps(sql_analysis_results, option=orjson.OPT_INDENT_2).decode("utf-8")}' + ) print(f"CONTEXTS: {sorted(set(contexts))}") return sorted(set(contexts)) @@ -335,7 +338,7 @@ async def get_question_sql_pairs( try: response = await llm_client.chat.completions.create( - model=os.getenv("OPENAI_GENERATION_MODEL", "gpt-3.5-turbo"), + model=os.getenv("GENERATION_MODEL", "gpt-3.5-turbo"), messages=messages, response_format={"type": "json_object"}, max_tokens=4096, @@ -345,7 +348,7 @@ async def get_question_sql_pairs( results = orjson.loads(response.choices[0].message.content)["results"] question_sql_pairs = await get_validated_question_sql_pairs(results) sqls = [question_sql_pair["sql"] for question_sql_pair in question_sql_pairs] - contexts = await get_contexts_from_sqls_v2(sqls) + contexts = await get_contexts_from_sqls(sqls) return [ {**quesiton_sql_pair, "context": context} for quesiton_sql_pair, context in zip(question_sql_pairs, contexts) @@ -355,51 +358,6 @@ async def get_question_sql_pairs( return [] -def show_er_diagram(models: List[dict], relationships: List[dict]): - # Start of the Graphviz syntax - graphviz = "digraph ERD {\n" - graphviz += ' graph [pad="0.5", nodesep="0.5", ranksep="2"];\n' - graphviz += " node [shape=plain]\n" - graphviz += " rankdir=LR;\n\n" - - # Function to format the label for Graphviz - def format_label(name, columns): - label = f'<' - for column in columns: - label += f'' - label += "
{name}
{column["name"]} : {column["type"]}
>" - return label - - # Add models (entities) to the Graphviz syntax - for model in models: - graphviz += f' {model["name"]} [label={format_label(model["name"], model["columns"])}];\n' - - graphviz += "\n" - - # Extract columns involved in each relationship - def extract_columns(condition): - # This regular expression should match the condition format and extract column names - matches = re.findall(r"(\w+)\.(\w+) = (\w+)\.(\w+)", condition) - if matches: - return matches[0][1], matches[0][3] # Returns (from_column, to_column) - return "", "" - - # Add relationships to the Graphviz syntax - for relationship in relationships: - from_model, to_model = relationship["models"] - from_column, to_column = extract_columns(relationship["condition"]) - label = ( - f'{relationship["name"]}\\n({from_column} to {to_column}) ({relationship['joinType']})' - if from_column and to_column - else relationship["name"] - ) - graphviz += f' {from_model} -> {to_model} [label="{label}"];\n' - - graphviz += "}" - - st.graphviz_chart(graphviz) - - def prettify_sql(sql: str) -> str: return sqlparse.format( sql,