Skip to content

Commit

Permalink
feat(wren-ai-service): Add a few properties in Prediction Result for …
Browse files Browse the repository at this point in the history
…Evaluation (#1297)
  • Loading branch information
paopa authored Feb 14, 2025
1 parent 33d8c19 commit 79d04f9
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 10 deletions.
4 changes: 4 additions & 0 deletions wren-ai-service/eval/metrics/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from deepeval.evaluate import TestResult
from deepeval.metrics import BaseMetric
from deepeval.test_case import LLMTestCase
from deprecated import deprecated

from eval.utils import get_data_from_wren_engine, get_openai_client

Expand Down Expand Up @@ -170,6 +171,9 @@ def __name__(self):
return "Accuracy(column-based)"


@deprecated(
reason="We don't generate multiple candidates for Text to SQL task, so don't need this metric"
)
class AccuracyMultiCandidateMetric(BaseMetric):
def __init__(self):
self.threshold = 0
Expand Down
45 changes: 35 additions & 10 deletions wren-ai-service/eval/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import re
import sys
from abc import abstractmethod
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Literal

Expand All @@ -15,7 +16,6 @@
from eval import EvalSettings
from eval.metrics import (
AccuracyMetric,
AccuracyMultiCandidateMetric,
AnswerRelevancyMetric,
ContextualPrecisionMetric,
ContextualRecallMetric,
Expand All @@ -40,7 +40,7 @@ async def wrapper():
asyncio.run(wrapper())


def extract_units(docs: list) -> list:
def extract_units(docs: list[dict]) -> list:
def parse_ddl(ddl: str) -> list:
"""
Parses a DDL statement and returns a list of column definitions in the format table_name.column_name, excluding foreign keys.
Expand Down Expand Up @@ -82,7 +82,7 @@ def parse_ddl(ddl: str) -> list:

columns = []
for doc in docs:
columns.extend(parse_ddl(doc))
columns.extend(parse_ddl(doc.get("table_ddl", "")))
return columns


Expand Down Expand Up @@ -115,12 +115,11 @@ async def wrapper(batch: list):
return [prediction for batch in batches for prediction in batch]

@abstractmethod
def _process(self, prediction: dict, **_) -> dict:
...
def _process(self, prediction: dict, **_) -> dict: ...

async def _flat(self, prediction: dict, **_) -> dict:
"""
No operation function to be overridden by subclasses,if needed.
No operation function to be overridden by subclasses if needed.
"""
return prediction

Expand All @@ -136,6 +135,8 @@ async def process(self, query: dict) -> dict:
"context": query["context"],
"samples": query.get("samples", []),
"type": "execution",
"reasoning": "",
"elapsed_time": 0,
}

langfuse_context.update_current_trace(
Expand All @@ -144,10 +145,20 @@ async def process(self, query: dict) -> dict:
metadata=trace_metadata(self._meta, type=prediction["type"]),
)

return await self._process(prediction, **query)
start_time = datetime.now()
returned = await self._process(prediction, **query)
returned["elapsed_time"] = (datetime.now() - start_time).total_seconds()

return returned

@observe(capture_input=False)
async def flat(self, prediction: dict, **kwargs) -> dict:
"""
This method changes the trace type to 'shallow' to handle cases where a trace has multiple actual outputs.
The flattening mechanism was historically used to get individual scores for evaluation when a single trace
produced multiple outputs. While currently maintained for backwards compatibility, this functionality may
be removed in the future if no longer needed.
"""
prediction["source_trace_id"] = prediction["trace_id"]
prediction["source_trace_url"] = prediction["trace_url"]
prediction["trace_id"] = langfuse_context.get_current_trace_id()
Expand Down Expand Up @@ -289,7 +300,7 @@ def metrics(engine_info: dict, enable_semantics_comparison: bool) -> dict:
ExactMatchAccuracy(),
ExecutionAccuracy(),
],
"post_metrics": [AccuracyMultiCandidateMetric()],
"post_metrics": [],
}


Expand Down Expand Up @@ -319,6 +330,9 @@ def __init__(
table_column_retrieval_size=settings.table_column_retrieval_size,
allow_using_db_schemas_without_pruning=settings.allow_using_db_schemas_without_pruning,
)
self._sql_reasoner = generation.SQLGenerationReasoning(
**pipe_components["sql_generation_reasoning"],
)
self._generation = generation.SQLGeneration(
**pipe_components["sql_generation"],
)
Expand All @@ -340,6 +354,16 @@ async def _process(self, prediction: dict, **_) -> dict:
documents = _retrieval_result.get("retrieval_results", [])
has_calculated_field = _retrieval_result.get("has_calculated_field", False)
has_metric = _retrieval_result.get("has_metric", False)

_reasoning = await self._sql_reasoner.run(
query=prediction["input"],
contexts=documents,
sql_samples=prediction.get("samples", [])
if self._allow_sql_samples
else [],
)
reasoning = _reasoning.get("post_process", {})

actual_output = await self._generation.run(
query=prediction["input"],
contexts=documents,
Expand All @@ -348,13 +372,14 @@ async def _process(self, prediction: dict, **_) -> dict:
else [],
has_calculated_field=has_calculated_field,
has_metric=has_metric,
sql_generation_reasoning=prediction.get("reasoning", ""),
sql_generation_reasoning=reasoning,
)

prediction["actual_output"] = actual_output
prediction["retrieval_context"] = extract_units(documents)
prediction["has_calculated_field"] = has_calculated_field
prediction["has_metric"] = has_metric
prediction["reasoning"] = reasoning

return prediction

Expand Down Expand Up @@ -388,7 +413,7 @@ def metrics(engine_info: dict, enable_semantics_comparison: bool) -> dict:
ExactMatchAccuracy(),
ExecutionAccuracy(),
],
"post_metrics": [AccuracyMultiCandidateMetric()],
"post_metrics": [],
}


Expand Down
1 change: 1 addition & 0 deletions wren-ai-service/eval/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def parse_args() -> Tuple[str, str]:
dataset = parse_toml(path)

settings = EvalSettings()
# todo: refactor this
_mdl = base64.b64encode(orjson.dumps(dataset["mdl"])).decode("utf-8")
if "spider_" in path:
settings.datasource = "duckdb"
Expand Down

0 comments on commit 79d04f9

Please sign in to comment.