-
Notifications
You must be signed in to change notification settings - Fork 223
Description
[rank1]: File "/DATA/cby/codes/TurboDiffusion/turbodiffusion/scripts/train.py", line 47, in launch
[rank1]: trainer.train(model, dataloader_train, dataloader_val)
[rank1]: File "/DATA/cby/codes/TurboDiffusion/turbodiffusion/rcm/trainers/trainer_distillation.py", line 196, in train
[rank1]: output_batch, loss, grad_accum_iter = self.training_step(
[rank1]: ^^^^^^^^^^^^^^^^^^^
[rank1]: File "/DATA/cby/codes/TurboDiffusion/turbodiffusion/rcm/trainers/trainer_distillation.py", line 270, in training_step
[rank1]: out_i, loss_i = closure() # loss_i must be scalar
[rank1]: ^^^^^^^^^
[rank1]: File "/DATA/cby/codes/TurboDiffusion/turbodiffusion/rcm/models/t2v_model_distill_rcm.py", line 781, in
[rank1]: yield "critic", lambda: self.training_step_critic(x0_B_C_T_H_W, condition, uncondition, condition_state, iteration)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/DATA/cby/codes/TurboDiffusion/turbodiffusion/rcm/models/t2v_model_distill_rcm.py", line 748, in training_step_critic
[rank1]: x0_theta_fake_B_C_T_H_W = self.denoise(D_xt_theta_B_C_T_H_W, condition_state, D_time_B_T, condition, net_type="fake_score").x0
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/DATA/cby/codes/TurboDiffusion/turbodiffusion/rcm/models/t2v_model_distill_rcm.py", line 506, in denoise
[rank1]: net_output_B_C_T_H_W = net(
[rank1]: ^^^^
[rank1]: File "/home/gpu/miniconda3/envs/turbodiffusion/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/gpu/miniconda3/envs/turbodiffusion/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1879, in _call_impl
[rank1]: return inner()
[rank1]: ^^^^^^^
[rank1]: File "/home/gpu/miniconda3/envs/turbodiffusion/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1827, in inner
[rank1]: result = forward_call(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/DATA/cby/codes/TurboDiffusion/turbodiffusion/rcm/networks/wan2pt1.py", line 703, in forward
[rank1]: x_B_L_D = block(x_B_L_D, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/gpu/miniconda3/envs/turbodiffusion/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/gpu/miniconda3/envs/turbodiffusion/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1879, in _call_impl
[rank1]: return inner()
[rank1]: ^^^^^^^
[rank1]: File "/home/gpu/miniconda3/envs/turbodiffusion/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1827, in inner
[rank1]: result = forward_call(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/gpu/miniconda3/envs/turbodiffusion/lib/python3.12/site-packages/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py", line 171, in forward
[rank1]: return self.checkpoint_fn( # type: ignore[misc]
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/gpu/miniconda3/envs/turbodiffusion/lib/python3.12/site-packages/torch/_compile.py", line 53, in inner
[rank1]: return disable_fn(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/gpu/miniconda3/envs/turbodiffusion/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 929, in _fn
[rank1]: return fn(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/gpu/miniconda3/envs/turbodiffusion/lib/python3.12/site-packages/torch/utils/checkpoint.py", line 495, in checkpoint
[rank1]: ret = function(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/gpu/miniconda3/envs/turbodiffusion/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/gpu/miniconda3/envs/turbodiffusion/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/DATA/cby/codes/TurboDiffusion/turbodiffusion/rcm/networks/wan2pt1.py", line 404, in forward
[rank1]: y = self.self_attn((self.norm1(x).float() * (1 + e[1]) + e[0]).type_as(x), seq_lens, freqs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/gpu/miniconda3/envs/turbodiffusion/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/gpu/miniconda3/envs/turbodiffusion/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/DATA/cby/codes/TurboDiffusion/turbodiffusion/rcm/networks/wan2pt1.py", line 269, in forward
[rank1]: x = self.attn_op(rope_apply(q, freqs), rope_apply(k, freqs), v)
[rank1]: ^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/DATA/cby/codes/TurboDiffusion/turbodiffusion/rcm/networks/wan2pt1.py", line 176, in rope_apply
[rank1]: rotated = flash_apply_rotary_emb(x.to(torch.float32), cos, sin, interleaved=True, inplace=False)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/gpu/miniconda3/envs/turbodiffusion/lib/python3.12/site-packages/flash_attn/layers/rotary.py", line 121, in apply_rotary_emb
[rank1]: return ApplyRotaryEmb.apply(
[rank1]: ^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/gpu/miniconda3/envs/turbodiffusion/lib/python3.12/site-packages/torch/autograd/function.py", line 576, in apply
[rank1]: return super().apply(*args, **kwargs) # type: ignore[misc]
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/gpu/miniconda3/envs/turbodiffusion/lib/python3.12/site-packages/flash_attn/layers/rotary.py", line 51, in forward
[rank1]: out = apply_rotary(
[rank1]: ^^^^^^^^^^^^^
[rank1]: File "/home/gpu/miniconda3/envs/turbodiffusion/lib/python3.12/site-packages/flash_attn/ops/triton/rotary.py", line 159, in apply_rotary
[rank1]: torch.library.wrap_triton(rotary_kernel)[grid](
[rank1]: File "/home/gpu/miniconda3/envs/turbodiffusion/lib/python3.12/site-packages/torch/_higher_order_ops/triton_kernel_wrap.py", line 2038, in call
[rank1]: return tracing_triton_hopifier_singleton.call_triton_kernel(
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/gpu/miniconda3/envs/turbodiffusion/lib/python3.12/site-packages/torch/_higher_order_ops/triton_kernel_wrap.py", line 1882, in call_triton_kernel
[rank1]: return self.call_HOP(variable, grids, combined_args_raw, tx)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/gpu/miniconda3/envs/turbodiffusion/lib/python3.12/site-packages/torch/_higher_order_ops/triton_kernel_wrap.py", line 1992, in call_HOP
[rank1]: return triton_kernel_wrapper_mutation(
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/gpu/miniconda3/envs/turbodiffusion/lib/python3.12/site-packages/torch/_higher_order_ops/triton_kernel_wrap.py", line 973, in call
[rank1]: return super().call(
[rank1]: ^^^^^^^^^^^^^^^^^
[rank1]: File "/home/gpu/miniconda3/envs/turbodiffusion/lib/python3.12/site-packages/torch/_ops.py", line 524, in call
[rank1]: return wrapper()
[rank1]: ^^^^^^^^^
[rank1]: File "/home/gpu/miniconda3/envs/turbodiffusion/lib/python3.12/site-packages/torch/_ops.py", line 520, in wrapper
[rank1]: return self.dispatch(
[rank1]: ^^^^^^^^^^^^^^
[rank1]: File "/home/gpu/miniconda3/envs/turbodiffusion/lib/python3.12/site-packages/torch/_ops.py", line 418, in dispatch
[rank1]: raise NotImplementedError(
[rank1]: NotImplementedError: There was no rule registered for HOP triton_kernel_wrapper_mutation and mode <torch.utils.checkpoint._CachingTorchDispatchMode object at 0x7fccb85b5100>. We recommend filing an issue.
Environment:
flash_attn 2.8.3
spas_sage_attn 0.1.0
torch 2.8.0
triton 3.4.0