|
18 | 18 | from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Sized, Tuple, Union |
19 | 19 |
|
20 | 20 | from neural_compressor.common.base_config import BaseConfig |
21 | | -from neural_compressor.common.utils import TuningLogger, logger |
| 21 | +from neural_compressor.common.utils import Statistics, TuningLogger, logger |
22 | 22 |
|
23 | 23 | __all__ = [ |
24 | 24 | "Evaluator", |
@@ -423,6 +423,47 @@ def add_trial_result(self, trial_index: int, trial_result: Union[int, float], qu |
423 | 423 | trial_record = _TrialRecord(trial_index, trial_result, quant_config) |
424 | 424 | self.tuning_history.append(trial_record) |
425 | 425 |
|
| 426 | + # Print tuning results table |
| 427 | + self._print_trial_results_table(trial_index, trial_result) |
| 428 | + |
| 429 | + def _print_trial_results_table(self, trial_index: int, trial_result: Union[int, float]) -> None: |
| 430 | + """Print trial results in a formatted table using Statistics class.""" |
| 431 | + baseline_val = self.baseline if self.baseline is not None else 0.0 |
| 432 | + baseline_str = f"{baseline_val:.4f}" if self.baseline is not None else "N/A" |
| 433 | + target_threshold_str = ( |
| 434 | + f"{baseline_val * (1 - self.tuning_config.tolerable_loss):.4f}" if self.baseline is not None else "N/A" |
| 435 | + ) |
| 436 | + |
| 437 | + # Calculate relative loss if baseline is available |
| 438 | + relative_loss_val = 0.0 |
| 439 | + relative_loss_str = "N/A" |
| 440 | + if self.baseline is not None: |
| 441 | + relative_loss_val = (baseline_val - trial_result) / baseline_val |
| 442 | + relative_loss_str = f"{relative_loss_val*100:.2f}%" |
| 443 | + |
| 444 | + # Get best result so far |
| 445 | + best_result = max(record.trial_result for record in self.tuning_history) |
| 446 | + |
| 447 | + # Status indicator with emoji |
| 448 | + if self.baseline is not None and trial_result >= (baseline_val * (1 - self.tuning_config.tolerable_loss)): |
| 449 | + status = "✅ PASSED" |
| 450 | + else: |
| 451 | + status = "❌ FAILED" |
| 452 | + |
| 453 | + # Prepare data for Statistics table with combined fields |
| 454 | + field_names = ["📊 Metric", "📈 Value"] |
| 455 | + output_data = [ |
| 456 | + ["Trial / Progress", f"{len(self.tuning_history)}/{self.tuning_config.max_trials}"], |
| 457 | + ["Baseline / Target", f"{baseline_str} / {target_threshold_str}"], |
| 458 | + ["Current / Status", f"{trial_result:.4f} | {status}"], |
| 459 | + ["Best / Relative Loss", f"{best_result:.4f} / {relative_loss_str}"], |
| 460 | + ] |
| 461 | + |
| 462 | + # Use Statistics class to print the table |
| 463 | + Statistics( |
| 464 | + output_data, header=f"🎯 Auto-Tune Trial #{trial_index} Results", field_names=field_names |
| 465 | + ).print_stat() |
| 466 | + |
426 | 467 | def set_baseline(self, baseline: float): |
427 | 468 | """Set the baseline value for auto-tune. |
428 | 469 |
|
@@ -488,4 +529,10 @@ def init_tuning(tuning_config: TuningConfig) -> Tuple[ConfigLoader, TuningLogger |
488 | 529 | config_loader = ConfigLoader(config_set=tuning_config.config_set, sampler=tuning_config.sampler) |
489 | 530 | tuning_logger = TuningLogger() |
490 | 531 | tuning_monitor = TuningMonitor(tuning_config) |
| 532 | + |
| 533 | + # Update max_trials based on actual number of available configurations |
| 534 | + actual_config_count = len(config_loader.config_set) |
| 535 | + if tuning_config.max_trials > actual_config_count: |
| 536 | + tuning_config.max_trials = actual_config_count |
| 537 | + |
491 | 538 | return config_loader, tuning_logger, tuning_monitor |
0 commit comments