Skip to content

NotImplementedError #97

@chengby2359

Description

@chengby2359

[rank1]: File "/DATA/cby/codes/TurboDiffusion/turbodiffusion/scripts/train.py", line 47, in launch
[rank1]: trainer.train(model, dataloader_train, dataloader_val)
[rank1]: File "/DATA/cby/codes/TurboDiffusion/turbodiffusion/rcm/trainers/trainer_distillation.py", line 196, in train
[rank1]: output_batch, loss, grad_accum_iter = self.training_step(
[rank1]: ^^^^^^^^^^^^^^^^^^^
[rank1]: File "/DATA/cby/codes/TurboDiffusion/turbodiffusion/rcm/trainers/trainer_distillation.py", line 270, in training_step
[rank1]: out_i, loss_i = closure() # loss_i must be scalar
[rank1]: ^^^^^^^^^
[rank1]: File "/DATA/cby/codes/TurboDiffusion/turbodiffusion/rcm/models/t2v_model_distill_rcm.py", line 781, in
[rank1]: yield "critic", lambda: self.training_step_critic(x0_B_C_T_H_W, condition, uncondition, condition_state, iteration)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/DATA/cby/codes/TurboDiffusion/turbodiffusion/rcm/models/t2v_model_distill_rcm.py", line 748, in training_step_critic
[rank1]: x0_theta_fake_B_C_T_H_W = self.denoise(D_xt_theta_B_C_T_H_W, condition_state, D_time_B_T, condition, net_type="fake_score").x0
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/DATA/cby/codes/TurboDiffusion/turbodiffusion/rcm/models/t2v_model_distill_rcm.py", line 506, in denoise
[rank1]: net_output_B_C_T_H_W = net(
[rank1]: ^^^^
[rank1]: File "/home/gpu/miniconda3/envs/turbodiffusion/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/gpu/miniconda3/envs/turbodiffusion/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1879, in _call_impl
[rank1]: return inner()
[rank1]: ^^^^^^^
[rank1]: File "/home/gpu/miniconda3/envs/turbodiffusion/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1827, in inner
[rank1]: result = forward_call(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/DATA/cby/codes/TurboDiffusion/turbodiffusion/rcm/networks/wan2pt1.py", line 703, in forward
[rank1]: x_B_L_D = block(x_B_L_D, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/gpu/miniconda3/envs/turbodiffusion/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/gpu/miniconda3/envs/turbodiffusion/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1879, in _call_impl
[rank1]: return inner()
[rank1]: ^^^^^^^
[rank1]: File "/home/gpu/miniconda3/envs/turbodiffusion/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1827, in inner
[rank1]: result = forward_call(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/gpu/miniconda3/envs/turbodiffusion/lib/python3.12/site-packages/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py", line 171, in forward
[rank1]: return self.checkpoint_fn( # type: ignore[misc]
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/gpu/miniconda3/envs/turbodiffusion/lib/python3.12/site-packages/torch/_compile.py", line 53, in inner
[rank1]: return disable_fn(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/gpu/miniconda3/envs/turbodiffusion/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 929, in _fn
[rank1]: return fn(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/gpu/miniconda3/envs/turbodiffusion/lib/python3.12/site-packages/torch/utils/checkpoint.py", line 495, in checkpoint
[rank1]: ret = function(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/gpu/miniconda3/envs/turbodiffusion/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/gpu/miniconda3/envs/turbodiffusion/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/DATA/cby/codes/TurboDiffusion/turbodiffusion/rcm/networks/wan2pt1.py", line 404, in forward
[rank1]: y = self.self_attn((self.norm1(x).float() * (1 + e[1]) + e[0]).type_as(x), seq_lens, freqs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/gpu/miniconda3/envs/turbodiffusion/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/gpu/miniconda3/envs/turbodiffusion/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/DATA/cby/codes/TurboDiffusion/turbodiffusion/rcm/networks/wan2pt1.py", line 269, in forward
[rank1]: x = self.attn_op(rope_apply(q, freqs), rope_apply(k, freqs), v)
[rank1]: ^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/DATA/cby/codes/TurboDiffusion/turbodiffusion/rcm/networks/wan2pt1.py", line 176, in rope_apply
[rank1]: rotated = flash_apply_rotary_emb(x.to(torch.float32), cos, sin, interleaved=True, inplace=False)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/gpu/miniconda3/envs/turbodiffusion/lib/python3.12/site-packages/flash_attn/layers/rotary.py", line 121, in apply_rotary_emb
[rank1]: return ApplyRotaryEmb.apply(
[rank1]: ^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/gpu/miniconda3/envs/turbodiffusion/lib/python3.12/site-packages/torch/autograd/function.py", line 576, in apply
[rank1]: return super().apply(*args, **kwargs) # type: ignore[misc]
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/gpu/miniconda3/envs/turbodiffusion/lib/python3.12/site-packages/flash_attn/layers/rotary.py", line 51, in forward
[rank1]: out = apply_rotary(
[rank1]: ^^^^^^^^^^^^^
[rank1]: File "/home/gpu/miniconda3/envs/turbodiffusion/lib/python3.12/site-packages/flash_attn/ops/triton/rotary.py", line 159, in apply_rotary
[rank1]: torch.library.wrap_triton(rotary_kernel)[grid](
[rank1]: File "/home/gpu/miniconda3/envs/turbodiffusion/lib/python3.12/site-packages/torch/_higher_order_ops/triton_kernel_wrap.py", line 2038, in call
[rank1]: return tracing_triton_hopifier_singleton.call_triton_kernel(
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/gpu/miniconda3/envs/turbodiffusion/lib/python3.12/site-packages/torch/_higher_order_ops/triton_kernel_wrap.py", line 1882, in call_triton_kernel
[rank1]: return self.call_HOP(variable, grids, combined_args_raw, tx)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/gpu/miniconda3/envs/turbodiffusion/lib/python3.12/site-packages/torch/_higher_order_ops/triton_kernel_wrap.py", line 1992, in call_HOP
[rank1]: return triton_kernel_wrapper_mutation(
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/gpu/miniconda3/envs/turbodiffusion/lib/python3.12/site-packages/torch/_higher_order_ops/triton_kernel_wrap.py", line 973, in call
[rank1]: return super().call(
[rank1]: ^^^^^^^^^^^^^^^^^
[rank1]: File "/home/gpu/miniconda3/envs/turbodiffusion/lib/python3.12/site-packages/torch/_ops.py", line 524, in call
[rank1]: return wrapper()
[rank1]: ^^^^^^^^^
[rank1]: File "/home/gpu/miniconda3/envs/turbodiffusion/lib/python3.12/site-packages/torch/_ops.py", line 520, in wrapper
[rank1]: return self.dispatch(
[rank1]: ^^^^^^^^^^^^^^
[rank1]: File "/home/gpu/miniconda3/envs/turbodiffusion/lib/python3.12/site-packages/torch/_ops.py", line 418, in dispatch
[rank1]: raise NotImplementedError(
[rank1]: NotImplementedError: There was no rule registered for HOP triton_kernel_wrapper_mutation and mode <torch.utils.checkpoint._CachingTorchDispatchMode object at 0x7fccb85b5100>. We recommend filing an issue.

Environment:
flash_attn 2.8.3
spas_sage_attn 0.1.0
torch 2.8.0
triton 3.4.0

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions