Skip to content

Commit

Permalink
Merge branch 'main' into feat/llm-as-judge
Browse files Browse the repository at this point in the history
  • Loading branch information
cyyeh authored Feb 21, 2025
2 parents f5b842e + ece1e6f commit b4712cd
Show file tree
Hide file tree
Showing 9 changed files with 325 additions and 121 deletions.
4 changes: 2 additions & 2 deletions wren-ai-service/Justfile
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ start: use-wren-ui-as-engine
curate_eval_data:
poetry run streamlit run eval/data_curation/app.py

prep:
poetry run python -m eval.preparation
prep dataset='spider1.0':
poetry run python -m eval.preparation --dataset {{dataset}}

predict dataset pipeline='ask':
@poetry run python -u eval/prediction.py --file {{dataset}} --pipeline {{pipeline}}
Expand Down
4 changes: 3 additions & 1 deletion wren-ai-service/demo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ def add_quotes(sql: str) -> Tuple[str, bool]:
try:
quoted_sql = sqlglot.transpile(sql, read="trino", identify=True)[0]
return quoted_sql, True
except Exception:
except Exception as e:
print(f"Error in adding quotes to SQL: {sql}")
print(f"Error: {e}")
return sql, False


Expand Down
1 change: 1 addition & 0 deletions wren-ai-service/eval/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class EvalSettings(Settings):
config_path: str = "eval/config.yaml"
openai_api_key: SecretStr = Field(alias="LLM_OPENAI_API_KEY")
allow_sql_samples: bool = True
db_path_for_duckdb: str = ""

# BigQuery
bigquery_project_id: str = Field(default="")
Expand Down
27 changes: 15 additions & 12 deletions wren-ai-service/eval/data_curation/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,6 @@
from streamlit_tags import st_tags

sys.path.append(f"{Path().parent.resolve()}")
from eval import EvalSettings
from eval.utils import (
get_documents_given_contexts,
get_eval_dataset_in_toml_string,
get_openai_client,
prepare_duckdb_init_sql,
prepare_duckdb_session_sql,
)
from utils import (
DATA_SOURCES,
WREN_ENGINE_ENDPOINT,
Expand All @@ -32,6 +24,15 @@
prettify_sql,
)

from eval import EvalSettings
from eval.utils import (
get_documents_given_contexts,
get_eval_dataset_in_toml_string,
get_openai_client,
prepare_duckdb_init_sql,
prepare_duckdb_session_sql,
)

st.set_page_config(layout="wide")
st.title("WrenAI Data Curation App")

Expand Down Expand Up @@ -66,9 +67,9 @@
def on_change_upload_eval_dataset():
doc = tomlkit.parse(st.session_state.uploaded_eval_file.getvalue().decode("utf-8"))

assert doc["mdl"] == st.session_state["mdl_json"], (
"The model in the uploaded dataset is different from the deployed model"
)
assert (
doc["mdl"] == st.session_state["mdl_json"]
), "The model in the uploaded dataset is different from the deployed model"
st.session_state["candidate_dataset"] = doc["eval_dataset"]


Expand Down Expand Up @@ -116,7 +117,9 @@ def on_click_setup_uploaded_file():
elif data_source == "duckdb":
prepare_duckdb_session_sql(WREN_ENGINE_ENDPOINT)
prepare_duckdb_init_sql(
WREN_ENGINE_ENDPOINT, st.session_state["mdl_json"]["catalog"]
WREN_ENGINE_ENDPOINT,
st.session_state["mdl_json"]["catalog"],
"etc/spider1.0/database",
)
else:
st.session_state["data_source"] = None
Expand Down
11 changes: 8 additions & 3 deletions wren-ai-service/eval/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,8 @@ async def wrapper(batch: list):
return [prediction for batch in batches for prediction in batch]

@abstractmethod
def _process(self, prediction: dict, **_) -> dict: ...
def _process(self, prediction: dict, **_) -> dict:
...

async def _flat(self, prediction: dict, **_) -> dict:
"""
Expand Down Expand Up @@ -252,7 +253,9 @@ def __init__(
)

self._allow_sql_samples = settings.allow_sql_samples
self._engine_info = engine_config(mdl, pipe_components)
self._engine_info = engine_config(
mdl, pipe_components, settings.db_path_for_duckdb
)

async def _flat(self, prediction: dict, actual: str) -> dict:
prediction["actual_output"] = actual
Expand Down Expand Up @@ -350,7 +353,9 @@ def __init__(
)
self._allow_sql_samples = settings.allow_sql_samples

self._engine_info = engine_config(mdl, pipe_components)
self._engine_info = engine_config(
mdl, pipe_components, settings.db_path_for_duckdb
)

async def _flat(self, prediction: dict, actual: str) -> dict:
prediction["actual_output"] = actual
Expand Down
7 changes: 7 additions & 0 deletions wren-ai-service/eval/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,13 @@ def parse_args() -> Tuple[str, str]:
_mdl = base64.b64encode(orjson.dumps(dataset["mdl"])).decode("utf-8")
if "spider_" in path:
settings.datasource = "duckdb"
settings.db_path_for_duckdb = "etc/spider1.0/database"
replace_wren_engine_env_variables(
"wren_engine", {"manifest": _mdl}, settings.config_path
)
elif "bird_" in path:
settings.datasource = "duckdb"
settings.db_path_for_duckdb = "etc/bird/minidev/MINIDEV/dev_databases"
replace_wren_engine_env_variables(
"wren_engine", {"manifest": _mdl}, settings.config_path
)
Expand Down
Loading

0 comments on commit b4712cd

Please sign in to comment.