2
2
import random
3
3
import shutil
4
4
from pathlib import Path
5
+ from typing import Optional , Union
5
6
6
7
import numpy
7
8
import pandas as pd
@@ -23,8 +24,12 @@ class LmEvalConfig(BaseModel):
23
24
task : str = "gsm8k"
24
25
num_fewshot : int = 5
25
26
limit : int = 1000
26
- metrics : dict
27
27
batch_size : int = 100
28
+ # Recovery testing (default): compare against base model performance
29
+ # Default threshold is 0.95 (retain ≥95% of base), can be overridden
30
+ recovery_threshold : Union [float , dict ] = 0.95
31
+ # Optional absolute metrics for warnings (not failures)
32
+ metrics : Optional [dict ] = None
28
33
29
34
30
35
try :
@@ -62,6 +67,16 @@ class TestLMEval:
62
67
or another identifier which can be used for the particular test case. If a recipe
63
68
is not provided, it is assumed that the scheme provided is a preset scheme and will
64
69
be used for quantization. Otherwise, the recipe will always be used if given.
70
+
71
+ Recovery Testing (DEFAULT):
72
+ Tests now use recovery-based validation by default, comparing compressed model
73
+ performance against the base model. Default threshold is 0.95 (≥95% recovery).
74
+
75
+ Config options:
76
+ - recovery_threshold: 0.95 (default if not specified)
77
+ - recovery_threshold: 0.93 (override default globally)
78
+ - recovery_threshold: {"metric1": 0.95, "metric2": 0.90} (per-metric)
79
+ - metrics: {...} (optional - used for warnings only, not failures)
65
80
""" # noqa: E501
66
81
67
82
def set_up (self , test_data_file : str ):
@@ -89,6 +104,11 @@ def set_up(self, test_data_file: str):
89
104
90
105
logger .info ("========== RUNNING ==============" )
91
106
logger .info (self .scheme )
107
+ logger .info (
108
+ f"Recovery threshold: { self .lmeval .recovery_threshold } (default: 0.95)"
109
+ )
110
+ if self .lmeval .metrics :
111
+ logger .info ("Absolute metrics provided - will show warnings if outside ±5%" )
92
112
93
113
self .num_calibration_samples = eval_config .get ("num_calibration_samples" , 512 )
94
114
self .max_seq_length = 2048
@@ -97,6 +117,10 @@ def test_lm_eval(self, test_data_file: str):
97
117
# Run vLLM with saved model
98
118
self .set_up (test_data_file )
99
119
120
+ # Always evaluate base model for recovery testing
121
+ logger .info ("================= Evaluating BASE model ======================" )
122
+ self .base_results = self ._eval_base_model ()
123
+
100
124
if not self .save_dir :
101
125
self .save_dir = self .model .split ("/" )[1 ] + f"-{ self .scheme } "
102
126
oneshot_model , processor = run_oneshot_for_e2e_testing (
@@ -119,11 +143,28 @@ def test_lm_eval(self, test_data_file: str):
119
143
# Reset session for next test case
120
144
self ._handle_recipe ()
121
145
122
- logger .info ("================= Running LM Eval ============ ==========" )
146
+ logger .info ("================= Running LM Eval on COMPRESSED model ==========" )
123
147
self ._run_lm_eval ()
124
148
125
149
self .tear_down ()
126
150
151
+ @log_time
152
+ def _eval_base_model (self ):
153
+ """Evaluate the base (uncompressed) model."""
154
+ model_args = {** self .lmeval .model_args , "pretrained" : self .model }
155
+
156
+ results = lm_eval .simple_evaluate (
157
+ model = self .lmeval .model ,
158
+ model_args = model_args ,
159
+ tasks = [self .lmeval .task ],
160
+ num_fewshot = self .lmeval .num_fewshot ,
161
+ limit = self .lmeval .limit ,
162
+ device = "cuda:0" ,
163
+ batch_size = self .lmeval .batch_size ,
164
+ )
165
+
166
+ return results
167
+
127
168
@log_time
128
169
def _save_compressed_model (self , oneshot_model , processor ):
129
170
oneshot_model .save_pretrained (self .save_dir )
@@ -152,46 +193,147 @@ def _run_lm_eval(self):
152
193
batch_size = self .lmeval .batch_size ,
153
194
)
154
195
196
+ # Always use recovery testing
197
+ self ._validate_recovery (results )
198
+
199
+ # If absolute metrics provided, show warnings (not failures)
200
+ if self .lmeval .metrics :
201
+ self ._check_absolute_warnings (results )
202
+
203
+ def _validate_recovery (self , compressed_results ):
204
+ """Validate using recovery testing - compare against base model."""
205
+ base_metrics = self .base_results ["results" ][self .lmeval .task ]
206
+ compressed_metrics = compressed_results ["results" ][self .lmeval .task ]
207
+ higher_is_better_map = compressed_results .get ("higher_is_better" , {}).get (
208
+ self .lmeval .task , {}
209
+ )
210
+
211
+ logger .info ("=" * 80 )
212
+ logger .info ("RECOVERY TESTING COMPARISON" )
213
+ logger .info ("=" * 80 )
214
+
215
+ # Get default threshold from config schema
216
+ default_threshold = self .lmeval .model_fields ["recovery_threshold" ].default
217
+
218
+ failures = []
219
+ # Iterate over compressed metrics (what we actually got)
220
+ for metric_key , compressed_val in compressed_metrics .items ():
221
+ # Skip stderr and other metadata
222
+ if "stderr" in metric_key or metric_key .startswith ("alias" ):
223
+ continue
224
+
225
+ base_val = base_metrics .get (metric_key )
226
+ if base_val is None :
227
+ logger .warning (
228
+ f"Metric { metric_key } in compressed results "
229
+ f"not found in base results, skipping"
230
+ )
231
+ continue
232
+
233
+ # Get threshold for this metric
234
+ if isinstance (self .lmeval .recovery_threshold , dict ):
235
+ threshold = self .lmeval .recovery_threshold .get (
236
+ metric_key , default_threshold
237
+ )
238
+ else :
239
+ threshold = self .lmeval .recovery_threshold
240
+
241
+ # Get direction
242
+ base_metric_name = metric_key .split ("," )[0 ]
243
+ higher_is_better = higher_is_better_map .get (base_metric_name , True )
244
+
245
+ # Compute recovery
246
+ if base_val == 0 :
247
+ recovery = 1.0 if compressed_val == 0 else 0.0
248
+ elif higher_is_better :
249
+ recovery = compressed_val / base_val
250
+ else :
251
+ # For "lower is better", invert ratio
252
+ recovery = base_val / compressed_val
253
+
254
+ # Check threshold
255
+ passed = recovery >= threshold
256
+ direction = "↑" if higher_is_better else "↓"
257
+
258
+ msg = (
259
+ f"{ metric_key :40} | Base: { base_val :.4f} | "
260
+ f"Compressed: { compressed_val :.4f} | "
261
+ f"Recovery: { recovery :6.2%} { direction } | Threshold: ≥{ threshold :.2%} "
262
+ )
263
+
264
+ if passed :
265
+ logger .info (f"✓ { msg } " )
266
+ else :
267
+ logger .error (f"✗ { msg } " )
268
+ failures .append (
269
+ f"{ metric_key } : { recovery :.2%} < { threshold :.2%} "
270
+ f"(base={ base_val :.4f} , compressed={ compressed_val :.4f} )"
271
+ )
272
+
273
+ # Validate that config thresholds match actual results
274
+ if isinstance (self .lmeval .recovery_threshold , dict ):
275
+ for config_metric_key in self .lmeval .recovery_threshold .keys ():
276
+ if config_metric_key not in compressed_metrics :
277
+ logger .warning (
278
+ f"Metric { config_metric_key } in recovery_threshold config "
279
+ f"not found in results"
280
+ )
281
+
282
+ logger .info ("=" * 80 )
283
+
284
+ if failures :
285
+ failure_msg = "\n " .join (failures )
286
+ raise AssertionError (f"Recovery testing failed:\n { failure_msg } " )
287
+
288
+ logger .info ("✓ ALL METRICS PASSED RECOVERY THRESHOLDS" )
289
+ logger .info ("=" * 80 )
290
+
291
+ def _check_absolute_warnings (self , results ):
292
+ """Check absolute metrics and warn if outside ±5% tolerance (not a failure)."""
293
+ logger .info ("=" * 80 )
294
+ logger .info ("ABSOLUTE METRICS CHECK (warnings only, not failures)" )
295
+ logger .info ("=" * 80 )
296
+
155
297
metrics : dict = results ["results" ][self .lmeval .task ]
156
298
for metric_key , expected_val in self .lmeval .metrics .items ():
157
- # stderr metrics are only used as absolute tolerance
158
- # checks for actual values
299
+ # Skip stderr metrics
159
300
if "stderr" in metric_key :
160
301
continue
302
+
161
303
actual_val = metrics .get (metric_key )
162
- higher_is_better = results ["higher_is_better" ][self .lmeval .task ].get (
163
- metric_key .split ("," )[0 ], True
164
- )
165
- stderr_key = metric_key .replace ("," , "_stderr," )
166
- std_err = self .lmeval .metrics .get (stderr_key )
167
-
168
- # If stderr is provided, use it as absolute tolerance
169
- # Otherwise, default to a 5% relative tolerance
170
- if std_err is None :
171
- logger .info (
172
- f"Comparing { metric_key } : Expecting { expected_val } "
173
- f"relative tolerance ±5%, Got { actual_val } . "
174
- f"Higher is better: { higher_is_better } "
304
+ if actual_val is None :
305
+ logger .warning (
306
+ f"Metric { metric_key } in config not found in results, "
307
+ f"skipping warning check"
175
308
)
176
- # If higher is better, assert actual val >= expected val * (1 - stderr)
177
- if higher_is_better :
178
- assert actual_val >= expected_val * (0.95 )
179
- # If higher is worse, assert actual val <= expected val * (1 + stderr)
180
- else :
181
- assert actual_val <= expected_val * (1.05 )
309
+ continue
310
+
311
+ higher_is_better = (
312
+ results .get ("higher_is_better" , {})
313
+ .get (self .lmeval .task , {})
314
+ .get (metric_key .split ("," )[0 ], True )
315
+ )
182
316
317
+ # Check if within ±5% relative tolerance
318
+ lower_bound = expected_val * 0.95
319
+ upper_bound = expected_val * 1.05
320
+
321
+ if higher_is_better :
322
+ # For higher is better, we care about lower bound
323
+ if actual_val < lower_bound :
324
+ logger .warning (
325
+ f"⚠ { metric_key :40} | Expected: { expected_val :.4f} (±5%) | "
326
+ f"Got: { actual_val :.4f} | Below expected range"
327
+ )
183
328
else :
184
- logger .info (
185
- f"Comparing { metric_key } : Expecting { expected_val } "
186
- f"absolute tolerance ±{ std_err * 100 } %, Got { actual_val } . "
187
- f"Higher is better: { higher_is_better } "
188
- )
189
- # If higher is better, assert actual val >= expected val - stderr
190
- if higher_is_better :
191
- assert actual_val >= expected_val - std_err
192
- # If higher is worse, assert actual val <= expected val + stderr
193
- else :
194
- assert actual_val <= expected_val + std_err
329
+ # For lower is better, we care about upper bound
330
+ if actual_val > upper_bound :
331
+ logger .warning (
332
+ f"⚠ { metric_key :40} | Expected: { expected_val :.4f} (±5%) | "
333
+ f"Got: { actual_val :.4f} | Above expected range"
334
+ )
335
+
336
+ logger .info ("=" * 80 )
195
337
196
338
def tear_down (self ):
197
339
timer = get_singleton_manager ()
0 commit comments