Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions acestep/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,13 +384,24 @@ def __init__(
self.lora_info = {}
logger.warning("PEFT not available, training without LoRA adapters")

# Added Torch Compile Logic
if hasattr(torch, "compile") and self.device_type == "cuda":
logger.info("Compiling DiT decoder...")
self.model.decoder = torch.compile(self.model.decoder, mode="default") # 'default' is more stable for LoRA
logger.info("torch.compile successful")
# torch.compile: optional perf optimization.
# PEFT LoRA wraps the decoder in PeftModelForFeatureExtraction which is
# incompatible with torch.compile/inductor on PyTorch 2.7.x
# (AssertionError at first forward pass, not at compile time).
# Only compile when NOT using PEFT adapters.
has_peft = bool(self.lora_info)
if hasattr(torch, "compile") and self.device_type == "cuda" and not has_peft:
try:
logger.info("Compiling DiT decoder...")
self.model.decoder = torch.compile(self.model.decoder, mode="default")
logger.info("torch.compile successful")
except Exception as e:
logger.warning(f"torch.compile failed ({e}), continuing without compilation")
else:
logger.warning("torch.compile is not available on this PyTorch version.")
if has_peft:
logger.info("Skipping torch.compile (incompatible with PEFT LoRA adapters)")
else:
logger.info("torch.compile not available on this device/PyTorch version, skipping")

# Model config for flow matching
self.config = model.config
Expand Down