Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

codegen error: index_map.find(alloc_dom[i]) != index_map.end() INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/index_compute.cpp" #871

Closed
jjsjann123 opened this issue Sep 12, 2023 · 11 comments
Assignees

Comments

@jjsjann123
Copy link
Collaborator

jjsjann123 commented Sep 12, 2023

Repro python script

import torch
from nvfuser import FusionDefinition, DataType

def nvfuser_fusion_id0(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[1, -1, -1], contiguity=[None, True, True], dtype=DataType.Half, is_cpu=False)
    T6 = fd.ops.broadcast_in_dim(T0, shape=[1, 2, 5, 1], broadcast_dims=[0, 1, 2])
    T13 = fd.ops.broadcast_in_dim(T6, shape=[1, 2, 5, 1, 1], broadcast_dims=[0, 1, 2, 4])
    T14 = fd.ops.slice(T13, start_indices=[0, 0, 0, 0, 0], end_indices=[1, 2, 2, 1, 1], strides=[1, 1, 1, 1, 1])
    T15 = fd.ops.slice(T13, start_indices=[0, 0, 2, 0, 0], end_indices=[1, 2, 5, 1, 1], strides=[1, 1, 1, 1, 1])
    T23 = fd.ops.broadcast_in_dim(T14, shape=[1, 2, 2, 1, 1, 1], broadcast_dims=[0, 1, 2, 4, 5])
    T31 = fd.ops.broadcast_in_dim(T23, shape=[1, 2, 2, 2, 1, 1], broadcast_dims=[0, 1, 2, 3, 4, 5])
    T39 = fd.ops.broadcast_in_dim(T15, shape=[1, 2, 3, 1, 1, 1], broadcast_dims=[0, 1, 2, 4, 5])
    T47 = fd.ops.broadcast_in_dim(T39, shape=[1, 2, 3, 1, 1, 1], broadcast_dims=[0, 1, 2, 3, 4, 5])
    T48 = fd.ops.reshape(T31, new_shape=[1, 2, 4, 1, 1])
    T49 = fd.ops.set(T48)
    T50 = fd.ops.reshape(T47, new_shape=[1, 2, 3, 1, 1])
    T51 = fd.ops.set(T50)
    T52 = fd.ops.cat([T49, T51], dim=2)
    fd.add_output(T52)

with FusionDefinition() as fd:
    nvfuser_fusion_id0(fd)

inputs = [
    torch.randn((10,), dtype=torch.float16, device='cuda:0').as_strided((1, 2, 5), (10, 5, 1)),
]
fd.execute(inputs)

Error message

