-
Notifications
You must be signed in to change notification settings - Fork 169
QLoRA DDP export #353
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
QLoRA DDP export #353
Conversation
WalkthroughAdds QLoRA-aware export and tooling: new export script, exporter flag propagation, quant export utilities updated to remap/skip QLoRA keys and conditionally emit per-layer quant configs, LoRA adapter cleanup, trainer hooks for LoRA best-model loading, and enabling a previously skipped QLoRA NVFP4 test. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant User
participant ExportScript as export.py
participant Model as QLoRA Model
participant CalibState as modelopt_state_calib.pth
participant Exporter as _export_hf_checkpoint
participant Utils as postprocess_state_dict
User->>ExportScript: run main(args)
ExportScript->>Model: load base model + LoRA adapters
ExportScript->>CalibState: load modelopt_state_train.pth (if present)
CalibState-->>Model: restore calibration + optional quantizer weights
ExportScript->>Exporter: _export_hf_checkpoint(model, is_modelopt_qlora=True)
Exporter->>Utils: postprocess_state_dict(state_dict, maxbound, quant, is_modelopt_qlora=True)
Utils-->>Exporter: processed_state_dict + hf_quant_config
Exporter-->>ExportScript: return artifacts
ExportScript->>User: write model, adapters, quant config, tokenizer
sequenceDiagram
autonumber
participant Trainer as QATTrainer
participant Model
participant FS as Filesystem
participant Loader as _load_best_model
Trainer->>Model: calibrate quantizers
Trainer->>FS: save modelopt_state_calib.pth
alt load best model (no FSDP, LoRA present)
Trainer->>Loader: _load_best_model(...)
Loader-->>Trainer: custom best-model load (handle adapters)
else
Trainer->>Loader: delegate to superclass
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Poem
Pre-merge checks and finishing touches✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
🔇 Additional comments (2)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
035117f
to
6254cad
Compare
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #353 +/- ##
==========================================
- Coverage 73.79% 73.79% -0.01%
==========================================
Files 171 171
Lines 17591 17591
==========================================
- Hits 12982 12981 -1
- Misses 4609 4610 +1 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
modelopt/torch/export/quant_utils.py (1)
743-745
: Verify dtype casting behavior and potential precision loss.The dtype casting is applied unconditionally when a dtype is provided, which could lead to precision loss or unexpected behavior if the original weight dtype is more precise than the target dtype.
Consider adding a warning or validation:
if dtype: + if weight.dtype != dtype: + logger.info(f"Converting weight from {weight.dtype} to {dtype}") weight = weight.to(dtype)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
examples/llm_qat/README.md
(1 hunks)examples/llm_qat/main.py
(1 hunks)modelopt/torch/export/quant_utils.py
(5 hunks)modelopt/torch/export/unified_export_hf.py
(4 hunks)modelopt/torch/quantization/plugins/transformers_trainer.py
(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (4)
modelopt/torch/quantization/plugins/transformers_trainer.py (1)
modelopt/torch/export/unified_export_hf.py (1)
export_hf_checkpoint
(503-556)
modelopt/torch/export/unified_export_hf.py (3)
modelopt/torch/export/quant_utils.py (2)
maybe_transpose_expert_weight_dimensions
(91-120)to_quantized_weight
(724-790)modelopt/torch/export/layer_utils.py (1)
is_quantlinear
(346-348)modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
is_enabled
(389-391)
modelopt/torch/export/quant_utils.py (1)
modelopt/torch/quantization/qtensor/base_qtensor.py (1)
to
(115-123)
examples/llm_qat/main.py (1)
modelopt/torch/quantization/plugins/transformers_trainer.py (1)
export_base_model
(291-295)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: linux
🔇 Additional comments (14)
modelopt/torch/export/quant_utils.py (4)
730-730
: LGTM! Optional dtype parameter added correctly.The addition of the optional
dtype
parameter follows good API design practices with a sensible default ofNone
.
833-836
: LGTM! Base layer key mappings added appropriately.The new base layer mappings support QLoRA export workflows by handling the transformation from LoRA-specific keys to standard model keys.
847-847
: LGTM! Proper exclusion of base_layer keys.The condition correctly filters out base_layer keys from the main processing loop, which is consistent with the new mappings approach.
899-903
: LGTM! LoRA adapter cleanup implemented correctly.The LoRA adapter removal logic is properly implemented to clean up adapter-specific parameters from the exported state dict, ensuring a clean base model export.
examples/llm_qat/main.py (1)
276-278
: LGTM! QLoRA export integration implemented correctly.The conditional call to
trainer.export_base_model()
is properly guarded by both the LoRA and compression flags, ensuring the base model is only exported when appropriate for QLoRA workflows.modelopt/torch/export/unified_export_hf.py (5)
88-91
: LGTM! Early return for LoRA models prevents unnecessary processing.The early return correctly skips processing for LoRA-finetuned models by detecting the presence of a
base_model
attribute, avoiding potential issues with the requantize/resmooth operations.
329-336
: Consistent dtype parameter usage.The non-NVFP4 quantization path correctly passes the dtype parameter to
to_quantized_weight
, maintaining consistency with the NVFP4 path above.
465-470
: Enhanced guard conditions for quantized weight export.The additional check for
hasattr(sub_module, "weight_quantizer")
andsub_module.weight_quantizer.is_enabled
provides better safety by ensuring quantizers exist and are active before attempting export.
531-536
: LGTM! Proper base model export for QLoRA models.The logic correctly identifies QLoRA models by checking for the
base_model
attribute and exports the underlying base model instead of the wrapper, which is essential for proper deployment.
317-323
: Resolved — internal dtype cast remains; no action needed.
quant_utils.py performsif dtype: weight = weight.to(dtype)
(modelopt/torch/export/quant_utils.py lines 743–744), so removing the pre-cast in unified_export_hf.py does not change quantization behavior.modelopt/torch/quantization/plugins/transformers_trainer.py (3)
31-31
: LGTM! Required import added.The import of
export_hf_checkpoint
is correctly added to support the new export functionality.
279-290
: LoRA-specific best model loading logic implemented correctly.The implementation correctly handles the difference between LoRA and non-LoRA models. For LoRA models, it properly removes and re-loads the adapter from the best checkpoint path.
Note: The TODO comment indicates this is temporary until
get_peft_model()
is used, which aligns with the PR description mentioning temporary fixes.
291-296
: Simple and effective base model export.The implementation correctly calls
export_hf_checkpoint
with the appropriate output directory structure, and the main process check ensures only one process performs the export.examples/llm_qat/README.md (1)
357-362
: Fix vLLM serve example: point --tokenizer to the base/merged model (keep --lora-modules syntax)
- Location: examples/llm_qat/README.md (lines 357–362).
- Replace the example so the served model is the merged or original base model and --tokenizer points to the base-model (or HF name). The current --lora-modules adapter=llama3-fp4-qlora is correct.
- Suggested command (use actual paths/names from your repo):
vllm serve --enable-lora --lora-modules adapter= --tokenizer --tokenizer-mode auto --trust-remote-code --port 8000- Verify tokenizer.pad_token_id == model.config.pad_token_id (or set pad_token = eos_token) to avoid generation/padding issues. Repo check found no base_model directory—confirm exact paths before committing this change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
modelopt/torch/export/quant_utils.py (1)
732-747
: Guard the new dtype pre-cast to prevent fp8/integer pitfallsCasting to an arbitrary dtype before quantization can raise “Promotion for Float8 Types is not supported” or produce invalid math if an integer/bool dtype is passed. Restrict to safe float dtypes and reject fp8/ints.
Apply this diff:
def to_quantized_weight( weight: torch.Tensor, weights_scaling_factor: torch.Tensor, quantization: str, weights_scaling_factor2: torch.Tensor | None = None, block_size: int | None = None, dtype: torch.dtype | None = None, ): @@ - if dtype: - weight = weight.to(dtype) + if dtype is not None: + # Only allow >=16-bit float dtypes here; fp8 and non-floats break downstream ops. + allowed_dtypes = {torch.float32, torch.float16, torch.bfloat16} + disallowed = {getattr(torch, "float8_e4m3fn", None)} + if dtype in disallowed or dtype not in allowed_dtypes: + raise ValueError(f"Unsupported pre-quant cast dtype: {dtype}. Use float32/float16/bfloat16.") + weight = weight.to(dtype)
🧹 Nitpick comments (1)
modelopt/torch/export/quant_utils.py (1)
611-612
: Replace commented-out debug print with logger.debug or removeKeep logs consistent and avoid dead code.
Apply this diff:
- # print(f"DEBUG LOG: Processing layer {k} with quantization {v}, block size {block_size_value}") + logger.debug("Processing layer %s with quantization=%s, block_size=%s", k, v, block_size_value)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
modelopt/torch/export/quant_utils.py
(7 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/torch/export/quant_utils.py (1)
modelopt/torch/quantization/qtensor/base_qtensor.py (1)
to
(115-123)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (1)
modelopt/torch/export/quant_utils.py (1)
835-853
: Base-layer key remap: avoid silent drops and verify collisionsUnmapped base_layer.* keys are dropped by design. If other base attributes exist (e.g., bias), they’ll vanish silently, and remaps may overwrite existing top-level keys.
Please confirm:
- All required base_layer.* fields are covered by replacements.
- Remapped targets (e.g., “weight”, “input_scale”, “weight_scale”) won’t collide with already-present keys.
If needed, I can add a warning when a base_layer key is encountered but not remapped. Want a patch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall implementation looks good. However we dont have to combine the export for QLoRA with transformer_trainer
. We should do the the export via hf_ptq.py
Okay, I will try refactoring the PR to do that! Thank you! |
@sugunav14 sounds good, for example here is how we support regular qat deployment - #353 (comment) I am thinking we should have something like:
For deployment of QAT and QLoRA checkpoint, we still need to specify
For QAT/QLoRA, can we support the following usage:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you test with Phi4-multimodal-instruct export (FP8 and NVFP4) and make sure the before and after the change the safetensors are the same?
To quant Phi4-multimodal-instruct, you need to :
- Download https://huggingface.co/microsoft/Phi-4-multimodal-instruct
- modify https://huggingface.co/microsoft/Phi-4-multimodal-instruct/blob/main/modeling_phi4mm.py#L2101 and enforce using InputMode.LANGUAGE.
- use transformers 4.48
- Run FP8 and NVFP4 PTQ
- Compare the generated safetensors with https://huggingface.co/nvidia/Phi-4-multimodal-instruct-FP4 and https://huggingface.co/nvidia/Phi-4-multimodal-instruct-FP8. Make sure the tensor keys are the same.
modelopt/torch/export/quant_utils.py
Outdated
# Get the corresponding AWQ block size | ||
block_size_value = layer_config_dict.get(awq_key, 0) | ||
|
||
# print(f"DEBUG LOG: Processing layer {k} with quantization {v}, block size {block_size_value}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove
modelopt/torch/export/quant_utils.py
Outdated
if isinstance(weight, QTensorWrapper): | ||
return weight.data | ||
|
||
if dtype: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we have a case where we need to cast the weights?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea, seems like in the current unified export logic we do perform a cast before quantizing the weights
modelopt/torch/export/quant_utils.py
Outdated
keys_to_delete.append(key) | ||
|
||
# remove LoRA adapters from state dict | ||
for key, value in post_state_dict.items(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what if the original model has lora adapters like phi4-multimodal?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did not think of that case! Thanks for the catch!
modelopt/torch/export/quant_utils.py
Outdated
layer_config_dict[name + ".quantization"] = quantization_format | ||
layer_config_dict[name + ".awq_block_size"] = block_size | ||
# Handles case if default weight quantizer is not enabled or is None | ||
if block_size != 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will this impact per tensor quant like fp8?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It will, just updated the condition.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Nitpick comments (9)
examples/llm_qat/export.py (7)
32-56
: Add error handling for missing modelopt state file.Line 47 loads
modelopt_state_calibration.pth
without checking if it exists. If the file is missing (e.g., model was not trained with QLoRA or was trained with an older version),torch.load
will raise aFileNotFoundError
with an unclear error message.Consider adding explicit validation:
+from pathlib import Path + def get_lora_model( ckpt_path: str, device="cuda", ): """ Loads a QLoRA model that has been trained using modelopt trainer. """ + # Validate modelopt state file exists + modelopt_state_path = Path(ckpt_path) / "modelopt_state_calibration.pth" + if not modelopt_state_path.exists(): + raise FileNotFoundError( + f"modelopt_state_calibration.pth not found in {ckpt_path}. " + "Ensure the model was trained with QLoRA using the modelopt trainer." + ) + device_map = "auto" if device == "cpu": device_map = "cpu" # Load model with adapters model = AutoModelForCausalLM.from_pretrained(ckpt_path, device_map=device_map) # Restore modelopt state - modelopt_state = torch.load(f"{ckpt_path}/modelopt_state_calibration.pth", weights_only=False) + modelopt_state = torch.load(modelopt_state_path, weights_only=False) restore_from_modelopt_state(model, modelopt_state)
53-53
: Consider using print_rank_0 for distributed consistency.The print statement on line 53 will execute on all ranks in a distributed setting, potentially causing log clutter. If this script may be used in a distributed context, consider using
print_rank_0
frommodelopt.torch.utils.logging
(already used in transformers_trainer.py).
87-87
: Remove or conditionalize debug print of config data.Line 87 prints the entire config_data dictionary, which can be very verbose. This appears to be debug code left in the script.
Consider removing or making it conditional:
- print(config_data) + # Optionally log config for debugging + # print(config_data)
98-102
: Improve error message clarity.The error message "Cannot export model to the model_config" is unclear. Consider clarifying what failed and providing actionable guidance.
warnings.warn( - "Cannot export model to the model_config. The modelopt-optimized model state_dict" - " can be saved with torch.save for further inspection." + f"Failed to export model to {export_dir}. The modelopt-optimized model state_dict " + "can be saved with torch.save for further inspection." )
29-29
: Remove unused RAND_SEED constant.The
RAND_SEED
constant is defined but never used in the script.-RAND_SEED = 1234 -
32-35
: Add type hints to function signatures.The function lacks type hints, which would improve code clarity and enable static type checking. Consider adding return type annotation.
-def get_lora_model( - ckpt_path: str, - device="cuda", -): +def get_lora_model( + ckpt_path: str, + device: str = "cuda", +) -> torch.nn.Module:
59-59
: Add type hints to main function.Consider adding type hints for better code clarity.
-def main(args): +def main(args: argparse.Namespace) -> None:modelopt/torch/export/quant_utils.py (2)
903-907
: Improve LoRA key detection to avoid false positives.The substring check
"lora" in key
on line 906 may inadvertently match keys that contain "lora" as part of a larger word (e.g., "flora", "explorer") or match model-native LoRA adapters that should be preserved.Based on learnings
Apply this diff to use more precise path-segment matching:
# remove LoRA adapters from state dict if is_modelopt_qlora: - for key in post_state_dict: - if "lora" in key and key not in keys_to_delete: + for key in list(post_state_dict.keys()): + parts = key.split(".") + # Check if "lora" appears as a complete path segment or as a prefix (e.g., lora_A, lora_B) + if (("lora" in parts or any(p.startswith("lora_") for p in parts)) + and key not in keys_to_delete): keys_to_delete.append(key)Note: Also changed to iterate over
list(post_state_dict.keys())
to avoid issues with dictionary modification during iteration.
1086-1096
: Improvement: NVFP4-specific block_size handling partially addresses past concerns.The new logic correctly skips layer config entries for NVFP4 formats when
block_size == 0
, which indicates the weight_quantizer is not enabled. This is an improvement over the previous approach that skipped ALL formats withblock_size == 0
.However,
awq_block_size
is still written for all formats on line 1096, even those that don't use block quantization (e.g.,QUANTIZATION_INT8_SQ
,QUANTIZATION_FP8
,QUANTIZATION_FP8_PC_PT
).Consider only writing
awq_block_size
for formats that actually use it:# Construct per layer config dictionary layer_config_dict[name + ".quantization"] = quantization_format - layer_config_dict[name + ".awq_block_size"] = block_size + # Only write block_size for block-quantized formats + if block_size > 0: + layer_config_dict[name + ".awq_block_size"] = block_sizeThis avoids polluting the config with unnecessary zero values for per-tensor formats, though the
process_layer_quant_config
function (line 601) already filters these out withif "awq_block_size" in k: continue
.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
examples/llm_qat/README.md
(1 hunks)examples/llm_qat/export.py
(1 hunks)modelopt/torch/export/quant_utils.py
(4 hunks)modelopt/torch/export/unified_export_hf.py
(3 hunks)modelopt/torch/quantization/plugins/transformers_trainer.py
(4 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- modelopt/torch/export/unified_export_hf.py
🧰 Additional context used
🧬 Code graph analysis (3)
examples/llm_qat/export.py (4)
modelopt/torch/export/convert_hf_config.py (1)
convert_hf_quant_config_format
(21-117)modelopt/torch/export/unified_export_hf.py (1)
_export_hf_checkpoint
(336-495)modelopt/torch/opt/conversion.py (2)
restore_from_modelopt_state
(510-567)modelopt_state
(444-486)modelopt/torch/quantization/utils.py (1)
set_quantizer_state_dict
(459-466)
modelopt/torch/quantization/plugins/transformers_trainer.py (3)
modelopt/torch/opt/conversion.py (2)
modelopt_state
(444-486)save
(489-507)modelopt/torch/quantization/utils.py (1)
get_quantizer_state_dict
(446-456)modelopt/torch/utils/logging.py (1)
print_rank_0
(92-95)
modelopt/torch/export/quant_utils.py (2)
modelopt/torch/export/unified_export_megatron.py (1)
state_dict
(465-469)modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
maxbound
(188-194)
🔇 Additional comments (5)
modelopt/torch/quantization/plugins/transformers_trainer.py (3)
212-219
: LGTM! Pre-compression state capture enables correct export.Saving the modelopt state and quantizer weights before compression is essential for the export workflow. The export.py script (line 47) loads this exact file to restore the quantizer configuration needed for proper model export. The timing (post-calibration, pre-compression) ensures the quantization metadata is preserved for deployment.
287-297
: Ignore attribute name inconsistency suggestion. args.lora is the boolean toggle and args.lora_config holds the adapter config; both are used correctly.Likely an incorrect or invalid review comment.
149-149
: Verify default adapter naming inadd_adapter
calls.
- Confirm that
self.model.add_adapter(self.args.lora_config)
(lines 149 & 361) assigns a defaultadapter_name
thatself.model.active_adapter()
(line 293) returns, sodelete_adapter
andload_adapter
operate correctly.- If not, pass an explicit
adapter_name
toadd_adapter
to guarantee consistent lifecycle management.examples/llm_qat/export.py (1)
71-76
: Clarify which quantization config format is saved to hf_quant_config.json.Line 73-74 saves the original
hf_quant_config
(modelopt format), while line 76 converts it to llm-compressor format and embeds it in config.json (line 89). This means two different formats are persisted:
base_model/hf_quant_config.json
: modelopt formatbase_model/config.json
(quantization_config field): llm-compressor formatIf both formats are needed for different consumers, consider adding a comment explaining why both are saved. Otherwise, consider saving only the converted format.
modelopt/torch/export/quant_utils.py (1)
835-846
: Potential logic issue: skip_keys may prevent replacements from being applied.On line 846,
"base_layer"
is appended toskip_keys
. Then on line 852, the code skips any key whereall(skip_key not in key for skip_key in skip_keys)
is true. This means keys containing"base_layer"
will be skipped entirely and never reach the replacement logic (lines 857-882).However, the replacements dictionary (lines 839-845) includes patterns like
"base_layer.weight"
which should be transformed to"weight"
. These replacements won't be applied because the keys are filtered out first.Consider refactoring to apply replacements first, then skip keys that should be removed:
- skip_keys = ["output_quantizer", "_amax", "_bias_value", "input_quantizer._pre_quant_scale"] + # Keys to skip entirely (not related to quantizers or base transformations) + skip_patterns = ["output_quantizer"] # For modelopt-trained LoRA models, we need to remove the base_layer prefix from the keys for deployment if is_modelopt_qlora: replacements.update( { "base_layer.weight": "weight", "base_layer.input_scale": "input_scale", "base_layer.weight_scale": "weight_scale", } ) - skip_keys.append("base_layer") post_state_dict = {} for key, value in state_dict.items(): - # Skip keys not related to quantizers - if all(skip_key not in key for skip_key in skip_keys): - post_state_dict[key] = value - continue + # Skip keys that should be entirely filtered out + if any(skip_pattern in key for skip_pattern in skip_patterns): + continue - # Apply replacements if the key matches any suffix in the replacements dict + # Try to apply replacements first + replaced = False for old_suffix, new_suffix in replacements.items(): if key.endswith(old_suffix): prefix = key[: -len(old_suffix)] if "_amax" in key: # ... existing _amax handling ... post_state_dict[prefix + new_suffix] = value + replaced = True break + + # If no replacement was applied and key doesn't contain quantizer-specific suffixes, keep it + if not replaced: + quantizer_suffixes = ["_amax", "_bias_value", "input_quantizer._pre_quant_scale"] + if not any(key.endswith(suffix) for suffix in quantizer_suffixes): + post_state_dict[key] = valueLikely an incorrect or invalid review comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (4)
examples/llm_qat/export.py (4)
29-29
: Remove unused constant.
RAND_SEED
is defined but never used in the script.-RAND_SEED = 1234 -
32-56
: Consider improving robustness and observability.The function lacks error handling for missing checkpoint files and uses print instead of logging. Consider:
- Error handling: Wrap file loading in try-except to provide clear error messages if
modelopt_state_calibration.pth
is missing or corrupted.- Logging: Replace
print("Restoring modelopt weights")
with proper logging (e.g.,logging.info(...)
).- Documentation: Expand the docstring to document parameters, return value, and expected checkpoint structure.
Example improvements:
+import logging + def get_lora_model( ckpt_path: str, device="cuda", ): """ - Loads a QLoRA model that has been trained using modelopt trainer. + Loads a QLoRA model that has been trained using modelopt trainer. + + Args: + ckpt_path: Path to the checkpoint directory containing the model and modelopt state. + device: Device to load the model on ("cuda" or "cpu"). + + Returns: + The loaded model with restored modelopt and quantizer state. """ device_map = "auto" if device == "cpu": device_map = "cpu" # Load model with adapters model = AutoModelForCausalLM.from_pretrained(ckpt_path, device_map=device_map) # Restore modelopt state + try: - modelopt_state = torch.load(f"{ckpt_path}/modelopt_state_calibration.pth", weights_only=False) + modelopt_state = torch.load(f"{ckpt_path}/modelopt_state_calibration.pth", weights_only=False) + except FileNotFoundError as e: + raise FileNotFoundError( + f"modelopt_state_calibration.pth not found in {ckpt_path}. " + "Ensure the checkpoint was saved correctly during training." + ) from e restore_from_modelopt_state(model, modelopt_state) # Restore modelopt quantizer state dict modelopt_weights = modelopt_state.pop("modelopt_state_weights", None) if modelopt_weights is not None: - print("Restoring modelopt weights") + logging.info("Restoring modelopt quantizer weights") set_quantizer_state_dict(model, modelopt_weights) return model
59-101
: Use Path objects consistently and improve error messages.The function mixes Path objects with f-string concatenation. For consistency and robustness, use Path objects throughout. Also, the exception handler's warning message is generic and doesn't help users diagnose the issue.
Apply this diff:
def main(args): # Load model model = get_lora_model(args.pyt_ckpt_path, args.device) tokenizer = AutoTokenizer.from_pretrained(args.pyt_ckpt_path) # Export HF checkpoint export_dir = Path(args.export_path) export_dir.mkdir(parents=True, exist_ok=True) base_model_dir = export_dir / "base_model" base_model_dir.mkdir(parents=True, exist_ok=True) try: post_state_dict, hf_quant_config = _export_hf_checkpoint(model, is_modelopt_qlora=True) - with open(f"{export_dir}/base_model/hf_quant_config.json", "w") as file: + with open(base_model_dir / "hf_quant_config.json", "w") as file: json.dump(hf_quant_config, file, indent=4) hf_quant_config = convert_hf_quant_config_format(hf_quant_config) # Save base model - model.base_model.save_pretrained(f"{export_dir}/base_model", state_dict=post_state_dict) + model.base_model.save_pretrained(base_model_dir, state_dict=post_state_dict) # Save adapters model.save_pretrained(export_dir) - config_path = f"{export_dir}/base_model/config.json" + config_path = base_model_dir / "config.json" # In the case of LoRA model.save_pretrained does not save the correct config.json config_data = model.config.to_dict() config_data["quantization_config"] = hf_quant_config with open(config_path, "w") as file: json.dump(config_data, file, indent=4) # Save tokenizer tokenizer.save_pretrained(export_dir) except Exception as e: warnings.warn( - "Cannot export model to the model_config. The modelopt-optimized model state_dict" - " can be saved with torch.save for further inspection." + f"Export failed: {e}. The modelopt-optimized model state_dict " + "can be saved with torch.save for further inspection." ) raise e
85-86
: Clarify the comment about config.json handling.The comment is grammatically incomplete. Consider rephrasing for clarity.
- # In the case of LoRA model.save_pretrained does not save the correct config.json + # For LoRA models, save_pretrained does not include quantization_config in config.json, + # so we manually reconstruct and write it here.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
examples/llm_qat/export.py
(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
examples/llm_qat/export.py (4)
modelopt/torch/export/convert_hf_config.py (1)
convert_hf_quant_config_format
(21-117)modelopt/torch/export/unified_export_hf.py (1)
_export_hf_checkpoint
(336-495)modelopt/torch/opt/conversion.py (2)
restore_from_modelopt_state
(510-567)modelopt_state
(444-486)modelopt/torch/quantization/utils.py (1)
set_quantizer_state_dict
(459-466)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (2)
examples/llm_qat/export.py (2)
104-118
: LGTM!The argument parser is well-structured with appropriate defaults and a required checkpoint path.
1-118
: Acknowledge planned refactor per PR discussion.Based on the PR objectives, this script is temporary. The author has agreed to refactor the export logic to integrate with
hf_ptq.py
per reviewer feedback. This standalone script serves the immediate QLoRA export use case but should be consolidated with the existing export flow in a follow-up.Based on PR objectives summary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (2)
examples/llm_qat/README.md (1)
368-370
: Consider clarifying the adapter name convention.The vLLM command uses
--lora-modules adapter=llama3-fp4-qlora-hf
where "adapter" is a generic name. It would be helpful to clarify whether this name is:
- A fixed convention expected by vLLM
- User-configurable and should match a specific naming pattern
- Related to the export output structure
Adding a brief note about the adapter naming would improve usability, especially given the PR discussion about simplifying the CLI for QLoRA checkpoints.
modelopt/torch/export/quant_utils.py (1)
904-908
: Substring matching may catch unintended keys.The LoRA adapter removal logic uses substring matching (
"lora" in key
), which could inadvertently match keys like "flora", "exploration", or "coloration". While theis_modelopt_qlora
gate reduces this risk, consider more precise matching.Apply this diff for more precise LoRA key detection:
# remove LoRA adapters from state dict if is_modelopt_qlora: for key in post_state_dict: - if "lora" in key and key not in keys_to_delete: + # Match LoRA keys more precisely: check for .lora. or lora_ patterns + parts = key.split(".") + if (any(p.startswith("lora_") or p == "lora" for p in parts)) and key not in keys_to_delete: keys_to_delete.append(key)This ensures matching only keys with "lora" as a complete path component or prefix (e.g., "model.lora_A.weight", "adapter.lora.bias"), not arbitrary substrings.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
examples/llm_qat/README.md
(1 hunks)examples/llm_qat/export.py
(1 hunks)modelopt/torch/export/quant_utils.py
(4 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- examples/llm_qat/export.py
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/torch/export/quant_utils.py (4)
modelopt/torch/export/unified_export_megatron.py (1)
state_dict
(465-469)modelopt/torch/opt/conversion.py (1)
state_dict
(130-132)modelopt/torch/distill/distillation_model.py (1)
state_dict
(189-192)modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
maxbound
(188-194)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (6)
examples/llm_qat/README.md (1)
357-370
: Documentation provides clear QLoRA deployment flow.The new export and deployment instructions successfully replace the previous experimental note with actionable guidance. The command examples are well-structured and the reference to vLLM documentation is helpful.
modelopt/torch/export/quant_utils.py (5)
812-827
: LGTM: Parameter documented.The new
is_modelopt_qlora
parameter has been properly documented in the docstring at line 824. The documentation is clear and consistent with the existing style.
836-836
: LGTM: Centralized skip logic.The
skip_keys
list provides a cleaner, more maintainable approach to filtering non-quantizer keys. This addresses the concern from previous reviews about structured key handling.
838-847
: LGTM: LoRA handling properly gated.The LoRA-specific key transformations are now correctly gated behind the
is_modelopt_qlora
flag. This addresses the previous concern about unconditional LoRA adapter removal and ensures the logic only applies to ModelOpt QLoRA exports.
1087-1093
: LGTM: NVFP4 block_size filtering fixed.The guard correctly skips only NVFP4-related formats when
block_size == 0
, indicating the weight quantizer is not enabled. This fixes the previous issue where non-block formats (INT8_SQ, FP8, FP8_PC_PT) were incorrectly dropped. Non-block formats now proceed to lines 1095-1097 regardless of block_size.
851-855
: postprocess_state_dict filtering logic is correct
Non-quantizer keys are retained, and all quantizer/base_layer keys are either transformed per the replacements mapping or removed as intended.
modelopt/torch/export/quant_utils.py
Outdated
block_size = get_weight_block_size(module) | ||
|
||
# In the case of NVFP4, block_size 0 indicates weight_quantizer is not enabled | ||
if block_size == 0 and quantization_format in [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we have a better flag instead of checking the block_size? E.g. weight_quantizer enabled vs disabled?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let me check and update the PR!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (3)
modelopt/torch/quantization/plugins/transformers_trainer.py (2)
354-354
: Same adapter naming concern as line 149.The explicit
adapter_name="adapter"
parameter was removed here as well. Ensure consistent adapter naming behavior across both QATTrainer and QADTrainer paths.
280-290
: Add error handling and validation as previously requested.The custom LoRA loading logic still lacks critical safety checks that were flagged in a previous review:
- No validation that
self.state.best_model_checkpoint
exists before attempting to load- No error handling if
load_adapter
fails- Missing compress flag check—your earlier comment mentioned this should only execute "if compress is enabled and fsdp2 is not enabled," but the current code doesn't verify the compress flag
- No logging to indicate success or failure
Apply this diff to add defensive checks:
def _load_best_model(self, *args, **kwargs): """Load the best model for final evaluation.""" is_lora = getattr(self.args, "lora", None) - if is_lora and not self.is_fsdp_enabled: + is_compressed = getattr(self.quant_args, "compress", False) + if is_lora and not self.is_fsdp_enabled and is_compressed: # Custom logic for loading best model with LoRA # TODO: Remove once we migrate to using get_peft_model() + if not self.state.best_model_checkpoint: + print_rank_0("No best model checkpoint found, skipping adapter reload") + return + try: - adapter_name = self.model.active_adapter() - self.model.delete_adapter(adapter_name) - self.model.load_adapter(self.state.best_model_checkpoint, adapter_name) + adapter_name = self.model.active_adapter() + self.model.delete_adapter(adapter_name) + self.model.load_adapter(self.state.best_model_checkpoint, adapter_name) + print_rank_0(f"Successfully loaded best adapter from {self.state.best_model_checkpoint}") + except Exception as e: + print_rank_0(f"Failed to load best adapter: {e}") + raise else: super()._load_best_model(*args, **kwargs)examples/llm_qat/export.py (1)
53-62
: Verify state restoration order matches trainer pattern.A previous review flagged that the state restoration order here differs from the trainer's pattern. The trainer pops
modelopt_state_weights
before callingrestore_from_modelopt_state
, but this code pops after. While the current order may work (sincerestore_from_modelopt_state
only accesses specific keys), aligning with the trainer's proven pattern would ensure consistency and avoid potential edge cases.Run the following script to verify the trainer's restoration pattern:
#!/bin/bash # Description: Find the trainer's state restoration pattern to confirm the correct order # Search for the trainer's _restore_modelopt_state_with_weights method rg -n -A 10 "_restore_modelopt_state_with_weights" --type py
🧹 Nitpick comments (2)
examples/llm_qat/export.py (2)
44-44
: Track the TODO for adapter merging options.The TODO mentions adding support for merging adapters in different precision modes and with quantization. These are valuable deployment options that would give users more flexibility.
Would you like me to open a new issue to track this enhancement, or would you prefer to implement this as part of the current PR?
109-114
: Consider a more specific error message.The warning message is generic and might not help users diagnose export failures. Consider including the exception type or a brief suggestion for troubleshooting.
Apply this diff to improve the error message:
except Exception as e: warnings.warn( - "Cannot export model to the model_config. The modelopt-optimized model state_dict" - " can be saved with torch.save for further inspection." + f"Export failed with {type(e).__name__}: {str(e)}. " + "The modelopt-optimized model state_dict can be saved with torch.save for inspection." ) raise e
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
examples/llm_qat/README.md
(2 hunks)examples/llm_qat/export.py
(1 hunks)modelopt/torch/export/quant_utils.py
(4 hunks)modelopt/torch/export/unified_export_hf.py
(3 hunks)modelopt/torch/quantization/plugins/transformers_trainer.py
(4 hunks)tests/examples/llm_qat/test_llm_qat.py
(0 hunks)
💤 Files with no reviewable changes (1)
- tests/examples/llm_qat/test_llm_qat.py
🚧 Files skipped from review as they are similar to previous changes (2)
- modelopt/torch/export/unified_export_hf.py
- modelopt/torch/export/quant_utils.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-09-15T20:46:29.252Z
Learnt from: realAsma
PR: NVIDIA/TensorRT-Model-Optimizer#318
File: modelopt/torch/quantization/plugins/transformers_trainer.py:170-189
Timestamp: 2025-09-15T20:46:29.252Z
Learning: In modelopt/torch/quantization/plugins/transformers_trainer.py, the restore_from_modelopt_state function can accept modelopt_state["modelopt_state_dict"] directly without needing to wrap it in a full dict structure or include modelopt_version.
Applied to files:
examples/llm_qat/export.py
🧬 Code graph analysis (1)
examples/llm_qat/export.py (6)
modelopt/torch/export/convert_hf_config.py (1)
convert_hf_quant_config_format
(21-117)modelopt/torch/export/unified_export_hf.py (1)
_export_hf_checkpoint
(340-499)modelopt/torch/opt/conversion.py (2)
restore_from_modelopt_state
(510-567)modelopt_state
(444-486)modelopt/torch/quantization/utils.py (1)
set_quantizer_state_dict
(459-466)modelopt/torch/utils/logging.py (1)
print_rank_0
(92-95)modelopt/torch/opt/plugins/huggingface.py (1)
enable_huggingface_checkpointing
(127-162)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (9)
modelopt/torch/quantization/plugins/transformers_trainer.py (2)
212-214
: Clarify ModelOpt state save timing relative to compression.The ModelOpt state is saved after quantization but before compression (line 217). A past review comment questioned whether this should be saved only before compression. If the saved state needs to include compressed weights for export, this placement may be incorrect.
Based on learnings
Please clarify:
- Should the saved state include compressed weights, or only quantization parameters?
- If compressed weights are needed, should this save call move to after line 217?
149-149
: Clarify adapter default naming in add_adapter: The explicitadapter_name
argument was removed inself.model.add_adapter(self.args.lora_config)
(lines 149 & 354). Confirm the default name used matches whatactive_adapter()
retrieves at line 286 to avoid naming conflicts when managing adapters.examples/llm_qat/README.md (2)
306-315
: Clear deployment documentation for QAT.The QAT deployment section clearly documents the export step using
export.py
before deployment. The command syntax is consistent and the flow is straightforward.
346-361
: Well-documented QLoRA deployment workflow.The QLoRA deployment section provides clear step-by-step instructions including export and vLLM deployment commands. The reference to vLLM documentation is helpful for users needing additional context.
examples/llm_qat/export.py (5)
37-50
: Good device mapping logic.The function correctly handles device mapping for both CUDA and CPU deployment scenarios. The conditional logic is clear and the fallback to "auto" is appropriate.
67-81
: Appropriate directory structure logic.The conditional directory structure for QLoRA (base_model subdirectory for the base model, adapters at the root) vs non-QLoRA matches the deployment expectations and aligns with the PR objectives.
82-96
: Export logic correctly handles QLoRA vs QAT paths.The export flow correctly:
- Calls
_export_hf_checkpoint
with theis_modelopt_qlora
flag- Saves the base model separately for QLoRA using
model.base_model.save_pretrained
- Saves adapters using
model.save_pretrained
at the export root- Uses the unified path for non-QLoRA models
This matches the documented behavior and deployment requirements.
117-135
: CLI argument definitions are clear and complete.The argument parser defines all necessary parameters with sensible defaults and help text. The
--export_path
parameter name is consistent with the README documentation.
54-54
: No changes needed: export usesmodelopt_state_train.pth
, matching the training save in modelopt/torch/quantization/plugins/transformers_trainer.py:165.
89269f7
to
13ccb6d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
♻️ Duplicate comments (1)
examples/llm_qat/export.py (1)
53-62
: Critical: Fix the state restoration order.This ordering issue was flagged in previous reviews. The
modelopt_state_weights
must be popped before callingrestore_from_modelopt_state
to prevent the weights from being passed into the restore function. The current order differs from the trainer's proven pattern.Per the previous review, apply this diff:
# Restore modelopt state for LoRA models. For QAT/QAD models from_pretrained call handles this if hasattr(model, "peft_config"): modelopt_state = torch.load(f"{ckpt_path}/modelopt_state_train.pth", weights_only=False) + modelopt_weights = modelopt_state.pop("modelopt_state_weights", None) restore_from_modelopt_state(model, modelopt_state) print_rank_0("Restored modelopt state") # Restore modelopt quantizer state dict - modelopt_weights = modelopt_state.pop("modelopt_state_weights", None) if modelopt_weights is not None: set_quantizer_state_dict(model, modelopt_weights) print_rank_0("Restored modelopt quantizer state dict")Based on learnings
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
examples/llm_qat/README.md
(2 hunks)examples/llm_qat/export.py
(1 hunks)modelopt/torch/export/quant_utils.py
(4 hunks)modelopt/torch/export/unified_export_hf.py
(3 hunks)modelopt/torch/quantization/plugins/transformers_trainer.py
(4 hunks)tests/examples/llm_qat/test_llm_qat.py
(0 hunks)
💤 Files with no reviewable changes (1)
- tests/examples/llm_qat/test_llm_qat.py
🚧 Files skipped from review as they are similar to previous changes (3)
- examples/llm_qat/README.md
- modelopt/torch/export/quant_utils.py
- modelopt/torch/quantization/plugins/transformers_trainer.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-09-15T20:46:29.252Z
Learnt from: realAsma
PR: NVIDIA/TensorRT-Model-Optimizer#318
File: modelopt/torch/quantization/plugins/transformers_trainer.py:170-189
Timestamp: 2025-09-15T20:46:29.252Z
Learning: In modelopt/torch/quantization/plugins/transformers_trainer.py, the restore_from_modelopt_state function can accept modelopt_state["modelopt_state_dict"] directly without needing to wrap it in a full dict structure or include modelopt_version.
Applied to files:
examples/llm_qat/export.py
🧬 Code graph analysis (2)
modelopt/torch/export/unified_export_hf.py (2)
modelopt/torch/export/layer_utils.py (1)
is_quantlinear
(346-348)modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
is_enabled
(395-397)
examples/llm_qat/export.py (7)
modelopt/torch/export/convert_hf_config.py (1)
convert_hf_quant_config_format
(21-117)modelopt/torch/export/unified_export_hf.py (1)
_export_hf_checkpoint
(340-499)modelopt/torch/opt/conversion.py (2)
restore_from_modelopt_state
(510-567)modelopt_state
(444-486)modelopt/torch/quantization/utils.py (1)
set_quantizer_state_dict
(459-466)modelopt/torch/utils/logging.py (1)
print_rank_0
(92-95)modelopt/torch/opt/plugins/huggingface.py (1)
enable_huggingface_checkpointing
(127-162)modelopt/torch/export/unified_export_megatron.py (1)
save_pretrained
(298-462)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (4)
modelopt/torch/export/unified_export_hf.py (2)
341-341
: LGTM!The new
is_modelopt_qlora
parameter enables LoRA-specific export handling while maintaining backward compatibility through its default value.
492-492
: LGTM!Passing the
is_modelopt_qlora
flag topostprocess_state_dict
enables LoRA-specific post-processing, which aligns with the PR objectives.examples/llm_qat/export.py (2)
67-96
: LGTM!The model loading, QLoRA detection, and export directory structure are correct. The separation of base model and adapter saves for QLoRA models aligns with the PR objectives.
106-114
: LGTM!Tokenizer saving and error handling are correctly implemented and consistent with the export patterns used elsewhere in the codebase.
config_path = f"{base_model_dir}/config.json" | ||
|
||
config_data = model.config.to_dict() | ||
|
||
config_data["quantization_config"] = hf_quant_config | ||
|
||
with open(config_path, "w") as file: | ||
json.dump(config_data, file, indent=4) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use base model config for QLoRA exports.
For QLoRA models, model.config
returns the PEFT model's config, but we're writing to base_model_dir/config.json
(the base model's config file). This overwrites the base model config that was saved by model.base_model.save_pretrained
at line 92.
Apply this diff:
config_path = f"{base_model_dir}/config.json"
- config_data = model.config.to_dict()
+ if is_qlora:
+ config_data = model.base_model.config.to_dict()
+ else:
+ config_data = model.config.to_dict()
config_data["quantization_config"] = hf_quant_config
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
config_path = f"{base_model_dir}/config.json" | |
config_data = model.config.to_dict() | |
config_data["quantization_config"] = hf_quant_config | |
with open(config_path, "w") as file: | |
json.dump(config_data, file, indent=4) | |
config_path = f"{base_model_dir}/config.json" | |
if is_qlora: | |
config_data = model.base_model.config.to_dict() | |
else: | |
config_data = model.config.to_dict() | |
config_data["quantization_config"] = hf_quant_config | |
with open(config_path, "w") as file: | |
json.dump(config_data, file, indent=4) |
🤖 Prompt for AI Agents
In examples/llm_qat/export.py around lines 97 to 104, the code uses
model.config.to_dict() then writes that to the base model's config.json which
overwrites the original base model config saved earlier; instead, use the base
model's config (model.base_model.config.to_dict()) as the starting dict, add the
"quantization_config" entry, and then write that to base_model_dir/config.json
so the original base model config is preserved and augmented rather than
replaced by the PEFT/QLoRA wrapper config.
if ( | ||
is_quantlinear(sub_module) | ||
and hasattr(sub_module, "weight_quantizer") | ||
and sub_module.weight_quantizer.is_enabled | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Use the _is_enabled_quantizer
helper for consistency.
The direct check sub_module.weight_quantizer.is_enabled
is inconsistent with the pattern used elsewhere in this file (lines 123-126). The helper function _is_enabled_quantizer
(lines 76-83) correctly handles both TensorQuantizer
and SequentialQuantizer
cases.
Apply this diff:
if (
is_quantlinear(sub_module)
and hasattr(sub_module, "weight_quantizer")
- and sub_module.weight_quantizer.is_enabled
+ and _is_enabled_quantizer(sub_module.weight_quantizer)
):
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
if ( | |
is_quantlinear(sub_module) | |
and hasattr(sub_module, "weight_quantizer") | |
and sub_module.weight_quantizer.is_enabled | |
): | |
if ( | |
is_quantlinear(sub_module) | |
and hasattr(sub_module, "weight_quantizer") | |
and _is_enabled_quantizer(sub_module.weight_quantizer) | |
): |
🤖 Prompt for AI Agents
In modelopt/torch/export/unified_export_hf.py around lines 464 to 468, replace
the direct check of sub_module.weight_quantizer.is_enabled with the existing
helper: call _is_enabled_quantizer(sub_module.weight_quantizer) (keep the
is_quantlinear check); if needed keep hasattr(sub_module, "weight_quantizer")
guard and pass the quantizer instance into the helper so both TensorQuantizer
and SequentialQuantizer cases are handled consistently.
Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
13ccb6d
to
231c147
Compare
Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
modelopt/torch/quantization/plugins/transformers_trainer.py (1)
280-294
: Add missing safety checks for checkpoint loading.The custom LoRA loading logic lacks critical safety validations that were flagged in previous reviews:
- No check that
self.state.best_model_checkpoint
exists before attempting to load- No error handling if
load_adapter
fails- Missing verification that
compress
flag is enabled (per earlier discussion this should only run for compressed models)Apply this diff to add defensive checks:
def _load_best_model(self, *args, **kwargs): """Load the best model for final evaluation.""" is_lora = getattr(self.args, "lora", None) - if is_lora and not self.is_fsdp_enabled: + is_compressed = getattr(self.quant_args, "compress", False) + if is_lora and not self.is_fsdp_enabled and is_compressed: # Custom logic for loading best model with LoRA # TODO: Remove once we migrate to using get_peft_model() # This custom logic only loads best adapters. Ensure base model is frozen + if not self.state.best_model_checkpoint: + print_rank_0("No best model checkpoint found, skipping adapter reload") + return + try: - assert all( - param.requires_grad is False for param in self.model.base_model.parameters() - ), "Base model must be frozen for lora" - adapter_name = self.model.active_adapter() - self.model.delete_adapter(adapter_name) - self.model.load_adapter(self.state.best_model_checkpoint, adapter_name) + assert all( + param.requires_grad is False for param in self.model.base_model.parameters() + ), "Base model must be frozen for lora" + adapter_name = self.model.active_adapter() + self.model.delete_adapter(adapter_name) + self.model.load_adapter(self.state.best_model_checkpoint, adapter_name) + print_rank_0(f"Successfully loaded best adapter from {self.state.best_model_checkpoint}") + except Exception as e: + print_rank_0(f"Failed to load best adapter: {e}") + raise else: super()._load_best_model(*args, **kwargs)
🧹 Nitpick comments (1)
modelopt/torch/export/quant_utils.py (1)
910-914
: Refine LoRA adapter removal logic Only remove keys matching adapters actually loaded viapeft_config
(e.g. check againstmodel.peft_config.target_modules
) instead of blanket"lora"
substring filtering.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
modelopt/torch/export/quant_utils.py
(4 hunks)modelopt/torch/quantization/plugins/transformers_trainer.py
(4 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/torch/export/quant_utils.py (2)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)
maxbound
(193-199)is_enabled
(395-397)modelopt/torch/quantization/utils.py (1)
quantizer_attr_names
(225-236)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (6)
modelopt/torch/quantization/plugins/transformers_trainer.py (3)
149-149
: LGTM! Adapter name parameter removed.The removal of the
adapter_name
parameter aligns with the temporary workaround approach until migration toget_peft_model()
is complete.
212-214
: LGTM! ModelOpt state saved after quantization.Moving the state save to occur after calibration and before potential compression ensures the quantized state is properly captured. The implementation correctly handles distributed training by checking
self.args.should_save
.
358-358
: LGTM! Consistent adapter handling in QADTrainer.The adapter name parameter removal is consistent with the change in QATTrainer.
modelopt/torch/export/quant_utils.py (3)
818-823
: LGTM! Backward-compatible parameter addition.The new
is_modelopt_qlora
parameter with a default value ofFalse
maintains backward compatibility for existing callers.
842-853
: LGTM! QLoRA-specific key remappings added.The conditional remapping logic correctly handles deployment-specific transformations for ModelOpt-trained QLoRA models, removing the
base_layer
prefix from weight and scale keys.
1082-1105
: LGTM! Fixed inclusion of per-tensor quantization formats.The updated logic correctly determines
weight_quantizer_enabled
by checking both block_size (for block quantization formats) andweight_quantizer.is_enabled
(for per-tensor formats like FP8/INT8). This resolves the previous issue where INT8/FP8 layers were incorrectly skipped.The logic properly handles:
- Block quantization formats (INT4_AWQ, NVFP4, etc.) via
block_size > 0
- Per-tensor formats (FP8, INT8_SQ) via
weight_quantizer.is_enabled
- Ensures only enabled quantizers contribute to the per-layer config
Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
What does this PR do?
Type of change: New example
Overview: This PR provides an e2e example for fine-tuning a model using QLoRA with DDP and exporting checkpoint for deployment using vllm.
Usage
Refer to README.md changes
Testing
Trainer
./launch.sh --model meta-llama/Meta-Llama-3-8B --num_epochs 0.01 --lr 1e-3 --do_train True --output_dir test --quant_cfg FP8_DEFAULT_CFG --compress True --lora True
Export
python export.py --pyt_ckpt_path test --export_dir test-fp8
Deployment
vllm serve test-fp8/base_model --enable-lora --lora-modules sql-lora=test-fp8 --port 8090 --tokenizer test-fp8
e2e unit test
Sanity check weights, dtypes of generated checkpoint
Test phi4
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
New Features
Improvements
Documentation
Tests