Skip to content

Commit

Permalink
feat(wren-ai-service): Add LLM-based evaluation metrics for SQL gener…
Browse files Browse the repository at this point in the history
…ation (#1303)
  • Loading branch information
paopa authored Feb 21, 2025
1 parent ece1e6f commit 33de785
Show file tree
Hide file tree
Showing 6 changed files with 239 additions and 9 deletions.
27 changes: 23 additions & 4 deletions wren-ai-service/eval/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,26 @@
import eval.pipelines as pipelines
import src.providers as provider
from eval import EvalSettings
from eval.utils import engine_config, parse_toml, trace_metadata
from eval.utils import parse_toml, trace_metadata
from src import utils


def formatter(prediction: dict, meta: dict) -> dict:
"""
Formats the prediction result to be used as evaluation input.
This function takes a prediction dictionary and a meta dictionary,
processes them to extract relevant information, and returns a formatted
dictionary that serves as input for evaluation. It includes details such
as input, actual and expected outputs, context, and additional metadata.
Args:
prediction (dict): A dictionary containing prediction details.
meta (dict): A dictionary containing metadata information.
Returns:
dict: A formatted dictionary containing evaluation input data.
"""
retrieval_context = [str(context) for context in prediction["retrieval_context"]]
context = [str(context) for context in prediction["context"]]
enable_spider_metrics = "spider" in meta.get("evaluation_dataset", "").lower()
Expand All @@ -33,6 +48,7 @@ def formatter(prediction: dict, meta: dict) -> dict:
"expected_output": prediction["expected_output"],
"retrieval_context": retrieval_context,
"context": context,
"reasoning": prediction.get("reasoning", ""),
"additional_metadata": {
"trace_id": prediction["trace_id"],
"trace_url": prediction["trace_url"],
Expand Down Expand Up @@ -82,7 +98,9 @@ def eval(self, meta: dict, predictions: list) -> None:

try:
test_case = LLMTestCase(**formatter(prediction, meta))
result = evaluate([test_case], self._metrics, ignore_errors=True).test_results[0]
result = evaluate(
[test_case], self._metrics, ignore_errors=True
).test_results[0]
self._score_metrics(test_case, result)
[metric.collect(test_case, result) for metric in self._post_metrics]
except Exception:
Expand Down Expand Up @@ -152,8 +170,9 @@ def _average_score(self, meta: dict) -> None:
predictions = predicted_file["predictions"]

dataset = parse_toml(meta["evaluation_dataset"])
engine_info = engine_config(dataset["mdl"], pipe_components)
metrics = pipelines.metrics_initiator(meta["pipeline"], engine_info, args.semantics)
metrics = pipelines.metrics_initiator(
meta["pipeline"], dataset, pipe_components, args.semantics
)

evaluator = Evaluator(**metrics)
evaluator.eval(meta, predictions)
Expand Down
8 changes: 8 additions & 0 deletions wren-ai-service/eval/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
from .context_recall import ContextualRecallMetric
from .context_relevancy import ContextualRelevancyMetric
from .faithfulness import FaithfulnessMetric
from .llm import (
QuestionToReasoningJudge,
ReasoningToSqlJudge,
SqlSemanticsJudge,
)
from .spider.exact_match import ExactMatchAccuracy
from .spider.exec_match import ExecutionAccuracy

Expand All @@ -17,4 +22,7 @@
"FaithfulnessMetric",
"ExactMatchAccuracy",
"ExecutionAccuracy",
"QuestionToReasoningJudge",
"ReasoningToSqlJudge",
"SqlSemanticsJudge",
]
173 changes: 173 additions & 0 deletions wren-ai-service/eval/metrics/llm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
import asyncio

from deepeval.metrics import BaseMetric
from deepeval.test_case import LLMTestCase
from haystack.components.builders.prompt_builder import PromptBuilder
from pydantic import BaseModel

from src.providers import LLMProvider


class EvalResult(BaseModel):
score: float
reason: str


_MODEL_KWARGS = {
"response_format": {
"type": "json_schema",
"json_schema": {
"name": "eval_result",
"schema": EvalResult.model_json_schema(),
},
}
}


def format(response: dict) -> EvalResult:
reply = response.get("replies", [])[0]
return EvalResult.model_validate_json(reply)


class QuestionToReasoningJudge(BaseMetric):
_system_prompt = """
You are an expert evaluator. Your task is to analyze the reasoning provided for a given question and determine if it makes sense.
Provide a score in the range 0.0~1.0 and a detailed explanation for your evaluation.
"""
_test_case_prompt = """
Question:
{{ question }}
Reasoning:
{{ reasoning }}
"""

def __init__(self, llm_provider: LLMProvider, **_):
self.threshold = 0
self.score = 0
self.llm_provider = llm_provider
self.llm = llm_provider.get_generator(
system_prompt=self._system_prompt,
generation_kwargs=_MODEL_KWARGS,
)
self.prompt_builder = PromptBuilder(template=self._test_case_prompt)

def measure(self, test_case: LLMTestCase):
return asyncio.run(self.a_measure(test_case))

async def a_measure(self, test_case: LLMTestCase, *args, **kwargs):
prompt = self.prompt_builder.run(
question=test_case.input,
reasoning=test_case.reasoning,
)
response = await self.llm(prompt.get("prompt"))
result = format(response)

self.score = result.score
self.reason = result.reason

self.success = self.score >= self.threshold
return self.score

def is_successful(self):
return self.success

@property
def __name__(self):
return "QuestionToReasoningJudge"


class ReasoningToSqlJudge(BaseMetric):
_system_prompt = """
You are an expert evaluator. Your task is to analyze the reasoning provided for a given SQL query and determine if it makes sense.
Provide a score in the range 0.0~1.0 and a detailed explanation for your evaluation.
"""
_test_case_prompt = """
Actual Output:
{{ actual_output }}
Reasoning:
{{ reasoning }}
"""

def __init__(self, llm_provider: LLMProvider, **_):
self.threshold = 0
self.score = 0
self.llm_provider = llm_provider
self.llm = llm_provider.get_generator(
system_prompt=self._system_prompt,
generation_kwargs=_MODEL_KWARGS,
)
self.prompt_builder = PromptBuilder(template=self._test_case_prompt)

def measure(self, test_case: LLMTestCase):
return asyncio.run(self.a_measure(test_case))

async def a_measure(self, test_case: LLMTestCase, *args, **kwargs):
prompt = self.prompt_builder.run(
actual_output=test_case.actual_output,
reasoning=test_case.reasoning,
)
response = await self.llm(prompt.get("prompt"))
result = format(response)

self.score = result.score
self.reason = result.reason

self.success = self.score >= self.threshold
return self.score

def is_successful(self):
return self.success

@property
def __name__(self):
return "ReasoningToSqlJudge"


class SqlSemanticsJudge(BaseMetric):
_system_prompt = """
You are an expert evaluator. Your task is to analyze the actual SQL query and the expected SQL query and determine if they are semantically equivalent.
Provide a score in the range 0.0~1.0 and a detailed explanation for your evaluation.
"""
_test_case_prompt = """
Actual SQL:
{{ actual_sql }}
Expected SQL:
{{ expected_sql }}
"""

def __init__(self, llm_provider: LLMProvider, **_):
self.threshold = 0
self.score = 0
self.llm_provider = llm_provider
self.llm = llm_provider.get_generator(
system_prompt=self._system_prompt,
generation_kwargs=_MODEL_KWARGS,
)
self.prompt_builder = PromptBuilder(template=self._test_case_prompt)

def measure(self, test_case: LLMTestCase):
return asyncio.run(self.a_measure(test_case))

async def a_measure(self, test_case: LLMTestCase, *args, **kwargs):
prompt = self.prompt_builder.run(
actual_sql=test_case.actual_output,
expected_sql=test_case.expected_output,
)
response = await self.llm(prompt.get("prompt"))
result = format(response)

self.score = result.score
self.reason = result.reason

self.success = self.score >= self.threshold
return self.score

def is_successful(self):
return self.success

@property
def __name__(self):
return "SqlSemanticsJudge"
36 changes: 31 additions & 5 deletions wren-ai-service/eval/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from langfuse.decorators import langfuse_context, observe
from tqdm.asyncio import tqdm_asyncio

from src.core.pipeline import PipelineComponent

sys.path.append(f"{Path().parent.resolve()}")

from eval import EvalSettings
Expand All @@ -23,6 +25,9 @@
ExactMatchAccuracy,
ExecutionAccuracy,
FaithfulnessMetric,
QuestionToReasoningJudge,
ReasoningToSqlJudge,
SqlSemanticsJudge,
)
from eval.utils import (
engine_config,
Expand Down Expand Up @@ -290,7 +295,11 @@ async def __call__(self, query: str, **_):
]

@staticmethod
def metrics(engine_info: dict, enable_semantics_comparison: bool) -> dict:
def metrics(
engine_info: dict,
enable_semantics_comparison: bool,
component: PipelineComponent,
) -> dict:
return {
"metrics": [
AccuracyMetric(
Expand All @@ -302,6 +311,9 @@ def metrics(engine_info: dict, enable_semantics_comparison: bool) -> dict:
# this is for spider dataset, rn we temporarily disable it
ExactMatchAccuracy(),
ExecutionAccuracy(),
QuestionToReasoningJudge(**component),
ReasoningToSqlJudge(**component),
SqlSemanticsJudge(**component),
],
"post_metrics": [],
}
Expand Down Expand Up @@ -402,7 +414,11 @@ async def __call__(self, query: str, **_):
]

@staticmethod
def metrics(engine_info: dict, enable_semantics_comparison: bool) -> dict:
def metrics(
engine_info: dict,
enable_semantics_comparison: bool,
component: PipelineComponent,
) -> dict:
return {
"metrics": [
AccuracyMetric(
Expand All @@ -417,6 +433,9 @@ def metrics(engine_info: dict, enable_semantics_comparison: bool) -> dict:
# this is for spider dataset, rn we temporarily disable it
ExactMatchAccuracy(),
ExecutionAccuracy(),
QuestionToReasoningJudge(**component),
ReasoningToSqlJudge(**component),
SqlSemanticsJudge(**component),
],
"post_metrics": [],
}
Expand Down Expand Up @@ -449,13 +468,20 @@ def init(

def metrics_initiator(
pipeline: str,
engine_info: dict,
dataset: dict,
pipe_components: dict[str, PipelineComponent],
enable_semantics_comparison: bool = True,
) -> dict:
engine_info = engine_config(dataset["mdl"], pipe_components)
component = pipe_components["evaluation"]
match pipeline:
case "retrieval":
return RetrievalPipeline.metrics(engine_info)
case "generation":
return GenerationPipeline.metrics(engine_info, enable_semantics_comparison)
return GenerationPipeline.metrics(
engine_info, enable_semantics_comparison, component
)
case "ask":
return AskPipeline.metrics(engine_info, enable_semantics_comparison)
return AskPipeline.metrics(
engine_info, enable_semantics_comparison, component
)
2 changes: 2 additions & 0 deletions wren-ai-service/tools/config/config.example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ pipes:
- name: sql_regeneration
llm: litellm_llm.gpt-4o-mini-2024-07-18
engine: wren_ui
- name: evaluation
llm: litellm_llm.gpt-4o-mini-2024-07-18

---
settings:
Expand Down
2 changes: 2 additions & 0 deletions wren-ai-service/tools/config/config.full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ pipes:
- name: sql_regeneration
llm: litellm_llm.gpt-4o-mini-2024-07-18
engine: wren_ui
- name: evaluation
llm: litellm_llm.gpt-4o-mini-2024-07-18

---
settings:
Expand Down

0 comments on commit 33de785

Please sign in to comment.