Exception in thread pool task: index_map.find(alloc_dom[i]) != index_map.end() INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/index_compute.cpp":2198, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. Couldn't find allocation mapping for T5_l[ iblockIdx.x224{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( ( (nvfuser_index_t)(2) ) * ( 2 * ( 1 * 1 ) ) ) ), 128) ), 1) ), 1) )}, iUS225{1}, iS223{1}, ithreadIdx.x221{128} ] ca_pos( 4 ) dim: 2 id: iS24{2}rf, loops:  iblockIdx.x154{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( ( (nvfuser_index_t)(2) ) * ( ( 2 * ( (nvfuser_index_t)(2) ) ) * ( 1 * 1 ) ) ) ), 128) ), 1) ), 1) )} iUS155{1} iS153{1} ithreadIdx.x151{128}
Exception raised from getNonGlobalConsumerStridedIndices at /opt/pytorch/nvfuser/csrc/index_compute.cpp:2198 (most recent call first):
frame #0: nvfuser::nvfCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x8d (0x7f176ea3aa59 in /usr/local/lib/python3.10/site-packages/torch/lib/libnvfuser_codegen.so)
frame #1: nvfuser::nvfErrorFail(char const*, char const*, unsigned int, char const*, std::string const&) + 0x53 (0x7f176eb74913 in /usr/local/lib/python3.10/site-packages/torch/lib/libnvfuser_codegen.so)
frame #2: nvfuser::Index::getNonGlobalConsumerStridedIndices(nvfuser::TensorView const*, std::vector<nvfuser::kir::ForLoop*, std::allocator<nvfuser::kir::ForLoop*> > const&, std::unordered_set<nvfuser::kir::ForLoop*, std::hash<nvfuser::kir::ForLoop*>, std::equal_to<nvfuser::kir::ForLoop*>, std::allocator<nvfuser::kir::ForLoop*> > const&, std::unordered_map<nvfuser::IterDomain*, nvfuser::Val*, std::hash<nvfuser::IterDomain*>, std::equal_to<nvfuser::IterDomain*>, std::allocator<std::pair<nvfuser::IterDomain* const, nvfuser::Val*> > > const&) + 0x11f7 (0x7f176ec110b7 in /usr/local/lib/python3.10/site-packages/torch/lib/libnvfuser_codegen.so)
frame #3: nvfuser::Index::getConsumerStridedIndices(nvfuser::TensorView*, std::vector<nvfuser::kir::ForLoop*, std::allocator<nvfuser::kir::ForLoop*> > const&, std::unordered_set<nvfuser::kir::ForLoop*, std::hash<nvfuser::kir::ForLoop*>, std::equal_to<nvfuser::kir::ForLoop*>, std::allocator<nvfuser::kir::ForLoop*> > const&, std::unordered_map<int, nvfuser::Val*, std::hash<int>, std::equal_to<int>, std::allocator<std::pair<int const, nvfuser::Val*> > > const&, bool) + 0xf3 (0x7f176ec11413 in /usr/local/lib/python3.10/site-packages/torch/lib/libnvfuser_codegen.so)
frame #4: nvfuser::Index::getConsumerIndex(nvfuser::TensorView*, std::vector<nvfuser::kir::ForLoop*, std::allocator<nvfuser::kir::ForLoop*> > const&, std::unordered_set<nvfuser::kir::ForLoop*, std::hash<nvfuser::kir::ForLoop*>, std::equal_to<nvfuser::kir::ForLoop*>, std::allocator<nvfuser::kir::ForLoop*> > const&, std::unordered_map<int, nvfuser::Val*, std::hash<int>, std::equal_to<int>, std::allocator<std::pair<int const, nvfuser::Val*> > > const&, bool) + 0x2e (0x7f176ec118de in /usr/local/lib/python3.10/site-packages/torch/lib/libnvfuser_codegen.so)
frame #5: nvfuser::IndexLowering::handle(nvfuser::LoadStoreOp const*) + 0x164 (0x7f176ed71f84 in /usr/local/lib/python3.10/site-packages/torch/lib/libnvfuser_codegen.so)
frame #6: nvfuser::IndexLowering::handle(nvfuser::kir::IfThenElse const*) + 0xcf (0x7f176ed7589f in /usr/local/lib/python3.10/site-packages/torch/lib/libnvfuser_codegen.so)
frame #7: nvfuser::IndexLowering::handle(nvfuser::kir::ForLoop const*) + 0xcf (0x7f176ed7511f in /usr/local/lib/python3.10/site-packages/torch/lib/libnvfuser_codegen.so)
frame #8: nvfuser::IndexLowering::handle(nvfuser::kir::ForLoop const*) + 0xcf (0x7f176ed7511f in /usr/local/lib/python3.10/site-packages/torch/lib/libnvfuser_codegen.so)
frame #9: nvfuser::IndexLowering::handle(nvfuser::kir::ForLoop const*) + 0xcf (0x7f176ed7511f in /usr/local/lib/python3.10/site-packages/torch/lib/libnvfuser_codegen.so)
frame #10: nvfuser::IndexLowering::handle(nvfuser::kir::IfThenElse const*) + 0xcf (0x7f176ed7589f in /usr/local/lib/python3.10/site-packages/torch/lib/libnvfuser_codegen.so)
frame #11: nvfuser::IndexLowering::handle(nvfuser::kir::ForLoop const*) + 0xcf (0x7f176ed7511f in /usr/local/lib/python3.10/site-packages/torch/lib/libnvfuser_codegen.so)
frame #12: nvfuser::IndexLowering::generate(std::vector<nvfuser::Expr*, std::allocator<nvfuser::Expr*> > const&) + 0x2f (0x7f176ed6e8ef in /usr/local/lib/python3.10/site-packages/torch/lib/libnvfuser_codegen.so)
frame #13: nvfuser::GpuLower::lower(nvfuser::Fusion*) + 0x261c (0x7f176edcdefc in /usr/local/lib/python3.10/site-packages/torch/lib/libnvfuser_codegen.so)
frame #14: nvfuser::GpuLower::GpuLower(nvfuser::Fusion*, nvfuser::CompileParams const&) + 0x66d (0x7f176eb97d1d in /usr/local/lib/python3.10/site-packages/torch/lib/libnvfuser_codegen.so)
frame #15: nvfuser::FusionExecutor::compileFusion(nvfuser::Fusion*, nvfuser::KernelArgumentHolder const&, nvfuser::LaunchParams const&, nvfuser::CompileParams) + 0x41f (0x7f176eb89d7f in /usr/local/lib/python3.10/site-packages/torch/lib/libnvfuser_codegen.so)
frame #16: nvfuser::FusionKernelRuntime::compileKernel(nvfuser::KernelArgumentHolder const&, nvfuser::SegmentedGroup*) + 0x193 (0x7f176ecc3673 in /usr/local/lib/python3.10/site-packages/torch/lib/libnvfuser_codegen.so)
frame #17: <unknown function> + 0x41e827 (0x7f176ecc3827 in /usr/local/lib/python3.10/site-packages/torch/lib/libnvfuser_codegen.so)
frame #18: c10::ThreadPool::main_loop(unsigned long) + 0x2b3 (0x7f17d1a781a3 in /usr/local/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #19: <unknown function> + 0xdc253 (0x7f17d12b0253 in /usr/lib/x86_64-linux-gnu/libstdc++.so.6)
frame #20: <unknown function> + 0x94b43 (0x7f17fac14b43 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #21: <unknown function> + 0x126a00 (0x7f17faca6a00 in /usr/lib/x86_64-linux-gnu/libc.so.6)
@naoyam
Copy link
Collaborator

naoyam commented Sep 13, 2023

Here's the failing fusion segment:

Inputs:
  T4_g[ iS254{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( ( (nvfuser_index_t)(2) ) * ( ( (nvfuser_index_t)(5) ) * ( 1 * 1 ) ) ) ), 128) ), 1) ), 1) )}, iS255{1}, iS253{1}, iS251{128} ], __half
Outputs:
  T16_g[ iblockIdx.x154{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( ( (nvfuser_index_t)(2) ) * ( ( 2 * ( (nvfuser_index_t)(2) ) ) * ( 1 * 1 ) ) ) ), 128) ), 1) ), 1) )}, iUS155{1}, iS153{1}, ithreadIdx.x151{128} ] ca_pos( 2 ) produce_pos( 4 )
, __half

%kernel_math {
T5_l[ iblockIdx.x224{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( ( (nvfuser_index_t)(2) ) * ( 2 * ( 1 * 1 ) ) ) ), 128) ), 1) ), 1) )}, iUS225{1}, iS223{1}, ithreadIdx.x221{128} ] ca_pos( 4 )
   = slice( T4_g[ iS254{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( ( (nvfuser_index_t)(2) ) * ( ( (nvfuser_index_t)(5) ) * ( 1 * 1 ) ) ) ), 128) ), 1) ), 1) )}, iS255{1}, iS253{1}, iS251{128} ], { {0, 1, 1} {0, 2, 1} {0, 2, 1} {0, 1, 1} {0,
1, 1} } )
T7_l[ iblockIdx.x214{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( ( (nvfuser_index_t)(2) ) * ( ( 2 * 1 ) * ( 1 * 1 ) ) ) ), 128) ), 1) ), 1) )}, iUS215{1}, iS213{1}, ithreadIdx.x211{128} ] ca_pos( 4 ) produce_pos( 4 )
   = broadcast( T5_l[ iblockIdx.x224{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( ( (nvfuser_index_t)(2) ) * ( 2 * ( 1 * 1 ) ) ) ), 128) ), 1) ), 1) )}, iUS225{1}, iS223{1}, ithreadIdx.x221{128} ] ca_pos( 4 ) )
T8_l[ iblockIdx.x204{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( ( (nvfuser_index_t)(2) ) * ( ( 2 * 1 ) * ( 1 * 1 ) ) ) ), 128) ), 1) ), 1) )}, iUS205{1}, iS203{1}, ithreadIdx.x201{128} ] ca_pos( 4 ) produce_pos( 4 )
   = Set( T7_l[ iblockIdx.x214{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( ( (nvfuser_index_t)(2) ) * ( ( 2 * 1 ) * ( 1 * 1 ) ) ) ), 128) ), 1) ), 1) )}, iUS215{1}, iS213{1}, ithreadIdx.x211{128} ] ca_pos( 4 ) produce_pos( 4 ) )
T9_l[ iblockIdx.x194{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( ( (nvfuser_index_t)(2) ) * ( ( 2 * 1 ) * ( 1 * 1 ) ) ) ), 128) ), 1) ), 1) )}, iUS195{1}, iS193{1}, ithreadIdx.x191{128} ] ca_pos( 4 ) produce_pos( 4 )
   = Set( T8_l[ iblockIdx.x204{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( ( (nvfuser_index_t)(2) ) * ( ( 2 * 1 ) * ( 1 * 1 ) ) ) ), 128) ), 1) ), 1) )}, iUS205{1}, iS203{1}, ithreadIdx.x201{128} ] ca_pos( 4 ) produce_pos( 4 ) )
i10 = (nvfuser_index_t)(2);
i116 = (nvfuser_index_t)(2);
T10_l[ iblockIdx.x184{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( ( (nvfuser_index_t)(2) ) * ( ( 2 * 1 ) * ( 1 * 1 ) ) ) ), 128) ), 1) ), 1) )}, iUS185{1}, iS183{1}, ithreadIdx.x181{128} ] ca_pos( 4 ) produce_pos( 4 ) = expand( T9_l[ iblockId
x.x194{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( ( (nvfuser_index_t)(2) ) * ( ( 2 * 1 ) * ( 1 * 1 ) ) ) ), 128) ), 1) ), 1) )}, iUS195{1}, iS193{1}, ithreadIdx.x191{128} ] ca_pos( 4 ) produce_pos( 4 ), {1, i10, 2, i116, 1, 1} )
T15_l[ iblockIdx.x174{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( ( (nvfuser_index_t)(2) ) * ( ( 2 * ( (nvfuser_index_t)(2) ) ) * ( 1 * 1 ) ) ) ), 128) ), 1) ), 1) )}, iUS175{1}, iS173{1}, ithreadIdx.x171{128} ] ca_pos( 4 ) produce_pos( 4 ) =
 view( T10_l[ iblockIdx.x184{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( ( (nvfuser_index_t)(2) ) * ( ( 2 * 1 ) * ( 1 * 1 ) ) ) ), 128) ), 1) ), 1) )}, iUS185{1}, iS183{1}, ithreadIdx.x181{128} ] ca_pos( 4 ) produce_pos( 4 ) )
T23_l[ iblockIdx.x164{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( ( (nvfuser_index_t)(2) ) * ( ( 2 * ( (nvfuser_index_t)(2) ) ) * ( 1 * 1 ) ) ) ), 128) ), 1) ), 1) )}, iUS165{1}, iS163{1}, ithreadIdx.x161{128} ] ca_pos( 4 ) produce_pos( 4 )
   = Set( T15_l[ iblockIdx.x174{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( ( (nvfuser_index_t)(2) ) * ( ( 2 * ( (nvfuser_index_t)(2) ) ) * ( 1 * 1 ) ) ) ), 128) ), 1) ), 1) )}, iUS175{1}, iS173{1}, ithreadIdx.x171{128} ] ca_pos( 4 ) produce_
pos( 4 ) )
T16_g[ iblockIdx.x154{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( ( (nvfuser_index_t)(2) ) * ( ( 2 * ( (nvfuser_index_t)(2) ) ) * ( 1 * 1 ) ) ) ), 128) ), 1) ), 1) )}, iUS155{1}, iS153{1}, ithreadIdx.x151{128} ] ca_pos( 2 ) produce_pos( 4 )
   = Set( T23_l[ iblockIdx.x164{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( ( (nvfuser_index_t)(2) ) * ( ( 2 * ( (nvfuser_index_t)(2) ) ) * ( 1 * 1 ) ) ) ), 128) ), 1) ), 1) )}, iUS165{1}, iS163{1}, ithreadIdx.x161{128} ] ca_pos( 4 ) produce_
pos( 4 ) )
}

The rfactor of T5 looks like: T5_l[ bS21{1}, iS22{( (nvfuser_index_t)(2) )}rf, iS24{2}rf, bS25{1}, bS26{1} ]. When we try to generate the consumer index of T5, the error happens when indexing is24 since these are the only domains that get indexed in index_map:

Map: bS216{( 1 * 1 )} -> 0
Map: bS26{1} -> 0
Map: bS25{1} -> 0
Map: iS22{( (nvfuser_index_t)(2) )}rf -> 0
Map: bS21{1} -> 0

No idea why there are only so small number of mappings.

@jacobhinkle
Copy link
Collaborator

The rfactor of T5 looks like: T5_l[ bS21{1}, iS22{( (nvfuser_index_t)(2) )}rf, iS24{2}rf, bS25{1}, bS26{1} ].

The kernel looks like this:

FOR threadIdx.x in ithreadIdx.x151{128}:
  T5_l[ iblockIdx.x224{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( ( (nvfuser_index_t)(2) ) * ( 2 * ( 1 * 1 ) ) ) ), 128) ), 1) ), 1) )}, iUS225{1}, iS223{1}, ithreadIdx.x221{128} ] ca_pos( 4 ) = ALLOCATE(buffer=T5_l[ iblockIdx.x224{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( ( (nvfuser_index_t)(2) ) * ( 2 * ( 1 * 1 ) ) ) ), 128) ), 1) ), 1) )}, iUS225{1}, iS223{1}, ithreadIdx.x221{128} ] ca_pos( 4 ), mem_type=register, size=1, zero_init=false)
  IF Manual true:
    T5_l[ iblockIdx.x224{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( ( (nvfuser_index_t)(2) ) * ( 2 * ( 1 * 1 ) ) ) ), 128) ), 1) ), 1) )}, iUS225{1}, iS223{1}, ithreadIdx.x221{128} ] ca_pos( 4 )
       = Set.Permute( 0 )
  IF Manual true:
    T5_l[ iblockIdx.x224{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( ( (nvfuser_index_t)(2) ) * ( 2 * ( 1 * 1 ) ) ) ), 128) ), 1) ), 1) )}, iUS225{1}, iS223{1}, ithreadIdx.x221{128} ] ca_pos( 4 )
       = slice( T4_g[ iS254{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( ( (nvfuser_index_t)(2) ) * ( ( (nvfuser_index_t)(5) ) * ( 1 * 1 ) ) ) ), 128) ), 1) ), 1) )}, iS255{1}, iS253{1}, iS251{128} ], { {0, 1, 1} {0, 2, 1} {0, 2, 1} {0, 1, 1} {0, 1, 1} } )
  T7_l[ iblockIdx.x214{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( ( (nvfuser_index_t)(2) ) * ( ( 2 * 1 ) * ( 1 * 1 ) ) ) ), 128) ), 1) ), 1) )}, iUS215{1}, iS213{1}, ithreadIdx.x211{128} ] ca_pos( 4 ) produce_pos( 4 ) = ALLOCATE(buffer=T7_l[ iblockIdx.x214{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( ( (nvfuser_index_t)(2) ) * ( ( 2 * 1 ) * ( 1 * 1 ) ) ) ), 128) ), 1) ), 1) )}, iUS215{1}, iS213{1}, ithreadIdx.x211{128} ] ca_pos( 4 ) produce_pos( 4 ), mem_type=register, size=1, zero_init=false)
  IF Manual true:
    T7_l[ iblockIdx.x214{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( ( (nvfuser_index_t)(2) ) * ( ( 2 * 1 ) * ( 1 * 1 ) ) ) ), 128) ), 1) ), 1) )}, iUS215{1}, iS213{1}, ithreadIdx.x211{128} ] ca_pos( 4 ) produce_pos( 4 )
       = broadcast( T5_l[ iblockIdx.x224{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( ( (nvfuser_index_t)(2) ) * ( 2 * ( 1 * 1 ) ) ) ), 128) ), 1) ), 1) )}, iUS225{1}, iS223{1}, ithreadIdx.x221{128} ] ca_pos( 4 ) )
  T8_l[ iblockIdx.x204{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( ( (nvfuser_index_t)(2) ) * ( ( 2 * 1 ) * ( 1 * 1 ) ) ) ), 128) ), 1) ), 1) )}, iUS205{1}, iS203{1}, ithreadIdx.x201{128} ] ca_pos( 4 ) produce_pos( 4 ) = ALLOCATE(buffer=T8_l[ iblockIdx.x204{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( ( (nvfuser_index_t)(2) ) * ( ( 2 * 1 ) * ( 1 * 1 ) ) ) ), 128) ), 1) ), 1) )}, iUS205{1}, iS203{1}, ithreadIdx.x201{128} ] ca_pos( 4 ) produce_pos( 4 ), mem_type=register, size=1, zero_init=false)
...

Error is happening when handling the Set.Permute( 0 ) expr.

@jacobhinkle
Copy link
Collaborator

jacobhinkle commented Sep 13, 2023

T5_l[ iblockIdx.x224{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( ( (nvfuser_index_t)(2) ) * ( 2 * ( 1 * 1 ) ) ) ), 128) ), 1) ), 1) )}, iUS225{1}, iS223{1}, ithreadIdx.x221{128} ] ca_pos( 4 )
 root domain : (bS21{1}, iS22{( (nvfuser_index_t)(2) )}rf, iS23{( (nvfuser_index_t)(5) )}rf, bS25{1}, bS26{1})
  Resize: iS23{( (nvfuser_index_t)(5) )}rf by 0 and ( 2 - ( (nvfuser_index_t)(5) ) ) -> iS24{2}rf
 rfactor domain : (bS21{1}, iS22{( (nvfuser_index_t)(2) )}rf, iS24{2}rf, bS25{1}, bS26{1})
 contiguity: n t t n n
  Merge: bS25{1} and bS26{1} -> bS216{( 1 * 1 )}
  Merge: iS24{2}rf and bS216{( 1 * 1 )} -> iS217{( 2 * ( 1 * 1 ) )}
  Merge: iS22{( (nvfuser_index_t)(2) )}rf and iS217{( 2 * ( 1 * 1 ) )} -> iS218{( ( (nvfuser_index_t)(2) ) * ( 2 * ( 1 * 1 ) ) )}
  Merge: bS21{1} and iS218{( ( (nvfuser_index_t)(2) ) * ( 2 * ( 1 * 1 ) ) )} -> iS219{( 1 * ( ( (nvfuser_index_t)(2) ) * ( 2 * ( 1 * 1 ) ) ) )}
  Split: iS219{( 1 * ( ( (nvfuser_index_t)(2) ) * ( 2 * ( 1 * 1 ) ) ) )} by factor 128 -> iS220{( ceilDiv(( 1 * ( ( (nvfuser_inde
x_t)(2) ) * ( 2 * ( 1 * 1 ) ) ) ), 128) )}, ithreadIdx.x221{128}, start offset: 0, stop offset: 0
  Split: iS220{( ceilDiv(( 1 * ( ( (nvfuser_index_t)(2) ) * ( 2 * ( 1 * 1 ) ) ) ), 128) )} by factor 1 -> iS222{( ceilDiv(( ceilD
iv(( 1 * ( ( (nvfuser_index_t)(2) ) * ( 2 * ( 1 * 1 ) ) ) ), 128) ), 1) )}, iS223{1}, start offset: 0, stop offset: 0
  Split: iS222{( ceilDiv(( ceilDiv(( 1 * ( ( (nvfuser_index_t)(2) ) * ( 2 * ( 1 * 1 ) ) ) ), 128) ), 1) )} by factor 1 -> iblockIdx.x224{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( ( (nvfuser_index_t)(2) ) * ( 2 * ( 1 * 1 ) ) ) ), 128) ), 1) ), 1) )}, iUS225{1}, start offset: 0, stop offset: 0
 leaf domain : (iblockIdx.x224{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( ( (nvfuser_index_t)(2) ) * ( 2 * ( 1 * 1 ) ) ) ), 128) ), 1) ), 1) )}, iUS225{1}, iS223{1}, ithreadIdx.x221{128})

@IvanYashchuk
Copy link
Collaborator

IvanYashchuk commented Nov 7, 2024

I'm sorry for the long reproducer. I'm seeing a "Couldn't find allocation mapping for" error. Is this the same problem? I encountered this error while trying to run HF's Qwen 2 model with Thunder (Lightning-AI/lightning-thunder#1406). This model is important to support soon. @jacobhinkle or @naoyam do you have an estimate of how much time it could take to fix this bug? Is there any change we could introduce to the fusion definition as a workaround?

RuntimeError:  INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/runtime/fusion_kernel_runtime.cpp":358, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. Detected exception while compiling fusion segments in parallel. Error messages from all threads are printed below.

Error from segmentation group 11:  INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/index_compute.cpp":1990, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. Couldn't find allocation mapping for T125_l___bfloat[ iblockIdx.x846{( ceilDiv(2, blockDim.x) )}, ithreadIdx.x847{blockDim.x}, iS855{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(32768, blockDim.y) ), 16) ), 1) ), gridDim.y) )}, iblockIdx.y854{gridDim.y}, ithreadIdx.y849{blockDim.y}, iUS853{1}, iUR851{16}, bS505{1} ] ca_pos( 6 ) dim: 2 id: iS507{2}
Exception raised from getNonGlobalConsumerStridedIndices at /opt/pytorch/nvfuser/csrc/index_compute.cpp:1990 (most recent call first):
# CUDA devices:
#  0: NVIDIA RTX 6000 Ada Generation
#  1: NVIDIA RTX 6000 Ada Generation
# torch version: 2.6.0a0+git8b08559
# cuda version: 12.3
# nvfuser version: 0.2.17+gitdf32dce
import torch
from nvfuser import FusionDefinition, DataType

def nvfuser_fusion_id13(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[1, 28, 32768, 2], contiguity=[None, True, False, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[3, 2, 1, 0])
    T1 = fd.define_tensor(shape=[1, 32768, 2], contiguity=[None, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0])
    T2 = fd.define_tensor(shape=[1, 4, 32768, 2], contiguity=[None, True, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[3, 2, 1, 0])
    T3 = fd.define_tensor(shape=[1, 28, 32768, 2], contiguity=[None, True, False, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[3, 2, 1, 0])
    T4 = fd.define_tensor(shape=[1, 28, 32768, 2], contiguity=[None, True, False, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[3, 2, 1, 0])
    T5 = fd.define_tensor(shape=[1, 32768, 2], contiguity=[None, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0])
    T6 = fd.define_tensor(shape=[1, 4, 32768, 2], contiguity=[None, True, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[3, 2, 1, 0])
    T13 = fd.ops.reshape(T0, new_shape=[1, 4, 7, 32768, 2])
    T14 = fd.ops.cast(T13, dtype=DataType.Float)
    T15 = fd.ops.sum(T14, dims=[0, 2], keepdim=False, dtype=DataType.Null)
    T16 = fd.ops.cast(T15, dtype=DataType.BFloat16)
    T23 = fd.ops.broadcast_in_dim(T16, shape=[1, 4, 1, 32768, 2], broadcast_dims=[1, 3, 4])
    T24 = fd.ops.cast(T23, dtype=DataType.Float)
    T25 = fd.ops.sum(T24, dims=[0, 2], keepdim=False, dtype=DataType.Null)
    T26 = fd.ops.cast(T25, dtype=DataType.BFloat16)
    T32 = fd.ops.broadcast_in_dim(T1, shape=[1, 1, 32768, 2], broadcast_dims=[0, 2, 3])
    T38 = fd.ops.broadcast_in_dim(T26, shape=[1, 4, 32768, 2], broadcast_dims=[1, 2, 3])
    T44 = fd.ops.broadcast_in_dim(T32, shape=[1, 28, 32768, 2], broadcast_dims=[0, 1, 2, 3])
    T45 = fd.ops.cast(T38, dtype=DataType.Float)
    T46 = fd.ops.cast(T2, dtype=DataType.Float)
    T52 = fd.ops.broadcast_in_dim(T32, shape=[1, 4, 32768, 2], broadcast_dims=[0, 1, 2, 3])
    T53 = fd.ops.cast(T3, dtype=DataType.Float)
    T54 = fd.ops.cast(T44, dtype=DataType.Float)
    T55 = fd.ops.add(T46, T45)
    T56 = fd.ops.cast(T52, dtype=DataType.Float)
    T63 = fd.ops.reshape(T4, new_shape=[1, 4, 7, 32768, 2])
    T64 = fd.ops.mul(T54, T53)
    T65 = fd.ops.mul(T56, T55)
    T66 = fd.ops.cast(T63, dtype=DataType.Float)
    T67 = fd.ops.cast(T64, dtype=DataType.BFloat16)
    T68 = fd.ops.cast(T65, dtype=DataType.BFloat16)
    T69 = fd.ops.sum(T66, dims=[0, 2], keepdim=False, dtype=DataType.Null)
    T85 = fd.ops.slice(T67, start_indices=[0, 0, 0, 0], end_indices=[1, 28, 32768, 1], strides=[1, 1, 1, 1])
    T101 = fd.ops.slice(T68, start_indices=[0, 0, 0, 0], end_indices=[1, 4, 32768, 1], strides=[1, 1, 1, 1])
    T102 = fd.ops.cast(T69, dtype=DataType.BFloat16)
    T103 = fd.ops.cast(T85, dtype=DataType.Float)
    T104 = fd.ops.cast(T101, dtype=DataType.Float)
    T111 = fd.ops.broadcast_in_dim(T102, shape=[1, 4, 1, 32768, 2], broadcast_dims=[1, 3, 4])
    T112 = fd.ops.neg(T103)
    T113 = fd.ops.neg(T104)
    T114 = fd.ops.cast(T111, dtype=DataType.Float)
    T120 = fd.ops.broadcast_in_dim(T5, shape=[1, 1, 32768, 2], broadcast_dims=[0, 2, 3])
    T136 = fd.ops.slice(T67, start_indices=[0, 0, 0, 1], end_indices=[1, 28, 32768, 2], strides=[1, 1, 1, 1])
    T137 = fd.ops.cast(T112, dtype=DataType.BFloat16)
    T153 = fd.ops.slice(T68, start_indices=[0, 0, 0, 1], end_indices=[1, 4, 32768, 2], strides=[1, 1, 1, 1])
    T154 = fd.ops.cast(T113, dtype=DataType.BFloat16)
    T155 = fd.ops.sum(T114, dims=[0, 2], keepdim=False, dtype=DataType.Null)
    T161 = fd.ops.broadcast_in_dim(T120, shape=[1, 28, 32768, 2], broadcast_dims=[0, 1, 2, 3])
    S162 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T163 = fd.ops.pad(T136, [0, 1, 0, 0, 0, 0, 0, 0], S162)
    S164 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T165 = fd.ops.pad(T137, [1, 0, 0, 0, 0, 0, 0, 0], S164)
    T171 = fd.ops.broadcast_in_dim(T120, shape=[1, 4, 32768, 2], broadcast_dims=[0, 1, 2, 3])
    S172 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T173 = fd.ops.pad(T153, [0, 1, 0, 0, 0, 0, 0, 0], S172)
    S174 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T175 = fd.ops.pad(T154, [1, 0, 0, 0, 0, 0, 0, 0], S174)
    T176 = fd.ops.cast(T155, dtype=DataType.BFloat16)
    T177 = fd.ops.cast(T161, dtype=DataType.Float)
    T178 = fd.ops.cast(T163, dtype=DataType.Float)
    T179 = fd.ops.cast(T165, dtype=DataType.Float)
    T180 = fd.ops.cast(T171, dtype=DataType.Float)
    T181 = fd.ops.cast(T173, dtype=DataType.Float)
    T182 = fd.ops.cast(T175, dtype=DataType.Float)
    T188 = fd.ops.broadcast_in_dim(T176, shape=[1, 4, 32768, 2], broadcast_dims=[1, 2, 3])
    T189 = fd.ops.mul(T177, T53)
    T190 = fd.ops.add(T179, T178)
    T191 = fd.ops.mul(T180, T55)
    T192 = fd.ops.add(T182, T181)
    T193 = fd.ops.cast(T188, dtype=DataType.Float)
    T194 = fd.ops.cast(T6, dtype=DataType.Float)
    T195 = fd.ops.add(T190, T189)
    T196 = fd.ops.add(T192, T191)
    T197 = fd.ops.add(T194, T193)
    T198 = fd.ops.cast(T195, dtype=DataType.BFloat16)
    T199 = fd.ops.cast(T196, dtype=DataType.BFloat16)
    T200 = fd.ops.cast(T197, dtype=DataType.BFloat16)
    T201 = fd.ops.permute(T198, dims=[0, 2, 1, 3])
    T202 = fd.ops.permute(T199, dims=[0, 2, 1, 3])
    T203 = fd.ops.permute(T200, dims=[0, 2, 1, 3])
    T208 = fd.ops.reshape(T201, new_shape=[1, 32768, 56])
    T213 = fd.ops.reshape(T202, new_shape=[1, 32768, 8])
    T218 = fd.ops.reshape(T203, new_shape=[1, 32768, 8])
    T219 = fd.ops.cast(T208, dtype=DataType.Float)
    T220 = fd.ops.cast(T213, dtype=DataType.Float)
    T221 = fd.ops.cast(T218, dtype=DataType.Float)
    T222 = fd.ops.sum(T219, dims=[0, 1], keepdim=False, dtype=DataType.Null)
    T226 = fd.ops.reshape(T208, new_shape=[32768, 56])
    T227 = fd.ops.sum(T220, dims=[0, 1], keepdim=False, dtype=DataType.Null)
    T231 = fd.ops.reshape(T213, new_shape=[32768, 8])
    T232 = fd.ops.sum(T221, dims=[0, 1], keepdim=False, dtype=DataType.Null)
    T236 = fd.ops.reshape(T218, new_shape=[32768, 8])
    T237 = fd.ops.cast(T222, dtype=DataType.BFloat16)
    T238 = fd.ops.permute(T226, dims=[1, 0])
    T239 = fd.ops.cast(T227, dtype=DataType.BFloat16)
    T240 = fd.ops.permute(T231, dims=[1, 0])
    T241 = fd.ops.cast(T232, dtype=DataType.BFloat16)
    T242 = fd.ops.permute(T236, dims=[1, 0])
    fd.add_output(T236)
    fd.add_output(T242)
    fd.add_output(T241)
    fd.add_output(T231)
    fd.add_output(T240)
    fd.add_output(T239)
    fd.add_output(T226)
    fd.add_output(T238)
    fd.add_output(T237)

with FusionDefinition() as fd:
    nvfuser_fusion_id13(fd)

inputs = [
    torch.randn(7340026, dtype=torch.bfloat16, device='cuda:0').as_strided((1, 28, 32768, 2), (7340032, 262144, 8, 1)),
    torch.testing.make_tensor((1, 32768, 2), dtype=torch.bfloat16, device='cuda:0'),
    torch.testing.make_tensor((1, 4, 32768, 2), dtype=torch.bfloat16, device='cuda:0'),
    torch.randn(7340026, dtype=torch.bfloat16, device='cuda:0').as_strided((1, 28, 32768, 2), (7340032, 262144, 8, 1)),
    torch.randn(7340026, dtype=torch.bfloat16, device='cuda:0').as_strided((1, 28, 32768, 2), (7340032, 262144, 8, 1)),
    torch.testing.make_tensor((1, 32768, 2), dtype=torch.bfloat16, device='cuda:0'),
    torch.testing.make_tensor((1, 4, 32768, 2), dtype=torch.bfloat16, device='cuda:0'),
]
fd.execute(inputs)

@jacobhinkle
Copy link
Collaborator

jacobhinkle commented Nov 7, 2024

Slightly smaller repro containing only the failing segment, simplified as much as possible, which is scheduled by the Reduction scheduler

# CUDA devices:
#  0: NVIDIA H100 80GB HBM3
# torch version: 2.6.0a0+gitffb7a08
# cuda version: 12.6
# nvfuser version: 0.2.22+git6912435
import torch
from nvfuser import FusionDefinition, DataType

def nvfuser_fusion_id0(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[28, 32768, 2], contiguity=[True, False, True], dtype=DataType.Float, is_cpu=False)
    T1 = fd.define_tensor(shape=[32768, 2], contiguity=[True, True], dtype=DataType.Float, is_cpu=False)
    T2 = fd.define_tensor(shape=[28, 32768, 1], contiguity=[True, False, None], dtype=DataType.Float, is_cpu=False, stride_order=[2, 1, 0])
    T3 = fd.define_tensor(shape=[28, 32768, 1], contiguity=[True, False, None], dtype=DataType.Float, is_cpu=False, stride_order=[2, 1, 0])
    T7 = fd.ops.pad(T2, [1, 0], None)
    T11 = fd.ops.pad(T3, [0, 1], None)
    T12 = fd.ops.add(T7, T11)
    T13 = fd.ops.broadcast(T1, is_broadcast_dim=[True, False, False])
    T14 = fd.ops.mul(T13, T0)
    T15 = fd.ops.add(T12, T14)
    T16 = fd.ops.permute(T15, dims=[1, 0, 2])
    T20 = fd.ops.reshape(T16, new_shape=[32768, 56])
    T21 = fd.ops.sum(T20, dims=[0], keepdim=False, dtype=DataType.Float)
    fd.add_output(T21)
    fd.add_output(T20)
    fd.add_output(T13)

with FusionDefinition() as fd:
    nvfuser_fusion_id0(fd)

inputs = [
    torch.testing.make_tensor((28, 32768, 2), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((32768, 2), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((28, 32768, 2), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((28, 32768, 2), dtype=torch.float32, device='cuda:0'),
]
fd.execute(inputs)

@naoyam
Copy link
Collaborator

naoyam commented Nov 7, 2024

Thanks for reporting @IvanYashchuk and thanks for the small repro @jacobhinkle. Let me look into it.

@naoyam
Copy link
Collaborator

naoyam commented Nov 7, 2024

@jjsjann123 Could you update the original repro? It doesn't seem to work with the latest nvfuser.

Traceback (most recent call last):
  File "/home/nmaruyama/nvfuser/debug3/../repro871_original.py", line 22, in <module>
    nvfuser_fusion_id0(fd)
  File "/home/nmaruyama/nvfuser/debug3/../repro871_original.py", line 14, in nvfuser_fusion_id0
    T48 = fd.ops.reshape(T31, original_shape=[1, 2, 2, 2, 1, 1], new_shape=[1, 2, 4, 1, 1])
TypeError: reshape(): incompatible function arguments. The following argument types are supported:
    1. (self: nvfuser._C._FusionDefinition.Operators, arg: nvfuser._C.Tensor, new_shape: nvfuser._C.Vector) -> nvfuser._C.Tensor
    2. (self: nvfuser._C._FusionDefinition.Operators, arg: nvfuser._C.Tensor, new_shape: list) -> nvfuser._C.Tensor
    3. (self: nvfuser._C._FusionDefinition.Operators, arg: nvfuser._C.Tensor, new_shape: tuple) -> nvfuser._C.Tensor

Invoked with: <nvfuser._C._FusionDefinition.Operators object at 0x7f00f8d25330>, Tensor(index=67, ndim=6); kwargs: original_shape=[1, 2, 2, 2, 1, 1], new_shape=[1, 2, 4, 1, 1]

@jjsjann123
Copy link
Collaborator Author

Updated.
FYI, it's just an API change. We no longer need original_shape in reshape.

@naoyam
Copy link
Collaborator

naoyam commented Nov 7, 2024

I suspect the original repro has the same root cause as #3299. Ivan's new repro is probably due to a different issue, for which I'm trying a WAR (#3373 )

@naoyam
Copy link
Collaborator

naoyam commented Nov 8, 2024

Moved Ivan's issue to #3374.

naoyam added a commit that referenced this issue Nov 14, 2024
Fixes #3299 and #871 

The legacy indexer fails when an expanded iter domain is involved in
reshape transformations.
@naoyam
Copy link
Collaborator

naoyam commented Nov 14, 2024

Should work fine with #3387. Let me know if not.

@naoyam naoyam closed this as completed Nov 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants