diff --git a/src/heretic/evaluator.py b/src/heretic/evaluator.py index f2a8a258..c2debc7f 100644 --- a/src/heretic/evaluator.py +++ b/src/heretic/evaluator.py @@ -110,7 +110,9 @@ def get_score(self) -> tuple[tuple[float, float], float, int]: kl_divergence_scale = self.settings.kl_divergence_scale kl_divergence_target = self.settings.kl_divergence_target - refusals_score = refusals / self.base_refusals + refusals_score = ( + refusals / self.base_refusals if self.base_refusals > 0 else refusals + ) if kl_divergence >= kl_divergence_target: kld_score = kl_divergence / kl_divergence_scale diff --git a/src/heretic/main.py b/src/heretic/main.py index 016c3920..f18c829c 100644 --- a/src/heretic/main.py +++ b/src/heretic/main.py @@ -54,14 +54,16 @@ ) -def obtain_merge_strategy(settings: Settings) -> str | None: +def obtain_merge_strategy(settings: Settings, model: Model) -> str | None: """ Prompts the user for how to proceed with saving the model. Provides info to the user if the model is quantized on memory use. Returns "merge", "adapter", or None (if cancelled/invalid). """ - if settings.quantization == QuantizationMethod.BNB_4BIT: + is_quantized = getattr(model.model.config, "quantization_config", None) is not None + + if is_quantized: print() print( "Model was loaded with quantization. Merging requires reloading the base model." @@ -174,9 +176,15 @@ def run(): # Adapted from https://github.com/huggingface/accelerate/blob/main/src/accelerate/commands/env.py if torch.cuda.is_available(): count = torch.cuda.device_count() - print(f"Detected [bold]{count}[/] CUDA device(s):") + total_vram = sum(torch.cuda.mem_get_info(i)[1] for i in range(count)) + print( + f"Detected [bold]{count}[/] CUDA device(s) ({total_vram / (1024**3):.2f} GB total VRAM):" + ) for i in range(count): - print(f"* GPU {i}: [bold]{torch.cuda.get_device_name(i)}[/]") + vram = torch.cuda.mem_get_info(i)[1] / (1024**3) + print( + f"* GPU {i}: [bold]{torch.cuda.get_device_name(i)}[/] ({vram:.2f} GB)" + ) elif is_xpu_available(): count = torch.xpu.device_count() print(f"Detected [bold]{count}[/] XPU device(s):") @@ -747,7 +755,7 @@ def count_completed_trials() -> int: if not save_directory: continue - strategy = obtain_merge_strategy(settings) + strategy = obtain_merge_strategy(settings, model) if strategy is None: continue @@ -796,7 +804,7 @@ def count_completed_trials() -> int: ) private = visibility == "Private" - strategy = obtain_merge_strategy(settings) + strategy = obtain_merge_strategy(settings, model) if strategy is None: continue diff --git a/src/heretic/model.py b/src/heretic/model.py index 58300b16..582a5447 100644 --- a/src/heretic/model.py +++ b/src/heretic/model.py @@ -223,8 +223,10 @@ def get_merged_model(self) -> PreTrainedModel: # Guard against calling this method at the wrong time. assert isinstance(self.model, PeftModel) - # Check if we need special handling for quantized models - if self.settings.quantization == QuantizationMethod.BNB_4BIT: + # Check if we need special handling for quantized models. + # This covers both on-the-fly quantization (e.g. BNB_4BIT) and pre-quantized + # models (e.g. FP8, MXFP4) — both set quantization_config on the model config. + if getattr(self.model.config, "quantization_config", None) is not None: # Quantized models need special handling - we must reload the base model # in full precision to merge the LoRA adapters