diff --git a/acestep/training/trainer.py b/acestep/training/trainer.py index 1f037bbc..e7318770 100644 --- a/acestep/training/trainer.py +++ b/acestep/training/trainer.py @@ -384,13 +384,24 @@ def __init__( self.lora_info = {} logger.warning("PEFT not available, training without LoRA adapters") - # Added Torch Compile Logic - if hasattr(torch, "compile") and self.device_type == "cuda": - logger.info("Compiling DiT decoder...") - self.model.decoder = torch.compile(self.model.decoder, mode="default") # 'default' is more stable for LoRA - logger.info("torch.compile successful") + # torch.compile: optional perf optimization. + # PEFT LoRA wraps the decoder in PeftModelForFeatureExtraction which is + # incompatible with torch.compile/inductor on PyTorch 2.7.x + # (AssertionError at first forward pass, not at compile time). + # Only compile when NOT using PEFT adapters. + has_peft = bool(self.lora_info) + if hasattr(torch, "compile") and self.device_type == "cuda" and not has_peft: + try: + logger.info("Compiling DiT decoder...") + self.model.decoder = torch.compile(self.model.decoder, mode="default") + logger.info("torch.compile successful") + except Exception as e: + logger.warning(f"torch.compile failed ({e}), continuing without compilation") else: - logger.warning("torch.compile is not available on this PyTorch version.") + if has_peft: + logger.info("Skipping torch.compile (incompatible with PEFT LoRA adapters)") + else: + logger.info("torch.compile not available on this device/PyTorch version, skipping") # Model config for flow matching self.config = model.config