Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes and improvements to task based llm as judge #1366

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
34 changes: 18 additions & 16 deletions examples/evaluate_external_rag_results_with_binary_llm_as_judge.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
"Supported foundation models available with watsonx.ai. Watsonx.ai offers numerous foundation models."
],
"ground_truths": ["Many Large Language Models are supported by Watsonx.ai"],
"a123123": "",
"metadata": {"data_classification_policy": ["public"]},
},
{
"question": "What foundation models are available in watsonx.ai ?",
Expand All @@ -20,6 +20,7 @@
"Supported foundation models available with Meta. Meta AI offers numerous foundation models."
],
"ground_truths": ["Many Large Language Models are supported by Watsonx.ai"],
"metadata": {"data_classification_policy": ["public"]},
},
{
"question": "What foundation models are available in watsonx.ai ?",
Expand All @@ -28,6 +29,7 @@
"Supported foundation models available with Meta. Meta AI offers numerous foundation models."
],
"ground_truths": ["Many Large Language Models are supported by Watsonx.ai"],
"metadata": {"data_classification_policy": ["public"]},
},
{
"question": "What foundation models are available in watsonx.ai ?",
Expand All @@ -36,6 +38,7 @@
"Supported foundation models available with Meta. Meta AI offers numerous foundation models."
],
"ground_truths": ["Many Large Language Models are supported by Watsonx.ai"],
"metadata": {"data_classification_policy": ["public"]},
},
{
"question": "What foundation models are available in watsonx.ai ?",
Expand All @@ -44,28 +47,28 @@
"Supported foundation models available with Meta. Meta AI offers numerous foundation models."
],
"ground_truths": ["Many Large Language Models are supported by Watsonx.ai"],
"metadata": {"data_classification_policy": ["public"]},
},
]

# select the desired metrics.
# all available metrics are under "catalog.metrics.llm_as_judge.binary"
# Select the desired metric(s).
# Each metric measures a certain aspect of the generated answer (answer_correctness, faithfulness,
# answer_relevance, context_relevance and correctness_holistic).
# All available metrics are under "catalog.metrics.rag"
# Those with extension "logprobs" provide a real value prediction in [0,1], the others provide a binary prediction.
# By default, all judges use llama_3_1_70b_instruct_wml. We will soon see how to change this.
metric_names = [
"answer_correctness_q_a_gt_loose_logprobs",
"answer_correctness_q_a_gt_strict_logprobs",
"faithfulness_q_c_a_logprobs",
"faithfulness_c_a_logprobs",
"context_relevance_q_c_ares_logprobs",
"answer_relevance_q_a_logprobs",
"metrics.rag.answer_correctness.llama_3_1_70b_instruct_wml_q_a_gt_loose_logprobs",
"metrics.rag.faithfulness.llama_3_1_70b_instruct_wml_q_c_a_logprobs",
]
metrics_path = "metrics.llm_as_judge.binary.llama_3_1_70b_instruct_wml_"

# select the desired model.
# all available models are under "catalog.engines.classification"
model_names = [
"mixtral_8x7b_instruct_v01_wml",
# "gpt_4_turbo_openai",
"engines.classification.mixtral_8x7b_instruct_v01_wml",
"engines.classification.llama_3_1_70b_instruct_wml",
# "engines.classification.gpt_4_turbo_openai",
]
models_path = "engines.classification"

if __name__ == "__main__":
multi_stream = MultiStream.from_iterables({"test": test_examples}, copying=True)
Expand All @@ -78,15 +81,14 @@
for model_name in model_names:
# override the metric with the inference model. the default model is llama_3_1_70b_instruct_wml so
# no need to override when using it.
llmaj_metric_name = f"{metrics_path}{metric_name}[inference_model={models_path}.{model_name}]"
llmaj_metric_name = f"{metric_name}[inference_model={model_name}]"

# apply the metric over the input
metrics_operator = SequentialOperator(steps=[llmaj_metric_name])
instances = metrics_operator(multi_stream)["test"]
instances = list(instances)

# all scores will have this prefix
score_name = f"{model_name}_{metric_name}"
score_name = instances[0]["score"]["instance"]["score_name"]
for i in range(len(instances)):
results[i][score_name] = instances[i]["score"]["instance"][score_name]
results[i][f"{score_name}_source"] = instances[i]["score"]["instance"][
Expand Down
2 changes: 1 addition & 1 deletion prepare/metrics/llm_as_judge/binary_judge.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def get_prediction_field(metric_type):
inference_model=inference_model,
template=f"templates.rag_eval.{metric_type}.{template_name}{logprobs_label}",
task=task_name,
format="formats.empty",
format=None,
main_score=metric_label,
prediction_field=get_prediction_field(metric_type),
infer_log_probs=use_logprobs,
Expand Down
7 changes: 6 additions & 1 deletion prepare/tasks/rag_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,12 @@ def convert_to_dict_of_type(field_list):
outputs=convert_to_dict_of_type(["is_correct", "number_val"]),
metrics=rag_classification_metrics[binary_val],
prediction_type="float",
defaults={"choices": ["yes", "no"], "is_correct": ["-"], "number_val": -1},
defaults={
"choices": ["yes", "no"],
"is_correct": ["-"],
"number_val": -1,
"contexts": ["-"],
},
),
f"tasks.rag_eval.answer_correctness.{binary_val}",
overwrite=True,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
},
"template": "templates.rag_eval.answer_correctness.judge_loose_match_no_context",
"task": "tasks.rag_eval.answer_correctness.binary",
"format": "formats.empty",
"format": null,
"main_score": "answer_correctness_q_a_gt_loose",
"prediction_field": "answer",
"infer_log_probs": false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
},
"template": "templates.rag_eval.answer_correctness.judge_simplified_format",
"task": "tasks.rag_eval.answer_correctness.binary",
"format": "formats.empty",
"format": null,
"main_score": "answer_correctness_q_a_gt_strict",
"prediction_field": "answer",
"infer_log_probs": false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
},
"template": "templates.rag_eval.answer_relevance.judge_answer_relevance",
"task": "tasks.rag_eval.answer_relevance.binary",
"format": "formats.empty",
"format": null,
"main_score": "answer_relevance_q_a",
"prediction_field": "answer",
"infer_log_probs": false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
},
"template": "templates.rag_eval.context_relevance.judge_context_relevance_ares",
"task": "tasks.rag_eval.context_relevance.binary",
"format": "formats.empty",
"format": null,
"main_score": "context_relevance_q_c_ares",
"prediction_field": null,
"infer_log_probs": false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
},
"template": "templates.rag_eval.correctness_holistic.judge_correctness_simple",
"task": "tasks.rag_eval.correctness_holistic.binary",
"format": "formats.empty",
"format": null,
"main_score": "correctness_holistic_q_c_a",
"prediction_field": "answer",
"infer_log_probs": false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
},
"template": "templates.rag_eval.faithfulness.judge_no_question_simplified",
"task": "tasks.rag_eval.faithfulness.binary",
"format": "formats.empty",
"format": null,
"main_score": "faithfulness_c_a",
"prediction_field": "answer",
"infer_log_probs": false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
},
"template": "templates.rag_eval.faithfulness.judge_with_question_simplified",
"task": "tasks.rag_eval.faithfulness.binary",
"format": "formats.empty",
"format": null,
"main_score": "faithfulness_q_c_a",
"prediction_field": "answer",
"infer_log_probs": false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"inference_model": "engines.classification.llama_3_1_70b_instruct_wml",
"template": "templates.rag_eval.answer_correctness.judge_loose_match_no_context",
"task": "tasks.rag_eval.answer_correctness.binary",
"format": "formats.empty",
"format": null,
"main_score": "answer_correctness_q_a_gt_loose",
"prediction_field": "answer",
"infer_log_probs": false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"inference_model": "engines.classification.llama_3_1_70b_instruct_wml",
"template": "templates.rag_eval.answer_correctness.judge_loose_match_no_context_logprobs",
"task": "tasks.rag_eval.answer_correctness.binary",
"format": "formats.empty",
"format": null,
"main_score": "answer_correctness_q_a_gt_loose_logprobs",
"prediction_field": "answer",
"infer_log_probs": true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"inference_model": "engines.classification.llama_3_1_70b_instruct_wml",
"template": "templates.rag_eval.answer_correctness.judge_simplified_format",
"task": "tasks.rag_eval.answer_correctness.binary",
"format": "formats.empty",
"format": null,
"main_score": "answer_correctness_q_a_gt_strict",
"prediction_field": "answer",
"infer_log_probs": false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"inference_model": "engines.classification.llama_3_1_70b_instruct_wml",
"template": "templates.rag_eval.answer_correctness.judge_simplified_format_logprobs",
"task": "tasks.rag_eval.answer_correctness.binary",
"format": "formats.empty",
"format": null,
"main_score": "answer_correctness_q_a_gt_strict_logprobs",
"prediction_field": "answer",
"infer_log_probs": true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"inference_model": "engines.classification.llama_3_1_70b_instruct_wml",
"template": "templates.rag_eval.answer_relevance.judge_answer_relevance",
"task": "tasks.rag_eval.answer_relevance.binary",
"format": "formats.empty",
"format": null,
"main_score": "answer_relevance_q_a",
"prediction_field": "answer",
"infer_log_probs": false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"inference_model": "engines.classification.llama_3_1_70b_instruct_wml",
"template": "templates.rag_eval.answer_relevance.judge_answer_relevance_logprobs",
"task": "tasks.rag_eval.answer_relevance.binary",
"format": "formats.empty",
"format": null,
"main_score": "answer_relevance_q_a_logprobs",
"prediction_field": "answer",
"infer_log_probs": true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"inference_model": "engines.classification.llama_3_1_70b_instruct_wml",
"template": "templates.rag_eval.context_relevance.judge_context_relevance_ares",
"task": "tasks.rag_eval.context_relevance.binary",
"format": "formats.empty",
"format": null,
"main_score": "context_relevance_q_c_ares",
"prediction_field": null,
"infer_log_probs": false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"inference_model": "engines.classification.llama_3_1_70b_instruct_wml",
"template": "templates.rag_eval.context_relevance.judge_context_relevance_ares_logprobs",
"task": "tasks.rag_eval.context_relevance.binary",
"format": "formats.empty",
"format": null,
"main_score": "context_relevance_q_c_ares_logprobs",
"prediction_field": null,
"infer_log_probs": true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"inference_model": "engines.classification.llama_3_1_70b_instruct_wml",
"template": "templates.rag_eval.correctness_holistic.judge_correctness_simple",
"task": "tasks.rag_eval.correctness_holistic.binary",
"format": "formats.empty",
"format": null,
"main_score": "correctness_holistic_q_c_a",
"prediction_field": "answer",
"infer_log_probs": false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"inference_model": "engines.classification.llama_3_1_70b_instruct_wml",
"template": "templates.rag_eval.correctness_holistic.judge_correctness_simple_logprobs",
"task": "tasks.rag_eval.correctness_holistic.binary",
"format": "formats.empty",
"format": null,
"main_score": "correctness_holistic_q_c_a_logprobs",
"prediction_field": "answer",
"infer_log_probs": true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"inference_model": "engines.classification.llama_3_1_70b_instruct_wml",
"template": "templates.rag_eval.faithfulness.judge_no_question_simplified",
"task": "tasks.rag_eval.faithfulness.binary",
"format": "formats.empty",
"format": null,
"main_score": "faithfulness_c_a",
"prediction_field": "answer",
"infer_log_probs": false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"inference_model": "engines.classification.llama_3_1_70b_instruct_wml",
"template": "templates.rag_eval.faithfulness.judge_no_question_simplified_logprobs",
"task": "tasks.rag_eval.faithfulness.binary",
"format": "formats.empty",
"format": null,
"main_score": "faithfulness_c_a_logprobs",
"prediction_field": "answer",
"infer_log_probs": true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"inference_model": "engines.classification.llama_3_1_70b_instruct_wml",
"template": "templates.rag_eval.faithfulness.judge_with_question_simplified",
"task": "tasks.rag_eval.faithfulness.binary",
"format": "formats.empty",
"format": null,
"main_score": "faithfulness_q_c_a",
"prediction_field": "answer",
"infer_log_probs": false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"inference_model": "engines.classification.llama_3_1_70b_instruct_wml",
"template": "templates.rag_eval.faithfulness.judge_with_question_simplified_logprobs",
"task": "tasks.rag_eval.faithfulness.binary",
"format": "formats.empty",
"format": null,
"main_score": "faithfulness_q_c_a_logprobs",
"prediction_field": "answer",
"infer_log_probs": true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
},
"template": "templates.rag_eval.answer_correctness.judge_loose_match_no_context",
"task": "tasks.rag_eval.answer_correctness.binary",
"format": "formats.empty",
"format": null,
"main_score": "answer_correctness_q_a_gt_loose",
"prediction_field": "answer",
"infer_log_probs": false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
},
"template": "templates.rag_eval.answer_correctness.judge_simplified_format",
"task": "tasks.rag_eval.answer_correctness.binary",
"format": "formats.empty",
"format": null,
"main_score": "answer_correctness_q_a_gt_strict",
"prediction_field": "answer",
"infer_log_probs": false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"inference_model": "engines.classification.llama_3_1_70b_instruct_wml",
"template": "templates.rag_eval.answer_correctness.judge_loose_match_no_context",
"task": "tasks.rag_eval.answer_correctness.binary",
"format": "formats.empty",
"format": null,
"main_score": "answer_correctness_q_a_gt_loose",
"prediction_field": "answer",
"infer_log_probs": false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"inference_model": "engines.classification.llama_3_1_70b_instruct_wml",
"template": "templates.rag_eval.answer_correctness.judge_loose_match_no_context_logprobs",
"task": "tasks.rag_eval.answer_correctness.binary",
"format": "formats.empty",
"format": null,
"main_score": "answer_correctness_q_a_gt_loose_logprobs",
"prediction_field": "answer",
"infer_log_probs": true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"inference_model": "engines.classification.llama_3_1_70b_instruct_wml",
"template": "templates.rag_eval.answer_correctness.judge_simplified_format",
"task": "tasks.rag_eval.answer_correctness.binary",
"format": "formats.empty",
"format": null,
"main_score": "answer_correctness_q_a_gt_strict",
"prediction_field": "answer",
"infer_log_probs": false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"inference_model": "engines.classification.llama_3_1_70b_instruct_wml",
"template": "templates.rag_eval.answer_correctness.judge_simplified_format_logprobs",
"task": "tasks.rag_eval.answer_correctness.binary",
"format": "formats.empty",
"format": null,
"main_score": "answer_correctness_q_a_gt_strict_logprobs",
"prediction_field": "answer",
"infer_log_probs": true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
},
"template": "templates.rag_eval.answer_relevance.judge_answer_relevance",
"task": "tasks.rag_eval.answer_relevance.binary",
"format": "formats.empty",
"format": null,
"main_score": "answer_relevance_q_a",
"prediction_field": "answer",
"infer_log_probs": false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"inference_model": "engines.classification.llama_3_1_70b_instruct_wml",
"template": "templates.rag_eval.answer_relevance.judge_answer_relevance",
"task": "tasks.rag_eval.answer_relevance.binary",
"format": "formats.empty",
"format": null,
"main_score": "answer_relevance_q_a",
"prediction_field": "answer",
"infer_log_probs": false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"inference_model": "engines.classification.llama_3_1_70b_instruct_wml",
"template": "templates.rag_eval.answer_relevance.judge_answer_relevance_logprobs",
"task": "tasks.rag_eval.answer_relevance.binary",
"format": "formats.empty",
"format": null,
"main_score": "answer_relevance_q_a_logprobs",
"prediction_field": "answer",
"infer_log_probs": true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
},
"template": "templates.rag_eval.context_relevance.judge_context_relevance_ares",
"task": "tasks.rag_eval.context_relevance.binary",
"format": "formats.empty",
"format": null,
"main_score": "context_relevance_q_c_ares",
"prediction_field": null,
"infer_log_probs": false
Expand Down
Loading
Loading