fix(training): skip torch.compile when PEFT LoRA adapters are active#640
Conversation
torch.compile succeeds at init but crashes at first forward pass when PEFT wraps the decoder in PeftModelForFeatureExtraction. The inductor backend raises AssertionError on PyTorch 2.7.x. Skip compilation when PEFT adapters are detected. Add try/except as safety net for future non-PEFT compile failures. Fixes LoRA training on PyTorch 2.7.x + CUDA.
📝 WalkthroughWalkthroughThis change adds conditional guards to torch.compile optimization for the DiT decoder in PreprocessedLoRAModule. Compilation now only attempts when CUDA is available, LoRA adapters are inactive, and torch.compile is supported, with graceful fallback and informational logging for incompatible configurations. Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~12 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches
🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
🧹 Nitpick comments (1)
acestep/training/trainer.py (1)
387-404: LGTM - Guard logic is correct and fixes the PEFT+torch.compile incompatibility.The fix correctly skips torch.compile when PEFT LoRA adapters are active (
bool(self.lora_info)), addressing the inductor AssertionError on PyTorch 2.7.x. The try/except safety net is reasonable given torch.compile's diverse failure modes.Two minor observations:
Static analysis (Ruff BLE001): The bare
except Exceptionis flagged, but acceptable here given the stated "safety net" intent—torch.compile can raise various exception types (RuntimeError, AssertionError, etc.) and graceful fallback is the correct behavior.Log message accuracy (line 404): The message "torch.compile not available on this device/PyTorch version" is slightly imprecise—torch.compile is available on MPS/CPU, it's just not used here by design. Consider:
Optional: more accurate log message
if has_peft: logger.info("Skipping torch.compile (incompatible with PEFT LoRA adapters)") + elif self.device_type != "cuda": + logger.info("Skipping torch.compile (only enabled for CUDA devices)") else: - logger.info("torch.compile not available on this device/PyTorch version, skipping") + logger.info("torch.compile not available on this PyTorch version, skipping"),
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@acestep/training/trainer.py` around lines 387 - 404, Update the informational log in the branch where torch.compile is not being used (the else that runs when not has_peft and either torch.compile missing or device_type != "cuda") to be more accurate: change the logger.info message that currently reads "torch.compile not available on this device/PyTorch version, skipping" to something reflecting that compilation is intentionally not being performed here (e.g., mention non-CUDA device or that torch.compile isn't being invoked), referencing the existing checks torch.compile, self.device_type and has_peft and updating the logger.info call accordingly.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@acestep/training/trainer.py`:
- Around line 387-404: Update the informational log in the branch where
torch.compile is not being used (the else that runs when not has_peft and either
torch.compile missing or device_type != "cuda") to be more accurate: change the
logger.info message that currently reads "torch.compile not available on this
device/PyTorch version, skipping" to something reflecting that compilation is
intentionally not being performed here (e.g., mention non-CUDA device or that
torch.compile isn't being invoked), referencing the existing checks
torch.compile, self.device_type and has_peft and updating the logger.info call
accordingly.
Summary
torch.compile(mode=\"default\")was added toPreprocessedLoRAModule.__init__()in PR #422.The compile call succeeds at init time, but crashes at the first forward pass when PEFT wraps
the decoder in
PeftModelForFeatureExtraction. The inductor backend raises:This makes all LoRA training non-functional on PyTorch 2.7.x + CUDA.
Fix: Skip
torch.compilewhen PEFT LoRA adapters are active (bool(self.lora_info)).Added try/except as safety net for potential future compile failures in non-PEFT scenarios.
Scope
acestep/training/trainer.py(1 file)Risk and Compatibility
has_peftguard only affects the compile decisionin
PreprocessedLoRAModule.__init__(). No other code paths are touched.exactly as before with an added try/except safety net.
Regression Checks
Reviewer Notes
Summary by CodeRabbit