-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improvements for: Groupwise scaling along M for FP8 gemm #2095
base: main
Are you sure you want to change the base?
Improvements for: Groupwise scaling along M for FP8 gemm #2095
Conversation
@LucasWilkinson , we upstreamed our change to groupwise scaling kernels. there are some conflicts in this PR that needs to be solved. Our change is mainly:
|
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
db87722
to
7f541db
Compare
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
apologies for the delay the PR has been updated, currently I am still vectorizing the loads of B scales along N (like |
@@ -280,8 +278,11 @@ struct CollectiveMma< | |||
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_A>(cute::make_shape(M,K,L), StrideA{}); | |||
constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits<ElementB>::value; | |||
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_B>(cute::make_shape(N,K,L), StrideB{}); | |||
/* MMA promotion interval should be a multiple of 4, since each mainloop iteration would issue 4 MMA instructions. */ | |||
implementable = implementable && (args.mma_promotion_interval % 4 == 0); | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any promblems when transpose A and transpose B?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
currently this assumes full tiles in N and K so if using this for inference where activations may have partial tiles if you transpose it to Y^T = WX^T
it may report not implementable, I think im going to update this since ideally in vLLM we'd like to transpose it to use smaller tensor core instructions, we do lose vectorization on the loads then though
include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp
Outdated
Show resolved
Hide resolved
...p8_warp_specialized_gemm_with_blockwise_scaling/reference/host/gemm_with_groupwise_scaling.h
Outdated
Show resolved
Hide resolved
@@ -575,7 +607,7 @@ struct CollectiveMma< | |||
|
|||
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; | |||
|
|||
GmmaFP8Accumulation accumulation(accum, mainloop_params.mma_promotion_interval, size<2>(tCrA)); | |||
GmmaFP8Accumulation accumulation(accum, size<2>(TileShape{}) / size<2>(typename TiledMma::AtomShape_MNK{}), size<2>(tCrA)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe still using ScalePromotionInterval here, and move size<2>(TileShape{}) / size<2>(typename TiledMma::AtomShape_MNK{}
to can_implement
check?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm im not sure I see ScalePromotionInterval
, what would be the motivation to not have this determined at compile time? it seems a bit unnecessarily burdensome on the user to have them set mma_promotion_interval manually
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In anycase moving this as constexpr somewhere on the top will better for readability.
static constexpr int ScalePromotionInterval = size<2>(TileShape{}) / size<2>(typename TiledMma::AtomShape_MNK{})
and using that here?
@hwu36, So this will be 4 for TileShapeK = 128 and InstructionShape = 32, which is the original case, for TileShape = 64 this will be 2. Will that be not supported?
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
edd90be
to
2a9256f
Compare
@@ -147,9 +147,10 @@ struct CollectiveMma< | |||
cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); | |||
|
|||
// Block scaling gmem-to-smem copy atom | |||
using BlockScaleCopyTypeA = cute::uint_byte_t<cute::min(static_cast<int>(sizeof(ElementBlockScale)) * ScaleMsPerTile, 16)>; | |||
// we can have partial tiles in M, so don't vectorize those loads |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this restriction only for M and not for N? dim-M usually maps to batch count while dim-N will be model_dimension, a nice multiple of 2? correct?
If this is A_row * B_col
groupwise GEMM, it is sometimes required that we do transposed and swap creating an underlying GEMM to be B_row * A_row
, swapping M <-> N. This is typically helpful for (a.) mixed-input BF16*F8 which doesn't apply here (b.) M is small say 64, we can swap and transpose to run a bette. I have seen that to give more performance for small M.
Does vectorizing scale_copy_b vs not-vectorizing give any performance improvements? If not, I would suggest that we be symmetric for this kernel in M and N to allow user to apply swap and transpose trick to this kernel.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was mostly just trying to keep it as close to the original as possible to minimize the chances of perf regressions, but I agree this is much less confusing. And I think we will want to transpose in vLLM in order to use smaller instructions for smaller batch sizes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pushed an update that enables partial tiles in N
copy_if(scale_copy_a, tApA_ScaleA, tAgA_ScaleA(_,_,*k_tile_iter), tAsA_ScaleA(_,_,write_stage)); | ||
copy_if(scale_copy_b, tBpB_ScaleB, tBgB_ScaleB(_,_,*k_tile_iter), tBsB_ScaleB(_,_,write_stage)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you make sure that this copy_if is issued by only 32 threads? The thread layout of shape 32 (created above) won't be tiled over entire tile by make_tiled_copy, just confirm please using simple printf
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ran
if ((!blockIdx.x && !blockIdx.y && !blockIdx.z)) printf("%d ", threadIdx.x);
if (thread0()) printf("\n");
pipeline.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive_noinc);
and got:
...
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
...
I think we should be good 👍
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) | ||
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should TMA related tensor constructions be in lane_predicate as before, no need for all the threads to construct this even in this implementation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Im not sure, I didn't think this was a big deal since if you look at the 3.6.0 diff with improved the mixed input GEMM (we were told 3.6 had perf improvements for mixed input) in include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp
you can see that it was updated to have all the threads compute the TMA tensors, not sure what the recommended approach is, or if this particular change had any impact. Would some guidance!
@@ -575,7 +607,7 @@ struct CollectiveMma< | |||
|
|||
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; | |||
|
|||
GmmaFP8Accumulation accumulation(accum, mainloop_params.mma_promotion_interval, size<2>(tCrA)); | |||
GmmaFP8Accumulation accumulation(accum, size<2>(TileShape{}) / size<2>(typename TiledMma::AtomShape_MNK{}), size<2>(tCrA)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In anycase moving this as constexpr somewhere on the top will better for readability.
static constexpr int ScalePromotionInterval = size<2>(TileShape{}) / size<2>(typename TiledMma::AtomShape_MNK{})
and using that here?
@hwu36, So this will be 4 for TileShapeK = 128 and InstructionShape = 32, which is the original case, for TileShape = 64 this will be 2. Will that be not supported?
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Various improvements to "Groupwise scaling along M" (#2037) namely to address: #2087, context vllm-project/vllm#11868 (comment)
Improvements:
this PR moves to a layout of (i.e. standard M-major):
making it much easier to integrate into inference libraries
These improvements were part of vLLMs adoption of this kernel https://github.com/vllm-project/vllm/blob/v0.7.1/csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp (PR: vllm-project/vllm#11868) and is in current wide scale use. Our goal is to rely on the CUTLASS implementation but that currently not possible given the issues above.