Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inlining error in Hopper matmul with AxisMapping and grid swizzling #3671

Open
jacobhinkle opened this issue Jan 6, 2025 · 6 comments
Open
Assignees
Labels

Comments

@jacobhinkle
Copy link
Collaborator

The inlining logic for MmaOp with AxisMapping checks that unmapped dimensions are Broadcast. We expect to have something like this

t0 {
  logical: [ iS0{M} iS1{K} ]
  loop: [ iS2{ ceilDiv(M, 256) }  bS5{1}  iS3{256} bS6{256} ... ]
    Split iS0 by 256 -> iS2, iS3
    Split bS4 by 256 -> bS5, bS6
      ...
  additional ids: bS4{1}
}
t1 {
  logical: [ iS10{N} iS11{K} ]
  loop: [ bS13{1} iS15{ ceilDiv(N, 256) } bS14{256} iS16{256} ... ]
    Split bS12 by 256 -> bS13, bS14
    Split iS10 by 256 -> iS15, iS16
      ...
  additional ids: bS12{1}
}

In this case, we are able to inline the mma operation that consumes these two tensors, but we check that the unmapped IDs 5, 6, 13, and 14 are Broadcast and that the operation is an MmaOp.

In the case of grid swizzling by a factor 4, we will do some further scheduling here. For example we will have

t0 {
  logical: [ iS0{M} iS1{K} ]
  loop: [ iS21{ ceilDiv( ceilDiv(M, 256), 4) } iS22{4} iS3{256} bS6{256} ... ]
    Split iS0 by 256 -> iS2, iS3
    Split bS4 by 256 -> bS5, bS6
    Split iS2 by 4 -> iS20, iS21{4}
    Merge bS5 with iS21{4} -> iS22{4}
      ...
  additional ids: bS4{1}
}

Now we have mixed the first two outer dimensions with this swizzle and what used to be a simple split of a loop broadcast (bS5) is now an iteration ID iS22{4} resulting from the merge.

I am not sure yet how to address this. I don't think we can just inline here without some other changes since when I disable this check I get errors in expression sorting.

@jacobhinkle jacobhinkle self-assigned this Jan 6, 2025
jacobhinkle added a commit that referenced this issue Jan 8, 2025
This updates the default (non-plugin) matmul heuristic to support Hopper
matmuls. This change means that we can not run matmuls on Hopper
similarly to how we do it on Ampere and Turing, including using the
Python interface.

I tried to make the default heuristic somewhat thoughtful and not just a
placeholder. Here are some notes about the Hopper heuristic in its
current form:
- I set the macro to Hopper_64_64_16. I intended to always use the
largest macro for which the N size divided the problem's N, but this led
to lower perf on the handful of examples I looked at. We should
benchmark more and find out why this is once we have warp specialization
and register stealing fully plumbed in, but for the time being I simply
left it at N=64.
- Once the instruction tile is set we set the warp tile equal to the
instruction tile (we can revisit this in the future). Then to find the
CTA tile we double the instruction tile in the M or N dimension until we
run out of registers.
- We start with 8 circular buffering stages and decrease until the
circular buffers fit into smem.
- We use `use_smem_epilogue` when possible. Whenever that is possible we
_always_ use `promote_prologue_smem_reuse` even if it's not needed. This
is to try and avoid bugs like #3602.
- I set the tile rasterization order so that the fast axis is the axis
with the fewest tiles, which should encourage more L2 hits unless there
are tons of tiles in each dimension.
- I cannot yet set grid swizzling due to #3671, but I placed a TODO
comment and some code to do the proper swizzling.

---------

Co-authored-by: Ryan Spring <rspring@nvidia.com>
@jacobhinkle
Copy link
Collaborator Author

This error also appears whenever we try and do persistent kernel scheduling when there is a translated MatmulOp or LinearOp node. In that case, commenting out the isBroadcast check leads to an error in SyncMap

INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/device_lower/analysis/sync_information.cpp":798, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. Inconsistent parallelization found between TV3 (T3_s___bfloat[iS27{( ceilDiv(( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[0] ), 128) ), 132) )}, iblockIdx.x28{132}, iS20{( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[1] ), 16) )}, bS23{256}, iS29{1}, iB31{16}, iB37{2}, iB34{4}, iB38{2}, iB36{8}] ca_pos( 3 )) and TV6(T6_l_float[iS65{( ceilDiv(( ( ceilDiv(( (( (( getMetaData(T1) )).logical_size ))[0] ), 256) ) * ( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[0] ), 128) ) ), 132) )}, iblockIdx.x66{132}, rS58{( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[1] ), 16) )}, ithreadIdx.y77{2}, iS69{1}, iS73{1}, rS75{1}, iMMA70{64}, iMMA74{256}, rMMA76{16}] ca_pos( 2 ) produce_pos( 3 )). Producer is required to be in Global Memory based on parallelization strategy. RAW flags: (blockIdx.x threadIdx.y)

I think the problem is in the AxisMapping approach itself; it feels brittle to me. We have already had to carve out special cases for MmaOp in inlining and predicate elimination, and now we see that that is not enough since the missing axes are sometimes merged with Iteration domains. The swizzle example above shows this issue. Ideally this case would behave just like the fusedMultiplySum case in which we have Broadcast axes in the fusion input. In that case, when we merge e.g. a Broadcast N axis with an Iteration M axis, we get an Iteration ID that permissive maps to the Iteration N/Broadcast M axis in the other operand and to the corresponding ID in the result.

I am considering now whether #3372 was a mistake and if we should revisit something like #3366 instead. What do you think @naoyam ?

@naoyam
Copy link
Collaborator

naoyam commented Feb 3, 2025

So, the issue is we can't inline iS22. What are we trying to inline it to?

@jacobhinkle
Copy link
Collaborator Author

Here is a diagram of grid swizzling for matmul:

Image

Here we need to inline iS22, iS13, and iS38, i.e. the outermost dimension in a scheduled MmaOp. The broadcast dimensions bS4 and bS2 (red) are additional_ids_ in this case. If they were logical IDs, then the permissive map would map all these IDs to one another.

Inlining is one thing: we can pretty much repeat the logic of permissive map to create an inlining check that allows this. However, when I bypass that check currently, I hit an error in SyncMap (inconsistent parallelization). That error is avoided when the IDs are Permissive mapped. Creating another carveout for MmaOp there should be possible just like for inlining and predicate elimination.

Another option (besides #3366) would be to try and make the additional IDs broadcast mapped with the corresponding consumer dimension, which would make these IDs all permissive mapped. For example we could add that as an option to TensorView::broadcast like int64_t map_to_consumer_root_pos. Inside pairwise logical domain map, we would check for that flag, check that the corresponding position in the consumer is not mapped to any producer ID already, and broadcast map it to the new additional_id.

@naoyam
Copy link
Collaborator

naoyam commented Feb 4, 2025

Thanks. I think I understand what's happening and what the problem is.

I wonder if we could just get rid of the offending BroadcastOp during lowering. #3366 may work. I'd also try scheduleLoopDomainsLike, but it seems the easiest and quickest way is to keep the BroadcastOps there in the Fusion IR and just remove them once they are no longer necessary. Would that work?

@jacobhinkle
Copy link
Collaborator Author

I wonder if we could just get rid of the offending BroadcastOp during lowering. #3366 may work. I'd also try scheduleLoopDomainsLike, but it seems the easiest and quickest way is to keep the BroadcastOps there in the Fusion IR and just remove them once they are no longer necessary. Would that work?

If we left the BroadcastOps in and scheduled it similar to how we do on Ampere, then we'd have something like this:

tv0_g[ i0, i2 ]
tv1_g[ i1, i2 ]
tv2_s[ i0, i2 ] = set(tv0_g)
tv3_s[ i1, i2 ] = set(tv1_g)
tv4_l[ i0, 1, i2 ] = broadcast(tv2_s)
tv5_l[ 1, i1, i2 ] = broadcast(tv3_s)
tv6_l[ i0, i1, r{i2} ] = mma(tv4_l, tv5_l)

If I understand correctly, you're saying we could potentially remove those broadcasts at lowering (say in the indexing pass) and replace tv4_l and tv5_l with tv2_s and tv3_s in the mma instruction there. Yes that might work! I will give it a shot.

@naoyam
Copy link
Collaborator

naoyam commented Feb 4, 2025

Some of the localized KIR changes are also done after indexing, e.g., https://github.com/NVIDIA/Fuser/blob/main/csrc/device_lower/lower2device.cpp#L283

Removing the ops in KIR may feel ad-hoc, but I'd suspect #3366 would be a significant change.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants