From a847cc7c5af9f9b091fa20bc8cbf9984260fbaa9 Mon Sep 17 00:00:00 2001 From: rahul-tuli Date: Wed, 1 Oct 2025 09:50:39 +0000 Subject: [PATCH 1/9] Add recovery testing config schema Introduce recovery_threshold field with default value of 0.95, making recovery-based testing the default behavior. Support both global float thresholds and per-metric dict thresholds. Make metrics field optional since it's now used only for warnings rather than test validation. Signed-off-by: Rahul Tuli Signed-off-by: rahul-tuli --- tests/lmeval/test_lmeval.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/lmeval/test_lmeval.py b/tests/lmeval/test_lmeval.py index 86020865f..0ca9452d7 100644 --- a/tests/lmeval/test_lmeval.py +++ b/tests/lmeval/test_lmeval.py @@ -2,6 +2,7 @@ import random import shutil from pathlib import Path +from typing import Optional, Union import numpy import pandas as pd @@ -23,8 +24,12 @@ class LmEvalConfig(BaseModel): task: str = "gsm8k" num_fewshot: int = 5 limit: int = 1000 - metrics: dict batch_size: int = 100 + # Recovery testing (default): compare against base model performance + # Default threshold is 0.95 (retain ≥95% of base), can be overridden + recovery_threshold: Union[float, dict] = 0.95 + # Optional absolute metrics for warnings (not failures) + metrics: Optional[dict] = None try: From c2f06f60673c17a89df77da0511b58c7200594dd Mon Sep 17 00:00:00 2001 From: rahul-tuli Date: Wed, 1 Oct 2025 09:51:00 +0000 Subject: [PATCH 2/9] Add documentation for recovery testing Update class docstring to document recovery testing behavior and configuration options. Add logging to display recovery threshold and metrics configuration at test startup. Signed-off-by: Rahul Tuli Signed-off-by: rahul-tuli --- tests/lmeval/test_lmeval.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/lmeval/test_lmeval.py b/tests/lmeval/test_lmeval.py index 0ca9452d7..3be4af2a7 100644 --- a/tests/lmeval/test_lmeval.py +++ b/tests/lmeval/test_lmeval.py @@ -67,6 +67,16 @@ class TestLMEval: or another identifier which can be used for the particular test case. If a recipe is not provided, it is assumed that the scheme provided is a preset scheme and will be used for quantization. Otherwise, the recipe will always be used if given. + + Recovery Testing (DEFAULT): + Tests now use recovery-based validation by default, comparing compressed model + performance against the base model. Default threshold is 0.95 (≥95% recovery). + + Config options: + - recovery_threshold: 0.95 (default if not specified) + - recovery_threshold: 0.93 (override default globally) + - recovery_threshold: {"metric1": 0.95, "metric2": 0.90} (per-metric) + - metrics: {...} (optional - used for warnings only, not failures) """ # noqa: E501 def set_up(self, test_data_file: str): @@ -94,6 +104,11 @@ def set_up(self, test_data_file: str): logger.info("========== RUNNING ==============") logger.info(self.scheme) + logger.info( + f"Recovery threshold: {self.lmeval.recovery_threshold} (default: 0.95)" + ) + if self.lmeval.metrics: + logger.info("Absolute metrics provided - will show warnings if outside ±5%") self.num_calibration_samples = eval_config.get("num_calibration_samples", 512) self.max_seq_length = 2048 From b945dba15dba7294d08ef1138d2dba8e708ffed5 Mon Sep 17 00:00:00 2001 From: rahul-tuli Date: Wed, 1 Oct 2025 09:51:14 +0000 Subject: [PATCH 3/9] Add base model evaluation method Implement _eval_base_model() to evaluate the uncompressed model using lm_eval. This provides baseline metrics for recovery testing. Signed-off-by: Rahul Tuli Signed-off-by: rahul-tuli --- tests/lmeval/test_lmeval.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/tests/lmeval/test_lmeval.py b/tests/lmeval/test_lmeval.py index 3be4af2a7..fb76e0ee3 100644 --- a/tests/lmeval/test_lmeval.py +++ b/tests/lmeval/test_lmeval.py @@ -144,6 +144,24 @@ def test_lm_eval(self, test_data_file: str): self.tear_down() + @log_time + def _eval_base_model(self): + """Evaluate the base (uncompressed) model.""" + model_args = {"pretrained": self.model} + model_args.update(self.lmeval.model_args) + + results = lm_eval.simple_evaluate( + model=self.lmeval.model, + model_args=model_args, + tasks=[self.lmeval.task], + num_fewshot=self.lmeval.num_fewshot, + limit=self.lmeval.limit, + device="cuda:0", + batch_size=self.lmeval.batch_size, + ) + + return results + @log_time def _save_compressed_model(self, oneshot_model, processor): oneshot_model.save_pretrained(self.save_dir) From 2371362fbe07728b3f271e1fd2aab3baf20581d5 Mon Sep 17 00:00:00 2001 From: rahul-tuli Date: Wed, 1 Oct 2025 09:51:29 +0000 Subject: [PATCH 4/9] Update test flow to always evaluate base model Modify test_lm_eval() to always call _eval_base_model() and store results in self.base_results for recovery testing. Update log messages to distinguish between base and compressed model evaluation phases. Signed-off-by: Rahul Tuli Signed-off-by: rahul-tuli --- tests/lmeval/test_lmeval.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/lmeval/test_lmeval.py b/tests/lmeval/test_lmeval.py index fb76e0ee3..58d2aeca1 100644 --- a/tests/lmeval/test_lmeval.py +++ b/tests/lmeval/test_lmeval.py @@ -117,6 +117,10 @@ def test_lm_eval(self, test_data_file: str): # Run vLLM with saved model self.set_up(test_data_file) + # Always evaluate base model for recovery testing + logger.info("================= Evaluating BASE model ======================") + self.base_results = self._eval_base_model() + if not self.save_dir: self.save_dir = self.model.split("/")[1] + f"-{self.scheme}" oneshot_model, processor = run_oneshot_for_e2e_testing( @@ -139,7 +143,7 @@ def test_lm_eval(self, test_data_file: str): # Reset session for next test case self._handle_recipe() - logger.info("================= Running LM Eval ======================") + logger.info("================= Running LM Eval on COMPRESSED model ==========") self._run_lm_eval() self.tear_down() From 1cdceb376d9eecb9b757da62141a0d06c191d3a9 Mon Sep 17 00:00:00 2001 From: rahul-tuli Date: Wed, 1 Oct 2025 09:51:57 +0000 Subject: [PATCH 5/9] Add recovery validation logic Implement _validate_recovery() to compare compressed model metrics against base model metrics. Features: - Computes recovery ratio (compressed/base or inverted for lower-is-better) - Supports both global and per-metric thresholds - Direction-aware handling (higher/lower is better) - Handles edge cases (zero values, missing metrics) - Fails test if any metric doesn't meet threshold Call recovery validation from _run_lm_eval() as the primary test validation mechanism. Signed-off-by: Rahul Tuli Signed-off-by: rahul-tuli --- tests/lmeval/test_lmeval.py | 80 +++++++++++++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) diff --git a/tests/lmeval/test_lmeval.py b/tests/lmeval/test_lmeval.py index 58d2aeca1..0ecd31a52 100644 --- a/tests/lmeval/test_lmeval.py +++ b/tests/lmeval/test_lmeval.py @@ -194,6 +194,86 @@ def _run_lm_eval(self): batch_size=self.lmeval.batch_size, ) + # Always use recovery testing + self._validate_recovery(results) + + # If absolute metrics provided, show warnings (not failures) + if self.lmeval.metrics: + self._check_absolute_warnings(results) + + def _validate_recovery(self, compressed_results): + """Validate using recovery testing - compare against base model.""" + base_metrics = self.base_results["results"][self.lmeval.task] + compressed_metrics = compressed_results["results"][self.lmeval.task] + higher_is_better_map = compressed_results["higher_is_better"][self.lmeval.task] + + # Get recovery threshold(s) + recovery_threshold = self.lmeval.recovery_threshold + is_dict = isinstance(recovery_threshold, dict) + + logger.info("=" * 80) + logger.info("RECOVERY TESTING COMPARISON") + logger.info("=" * 80) + + failures = [] + for metric_key, base_val in base_metrics.items(): + # Skip stderr and other metadata + if "stderr" in metric_key or metric_key.startswith("alias"): + continue + + compressed_val = compressed_metrics.get(metric_key) + if compressed_val is None: + continue + + # Get threshold for this metric + if is_dict: + threshold = recovery_threshold.get(metric_key, 0.95) + else: + threshold = recovery_threshold + + # Get direction + base_metric_name = metric_key.split(",")[0] + higher_is_better = higher_is_better_map.get(base_metric_name, True) + + # Compute recovery + if base_val == 0: + recovery = 1.0 if compressed_val == 0 else 0.0 + elif higher_is_better: + recovery = compressed_val / base_val + else: + # For "lower is better", invert ratio + recovery = base_val / compressed_val + + # Check threshold + passed = recovery >= threshold + direction = "↑" if higher_is_better else "↓" + + msg = ( + f"{metric_key:40} | Base: {base_val:.4f} | " + f"Compressed: {compressed_val:.4f} | " + f"Recovery: {recovery:6.2%} {direction} | Threshold: ≥{threshold:.2%}" + ) + + if passed: + logger.info(f"✓ {msg}") + else: + logger.error(f"✗ {msg}") + failures.append( + f"{metric_key}: {recovery:.2%} < {threshold:.2%} " + f"(base={base_val:.4f}, compressed={compressed_val:.4f})" + ) + + logger.info("=" * 80) + + if failures: + failure_msg = "\n".join(failures) + raise AssertionError(f"Recovery testing failed:\n{failure_msg}") + + logger.info("✓ ALL METRICS PASSED RECOVERY THRESHOLDS") + logger.info("=" * 80) + + def _check_absolute_warnings(self, results): + """Check absolute metrics and warn if outside ±5% tolerance (not a failure).""" metrics: dict = results["results"][self.lmeval.task] for metric_key, expected_val in self.lmeval.metrics.items(): # stderr metrics are only used as absolute tolerance From 023ccf18c220ebc62372f8d64609911bf8c94efe Mon Sep 17 00:00:00 2001 From: rahul-tuli Date: Wed, 1 Oct 2025 09:52:25 +0000 Subject: [PATCH 6/9] Convert metrics validation to warnings-only MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace assertion-based metrics validation with warning-only checks. The metrics dict is now optional and used for informational purposes rather than test failures. This provides backward compatibility while making recovery testing the primary validation mechanism. Changes: - Remove stderr-based tolerance logic (unused) - Standardize on ±5% relative tolerance - Log warnings instead of raising assertions - Skip missing metrics gracefully - Direction-aware warnings (higher vs lower is better) Signed-off-by: Rahul Tuli Signed-off-by: rahul-tuli --- tests/lmeval/test_lmeval.py | 58 ++++++++++++++++++------------------- 1 file changed, 28 insertions(+), 30 deletions(-) diff --git a/tests/lmeval/test_lmeval.py b/tests/lmeval/test_lmeval.py index 0ecd31a52..4e3ef962a 100644 --- a/tests/lmeval/test_lmeval.py +++ b/tests/lmeval/test_lmeval.py @@ -274,46 +274,44 @@ def _validate_recovery(self, compressed_results): def _check_absolute_warnings(self, results): """Check absolute metrics and warn if outside ±5% tolerance (not a failure).""" + logger.info("=" * 80) + logger.info("ABSOLUTE METRICS CHECK (warnings only, not failures)") + logger.info("=" * 80) + metrics: dict = results["results"][self.lmeval.task] for metric_key, expected_val in self.lmeval.metrics.items(): - # stderr metrics are only used as absolute tolerance - # checks for actual values + # Skip stderr metrics if "stderr" in metric_key: continue + actual_val = metrics.get(metric_key) + if actual_val is None: + continue + higher_is_better = results["higher_is_better"][self.lmeval.task].get( metric_key.split(",")[0], True ) - stderr_key = metric_key.replace(",", "_stderr,") - std_err = self.lmeval.metrics.get(stderr_key) - - # If stderr is provided, use it as absolute tolerance - # Otherwise, default to a 5% relative tolerance - if std_err is None: - logger.info( - f"Comparing {metric_key}: Expecting {expected_val} " - f"relative tolerance ±5%, Got {actual_val}. " - f"Higher is better: {higher_is_better}" - ) - # If higher is better, assert actual val >= expected val * (1 - stderr) - if higher_is_better: - assert actual_val >= expected_val * (0.95) - # If higher is worse, assert actual val <= expected val * (1 + stderr) - else: - assert actual_val <= expected_val * (1.05) + # Check if within ±5% relative tolerance + lower_bound = expected_val * 0.95 + upper_bound = expected_val * 1.05 + + if higher_is_better: + # For higher is better, we care about lower bound + if actual_val < lower_bound: + logger.warning( + f"⚠ {metric_key:40} | Expected: {expected_val:.4f} (±5%) | " + f"Got: {actual_val:.4f} | Below expected range" + ) else: - logger.info( - f"Comparing {metric_key}: Expecting {expected_val} " - f"absolute tolerance ±{std_err*100}%, Got {actual_val}. " - f"Higher is better: {higher_is_better}" - ) - # If higher is better, assert actual val >= expected val - stderr - if higher_is_better: - assert actual_val >= expected_val - std_err - # If higher is worse, assert actual val <= expected val + stderr - else: - assert actual_val <= expected_val + std_err + # For lower is better, we care about upper bound + if actual_val > upper_bound: + logger.warning( + f"⚠ {metric_key:40} | Expected: {expected_val:.4f} (±5%) | " + f"Got: {actual_val:.4f} | Above expected range" + ) + + logger.info("=" * 80) def tear_down(self): timer = get_singleton_manager() From b5efda0bd6e3601c7a50b7c20637fc7143ac982c Mon Sep 17 00:00:00 2001 From: rahul-tuli Date: Wed, 1 Oct 2025 11:14:43 +0000 Subject: [PATCH 7/9] Set recovery thresholds for NVFP4 config NVFP4 scheme uses 4-bit weights and activations, resulting in lower recovery than FP8/INT8. Set per-metric thresholds based on observed values to allow ~89-92% recovery. Signed-off-by: rahul-tuli --- tests/lmeval/configs/w4a4_nvfp4.yaml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/lmeval/configs/w4a4_nvfp4.yaml b/tests/lmeval/configs/w4a4_nvfp4.yaml index a44c99894..8862b5457 100644 --- a/tests/lmeval/configs/w4a4_nvfp4.yaml +++ b/tests/lmeval/configs/w4a4_nvfp4.yaml @@ -5,6 +5,12 @@ dataset_id: HuggingFaceH4/ultrachat_200k dataset_split: train_sft num_calibration_samples: 20 lmeval: + # NVFP4 (4-bit weights + 4-bit activations) has lower recovery than FP8/INT8 + # Observed: strict-match ~92.81%, flexible-extract ~89.59% + recovery_threshold: + exact_match,strict-match: 0.92 + exact_match,flexible-extract: 0.89 + # Absolute metrics for warnings only metrics: exact_match,flexible-extract: 0.70 exact_match,strict-match: 0.65 From 87a2f1548b2641583ef8d9809c08bc029e8896dc Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Wed, 1 Oct 2025 19:31:23 +0530 Subject: [PATCH 8/9] Update tests/lmeval/test_lmeval.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- tests/lmeval/test_lmeval.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/lmeval/test_lmeval.py b/tests/lmeval/test_lmeval.py index 4e3ef962a..a47aa1ea5 100644 --- a/tests/lmeval/test_lmeval.py +++ b/tests/lmeval/test_lmeval.py @@ -151,8 +151,7 @@ def test_lm_eval(self, test_data_file: str): @log_time def _eval_base_model(self): """Evaluate the base (uncompressed) model.""" - model_args = {"pretrained": self.model} - model_args.update(self.lmeval.model_args) + model_args = {**self.lmeval.model_args, "pretrained": self.model} results = lm_eval.simple_evaluate( model=self.lmeval.model, From 0922265c8ac566a3ea727b5736c89ba5df742979 Mon Sep 17 00:00:00 2001 From: rahul-tuli Date: Wed, 1 Oct 2025 15:57:28 +0000 Subject: [PATCH 9/9] Address code review feedback Changes based on kylesayrs review: - Remove is_dict variable, use isinstance() inline - Replace hardcoded 0.95 with model_fields default value - Add safety checks for higher_is_better_map existence - Refactor to iterate over compressed_metrics instead of base_metrics - Add warnings when config has keys not found in results - Validate recovery_threshold dict keys against actual results This eliminates code duplication, improves robustness, and ensures configs are validated against actual test results. Signed-off-by: Rahul Tuli --- tests/lmeval/test_lmeval.py | 49 +++++++++++++++++++++++++++---------- 1 file changed, 36 insertions(+), 13 deletions(-) diff --git a/tests/lmeval/test_lmeval.py b/tests/lmeval/test_lmeval.py index a47aa1ea5..a44cd042f 100644 --- a/tests/lmeval/test_lmeval.py +++ b/tests/lmeval/test_lmeval.py @@ -204,31 +204,39 @@ def _validate_recovery(self, compressed_results): """Validate using recovery testing - compare against base model.""" base_metrics = self.base_results["results"][self.lmeval.task] compressed_metrics = compressed_results["results"][self.lmeval.task] - higher_is_better_map = compressed_results["higher_is_better"][self.lmeval.task] - - # Get recovery threshold(s) - recovery_threshold = self.lmeval.recovery_threshold - is_dict = isinstance(recovery_threshold, dict) + higher_is_better_map = compressed_results.get("higher_is_better", {}).get( + self.lmeval.task, {} + ) logger.info("=" * 80) logger.info("RECOVERY TESTING COMPARISON") logger.info("=" * 80) + # Get default threshold from config schema + default_threshold = self.lmeval.model_fields["recovery_threshold"].default + failures = [] - for metric_key, base_val in base_metrics.items(): + # Iterate over compressed metrics (what we actually got) + for metric_key, compressed_val in compressed_metrics.items(): # Skip stderr and other metadata if "stderr" in metric_key or metric_key.startswith("alias"): continue - compressed_val = compressed_metrics.get(metric_key) - if compressed_val is None: + base_val = base_metrics.get(metric_key) + if base_val is None: + logger.warning( + f"Metric {metric_key} in compressed results " + f"not found in base results, skipping" + ) continue # Get threshold for this metric - if is_dict: - threshold = recovery_threshold.get(metric_key, 0.95) + if isinstance(self.lmeval.recovery_threshold, dict): + threshold = self.lmeval.recovery_threshold.get( + metric_key, default_threshold + ) else: - threshold = recovery_threshold + threshold = self.lmeval.recovery_threshold # Get direction base_metric_name = metric_key.split(",")[0] @@ -262,6 +270,15 @@ def _validate_recovery(self, compressed_results): f"(base={base_val:.4f}, compressed={compressed_val:.4f})" ) + # Validate that config thresholds match actual results + if isinstance(self.lmeval.recovery_threshold, dict): + for config_metric_key in self.lmeval.recovery_threshold.keys(): + if config_metric_key not in compressed_metrics: + logger.warning( + f"Metric {config_metric_key} in recovery_threshold config " + f"not found in results" + ) + logger.info("=" * 80) if failures: @@ -285,10 +302,16 @@ def _check_absolute_warnings(self, results): actual_val = metrics.get(metric_key) if actual_val is None: + logger.warning( + f"Metric {metric_key} in config not found in results, " + f"skipping warning check" + ) continue - higher_is_better = results["higher_is_better"][self.lmeval.task].get( - metric_key.split(",")[0], True + higher_is_better = ( + results.get("higher_is_better", {}) + .get(self.lmeval.task, {}) + .get(metric_key.split(",")[0], True) ) # Check if within ±5% relative tolerance