-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
codegen error: index_map.find(alloc_dom[i]) != index_map.end() INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/index_compute.cpp" #871
Comments
Here's the failing fusion segment:
The rfactor of
No idea why there are only so small number of mappings. |
The kernel looks like this:
Error is happening when handling the |
|
I'm sorry for the long reproducer. I'm seeing a "Couldn't find allocation mapping for" error. Is this the same problem? I encountered this error while trying to run HF's Qwen 2 model with Thunder (Lightning-AI/lightning-thunder#1406). This model is important to support soon. @jacobhinkle or @naoyam do you have an estimate of how much time it could take to fix this bug? Is there any change we could introduce to the fusion definition as a workaround? RuntimeError: INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/runtime/fusion_kernel_runtime.cpp":358, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. Detected exception while compiling fusion segments in parallel. Error messages from all threads are printed below.
Error from segmentation group 11: INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/index_compute.cpp":1990, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. Couldn't find allocation mapping for T125_l___bfloat[ iblockIdx.x846{( ceilDiv(2, blockDim.x) )}, ithreadIdx.x847{blockDim.x}, iS855{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(32768, blockDim.y) ), 16) ), 1) ), gridDim.y) )}, iblockIdx.y854{gridDim.y}, ithreadIdx.y849{blockDim.y}, iUS853{1}, iUR851{16}, bS505{1} ] ca_pos( 6 ) dim: 2 id: iS507{2}
Exception raised from getNonGlobalConsumerStridedIndices at /opt/pytorch/nvfuser/csrc/index_compute.cpp:1990 (most recent call first): # CUDA devices:
# 0: NVIDIA RTX 6000 Ada Generation
# 1: NVIDIA RTX 6000 Ada Generation
# torch version: 2.6.0a0+git8b08559
# cuda version: 12.3
# nvfuser version: 0.2.17+gitdf32dce
import torch
from nvfuser import FusionDefinition, DataType
def nvfuser_fusion_id13(fd : FusionDefinition) -> None :
T0 = fd.define_tensor(shape=[1, 28, 32768, 2], contiguity=[None, True, False, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[3, 2, 1, 0])
T1 = fd.define_tensor(shape=[1, 32768, 2], contiguity=[None, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0])
T2 = fd.define_tensor(shape=[1, 4, 32768, 2], contiguity=[None, True, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[3, 2, 1, 0])
T3 = fd.define_tensor(shape=[1, 28, 32768, 2], contiguity=[None, True, False, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[3, 2, 1, 0])
T4 = fd.define_tensor(shape=[1, 28, 32768, 2], contiguity=[None, True, False, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[3, 2, 1, 0])
T5 = fd.define_tensor(shape=[1, 32768, 2], contiguity=[None, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0])
T6 = fd.define_tensor(shape=[1, 4, 32768, 2], contiguity=[None, True, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[3, 2, 1, 0])
T13 = fd.ops.reshape(T0, new_shape=[1, 4, 7, 32768, 2])
T14 = fd.ops.cast(T13, dtype=DataType.Float)
T15 = fd.ops.sum(T14, dims=[0, 2], keepdim=False, dtype=DataType.Null)
T16 = fd.ops.cast(T15, dtype=DataType.BFloat16)
T23 = fd.ops.broadcast_in_dim(T16, shape=[1, 4, 1, 32768, 2], broadcast_dims=[1, 3, 4])
T24 = fd.ops.cast(T23, dtype=DataType.Float)
T25 = fd.ops.sum(T24, dims=[0, 2], keepdim=False, dtype=DataType.Null)
T26 = fd.ops.cast(T25, dtype=DataType.BFloat16)
T32 = fd.ops.broadcast_in_dim(T1, shape=[1, 1, 32768, 2], broadcast_dims=[0, 2, 3])
T38 = fd.ops.broadcast_in_dim(T26, shape=[1, 4, 32768, 2], broadcast_dims=[1, 2, 3])
T44 = fd.ops.broadcast_in_dim(T32, shape=[1, 28, 32768, 2], broadcast_dims=[0, 1, 2, 3])
T45 = fd.ops.cast(T38, dtype=DataType.Float)
T46 = fd.ops.cast(T2, dtype=DataType.Float)
T52 = fd.ops.broadcast_in_dim(T32, shape=[1, 4, 32768, 2], broadcast_dims=[0, 1, 2, 3])
T53 = fd.ops.cast(T3, dtype=DataType.Float)
T54 = fd.ops.cast(T44, dtype=DataType.Float)
T55 = fd.ops.add(T46, T45)
T56 = fd.ops.cast(T52, dtype=DataType.Float)
T63 = fd.ops.reshape(T4, new_shape=[1, 4, 7, 32768, 2])
T64 = fd.ops.mul(T54, T53)
T65 = fd.ops.mul(T56, T55)
T66 = fd.ops.cast(T63, dtype=DataType.Float)
T67 = fd.ops.cast(T64, dtype=DataType.BFloat16)
T68 = fd.ops.cast(T65, dtype=DataType.BFloat16)
T69 = fd.ops.sum(T66, dims=[0, 2], keepdim=False, dtype=DataType.Null)
T85 = fd.ops.slice(T67, start_indices=[0, 0, 0, 0], end_indices=[1, 28, 32768, 1], strides=[1, 1, 1, 1])
T101 = fd.ops.slice(T68, start_indices=[0, 0, 0, 0], end_indices=[1, 4, 32768, 1], strides=[1, 1, 1, 1])
T102 = fd.ops.cast(T69, dtype=DataType.BFloat16)
T103 = fd.ops.cast(T85, dtype=DataType.Float)
T104 = fd.ops.cast(T101, dtype=DataType.Float)
T111 = fd.ops.broadcast_in_dim(T102, shape=[1, 4, 1, 32768, 2], broadcast_dims=[1, 3, 4])
T112 = fd.ops.neg(T103)
T113 = fd.ops.neg(T104)
T114 = fd.ops.cast(T111, dtype=DataType.Float)
T120 = fd.ops.broadcast_in_dim(T5, shape=[1, 1, 32768, 2], broadcast_dims=[0, 2, 3])
T136 = fd.ops.slice(T67, start_indices=[0, 0, 0, 1], end_indices=[1, 28, 32768, 2], strides=[1, 1, 1, 1])
T137 = fd.ops.cast(T112, dtype=DataType.BFloat16)
T153 = fd.ops.slice(T68, start_indices=[0, 0, 0, 1], end_indices=[1, 4, 32768, 2], strides=[1, 1, 1, 1])
T154 = fd.ops.cast(T113, dtype=DataType.BFloat16)
T155 = fd.ops.sum(T114, dims=[0, 2], keepdim=False, dtype=DataType.Null)
T161 = fd.ops.broadcast_in_dim(T120, shape=[1, 28, 32768, 2], broadcast_dims=[0, 1, 2, 3])
S162 = fd.define_scalar(0.00000, dtype=DataType.Double)
T163 = fd.ops.pad(T136, [0, 1, 0, 0, 0, 0, 0, 0], S162)
S164 = fd.define_scalar(0.00000, dtype=DataType.Double)
T165 = fd.ops.pad(T137, [1, 0, 0, 0, 0, 0, 0, 0], S164)
T171 = fd.ops.broadcast_in_dim(T120, shape=[1, 4, 32768, 2], broadcast_dims=[0, 1, 2, 3])
S172 = fd.define_scalar(0.00000, dtype=DataType.Double)
T173 = fd.ops.pad(T153, [0, 1, 0, 0, 0, 0, 0, 0], S172)
S174 = fd.define_scalar(0.00000, dtype=DataType.Double)
T175 = fd.ops.pad(T154, [1, 0, 0, 0, 0, 0, 0, 0], S174)
T176 = fd.ops.cast(T155, dtype=DataType.BFloat16)
T177 = fd.ops.cast(T161, dtype=DataType.Float)
T178 = fd.ops.cast(T163, dtype=DataType.Float)
T179 = fd.ops.cast(T165, dtype=DataType.Float)
T180 = fd.ops.cast(T171, dtype=DataType.Float)
T181 = fd.ops.cast(T173, dtype=DataType.Float)
T182 = fd.ops.cast(T175, dtype=DataType.Float)
T188 = fd.ops.broadcast_in_dim(T176, shape=[1, 4, 32768, 2], broadcast_dims=[1, 2, 3])
T189 = fd.ops.mul(T177, T53)
T190 = fd.ops.add(T179, T178)
T191 = fd.ops.mul(T180, T55)
T192 = fd.ops.add(T182, T181)
T193 = fd.ops.cast(T188, dtype=DataType.Float)
T194 = fd.ops.cast(T6, dtype=DataType.Float)
T195 = fd.ops.add(T190, T189)
T196 = fd.ops.add(T192, T191)
T197 = fd.ops.add(T194, T193)
T198 = fd.ops.cast(T195, dtype=DataType.BFloat16)
T199 = fd.ops.cast(T196, dtype=DataType.BFloat16)
T200 = fd.ops.cast(T197, dtype=DataType.BFloat16)
T201 = fd.ops.permute(T198, dims=[0, 2, 1, 3])
T202 = fd.ops.permute(T199, dims=[0, 2, 1, 3])
T203 = fd.ops.permute(T200, dims=[0, 2, 1, 3])
T208 = fd.ops.reshape(T201, new_shape=[1, 32768, 56])
T213 = fd.ops.reshape(T202, new_shape=[1, 32768, 8])
T218 = fd.ops.reshape(T203, new_shape=[1, 32768, 8])
T219 = fd.ops.cast(T208, dtype=DataType.Float)
T220 = fd.ops.cast(T213, dtype=DataType.Float)
T221 = fd.ops.cast(T218, dtype=DataType.Float)
T222 = fd.ops.sum(T219, dims=[0, 1], keepdim=False, dtype=DataType.Null)
T226 = fd.ops.reshape(T208, new_shape=[32768, 56])
T227 = fd.ops.sum(T220, dims=[0, 1], keepdim=False, dtype=DataType.Null)
T231 = fd.ops.reshape(T213, new_shape=[32768, 8])
T232 = fd.ops.sum(T221, dims=[0, 1], keepdim=False, dtype=DataType.Null)
T236 = fd.ops.reshape(T218, new_shape=[32768, 8])
T237 = fd.ops.cast(T222, dtype=DataType.BFloat16)
T238 = fd.ops.permute(T226, dims=[1, 0])
T239 = fd.ops.cast(T227, dtype=DataType.BFloat16)
T240 = fd.ops.permute(T231, dims=[1, 0])
T241 = fd.ops.cast(T232, dtype=DataType.BFloat16)
T242 = fd.ops.permute(T236, dims=[1, 0])
fd.add_output(T236)
fd.add_output(T242)
fd.add_output(T241)
fd.add_output(T231)
fd.add_output(T240)
fd.add_output(T239)
fd.add_output(T226)
fd.add_output(T238)
fd.add_output(T237)
with FusionDefinition() as fd:
nvfuser_fusion_id13(fd)
inputs = [
torch.randn(7340026, dtype=torch.bfloat16, device='cuda:0').as_strided((1, 28, 32768, 2), (7340032, 262144, 8, 1)),
torch.testing.make_tensor((1, 32768, 2), dtype=torch.bfloat16, device='cuda:0'),
torch.testing.make_tensor((1, 4, 32768, 2), dtype=torch.bfloat16, device='cuda:0'),
torch.randn(7340026, dtype=torch.bfloat16, device='cuda:0').as_strided((1, 28, 32768, 2), (7340032, 262144, 8, 1)),
torch.randn(7340026, dtype=torch.bfloat16, device='cuda:0').as_strided((1, 28, 32768, 2), (7340032, 262144, 8, 1)),
torch.testing.make_tensor((1, 32768, 2), dtype=torch.bfloat16, device='cuda:0'),
torch.testing.make_tensor((1, 4, 32768, 2), dtype=torch.bfloat16, device='cuda:0'),
]
fd.execute(inputs) |
Slightly smaller repro containing only the failing segment, simplified as much as possible, which is scheduled by the Reduction scheduler # CUDA devices:
# 0: NVIDIA H100 80GB HBM3
# torch version: 2.6.0a0+gitffb7a08
# cuda version: 12.6
# nvfuser version: 0.2.22+git6912435
import torch
from nvfuser import FusionDefinition, DataType
def nvfuser_fusion_id0(fd : FusionDefinition) -> None :
T0 = fd.define_tensor(shape=[28, 32768, 2], contiguity=[True, False, True], dtype=DataType.Float, is_cpu=False)
T1 = fd.define_tensor(shape=[32768, 2], contiguity=[True, True], dtype=DataType.Float, is_cpu=False)
T2 = fd.define_tensor(shape=[28, 32768, 1], contiguity=[True, False, None], dtype=DataType.Float, is_cpu=False, stride_order=[2, 1, 0])
T3 = fd.define_tensor(shape=[28, 32768, 1], contiguity=[True, False, None], dtype=DataType.Float, is_cpu=False, stride_order=[2, 1, 0])
T7 = fd.ops.pad(T2, [1, 0], None)
T11 = fd.ops.pad(T3, [0, 1], None)
T12 = fd.ops.add(T7, T11)
T13 = fd.ops.broadcast(T1, is_broadcast_dim=[True, False, False])
T14 = fd.ops.mul(T13, T0)
T15 = fd.ops.add(T12, T14)
T16 = fd.ops.permute(T15, dims=[1, 0, 2])
T20 = fd.ops.reshape(T16, new_shape=[32768, 56])
T21 = fd.ops.sum(T20, dims=[0], keepdim=False, dtype=DataType.Float)
fd.add_output(T21)
fd.add_output(T20)
fd.add_output(T13)
with FusionDefinition() as fd:
nvfuser_fusion_id0(fd)
inputs = [
torch.testing.make_tensor((28, 32768, 2), dtype=torch.float32, device='cuda:0'),
torch.testing.make_tensor((32768, 2), dtype=torch.float32, device='cuda:0'),
torch.testing.make_tensor((28, 32768, 2), dtype=torch.float32, device='cuda:0'),
torch.testing.make_tensor((28, 32768, 2), dtype=torch.float32, device='cuda:0'),
]
fd.execute(inputs) |
Thanks for reporting @IvanYashchuk and thanks for the small repro @jacobhinkle. Let me look into it. |
@jjsjann123 Could you update the original repro? It doesn't seem to work with the latest nvfuser.
|
Updated. |
Moved Ivan's issue to #3374. |
Should work fine with #3387. Let me know if not. |
Repro python script
Error message
The text was updated successfully, but these errors were encountered: