From 581e3d69fd5a9a134f1e1a93275a099b306b9de9 Mon Sep 17 00:00:00 2001 From: FeelTheFonk Date: Thu, 19 Feb 2026 09:27:18 +0100 Subject: [PATCH] fix(training): skip torch.compile when PEFT LoRA adapters are active torch.compile succeeds at init but crashes at first forward pass when PEFT wraps the decoder in PeftModelForFeatureExtraction. The inductor backend raises AssertionError on PyTorch 2.7.x. Skip compilation when PEFT adapters are detected. Add try/except as safety net for future non-PEFT compile failures. Fixes LoRA training on PyTorch 2.7.x + CUDA. --- acestep/training/trainer.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/acestep/training/trainer.py b/acestep/training/trainer.py index 1f037bbc..e7318770 100644 --- a/acestep/training/trainer.py +++ b/acestep/training/trainer.py @@ -384,13 +384,24 @@ def __init__( self.lora_info = {} logger.warning("PEFT not available, training without LoRA adapters") - # Added Torch Compile Logic - if hasattr(torch, "compile") and self.device_type == "cuda": - logger.info("Compiling DiT decoder...") - self.model.decoder = torch.compile(self.model.decoder, mode="default") # 'default' is more stable for LoRA - logger.info("torch.compile successful") + # torch.compile: optional perf optimization. + # PEFT LoRA wraps the decoder in PeftModelForFeatureExtraction which is + # incompatible with torch.compile/inductor on PyTorch 2.7.x + # (AssertionError at first forward pass, not at compile time). + # Only compile when NOT using PEFT adapters. + has_peft = bool(self.lora_info) + if hasattr(torch, "compile") and self.device_type == "cuda" and not has_peft: + try: + logger.info("Compiling DiT decoder...") + self.model.decoder = torch.compile(self.model.decoder, mode="default") + logger.info("torch.compile successful") + except Exception as e: + logger.warning(f"torch.compile failed ({e}), continuing without compilation") else: - logger.warning("torch.compile is not available on this PyTorch version.") + if has_peft: + logger.info("Skipping torch.compile (incompatible with PEFT LoRA adapters)") + else: + logger.info("torch.compile not available on this device/PyTorch version, skipping") # Model config for flow matching self.config = model.config