Skip to content

Commit cb8d775

Browse files
rahul-tuligemini-code-assist[bot]dsikka
authored
[Tests] Add recovery-based validation to LM-Eval tests (#1750)
# Recovery-Based Testing for LM-Eval This PR implements **recovery-based testing** as the default validation mechanism for all lm-eval tests. Tests now compare compressed model performance against base model performance, making them robust to upstream changes while ensuring quantization quality. **Current Problem:** - Tests fail when base models regress due to external changes (e.g., transformers updates, lm-eval changes) - False positives block CI even when quantization maintains expected recovery - Absolute thresholds become stale as models/libraries evolve **Example:** Qwen2.5-VL tests fail with transformers ≥ 4.54.0 due to ~10% base model accuracy drop, despite quantization maintaining the same relative performance. **Solution:** Recovery testing validates that compressed models retain ≥95% (configurable) of base model performance, regardless of absolute score changes. --- ## 🚀 New Behavior ### Default Behavior (Zero Config Required) All lm-eval tests now **automatically**: 1. ✅ Evaluate the base (uncompressed) model 2. ✅ Quantize the model using configured scheme 3. ✅ Evaluate the compressed model 4. ✅ Validate recovery ≥ 95% (default threshold) 5. ✅ Show optional warnings for absolute metrics **Recovery Formula:** ```python # For "higher is better" metrics (accuracy, F1, etc.) recovery = compressed_score / base_score # For "lower is better" metrics (perplexity, loss) recovery = base_score / compressed_score # Inverted! # Validation assert recovery >= threshold # Default: 0.95 ``` **Recovery Interpretation:** - `1.00` = Perfect (0% degradation) - `0.96` = 96% retained (4% degradation) ✅ - `0.93` = 93% retained (7% degradation) ❌ (with default threshold) --- ## 📝 Configuration Options ### Option 1: Use Default (Recommended) No configuration needed - uses 95% recovery threshold: ```yaml cadence: "weekly" model: meta-llama/Meta-Llama-3-8B-Instruct scheme: FP8_DYNAMIC lmeval: # That's it! Uses recovery_threshold: 0.95 by default ``` ### Option 2: Override Global Threshold Set a different threshold for all metrics: ```yaml lmeval: recovery_threshold: 0.93 # All metrics need ≥93% recovery ``` ### Option 3: Per-Metric Thresholds Set different thresholds for different metrics: ```yaml lmeval: recovery_threshold: exact_match,flexible-extract: 0.95 # Strict threshold exact_match,strict-match: 0.90 # Relaxed threshold ``` ### Option 4: With Absolute Metric Warnings Keep absolute metrics for informational warnings (not failures): ```yaml lmeval: recovery_threshold: 0.95 # Required - TEST FAILS if not met metrics: # Optional - warnings only, no failures exact_match,flexible-extract: 0.75 exact_match,strict-match: 0.72 ``` --- ## Example Output ### ✅ Recovery Validation (Always Shown) ``` ================================================================================ RECOVERY TESTING COMPARISON ================================================================================ ✓ exact_match,flexible-extract | Base: 0.7890 | Compressed: 0.7601 | Recovery: 96.34% ↑ | Threshold: ≥95.00% ✓ exact_match,strict-match | Base: 0.7564 | Compressed: 0.7262 | Recovery: 96.01% ↑ | Threshold: ≥95.00% ================================================================================ ✓ ALL METRICS PASSED RECOVERY THRESHOLDS ================================================================================ ``` ### Absolute Metric Warnings (If Configured) ``` ================================================================================ ABSOLUTE METRICS CHECK (warnings only, not failures) ================================================================================ ✓ exact_match,flexible-extract | Expected: 0.7500 (±5%) | Got: 0.7601 | Within expected range ⚠ exact_match,strict-match | Expected: 0.8000 (±5%) | Got: 0.7262 | Below expected range ================================================================================ ``` **Note:** The warning above doesn't fail the test - recovery validation already passed! --- ## 🔄 Migration Guide ### Existing Configs with Absolute Metrics **Before (absolute thresholds cause failures):** ```yaml lmeval: metrics: exact_match: 0.75 # TEST FAILS if not met ``` **After (minimal - uses recovery testing):** ```yaml lmeval: # Uses default recovery_threshold: 0.95 # No other config needed! ``` **After (keep warnings):** ```yaml lmeval: # recovery_threshold: 0.95 is implicit (default) metrics: # Now just warnings, won't fail tests exact_match: 0.75 ``` ### No Breaking Changes - ✅ All existing configs continue to work - ✅ `metrics` dict now shows warnings instead of failing - ✅ Recovery testing automatically enabled with sensible default - ✅ Backward compatible with all test infrastructure --- ## Implementation Details ### Files Changed - **`tests/lmeval/test_lmeval.py`** (+151/-31 lines) - Added `recovery_threshold` config field (default: 0.95) - Made `metrics` field optional - Added `_eval_base_model()` method - Added `_validate_recovery()` method - Modified `_check_absolute_warnings()` to only warn, not fail - Updated test flow to always evaluate base model first ### Key Features 1. **Direction-Aware Recovery** - Automatically detects "higher is better" vs "lower is better" metrics - Inverts ratio for perplexity-style metrics 2. **Edge Case Handling** - Zero base values: `recovery = 1.0 if compressed == 0 else 0.0` - Missing metrics: Skipped gracefully - Metadata filtering: Skips stderr and alias keys 3. **Flexible Thresholds** - Global float: `recovery_threshold: 0.93` - Per-metric dict: `recovery_threshold: {metric1: 0.95, metric2: 0.90}` - Fallback to 0.95 for unlisted metrics when using dict 4. **Comprehensive Logging** - Recovery threshold displayed at test start - Detailed comparison table with base/compressed/recovery values - Clear pass/fail indicators with direction arrows (↑/↓) - Separate section for optional absolute warnings --- ## Performance Impact **Additional Runtime:** - Base model evaluation: ~2-10 minutes - Compressed model evaluation: ~2-10 minutes (unchanged) - **Total: ~2x single evaluation time** **Trade-off:** Doubled evaluation time for robust, meaningful metrics that don't break from upstream changes. **Mitigation:** Tests run on weekly cadence, making the additional time acceptable. --- ## ✅ Benefits | Benefit | Description | |---------|-------------| | 🛡️ **Robustness** | Tests never break from lm-eval or transformers updates | | 📊 **Meaningful** | Measures actual compression degradation, not arbitrary thresholds | | 🎯 **Automatic** | Works out of box, no config needed | | 🔧 **Flexible** | Override threshold globally or per-metric | | ↔️ **Compatible** | Zero breaking changes, existing configs work | | 🧹 **Simple** | ~150 lines in single file, no new dependencies | --- ## Testing To test recovery-based validation: ```bash # Uses default recovery threshold (0.95) CADENCE=weekly TEST_DATA_FILE=tests/lmeval/configs/fp8_dynamic_per_token.yaml \ pytest tests/lmeval/test_lmeval.py -v ``` --- --------- Signed-off-by: Rahul Tuli <rtuli@redhat.com> Signed-off-by: rahul-tuli <rtuli@redhat.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Dipika Sikka <dipikasikka1@gmail.com>
1 parent d180851 commit cb8d775

File tree

2 files changed

+182
-34
lines changed

2 files changed

+182
-34
lines changed

tests/lmeval/configs/w4a4_nvfp4.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,12 @@ dataset_id: HuggingFaceH4/ultrachat_200k
55
dataset_split: train_sft
66
num_calibration_samples: 20
77
lmeval:
8+
# NVFP4 (4-bit weights + 4-bit activations) has lower recovery than FP8/INT8
9+
# Observed: strict-match ~92.81%, flexible-extract ~89.59%
10+
recovery_threshold:
11+
exact_match,strict-match: 0.92
12+
exact_match,flexible-extract: 0.89
13+
# Absolute metrics for warnings only
814
metrics:
915
exact_match,flexible-extract: 0.70
1016
exact_match,strict-match: 0.65

tests/lmeval/test_lmeval.py

Lines changed: 176 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import random
33
import shutil
44
from pathlib import Path
5+
from typing import Optional, Union
56

67
import numpy
78
import pandas as pd
@@ -23,8 +24,12 @@ class LmEvalConfig(BaseModel):
2324
task: str = "gsm8k"
2425
num_fewshot: int = 5
2526
limit: int = 1000
26-
metrics: dict
2727
batch_size: int = 100
28+
# Recovery testing (default): compare against base model performance
29+
# Default threshold is 0.95 (retain ≥95% of base), can be overridden
30+
recovery_threshold: Union[float, dict] = 0.95
31+
# Optional absolute metrics for warnings (not failures)
32+
metrics: Optional[dict] = None
2833

2934

3035
try:
@@ -62,6 +67,16 @@ class TestLMEval:
6267
or another identifier which can be used for the particular test case. If a recipe
6368
is not provided, it is assumed that the scheme provided is a preset scheme and will
6469
be used for quantization. Otherwise, the recipe will always be used if given.
70+
71+
Recovery Testing (DEFAULT):
72+
Tests now use recovery-based validation by default, comparing compressed model
73+
performance against the base model. Default threshold is 0.95 (≥95% recovery).
74+
75+
Config options:
76+
- recovery_threshold: 0.95 (default if not specified)
77+
- recovery_threshold: 0.93 (override default globally)
78+
- recovery_threshold: {"metric1": 0.95, "metric2": 0.90} (per-metric)
79+
- metrics: {...} (optional - used for warnings only, not failures)
6580
""" # noqa: E501
6681

6782
def set_up(self, test_data_file: str):
@@ -89,6 +104,11 @@ def set_up(self, test_data_file: str):
89104

90105
logger.info("========== RUNNING ==============")
91106
logger.info(self.scheme)
107+
logger.info(
108+
f"Recovery threshold: {self.lmeval.recovery_threshold} (default: 0.95)"
109+
)
110+
if self.lmeval.metrics:
111+
logger.info("Absolute metrics provided - will show warnings if outside ±5%")
92112

93113
self.num_calibration_samples = eval_config.get("num_calibration_samples", 512)
94114
self.max_seq_length = 2048
@@ -97,6 +117,10 @@ def test_lm_eval(self, test_data_file: str):
97117
# Run vLLM with saved model
98118
self.set_up(test_data_file)
99119

120+
# Always evaluate base model for recovery testing
121+
logger.info("================= Evaluating BASE model ======================")
122+
self.base_results = self._eval_base_model()
123+
100124
if not self.save_dir:
101125
self.save_dir = self.model.split("/")[1] + f"-{self.scheme}"
102126
oneshot_model, processor = run_oneshot_for_e2e_testing(
@@ -119,11 +143,28 @@ def test_lm_eval(self, test_data_file: str):
119143
# Reset session for next test case
120144
self._handle_recipe()
121145

122-
logger.info("================= Running LM Eval ======================")
146+
logger.info("================= Running LM Eval on COMPRESSED model ==========")
123147
self._run_lm_eval()
124148

125149
self.tear_down()
126150

151+
@log_time
152+
def _eval_base_model(self):
153+
"""Evaluate the base (uncompressed) model."""
154+
model_args = {**self.lmeval.model_args, "pretrained": self.model}
155+
156+
results = lm_eval.simple_evaluate(
157+
model=self.lmeval.model,
158+
model_args=model_args,
159+
tasks=[self.lmeval.task],
160+
num_fewshot=self.lmeval.num_fewshot,
161+
limit=self.lmeval.limit,
162+
device="cuda:0",
163+
batch_size=self.lmeval.batch_size,
164+
)
165+
166+
return results
167+
127168
@log_time
128169
def _save_compressed_model(self, oneshot_model, processor):
129170
oneshot_model.save_pretrained(self.save_dir)
@@ -152,46 +193,147 @@ def _run_lm_eval(self):
152193
batch_size=self.lmeval.batch_size,
153194
)
154195

196+
# Always use recovery testing
197+
self._validate_recovery(results)
198+
199+
# If absolute metrics provided, show warnings (not failures)
200+
if self.lmeval.metrics:
201+
self._check_absolute_warnings(results)
202+
203+
def _validate_recovery(self, compressed_results):
204+
"""Validate using recovery testing - compare against base model."""
205+
base_metrics = self.base_results["results"][self.lmeval.task]
206+
compressed_metrics = compressed_results["results"][self.lmeval.task]
207+
higher_is_better_map = compressed_results.get("higher_is_better", {}).get(
208+
self.lmeval.task, {}
209+
)
210+
211+
logger.info("=" * 80)
212+
logger.info("RECOVERY TESTING COMPARISON")
213+
logger.info("=" * 80)
214+
215+
# Get default threshold from config schema
216+
default_threshold = self.lmeval.model_fields["recovery_threshold"].default
217+
218+
failures = []
219+
# Iterate over compressed metrics (what we actually got)
220+
for metric_key, compressed_val in compressed_metrics.items():
221+
# Skip stderr and other metadata
222+
if "stderr" in metric_key or metric_key.startswith("alias"):
223+
continue
224+
225+
base_val = base_metrics.get(metric_key)
226+
if base_val is None:
227+
logger.warning(
228+
f"Metric {metric_key} in compressed results "
229+
f"not found in base results, skipping"
230+
)
231+
continue
232+
233+
# Get threshold for this metric
234+
if isinstance(self.lmeval.recovery_threshold, dict):
235+
threshold = self.lmeval.recovery_threshold.get(
236+
metric_key, default_threshold
237+
)
238+
else:
239+
threshold = self.lmeval.recovery_threshold
240+
241+
# Get direction
242+
base_metric_name = metric_key.split(",")[0]
243+
higher_is_better = higher_is_better_map.get(base_metric_name, True)
244+
245+
# Compute recovery
246+
if base_val == 0:
247+
recovery = 1.0 if compressed_val == 0 else 0.0
248+
elif higher_is_better:
249+
recovery = compressed_val / base_val
250+
else:
251+
# For "lower is better", invert ratio
252+
recovery = base_val / compressed_val
253+
254+
# Check threshold
255+
passed = recovery >= threshold
256+
direction = "↑" if higher_is_better else "↓"
257+
258+
msg = (
259+
f"{metric_key:40} | Base: {base_val:.4f} | "
260+
f"Compressed: {compressed_val:.4f} | "
261+
f"Recovery: {recovery:6.2%} {direction} | Threshold: ≥{threshold:.2%}"
262+
)
263+
264+
if passed:
265+
logger.info(f"✓ {msg}")
266+
else:
267+
logger.error(f"✗ {msg}")
268+
failures.append(
269+
f"{metric_key}: {recovery:.2%} < {threshold:.2%} "
270+
f"(base={base_val:.4f}, compressed={compressed_val:.4f})"
271+
)
272+
273+
# Validate that config thresholds match actual results
274+
if isinstance(self.lmeval.recovery_threshold, dict):
275+
for config_metric_key in self.lmeval.recovery_threshold.keys():
276+
if config_metric_key not in compressed_metrics:
277+
logger.warning(
278+
f"Metric {config_metric_key} in recovery_threshold config "
279+
f"not found in results"
280+
)
281+
282+
logger.info("=" * 80)
283+
284+
if failures:
285+
failure_msg = "\n".join(failures)
286+
raise AssertionError(f"Recovery testing failed:\n{failure_msg}")
287+
288+
logger.info("✓ ALL METRICS PASSED RECOVERY THRESHOLDS")
289+
logger.info("=" * 80)
290+
291+
def _check_absolute_warnings(self, results):
292+
"""Check absolute metrics and warn if outside ±5% tolerance (not a failure)."""
293+
logger.info("=" * 80)
294+
logger.info("ABSOLUTE METRICS CHECK (warnings only, not failures)")
295+
logger.info("=" * 80)
296+
155297
metrics: dict = results["results"][self.lmeval.task]
156298
for metric_key, expected_val in self.lmeval.metrics.items():
157-
# stderr metrics are only used as absolute tolerance
158-
# checks for actual values
299+
# Skip stderr metrics
159300
if "stderr" in metric_key:
160301
continue
302+
161303
actual_val = metrics.get(metric_key)
162-
higher_is_better = results["higher_is_better"][self.lmeval.task].get(
163-
metric_key.split(",")[0], True
164-
)
165-
stderr_key = metric_key.replace(",", "_stderr,")
166-
std_err = self.lmeval.metrics.get(stderr_key)
167-
168-
# If stderr is provided, use it as absolute tolerance
169-
# Otherwise, default to a 5% relative tolerance
170-
if std_err is None:
171-
logger.info(
172-
f"Comparing {metric_key}: Expecting {expected_val} "
173-
f"relative tolerance ±5%, Got {actual_val}. "
174-
f"Higher is better: {higher_is_better}"
304+
if actual_val is None:
305+
logger.warning(
306+
f"Metric {metric_key} in config not found in results, "
307+
f"skipping warning check"
175308
)
176-
# If higher is better, assert actual val >= expected val * (1 - stderr)
177-
if higher_is_better:
178-
assert actual_val >= expected_val * (0.95)
179-
# If higher is worse, assert actual val <= expected val * (1 + stderr)
180-
else:
181-
assert actual_val <= expected_val * (1.05)
309+
continue
310+
311+
higher_is_better = (
312+
results.get("higher_is_better", {})
313+
.get(self.lmeval.task, {})
314+
.get(metric_key.split(",")[0], True)
315+
)
182316

317+
# Check if within ±5% relative tolerance
318+
lower_bound = expected_val * 0.95
319+
upper_bound = expected_val * 1.05
320+
321+
if higher_is_better:
322+
# For higher is better, we care about lower bound
323+
if actual_val < lower_bound:
324+
logger.warning(
325+
f"⚠ {metric_key:40} | Expected: {expected_val:.4f} (±5%) | "
326+
f"Got: {actual_val:.4f} | Below expected range"
327+
)
183328
else:
184-
logger.info(
185-
f"Comparing {metric_key}: Expecting {expected_val} "
186-
f"absolute tolerance ±{std_err*100}%, Got {actual_val}. "
187-
f"Higher is better: {higher_is_better}"
188-
)
189-
# If higher is better, assert actual val >= expected val - stderr
190-
if higher_is_better:
191-
assert actual_val >= expected_val - std_err
192-
# If higher is worse, assert actual val <= expected val + stderr
193-
else:
194-
assert actual_val <= expected_val + std_err
329+
# For lower is better, we care about upper bound
330+
if actual_val > upper_bound:
331+
logger.warning(
332+
f"⚠ {metric_key:40} | Expected: {expected_val:.4f} (±5%) | "
333+
f"Got: {actual_val:.4f} | Above expected range"
334+
)
335+
336+
logger.info("=" * 80)
195337

196338
def tear_down(self):
197339
timer = get_singleton_manager()

0 commit comments

Comments
 (0)