-
Notifications
You must be signed in to change notification settings - Fork 35
Onboarding SimpleQA #184
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Onboarding SimpleQA #184
Conversation
| from .transform import DFTransformBase | ||
|
|
||
| @dataclass | ||
| class SimpleQA_MetadataExplode(DFTransformBase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Would it be good to document what is the expected format of the metadata column? Just so it's easier for someone who looks at this file to know what it should look like without having to look at other parts of the code.
| metadata_column: str | ||
|
|
||
| def transform(self, df: pd.DataFrame) -> pd.DataFrame: | ||
| df[self.metadata_column] = df[self.metadata_column].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If x is an invalid Python literal, ast.literal_eval() will raise an exception. Do we want to handle this gracefully here or is it okay to let the program crash in this case?
| return df | ||
|
|
||
| def explode_metadata(self, df): | ||
| # TODO this would break if the first row does not have all the metrics, e.g. due invalid inference results |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the first row has a None, there is will be an exception since None does not have a keys() attribute. But, seems like this case should not happen. Flagging just in case.
| def __init__(self, is_correct_column_name, is_incorrect_column_name, is_not_attempted_column_name, output_dir, group_by=None, **kwargs): | ||
| """ | ||
| args: | ||
| - is_correct_column_name (str): The name of the column containing the correct values. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this column a boolean indicating whether the response was correct or a count of correct responses? May be good to clarify this in the docstring.
|
|
||
| def process_row(self, row): | ||
| grading_response = row["model_output"] | ||
| if grading_response is None or str(grading_response)=="nan": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: An alternative to
str(grading_response)=="nan"could be
pd.isna(grading_response)I think this would usually be safer, but up to you.
| @@ -0,0 +1,78 @@ | |||
| Your job is to look at a question, a gold target, and a predicted answer, and then assign a grade of either ["CORRECT", "INCORRECT", "NOT_ATTEMPTED"]. | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: This instruction seems slightly misaligned with the final one:
Grade the predicted answer of this new question as one of:
A: CORRECT
B: INCORRECT
C: NOT_ATTEMPTEDJust return the letters "A", "B", or "C", with no text around it.
| - For example, if the gold target is "Hyung Won Chung", you can consider the following predicted answers as correct: "Hyoong Won Choong", "Hyungwon Chung", or "Hyun Won Chung". | ||
|
|
||
|
|
||
| Here is a new example. Simply reply with either CORRECT, INCORRECT, NOT ATTEMPTED. Don't apologize or correct yourself if there was a mistake; we are just trying to grade the answer. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above: The model should output A, B, or C but this instruction tells the model to output CORRECT, INCORRECT, NOT ATTEMPTED>
| "path": "lighteval/SimpleQA", | ||
| "split": "test", | ||
| "transform": SequenceTransform([ | ||
| SamplerTransform(sample_count=100, random_seed=42), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we remove the sampling before submitting?
| "SimpleQA_Metric_grade": "SimpleQA_Metric_grade_onerun", | ||
| } | ||
| ), | ||
| AddColumn("SimpleQA_Metric_grade"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just for my understanding, why are these two transforms needed? Couldn't we just use SimpleQA_Metric_grade?
| ), | ||
| AddColumn("SimpleQA_Metric_grade"), | ||
| MajorityVoteTransform(model_output_col="SimpleQA_Metric_grade_onerun", model_label_column="SimpleQA_Metric_is_correct"), | ||
| RunPythonTransform("df = df.rename_axis(index={'data_point_id': 'data_point_id_idx'})"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also just so I can understand: Why is it needed to rename the index?
| self.data_processing_comp.data_reader_config.init_args["path"] = "google/simpleqa-verified" | ||
| self.data_processing_comp.data_reader_config.init_args["split"] = "eval" | ||
| num_transforms = len(self.data_processing_comp.data_reader_config.init_args["transform"].transforms) | ||
| self.data_processing_comp.data_reader_config.init_args["transform"].transforms = self.data_processing_comp.data_reader_config.init_args["transform"].transforms[0:num_transforms-1] # remove last transform which is SimpleQA_MetadataExplode |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add a comment explaining why the explode is not necessary in this case? Maybe the verified dataset does not have the metadata column which is being exploded?
| if grading_response is None or str(grading_response)=="nan": | ||
| grade_letter = "C" # Default to "NOT_ATTEMPTED" if there is no grading response | ||
| else: | ||
| match = re.search(r"(A|B|C)", grading_response) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the grading model outputs something other than only one of A, B, or C, then this regex may result in false positives.
For example:
import re
re.search(r"(A|B|C)", "A grading response that has some of the letters B, C")would match the article 'A' at the beginning of the sentence.
Maybe a more strict regex would do?
match = re.search(r"^(A|B|C)$", grading_response)| if self.group_by == 'data_repeat_id': | ||
| self._aggregate(data) | ||
| else: | ||
| original_group_by = self.group_by |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need this original_group_by variable? Could we just use self.group_by instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can pull out some of the processing functions into helpers just to make each function smaller?
Example:
import re
import pandas as pd
from eureka_ml_insights.metrics.metrics_base import CompositeMetric
from eureka_ml_insights.metrics.reports import NumericalAggregator
GRADE_REGEX = re.compile(r"^(A|B|C)$")
GRADE_DEFAULT = "C"
def _parse_grade(response):
"""Extract grade A/B/C from the model response."""
if response is None or pd.isna(response):
return GRADE_DEFAULT
match = GRADE_REGEX.search(str(response))
return match.group(0) if match else GRADE_DEFAULT
def _process_row(self, row):
grade = _parse_grade(row["model_output"])
return {
"grade": grade,
"is_correct": grade == "A",
"is_incorrect": grade == "B",
"is_not_attempted": grade == "C",
}
class SimpleQA_Metric(CompositeMetric):
"""
Composite metric for evaluating SimpleQA responses.
"""
def __evaluate__(self, row):
return self._process_row(row)
class SQA_CGAAggregator(NumericalAggregator):
"""
Computes accuracy = correct / attempted,
where attempted = correct + incorrect.
"""
def __init__(self, correct_col, incorrect_col, not_attempted_col, output_dir, group_by=None, **kwargs):
super().__init__(
[correct_col, incorrect_col, not_attempted_col],
output_dir,
group_by=group_by,
**kwargs,
)
self.correct_col = correct_col
self.incorrect_col = incorrect_col
self.not_attempted_col = not_attempted_col
def _compute_accuracy(self, df):
attempted = df[self.correct_col].sum() + df[self.incorrect_col].sum()
if attempted == 0:
return 0.0
return df[self.correct_col].sum() / attempted
def _aggregate(self, data):
self.aggregated_result = {
"accuracy_given_attempted": self._compute_accuracy(data)
}
def _aggregate_grouped(self, data):
grouped = data.groupby(self.group_by)
results = {}
for name, group in grouped:
results[name] = self._compute_accuracy(group)
self.aggregated_result = {"accuracy_given_attempted": results}
class SQA_CGAAvgPass1Aggregator(SQA_CGAAggregator):
"""
Computes accuracy per repeat (if present), then averages across repeats.
"""
def _aggregate(self, data):
# Handle pass-1 repeat grouping
if "data_repeat_id" not in data.columns:
return super()._aggregate(data)
# Temporarily override grouping
self.group_by = "data_repeat_id"
super()._aggregate_grouped(data)
group_results = self.aggregated_result["accuracy_given_attempted"].values()
mean_acc = sum(group_results) / len(group_results) if group_results else 0.0
self.aggregated_result = {"accuracy_given_attempted": mean_acc}
def _aggregate_grouped(self, data):
# If grouping by repeat, use the special logic above
if self.group_by == "data_repeat_id":
return self._aggregate(data)
# Otherwise, compute accuracy within each external group
grouped = data.groupby(self.group_by)
results = {}
for name, group in grouped:
super()._aggregate(group)
results[name] = self.aggregated_result["accuracy_given_attempted"]
self.aggregated_result = {"accuracy_given_attempted": results}
This PR onboards SimpleQA based on https://github.com/openai/simple-evals/blob/main/simpleqa_eval.py