Skip to content

Conversation

@vidhishanair
Copy link
Collaborator

from .transform import DFTransformBase

@dataclass
class SimpleQA_MetadataExplode(DFTransformBase):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Would it be good to document what is the expected format of the metadata column? Just so it's easier for someone who looks at this file to know what it should look like without having to look at other parts of the code.

metadata_column: str

def transform(self, df: pd.DataFrame) -> pd.DataFrame:
df[self.metadata_column] = df[self.metadata_column].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else x)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If x is an invalid Python literal, ast.literal_eval() will raise an exception. Do we want to handle this gracefully here or is it okay to let the program crash in this case?

return df

def explode_metadata(self, df):
# TODO this would break if the first row does not have all the metrics, e.g. due invalid inference results
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the first row has a None, there is will be an exception since None does not have a keys() attribute. But, seems like this case should not happen. Flagging just in case.

def __init__(self, is_correct_column_name, is_incorrect_column_name, is_not_attempted_column_name, output_dir, group_by=None, **kwargs):
"""
args:
- is_correct_column_name (str): The name of the column containing the correct values.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this column a boolean indicating whether the response was correct or a count of correct responses? May be good to clarify this in the docstring.


def process_row(self, row):
grading_response = row["model_output"]
if grading_response is None or str(grading_response)=="nan":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: An alternative to

str(grading_response)=="nan"

could be

pd.isna(grading_response)

I think this would usually be safer, but up to you.

@@ -0,0 +1,78 @@
Your job is to look at a question, a gold target, and a predicted answer, and then assign a grade of either ["CORRECT", "INCORRECT", "NOT_ATTEMPTED"].
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: This instruction seems slightly misaligned with the final one:

Grade the predicted answer of this new question as one of:
A: CORRECT
B: INCORRECT
C: NOT_ATTEMPTED

Just return the letters "A", "B", or "C", with no text around it.

- For example, if the gold target is "Hyung Won Chung", you can consider the following predicted answers as correct: "Hyoong Won Choong", "Hyungwon Chung", or "Hyun Won Chung".


Here is a new example. Simply reply with either CORRECT, INCORRECT, NOT ATTEMPTED. Don't apologize or correct yourself if there was a mistake; we are just trying to grade the answer.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above: The model should output A, B, or C but this instruction tells the model to output CORRECT, INCORRECT, NOT ATTEMPTED>

"path": "lighteval/SimpleQA",
"split": "test",
"transform": SequenceTransform([
SamplerTransform(sample_count=100, random_seed=42),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we remove the sampling before submitting?

"SimpleQA_Metric_grade": "SimpleQA_Metric_grade_onerun",
}
),
AddColumn("SimpleQA_Metric_grade"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for my understanding, why are these two transforms needed? Couldn't we just use SimpleQA_Metric_grade?

),
AddColumn("SimpleQA_Metric_grade"),
MajorityVoteTransform(model_output_col="SimpleQA_Metric_grade_onerun", model_label_column="SimpleQA_Metric_is_correct"),
RunPythonTransform("df = df.rename_axis(index={'data_point_id': 'data_point_id_idx'})"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also just so I can understand: Why is it needed to rename the index?

self.data_processing_comp.data_reader_config.init_args["path"] = "google/simpleqa-verified"
self.data_processing_comp.data_reader_config.init_args["split"] = "eval"
num_transforms = len(self.data_processing_comp.data_reader_config.init_args["transform"].transforms)
self.data_processing_comp.data_reader_config.init_args["transform"].transforms = self.data_processing_comp.data_reader_config.init_args["transform"].transforms[0:num_transforms-1] # remove last transform which is SimpleQA_MetadataExplode
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a comment explaining why the explode is not necessary in this case? Maybe the verified dataset does not have the metadata column which is being exploded?

if grading_response is None or str(grading_response)=="nan":
grade_letter = "C" # Default to "NOT_ATTEMPTED" if there is no grading response
else:
match = re.search(r"(A|B|C)", grading_response)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the grading model outputs something other than only one of A, B, or C, then this regex may result in false positives.

For example:

import re
re.search(r"(A|B|C)", "A grading response that has some of the letters B, C")

would match the article 'A' at the beginning of the sentence.

Maybe a more strict regex would do?

match = re.search(r"^(A|B|C)$", grading_response)

if self.group_by == 'data_repeat_id':
self._aggregate(data)
else:
original_group_by = self.group_by
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this original_group_by variable? Could we just use self.group_by instead?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can pull out some of the processing functions into helpers just to make each function smaller?

Example:

import re
import pandas as pd
from eureka_ml_insights.metrics.metrics_base import CompositeMetric
from eureka_ml_insights.metrics.reports import NumericalAggregator

GRADE_REGEX = re.compile(r"^(A|B|C)$")
GRADE_DEFAULT = "C"


def _parse_grade(response):
  """Extract grade A/B/C from the model response."""
  if response is None or pd.isna(response):
      return GRADE_DEFAULT
  
  match = GRADE_REGEX.search(str(response))
  return match.group(0) if match else GRADE_DEFAULT


def _process_row(self, row):
  grade = _parse_grade(row["model_output"])
  
  return {
      "grade": grade,
      "is_correct": grade == "A",
      "is_incorrect": grade == "B",
      "is_not_attempted": grade == "C",
  }


class SimpleQA_Metric(CompositeMetric):
    """
    Composite metric for evaluating SimpleQA responses.
    """

    def __evaluate__(self, row):
        return self._process_row(row)


class SQA_CGAAggregator(NumericalAggregator):
    """
    Computes accuracy = correct / attempted,
    where attempted = correct + incorrect.
    """

    def __init__(self, correct_col, incorrect_col, not_attempted_col, output_dir, group_by=None, **kwargs):
        super().__init__(
            [correct_col, incorrect_col, not_attempted_col],
            output_dir,
            group_by=group_by,
            **kwargs,
        )
        self.correct_col = correct_col
        self.incorrect_col = incorrect_col
        self.not_attempted_col = not_attempted_col

    def _compute_accuracy(self, df):
        attempted = df[self.correct_col].sum() + df[self.incorrect_col].sum()
        if attempted == 0:
            return 0.0
        return df[self.correct_col].sum() / attempted

    def _aggregate(self, data):
        self.aggregated_result = {
            "accuracy_given_attempted": self._compute_accuracy(data)
        }

    def _aggregate_grouped(self, data):
        grouped = data.groupby(self.group_by)
        results = {}

        for name, group in grouped:
            results[name] = self._compute_accuracy(group)

        self.aggregated_result = {"accuracy_given_attempted": results}


class SQA_CGAAvgPass1Aggregator(SQA_CGAAggregator):
    """
    Computes accuracy per repeat (if present), then averages across repeats.
    """

    def _aggregate(self, data):
        # Handle pass-1 repeat grouping
        if "data_repeat_id" not in data.columns:
            return super()._aggregate(data)

        # Temporarily override grouping
        self.group_by = "data_repeat_id"
        super()._aggregate_grouped(data)

        group_results = self.aggregated_result["accuracy_given_attempted"].values()
        mean_acc = sum(group_results) / len(group_results) if group_results else 0.0
        self.aggregated_result = {"accuracy_given_attempted": mean_acc}

    def _aggregate_grouped(self, data):
        # If grouping by repeat, use the special logic above
        if self.group_by == "data_repeat_id":
            return self._aggregate(data)

        # Otherwise, compute accuracy within each external group
        grouped = data.groupby(self.group_by)
        results = {}

        for name, group in grouped:
            super()._aggregate(group)
            results[name] = self.aggregated_result["accuracy_given_attempted"]

        self.aggregated_result = {"accuracy_given_attempted": results}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants