Skip to content
Merged
6 changes: 6 additions & 0 deletions tests/lmeval/configs/w4a4_nvfp4.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@ dataset_id: HuggingFaceH4/ultrachat_200k
dataset_split: train_sft
num_calibration_samples: 20
lmeval:
# NVFP4 (4-bit weights + 4-bit activations) has lower recovery than FP8/INT8
# Observed: strict-match ~92.81%, flexible-extract ~89.59%
recovery_threshold:
exact_match,strict-match: 0.92
exact_match,flexible-extract: 0.89
# Absolute metrics for warnings only
metrics:
exact_match,flexible-extract: 0.70
exact_match,strict-match: 0.65
210 changes: 176 additions & 34 deletions tests/lmeval/test_lmeval.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import random
import shutil
from pathlib import Path
from typing import Optional, Union

import numpy
import pandas as pd
Expand All @@ -23,8 +24,12 @@ class LmEvalConfig(BaseModel):
task: str = "gsm8k"
num_fewshot: int = 5
limit: int = 1000
metrics: dict
batch_size: int = 100
# Recovery testing (default): compare against base model performance
# Default threshold is 0.95 (retain ≥95% of base), can be overridden
recovery_threshold: Union[float, dict] = 0.95
# Optional absolute metrics for warnings (not failures)
metrics: Optional[dict] = None


try:
Expand Down Expand Up @@ -62,6 +67,16 @@ class TestLMEval:
or another identifier which can be used for the particular test case. If a recipe
is not provided, it is assumed that the scheme provided is a preset scheme and will
be used for quantization. Otherwise, the recipe will always be used if given.

Recovery Testing (DEFAULT):
Tests now use recovery-based validation by default, comparing compressed model
performance against the base model. Default threshold is 0.95 (≥95% recovery).

Config options:
- recovery_threshold: 0.95 (default if not specified)
- recovery_threshold: 0.93 (override default globally)
- recovery_threshold: {"metric1": 0.95, "metric2": 0.90} (per-metric)
- metrics: {...} (optional - used for warnings only, not failures)
""" # noqa: E501

def set_up(self, test_data_file: str):
Expand Down Expand Up @@ -89,6 +104,11 @@ def set_up(self, test_data_file: str):

logger.info("========== RUNNING ==============")
logger.info(self.scheme)
logger.info(
f"Recovery threshold: {self.lmeval.recovery_threshold} (default: 0.95)"
)
if self.lmeval.metrics:
logger.info("Absolute metrics provided - will show warnings if outside ±5%")

self.num_calibration_samples = eval_config.get("num_calibration_samples", 512)
self.max_seq_length = 2048
Expand All @@ -97,6 +117,10 @@ def test_lm_eval(self, test_data_file: str):
# Run vLLM with saved model
self.set_up(test_data_file)

# Always evaluate base model for recovery testing
logger.info("================= Evaluating BASE model ======================")
self.base_results = self._eval_base_model()

if not self.save_dir:
self.save_dir = self.model.split("/")[1] + f"-{self.scheme}"
oneshot_model, processor = run_oneshot_for_e2e_testing(
Expand All @@ -119,11 +143,28 @@ def test_lm_eval(self, test_data_file: str):
# Reset session for next test case
self._handle_recipe()

logger.info("================= Running LM Eval ======================")
logger.info("================= Running LM Eval on COMPRESSED model ==========")
self._run_lm_eval()

self.tear_down()

@log_time
def _eval_base_model(self):
"""Evaluate the base (uncompressed) model."""
model_args = {**self.lmeval.model_args, "pretrained": self.model}

results = lm_eval.simple_evaluate(
model=self.lmeval.model,
model_args=model_args,
tasks=[self.lmeval.task],
num_fewshot=self.lmeval.num_fewshot,
limit=self.lmeval.limit,
device="cuda:0",
batch_size=self.lmeval.batch_size,
)

return results

@log_time
def _save_compressed_model(self, oneshot_model, processor):
oneshot_model.save_pretrained(self.save_dir)
Expand Down Expand Up @@ -152,46 +193,147 @@ def _run_lm_eval(self):
batch_size=self.lmeval.batch_size,
)

# Always use recovery testing
self._validate_recovery(results)

# If absolute metrics provided, show warnings (not failures)
if self.lmeval.metrics:
self._check_absolute_warnings(results)

def _validate_recovery(self, compressed_results):
"""Validate using recovery testing - compare against base model."""
base_metrics = self.base_results["results"][self.lmeval.task]
compressed_metrics = compressed_results["results"][self.lmeval.task]
higher_is_better_map = compressed_results.get("higher_is_better", {}).get(
self.lmeval.task, {}
)

logger.info("=" * 80)
logger.info("RECOVERY TESTING COMPARISON")
logger.info("=" * 80)

# Get default threshold from config schema
default_threshold = self.lmeval.model_fields["recovery_threshold"].default

failures = []
# Iterate over compressed metrics (what we actually got)
for metric_key, compressed_val in compressed_metrics.items():
# Skip stderr and other metadata
if "stderr" in metric_key or metric_key.startswith("alias"):
continue

base_val = base_metrics.get(metric_key)
if base_val is None:
logger.warning(
f"Metric {metric_key} in compressed results "
f"not found in base results, skipping"
)
continue

# Get threshold for this metric
if isinstance(self.lmeval.recovery_threshold, dict):
threshold = self.lmeval.recovery_threshold.get(
metric_key, default_threshold
)
else:
threshold = self.lmeval.recovery_threshold

# Get direction
base_metric_name = metric_key.split(",")[0]
higher_is_better = higher_is_better_map.get(base_metric_name, True)

# Compute recovery
if base_val == 0:
recovery = 1.0 if compressed_val == 0 else 0.0
elif higher_is_better:
recovery = compressed_val / base_val
else:
# For "lower is better", invert ratio
recovery = base_val / compressed_val

# Check threshold
passed = recovery >= threshold
direction = "↑" if higher_is_better else "↓"

msg = (
f"{metric_key:40} | Base: {base_val:.4f} | "
f"Compressed: {compressed_val:.4f} | "
f"Recovery: {recovery:6.2%} {direction} | Threshold: ≥{threshold:.2%}"
)

if passed:
logger.info(f"✓ {msg}")
else:
logger.error(f"✗ {msg}")
failures.append(
f"{metric_key}: {recovery:.2%} < {threshold:.2%} "
f"(base={base_val:.4f}, compressed={compressed_val:.4f})"
)

# Validate that config thresholds match actual results
if isinstance(self.lmeval.recovery_threshold, dict):
for config_metric_key in self.lmeval.recovery_threshold.keys():
if config_metric_key not in compressed_metrics:
logger.warning(
f"Metric {config_metric_key} in recovery_threshold config "
f"not found in results"
)

logger.info("=" * 80)

if failures:
failure_msg = "\n".join(failures)
raise AssertionError(f"Recovery testing failed:\n{failure_msg}")

logger.info("✓ ALL METRICS PASSED RECOVERY THRESHOLDS")
logger.info("=" * 80)

def _check_absolute_warnings(self, results):
"""Check absolute metrics and warn if outside ±5% tolerance (not a failure)."""
logger.info("=" * 80)
logger.info("ABSOLUTE METRICS CHECK (warnings only, not failures)")
logger.info("=" * 80)

metrics: dict = results["results"][self.lmeval.task]
for metric_key, expected_val in self.lmeval.metrics.items():
# stderr metrics are only used as absolute tolerance
# checks for actual values
# Skip stderr metrics
if "stderr" in metric_key:
continue

actual_val = metrics.get(metric_key)
higher_is_better = results["higher_is_better"][self.lmeval.task].get(
metric_key.split(",")[0], True
)
stderr_key = metric_key.replace(",", "_stderr,")
std_err = self.lmeval.metrics.get(stderr_key)

# If stderr is provided, use it as absolute tolerance
# Otherwise, default to a 5% relative tolerance
if std_err is None:
logger.info(
f"Comparing {metric_key}: Expecting {expected_val} "
f"relative tolerance ±5%, Got {actual_val}. "
f"Higher is better: {higher_is_better}"
if actual_val is None:
logger.warning(
f"Metric {metric_key} in config not found in results, "
f"skipping warning check"
)
# If higher is better, assert actual val >= expected val * (1 - stderr)
if higher_is_better:
assert actual_val >= expected_val * (0.95)
# If higher is worse, assert actual val <= expected val * (1 + stderr)
else:
assert actual_val <= expected_val * (1.05)
continue

higher_is_better = (
results.get("higher_is_better", {})
.get(self.lmeval.task, {})
.get(metric_key.split(",")[0], True)
)

# Check if within ±5% relative tolerance
lower_bound = expected_val * 0.95
upper_bound = expected_val * 1.05

if higher_is_better:
# For higher is better, we care about lower bound
if actual_val < lower_bound:
logger.warning(
f"⚠ {metric_key:40} | Expected: {expected_val:.4f} (±5%) | "
f"Got: {actual_val:.4f} | Below expected range"
)
else:
logger.info(
f"Comparing {metric_key}: Expecting {expected_val} "
f"absolute tolerance ±{std_err*100}%, Got {actual_val}. "
f"Higher is better: {higher_is_better}"
)
# If higher is better, assert actual val >= expected val - stderr
if higher_is_better:
assert actual_val >= expected_val - std_err
# If higher is worse, assert actual val <= expected val + stderr
else:
assert actual_val <= expected_val + std_err
# For lower is better, we care about upper bound
if actual_val > upper_bound:
logger.warning(
f"⚠ {metric_key:40} | Expected: {expected_val:.4f} (±5%) | "
f"Got: {actual_val:.4f} | Above expected range"
)

logger.info("=" * 80)

def tear_down(self):
timer = get_singleton_manager()
Expand Down