Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/heretic/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,9 @@ def get_score(self) -> tuple[tuple[float, float], float, int]:
kl_divergence_scale = self.settings.kl_divergence_scale
kl_divergence_target = self.settings.kl_divergence_target

refusals_score = refusals / self.base_refusals
refusals_score = (
refusals / self.base_refusals if self.base_refusals > 0 else refusals
)

if kl_divergence >= kl_divergence_target:
kld_score = kl_divergence / kl_divergence_scale
Expand Down
20 changes: 14 additions & 6 deletions src/heretic/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,16 @@
)


def obtain_merge_strategy(settings: Settings) -> str | None:
def obtain_merge_strategy(settings: Settings, model: Model) -> str | None:
"""
Prompts the user for how to proceed with saving the model.
Provides info to the user if the model is quantized on memory use.
Returns "merge", "adapter", or None (if cancelled/invalid).
"""

if settings.quantization == QuantizationMethod.BNB_4BIT:
is_quantized = getattr(model.model.config, "quantization_config", None) is not None
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this work for models quantized on-the-fly by Heretic, i.e., with bitsandbytes?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes — verified with both models below. The model.config.quantization_config check correctly detects BNB models since HuggingFace stores the BitsAndBytesConfig there on load.


if is_quantized:
print()
print(
"Model was loaded with quantization. Merging requires reloading the base model."
Expand Down Expand Up @@ -174,9 +176,15 @@ def run():
# Adapted from https://github.com/huggingface/accelerate/blob/main/src/accelerate/commands/env.py
if torch.cuda.is_available():
count = torch.cuda.device_count()
print(f"Detected [bold]{count}[/] CUDA device(s):")
total_vram = sum(torch.cuda.mem_get_info(i)[1] for i in range(count))
print(
f"Detected [bold]{count}[/] CUDA device(s) ({total_vram / (1024**3):.2f} GB total VRAM):"
)
for i in range(count):
print(f"* GPU {i}: [bold]{torch.cuda.get_device_name(i)}[/]")
vram = torch.cuda.mem_get_info(i)[1] / (1024**3)
print(
f"* GPU {i}: [bold]{torch.cuda.get_device_name(i)}[/] ({vram:.2f} GB)"
)
elif is_xpu_available():
count = torch.xpu.device_count()
print(f"Detected [bold]{count}[/] XPU device(s):")
Expand Down Expand Up @@ -747,7 +755,7 @@ def count_completed_trials() -> int:
if not save_directory:
continue

strategy = obtain_merge_strategy(settings)
strategy = obtain_merge_strategy(settings, model)
if strategy is None:
continue

Expand Down Expand Up @@ -796,7 +804,7 @@ def count_completed_trials() -> int:
)
private = visibility == "Private"

strategy = obtain_merge_strategy(settings)
strategy = obtain_merge_strategy(settings, model)
if strategy is None:
continue

Expand Down
6 changes: 4 additions & 2 deletions src/heretic/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,10 @@ def get_merged_model(self) -> PreTrainedModel:
# Guard against calling this method at the wrong time.
assert isinstance(self.model, PeftModel)

# Check if we need special handling for quantized models
if self.settings.quantization == QuantizationMethod.BNB_4BIT:
# Check if we need special handling for quantized models.
# This covers both on-the-fly quantization (e.g. BNB_4BIT) and pre-quantized
# models (e.g. FP8, MXFP4) — both set quantization_config on the model config.
if getattr(self.model.config, "quantization_config", None) is not None:
# Quantized models need special handling - we must reload the base model
# in full precision to merge the LoRA adapters

Expand Down