Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] Update cutlass_scaled_mm to support 2d group (blockwise) scaling #11868

Merged
merged 25 commits into from
Jan 31, 2025

Conversation

LucasWilkinson
Copy link
Collaborator

@LucasWilkinson LucasWilkinson commented Jan 8, 2025

Currently only supports scale_a block shapes of 1x128 and scale_b block shapes of 128x128 (for deepseek v3)

Shout-out to @manishucsd and @soundOfDestiny for the kernel, kernel adapted from: https://github.com/soundOfDestiny/cutlass/tree/f8_blockwise_scaling_pr_branch

This PR also splits up scaled_mm_c3x.cu to help parallelize the building of the kernels

TODO:

  • Clean-up
  • Benchmarking
  • (future PR) See if we can support K-major a_scales, or update per_token_group_quant_fp8 to output M-major
  • (future PR)Transpose A and B to allow for smaller tensor core instructions

Benchmarking (Scroll horizontally to see cutlass_fp8_fp8_fp16_scaled_mm_blockwise):

== Results torch.float8_e4m3fn mistralai/Mistral-7B-v0.1-TP1 ====
[--------------------------------------------------------- scaled-torch.float8_e4m3fn-gemm ----------------------------------------------------------]
                            |  cutlass_fp8_fp8_fp16_scaled_mm  |  triton_fp8_fp8_fp16_scaled_mm_blockwise  |  cutlass_fp8_fp8_fp16_scaled_mm_blockwise
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------
      MKN=(1x4096x6144)     |               12.1               |                   327.3                   |                    23.9                  
      MKN=(1x4096x4096)     |                9.1               |                    69.2                   |                    23.6                  
      MKN=(1x4096x28672)    |               51.8               |                   131.6                   |                    68.8                  
      MKN=(1x14336x4096)    |               29.6               |                   229.4                   |                    69.6                  
      MKN=(16x4096x6144)    |               12.1               |                    69.1                   |                    23.9                  
      MKN=(16x4096x4096)    |                9.2               |                    69.0                   |                    23.6                  
      MKN=(16x4096x28672)   |               53.0               |                   131.5                   |                    62.0                  
      MKN=(16x14336x4096)   |               30.1               |                   231.5                   |                    71.8                  
      MKN=(32x4096x6144)    |               11.6               |                    85.6                   |                    24.0                  
      MKN=(32x4096x4096)    |                9.1               |                    85.2                   |                    23.6                  
      MKN=(32x4096x28672)   |               53.8               |                   135.7                   |                    62.1                  
      MKN=(32x14336x4096)   |               30.1               |                   291.6                   |                    71.7                  
      MKN=(64x4096x6144)    |               10.9               |                    37.8                   |                    24.0                  
      MKN=(64x4096x4096)    |                9.1               |                    39.1                   |                    23.7                  
      MKN=(64x4096x28672)   |               53.5               |                    67.4                   |                    64.0                  
      MKN=(64x14336x4096)   |               30.7               |                   130.1                   |                    71.9                  
      MKN=(128x4096x6144)   |               12.1               |                   133.5                   |                    23.7                  
      MKN=(128x4096x4096)   |               11.9               |                   124.3                   |                    23.7                  
      MKN=(128x4096x28672)  |               59.5               |                   301.8                   |                    62.0                  
      MKN=(128x14336x4096)  |               35.4               |                   512.9                   |                    72.5                  
      MKN=(256x4096x6144)   |               17.3               |                   138.2                   |                    24.9                  
      MKN=(256x4096x4096)   |               16.5               |                   126.2                   |                    24.0                  
      MKN=(256x4096x28672)  |               72.9               |                   576.5                   |                    92.2                  
      MKN=(256x14336x4096)  |               48.8               |                   500.6                   |                    73.2                  
      MKN=(512x4096x6144)   |               30.7               |                   267.5                   |                    48.0                  
      MKN=(512x4096x4096)   |               18.2               |                   137.8                   |                    26.1                  
      MKN=(512x4096x28672)  |              118.7               |                   983.2                   |                   163.2                  
      MKN=(512x14336x4096)  |               56.2               |                   500.0                   |                    81.1                  

Times are in microseconds (us).

== Results torch.float8_e4m3fn meta-llama/Llama-2-7b-hf-TP1 ====
[--------------------------------------------------------- scaled-torch.float8_e4m3fn-gemm ----------------------------------------------------------]
                            |  cutlass_fp8_fp8_fp16_scaled_mm  |  triton_fp8_fp8_fp16_scaled_mm_blockwise  |  cutlass_fp8_fp8_fp16_scaled_mm_blockwise
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------
      MKN=(1x4096x12288)    |               26.0               |                    68.7                   |                    36.8                  
      MKN=(1x4096x4096)     |                9.2               |                    68.6                   |                    23.7                  
      MKN=(1x4096x22016)    |               41.8               |                   120.7                   |                    61.8                  
      MKN=(1x11008x4096)    |               21.9               |                   177.8                   |                    55.6                  
      MKN=(16x4096x12288)   |               26.5               |                    68.4                   |                    33.2                  
      MKN=(16x4096x4096)    |                9.1               |                    68.8                   |                    23.6                  
      MKN=(16x4096x22016)   |               43.2               |                   123.6                   |                    57.9                  
      MKN=(16x11008x4096)   |               22.2               |                   178.7                   |                    56.7                  
      MKN=(32x4096x12288)   |               26.6               |                    89.7                   |                    34.0                  
      MKN=(32x4096x4096)    |               10.4               |                    85.1                   |                    23.7                  
      MKN=(32x4096x22016)   |               43.9               |                   141.8                   |                    58.6                  
      MKN=(32x11008x4096)   |               24.7               |                   224.8                   |                    57.2                  
      MKN=(64x4096x12288)   |               27.3               |                    42.9                   |                    34.3                  
      MKN=(64x4096x4096)    |                9.1               |                    39.1                   |                    23.6                  
      MKN=(64x4096x22016)   |               44.9               |                    63.2                   |                    59.1                  
      MKN=(64x11008x4096)   |               25.0               |                   354.3                   |                    57.3                  
      MKN=(128x4096x12288)  |               29.6               |                   152.4                   |                    30.7                  
      MKN=(128x4096x4096)   |               12.0               |                   124.3                   |                    23.7                  
      MKN=(128x4096x22016)  |               45.9               |                   297.8                   |                    58.6                  
      MKN=(128x11008x4096)  |               28.6               |                   393.8                   |                    57.3                  
      MKN=(256x4096x12288)  |               37.2               |                   286.5                   |                    48.7                  
      MKN=(256x4096x4096)   |               16.5               |                   126.1                   |                    24.1                  
      MKN=(256x4096x22016)  |               56.2               |                   439.6                   |                    71.7                  
      MKN=(256x11008x4096)  |               39.2               |                   383.3                   |                    57.6                  
      MKN=(512x4096x12288)  |               52.2               |                   421.5                   |                    72.5                  
      MKN=(512x4096x4096)   |               18.2               |                   137.9                   |                    25.9                  
      MKN=(512x4096x22016)  |               94.1               |                   829.5                   |                   135.7                  
      MKN=(512x11008x4096)  |               44.5               |                   385.4                   |                    62.8                  

Times are in microseconds (us).

== Results torch.float8_e4m3fn meta-llama/Llama-3-8b-TP1 ====
[--------------------------------------------------------- scaled-torch.float8_e4m3fn-gemm ----------------------------------------------------------]
                            |  cutlass_fp8_fp8_fp16_scaled_mm  |  triton_fp8_fp8_fp16_scaled_mm_blockwise  |  cutlass_fp8_fp8_fp16_scaled_mm_blockwise
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------
      MKN=(1x4096x6144)     |               12.3               |                    68.7                   |                    24.1                  
      MKN=(1x4096x4096)     |                9.1               |                    68.6                   |                    23.8                  
      MKN=(1x4096x28672)    |               51.8               |                   130.1                   |                    69.1                  
      MKN=(1x14336x4096)    |               29.6               |                   229.6                   |                    69.5                  
      MKN=(16x4096x6144)    |               12.1               |                    69.1                   |                    24.0                  
      MKN=(16x4096x4096)    |                9.1               |                    68.9                   |                    23.6                  
      MKN=(16x4096x28672)   |               53.9               |                   131.8                   |                    62.1                  
      MKN=(16x14336x4096)   |               29.9               |                   231.6                   |                    71.8                  
      MKN=(32x4096x6144)    |               11.6               |                    85.7                   |                    24.0                  
      MKN=(32x4096x4096)    |                9.1               |                    85.0                   |                    23.6                  
      MKN=(32x4096x28672)   |               53.8               |                   135.8                   |                    62.1                  
      MKN=(32x14336x4096)   |               30.1               |                   291.7                   |                    71.9                  
      MKN=(64x4096x6144)    |               10.9               |                    38.1                   |                    24.1                  
      MKN=(64x4096x4096)    |                9.2               |                    39.2                   |                    23.7                  
      MKN=(64x4096x28672)   |               54.0               |                    67.3                   |                    64.1                  
      MKN=(64x14336x4096)   |               30.7               |                   130.2                   |                    72.0                  
      MKN=(128x4096x6144)   |               12.1               |                   133.5                   |                    23.8                  
      MKN=(128x4096x4096)   |               12.0               |                   124.3                   |                    23.7                  
      MKN=(128x4096x28672)  |               59.5               |                   301.6                   |                    61.9                  
      MKN=(128x14336x4096)  |               35.4               |                   513.6                   |                    72.6                  
      MKN=(256x4096x6144)   |               17.4               |                   138.2                   |                    25.1                  
      MKN=(256x4096x4096)   |               16.5               |                   126.3                   |                    24.1                  
      MKN=(256x4096x28672)  |               71.2               |                   576.7                   |                    94.1                  
      MKN=(256x14336x4096)  |               48.8               |                   500.2                   |                    73.3                  
      MKN=(512x4096x6144)   |               30.8               |                   267.7                   |                    47.8                  
      MKN=(512x4096x4096)   |               18.2               |                   137.9                   |                    25.9                  
      MKN=(512x4096x28672)  |              110.7               |                   982.6                   |                   160.9                  
      MKN=(512x14336x4096)  |               56.7               |                   500.2                   |                    82.3                  

Times are in microseconds (us).

== Results torch.float8_e4m3fn meta-llama/Llama-2-13b-hf-TP1 ====
[--------------------------------------------------------- scaled-torch.float8_e4m3fn-gemm ----------------------------------------------------------]
                            |  cutlass_fp8_fp8_fp16_scaled_mm  |  triton_fp8_fp8_fp16_scaled_mm_blockwise  |  cutlass_fp8_fp8_fp16_scaled_mm_blockwise
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------
      MKN=(1x5120x15360)    |               38.3               |                     83.9                  |                    46.8                  
      MKN=(1x5120x5120)     |               13.2               |                     85.9                  |                    28.5                  
      MKN=(1x5120x27648)    |               62.1               |                    161.3                  |                    83.8                  
      MKN=(1x13824x5120)    |               34.3               |                    219.8                  |                    68.7                  
      MKN=(16x5120x15360)   |               39.5               |                     83.5                  |                    43.9                  
      MKN=(16x5120x5120)    |               12.9               |                     85.9                  |                    28.4                  
      MKN=(16x5120x27648)   |               63.7               |                    163.2                  |                    72.2                  
      MKN=(16x13824x5120)   |               34.8               |                    223.6                  |                    70.2                  
      MKN=(32x5120x15360)   |               39.6               |                    111.5                  |                    44.3                  
      MKN=(32x5120x5120)    |               12.6               |                    105.9                  |                    28.6                  
      MKN=(32x5120x27648)   |               63.9               |                    167.9                  |                    73.5                  
      MKN=(32x13824x5120)   |               34.8               |                    281.7                  |                    70.0                  
      MKN=(64x5120x15360)   |               40.1               |                     55.8                  |                    45.7                  
      MKN=(64x5120x5120)    |               12.0               |                     46.3                  |                    28.6                  
      MKN=(64x5120x27648)   |               64.1               |                     82.4                  |                    75.8                  
      MKN=(64x13824x5120)   |               35.6               |                    124.0                  |                    70.1                  
      MKN=(128x5120x15360)  |               40.4               |                    190.1                  |                    42.6                  
      MKN=(128x5120x5120)   |               13.5               |                    169.7                  |                    28.6                  
      MKN=(128x5120x27648)  |               70.6               |                    373.5                  |                    72.0                  
      MKN=(128x13824x5120)  |               37.7               |                    492.4                  |                    70.4                  
      MKN=(256x5120x15360)  |               47.9               |                    370.6                  |                    61.0                  
      MKN=(256x5120x5120)   |               19.7               |                    163.0                  |                    30.0                  
      MKN=(256x5120x27648)  |               87.3               |                    723.4                  |                   113.5                  
      MKN=(256x13824x5120)  |               49.9               |                    482.6                  |                    73.5                  
      MKN=(512x5120x15360)  |               80.1               |                    708.8                  |                   115.3                  
      MKN=(512x5120x5120)   |               36.0               |                    329.4                  |                    59.2                  
      MKN=(512x5120x27648)  |              142.8               |                   1242.3                  |                   201.9                  
      MKN=(512x13824x5120)  |               90.0               |                    948.0                  |                   144.9                  

Times are in microseconds (us).

== Results torch.float8_e4m3fn meta-llama/Llama-2-70b-hf-TP1 ====
[--------------------------------------------------------- scaled-torch.float8_e4m3fn-gemm ----------------------------------------------------------]
                            |  cutlass_fp8_fp8_fp16_scaled_mm  |  triton_fp8_fp8_fp16_scaled_mm_blockwise  |  cutlass_fp8_fp8_fp16_scaled_mm_blockwise
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------
      MKN=(1x8192x10240)    |               40.3               |                    130.6                  |                    57.3                  
      MKN=(1x8192x8192)     |               35.3               |                    130.8                  |                    47.7                  
      MKN=(1x8192x57344)    |              189.1               |                    414.5                  |                   257.9                  
      MKN=(1x28672x8192)    |              107.3               |                    433.9                  |                   134.3                  
      MKN=(16x8192x10240)   |               40.5               |                    129.3                  |                    46.3                  
      MKN=(16x8192x8192)    |               36.6               |                    130.2                  |                    45.4                  
      MKN=(16x8192x57344)   |              190.5               |                    419.3                  |                   236.6                  
      MKN=(16x28672x8192)   |              107.9               |                    434.6                  |                   138.6                  
      MKN=(32x8192x10240)   |               40.7               |                    171.4                  |                    46.7                  
      MKN=(32x8192x8192)    |               36.5               |                    170.1                  |                    45.4                  
      MKN=(32x8192x57344)   |              192.1               |                    533.1                  |                   239.7                  
      MKN=(32x28672x8192)   |              109.9               |                    567.1                  |                   138.5                  
      MKN=(64x8192x10240)   |               41.7               |                     77.2                  |                    47.3                  
      MKN=(64x8192x8192)    |               37.1               |                     76.5                  |                    45.5                  
      MKN=(64x8192x57344)   |              194.0               |                    253.1                  |                   243.3                  
      MKN=(64x28672x8192)   |              112.5               |                    253.3                  |                   138.9                  
      MKN=(128x8192x10240)  |               49.1               |                    296.7                  |                    48.3                  
      MKN=(128x8192x8192)   |               35.7               |                    295.8                  |                    46.2                  
      MKN=(128x8192x57344)  |              204.4               |                   1197.5                  |                   221.5                  
      MKN=(128x28672x8192)  |              101.2               |                   1023.2                  |                   142.9                  
      MKN=(256x8192x10240)  |               62.1               |                    574.5                  |                    89.4                  
      MKN=(256x8192x8192)   |               40.4               |                    297.2                  |                    51.3                  
      MKN=(256x8192x57344)  |              249.8               |                   2046.2                  |                   314.3                  
      MKN=(256x28672x8192)  |              122.7               |                   1024.6                  |                   163.6                  
      MKN=(512x8192x10240)  |               91.5               |                    856.1                  |                   133.3                  
      MKN=(512x8192x8192)   |               63.9               |                    580.1                  |                    93.5                  
      MKN=(512x8192x57344)  |              449.4               |                   4034.0                  |                   632.6                  
      MKN=(512x28672x8192)  |              220.0               |                   2008.9                  |                   316.7                  

Copy link

github-actions bot commented Jan 8, 2025

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@mergify mergify bot added the ci/build label Jan 8, 2025
@LucasWilkinson LucasWilkinson changed the title [WIP][Kernel] Update cutlass_scaled_mm to support 2d group (blockwise) scaling [Kernel] Update cutlass_scaled_mm to support 2d group (blockwise) scaling Jan 10, 2025
@LucasWilkinson LucasWilkinson marked this pull request as ready for review January 10, 2025 19:25
@mgoin mgoin self-requested a review January 10, 2025 22:07
Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really nice performance! When you are ready for e2e testing lmk and we can hook these up for a full dsv3 eval


// Check for strides and alignment
TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
TORCH_CHECK(b.stride(0) == 1); // Column-major
TORCH_CHECK(c.stride(0) % 16 == 0 &&
b.stride(1) % 16 == 0); // 16 Byte Alignment
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why would these not be contiguous coming in?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently for the blockwise scaling a_scales and b_scales must be column-major, this is something we may need to relax for a_scales (but likely won't for b_scales since we can just transpose it offline). Figured id save that for a future PR though

Comment on lines 40 to 48
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;

using ElementB = ElementAB;
using LayoutB = cutlass::layout::ColumnMajor;
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;

using ElementD = OutType;
using StrideD = Stride<int64_t, Int<1>, Int<0>>;
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these alignments right? I think these are used widely in the CUTLASS examples but we use different/smaller alignments in the vLLM dense cutlass gemm kernels.

Also wondering about the difference between AlignmentC and AlignmentD

Copy link
Collaborator Author

@LucasWilkinson LucasWilkinson Jan 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't actually use smaller alignments in the dense cutlass GEMM kernels, we just "hardcoded" them in terms of elements, for example here for fp8/int8 we "hardcoded" A and B alignment to 16 elements:

ElementAB, cutlass::layout::RowMajor, 16,
ElementAB, cutlass::layout::ColumnMajor, 16,

which is the same as 128bits / 8bits thats used here, im not really opposed to either. You are correct I was working off an example so I went with this style, happy to change it to be "hardcoded" to be more terse

Also wondering about the difference between AlignmentC and AlignmentD

we dont use C (type is void) so I just set the alignment to a dummy value (float32 alignment 128bit / 32bit = 4) but updated to just be the same as AlignmentD to avoid confusion 👍 (probably a better dummy value anyways)

@LucasWilkinson LucasWilkinson force-pushed the lwilkinson/blockwise-scaled-mm branch from 17990c0 to 211f663 Compare January 23, 2025 16:01
Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few minor comments, but looks good to me!

Copy link

mergify bot commented Jan 27, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @LucasWilkinson.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jan 27, 2025
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
sahelib25 pushed a commit to krai/vllm that referenced this pull request Feb 3, 2025
…DeepSeekV3 (vllm-project#12587)

Integrates the block-quantized kernels introduced in
vllm-project#11868 for use in linear
layers.

Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
yessenzhar pushed a commit to deepinfra/vllm that referenced this pull request Feb 3, 2025
yessenzhar pushed a commit to deepinfra/vllm that referenced this pull request Feb 3, 2025
…DeepSeekV3 (vllm-project#12587)

Integrates the block-quantized kernels introduced in
vllm-project#11868 for use in linear
layers.

Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
fxmarty-amd pushed a commit to fxmarty-amd/vllm that referenced this pull request Feb 7, 2025
…DeepSeekV3 (vllm-project#12587)

Integrates the block-quantized kernels introduced in
vllm-project#11868 for use in linear
layers.

Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
Signed-off-by: Felix Marty <felmarty@amd.com>
NickLucche pushed a commit to NickLucche/vllm that referenced this pull request Feb 7, 2025
NickLucche pushed a commit to NickLucche/vllm that referenced this pull request Feb 7, 2025
…DeepSeekV3 (vllm-project#12587)

Integrates the block-quantized kernels introduced in
vllm-project#11868 for use in linear
layers.

Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
@yizhang2077
Copy link

yizhang2077 commented Feb 8, 2025

hi @LucasWilkinson , thanks for your awesome work, I am doing similar work but using latest cutlass version. But it can not work correctly. I find you make some other adaptions in sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp, and I wander have you face the same problem? Will be grateful for your help, thanks!

@Andy0422
Copy link

Andy0422 commented Feb 9, 2025

Currently only supports scale_a block shapes of 1x128 and scale_b block shapes of 128x128 (for deepseek v3)

Shout-out to @manishucsd and @soundOfDestiny for the kernel, kernel adapted from: https://github.com/soundOfDestiny/cutlass/tree/f8_blockwise_scaling_pr_branch

This PR also splits up scaled_mm_c3x.cu to help parallelize the building of the kernels

TODO:

  • Clean-up
  • Benchmarking
  • (future PR) See if we can support K-major a_scales, or update per_token_group_quant_fp8 to output M-major
  • (future PR)Transpose A and B to allow for smaller tensor core instructions

Benchmarking (Scroll horizontally to see cutlass_fp8_fp8_fp16_scaled_mm_blockwise):

== Results torch.float8_e4m3fn mistralai/Mistral-7B-v0.1-TP1 ====
[--------------------------------------------------------- scaled-torch.float8_e4m3fn-gemm ----------------------------------------------------------]
                            |  cutlass_fp8_fp8_fp16_scaled_mm  |  triton_fp8_fp8_fp16_scaled_mm_blockwise  |  cutlass_fp8_fp8_fp16_scaled_mm_blockwise
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------
      MKN=(1x4096x6144)     |               12.1               |                   327.3                   |                    23.9                  
      MKN=(1x4096x4096)     |                9.1               |                    69.2                   |                    23.6                  
      MKN=(1x4096x28672)    |               51.8               |                   131.6                   |                    68.8                  
      MKN=(1x14336x4096)    |               29.6               |                   229.4                   |                    69.6                  
      MKN=(16x4096x6144)    |               12.1               |                    69.1                   |                    23.9                  
      MKN=(16x4096x4096)    |                9.2               |                    69.0                   |                    23.6                  
      MKN=(16x4096x28672)   |               53.0               |                   131.5                   |                    62.0                  
      MKN=(16x14336x4096)   |               30.1               |                   231.5                   |                    71.8                  
      MKN=(32x4096x6144)    |               11.6               |                    85.6                   |                    24.0                  
      MKN=(32x4096x4096)    |                9.1               |                    85.2                   |                    23.6                  
      MKN=(32x4096x28672)   |               53.8               |                   135.7                   |                    62.1                  
      MKN=(32x14336x4096)   |               30.1               |                   291.6                   |                    71.7                  
      MKN=(64x4096x6144)    |               10.9               |                    37.8                   |                    24.0                  
      MKN=(64x4096x4096)    |                9.1               |                    39.1                   |                    23.7                  
      MKN=(64x4096x28672)   |               53.5               |                    67.4                   |                    64.0                  
      MKN=(64x14336x4096)   |               30.7               |                   130.1                   |                    71.9                  
      MKN=(128x4096x6144)   |               12.1               |                   133.5                   |                    23.7                  
      MKN=(128x4096x4096)   |               11.9               |                   124.3                   |                    23.7                  
      MKN=(128x4096x28672)  |               59.5               |                   301.8                   |                    62.0                  
      MKN=(128x14336x4096)  |               35.4               |                   512.9                   |                    72.5                  
      MKN=(256x4096x6144)   |               17.3               |                   138.2                   |                    24.9                  
      MKN=(256x4096x4096)   |               16.5               |                   126.2                   |                    24.0                  
      MKN=(256x4096x28672)  |               72.9               |                   576.5                   |                    92.2                  
      MKN=(256x14336x4096)  |               48.8               |                   500.6                   |                    73.2                  
      MKN=(512x4096x6144)   |               30.7               |                   267.5                   |                    48.0                  
      MKN=(512x4096x4096)   |               18.2               |                   137.8                   |                    26.1                  
      MKN=(512x4096x28672)  |              118.7               |                   983.2                   |                   163.2                  
      MKN=(512x14336x4096)  |               56.2               |                   500.0                   |                    81.1                  

Times are in microseconds (us).

== Results torch.float8_e4m3fn meta-llama/Llama-2-7b-hf-TP1 ====
[--------------------------------------------------------- scaled-torch.float8_e4m3fn-gemm ----------------------------------------------------------]
                            |  cutlass_fp8_fp8_fp16_scaled_mm  |  triton_fp8_fp8_fp16_scaled_mm_blockwise  |  cutlass_fp8_fp8_fp16_scaled_mm_blockwise
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------
      MKN=(1x4096x12288)    |               26.0               |                    68.7                   |                    36.8                  
      MKN=(1x4096x4096)     |                9.2               |                    68.6                   |                    23.7                  
      MKN=(1x4096x22016)    |               41.8               |                   120.7                   |                    61.8                  
      MKN=(1x11008x4096)    |               21.9               |                   177.8                   |                    55.6                  
      MKN=(16x4096x12288)   |               26.5               |                    68.4                   |                    33.2                  
      MKN=(16x4096x4096)    |                9.1               |                    68.8                   |                    23.6                  
      MKN=(16x4096x22016)   |               43.2               |                   123.6                   |                    57.9                  
      MKN=(16x11008x4096)   |               22.2               |                   178.7                   |                    56.7                  
      MKN=(32x4096x12288)   |               26.6               |                    89.7                   |                    34.0                  
      MKN=(32x4096x4096)    |               10.4               |                    85.1                   |                    23.7                  
      MKN=(32x4096x22016)   |               43.9               |                   141.8                   |                    58.6                  
      MKN=(32x11008x4096)   |               24.7               |                   224.8                   |                    57.2                  
      MKN=(64x4096x12288)   |               27.3               |                    42.9                   |                    34.3                  
      MKN=(64x4096x4096)    |                9.1               |                    39.1                   |                    23.6                  
      MKN=(64x4096x22016)   |               44.9               |                    63.2                   |                    59.1                  
      MKN=(64x11008x4096)   |               25.0               |                   354.3                   |                    57.3                  
      MKN=(128x4096x12288)  |               29.6               |                   152.4                   |                    30.7                  
      MKN=(128x4096x4096)   |               12.0               |                   124.3                   |                    23.7                  
      MKN=(128x4096x22016)  |               45.9               |                   297.8                   |                    58.6                  
      MKN=(128x11008x4096)  |               28.6               |                   393.8                   |                    57.3                  
      MKN=(256x4096x12288)  |               37.2               |                   286.5                   |                    48.7                  
      MKN=(256x4096x4096)   |               16.5               |                   126.1                   |                    24.1                  
      MKN=(256x4096x22016)  |               56.2               |                   439.6                   |                    71.7                  
      MKN=(256x11008x4096)  |               39.2               |                   383.3                   |                    57.6                  
      MKN=(512x4096x12288)  |               52.2               |                   421.5                   |                    72.5                  
      MKN=(512x4096x4096)   |               18.2               |                   137.9                   |                    25.9                  
      MKN=(512x4096x22016)  |               94.1               |                   829.5                   |                   135.7                  
      MKN=(512x11008x4096)  |               44.5               |                   385.4                   |                    62.8                  

Times are in microseconds (us).

== Results torch.float8_e4m3fn meta-llama/Llama-3-8b-TP1 ====
[--------------------------------------------------------- scaled-torch.float8_e4m3fn-gemm ----------------------------------------------------------]
                            |  cutlass_fp8_fp8_fp16_scaled_mm  |  triton_fp8_fp8_fp16_scaled_mm_blockwise  |  cutlass_fp8_fp8_fp16_scaled_mm_blockwise
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------
      MKN=(1x4096x6144)     |               12.3               |                    68.7                   |                    24.1                  
      MKN=(1x4096x4096)     |                9.1               |                    68.6                   |                    23.8                  
      MKN=(1x4096x28672)    |               51.8               |                   130.1                   |                    69.1                  
      MKN=(1x14336x4096)    |               29.6               |                   229.6                   |                    69.5                  
      MKN=(16x4096x6144)    |               12.1               |                    69.1                   |                    24.0                  
      MKN=(16x4096x4096)    |                9.1               |                    68.9                   |                    23.6                  
      MKN=(16x4096x28672)   |               53.9               |                   131.8                   |                    62.1                  
      MKN=(16x14336x4096)   |               29.9               |                   231.6                   |                    71.8                  
      MKN=(32x4096x6144)    |               11.6               |                    85.7                   |                    24.0                  
      MKN=(32x4096x4096)    |                9.1               |                    85.0                   |                    23.6                  
      MKN=(32x4096x28672)   |               53.8               |                   135.8                   |                    62.1                  
      MKN=(32x14336x4096)   |               30.1               |                   291.7                   |                    71.9                  
      MKN=(64x4096x6144)    |               10.9               |                    38.1                   |                    24.1                  
      MKN=(64x4096x4096)    |                9.2               |                    39.2                   |                    23.7                  
      MKN=(64x4096x28672)   |               54.0               |                    67.3                   |                    64.1                  
      MKN=(64x14336x4096)   |               30.7               |                   130.2                   |                    72.0                  
      MKN=(128x4096x6144)   |               12.1               |                   133.5                   |                    23.8                  
      MKN=(128x4096x4096)   |               12.0               |                   124.3                   |                    23.7                  
      MKN=(128x4096x28672)  |               59.5               |                   301.6                   |                    61.9                  
      MKN=(128x14336x4096)  |               35.4               |                   513.6                   |                    72.6                  
      MKN=(256x4096x6144)   |               17.4               |                   138.2                   |                    25.1                  
      MKN=(256x4096x4096)   |               16.5               |                   126.3                   |                    24.1                  
      MKN=(256x4096x28672)  |               71.2               |                   576.7                   |                    94.1                  
      MKN=(256x14336x4096)  |               48.8               |                   500.2                   |                    73.3                  
      MKN=(512x4096x6144)   |               30.8               |                   267.7                   |                    47.8                  
      MKN=(512x4096x4096)   |               18.2               |                   137.9                   |                    25.9                  
      MKN=(512x4096x28672)  |              110.7               |                   982.6                   |                   160.9                  
      MKN=(512x14336x4096)  |               56.7               |                   500.2                   |                    82.3                  

Times are in microseconds (us).

== Results torch.float8_e4m3fn meta-llama/Llama-2-13b-hf-TP1 ====
[--------------------------------------------------------- scaled-torch.float8_e4m3fn-gemm ----------------------------------------------------------]
                            |  cutlass_fp8_fp8_fp16_scaled_mm  |  triton_fp8_fp8_fp16_scaled_mm_blockwise  |  cutlass_fp8_fp8_fp16_scaled_mm_blockwise
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------
      MKN=(1x5120x15360)    |               38.3               |                     83.9                  |                    46.8                  
      MKN=(1x5120x5120)     |               13.2               |                     85.9                  |                    28.5                  
      MKN=(1x5120x27648)    |               62.1               |                    161.3                  |                    83.8                  
      MKN=(1x13824x5120)    |               34.3               |                    219.8                  |                    68.7                  
      MKN=(16x5120x15360)   |               39.5               |                     83.5                  |                    43.9                  
      MKN=(16x5120x5120)    |               12.9               |                     85.9                  |                    28.4                  
      MKN=(16x5120x27648)   |               63.7               |                    163.2                  |                    72.2                  
      MKN=(16x13824x5120)   |               34.8               |                    223.6                  |                    70.2                  
      MKN=(32x5120x15360)   |               39.6               |                    111.5                  |                    44.3                  
      MKN=(32x5120x5120)    |               12.6               |                    105.9                  |                    28.6                  
      MKN=(32x5120x27648)   |               63.9               |                    167.9                  |                    73.5                  
      MKN=(32x13824x5120)   |               34.8               |                    281.7                  |                    70.0                  
      MKN=(64x5120x15360)   |               40.1               |                     55.8                  |                    45.7                  
      MKN=(64x5120x5120)    |               12.0               |                     46.3                  |                    28.6                  
      MKN=(64x5120x27648)   |               64.1               |                     82.4                  |                    75.8                  
      MKN=(64x13824x5120)   |               35.6               |                    124.0                  |                    70.1                  
      MKN=(128x5120x15360)  |               40.4               |                    190.1                  |                    42.6                  
      MKN=(128x5120x5120)   |               13.5               |                    169.7                  |                    28.6                  
      MKN=(128x5120x27648)  |               70.6               |                    373.5                  |                    72.0                  
      MKN=(128x13824x5120)  |               37.7               |                    492.4                  |                    70.4                  
      MKN=(256x5120x15360)  |               47.9               |                    370.6                  |                    61.0                  
      MKN=(256x5120x5120)   |               19.7               |                    163.0                  |                    30.0                  
      MKN=(256x5120x27648)  |               87.3               |                    723.4                  |                   113.5                  
      MKN=(256x13824x5120)  |               49.9               |                    482.6                  |                    73.5                  
      MKN=(512x5120x15360)  |               80.1               |                    708.8                  |                   115.3                  
      MKN=(512x5120x5120)   |               36.0               |                    329.4                  |                    59.2                  
      MKN=(512x5120x27648)  |              142.8               |                   1242.3                  |                   201.9                  
      MKN=(512x13824x5120)  |               90.0               |                    948.0                  |                   144.9                  

Times are in microseconds (us).

== Results torch.float8_e4m3fn meta-llama/Llama-2-70b-hf-TP1 ====
[--------------------------------------------------------- scaled-torch.float8_e4m3fn-gemm ----------------------------------------------------------]
                            |  cutlass_fp8_fp8_fp16_scaled_mm  |  triton_fp8_fp8_fp16_scaled_mm_blockwise  |  cutlass_fp8_fp8_fp16_scaled_mm_blockwise
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------
      MKN=(1x8192x10240)    |               40.3               |                    130.6                  |                    57.3                  
      MKN=(1x8192x8192)     |               35.3               |                    130.8                  |                    47.7                  
      MKN=(1x8192x57344)    |              189.1               |                    414.5                  |                   257.9                  
      MKN=(1x28672x8192)    |              107.3               |                    433.9                  |                   134.3                  
      MKN=(16x8192x10240)   |               40.5               |                    129.3                  |                    46.3                  
      MKN=(16x8192x8192)    |               36.6               |                    130.2                  |                    45.4                  
      MKN=(16x8192x57344)   |              190.5               |                    419.3                  |                   236.6                  
      MKN=(16x28672x8192)   |              107.9               |                    434.6                  |                   138.6                  
      MKN=(32x8192x10240)   |               40.7               |                    171.4                  |                    46.7                  
      MKN=(32x8192x8192)    |               36.5               |                    170.1                  |                    45.4                  
      MKN=(32x8192x57344)   |              192.1               |                    533.1                  |                   239.7                  
      MKN=(32x28672x8192)   |              109.9               |                    567.1                  |                   138.5                  
      MKN=(64x8192x10240)   |               41.7               |                     77.2                  |                    47.3                  
      MKN=(64x8192x8192)    |               37.1               |                     76.5                  |                    45.5                  
      MKN=(64x8192x57344)   |              194.0               |                    253.1                  |                   243.3                  
      MKN=(64x28672x8192)   |              112.5               |                    253.3                  |                   138.9                  
      MKN=(128x8192x10240)  |               49.1               |                    296.7                  |                    48.3                  
      MKN=(128x8192x8192)   |               35.7               |                    295.8                  |                    46.2                  
      MKN=(128x8192x57344)  |              204.4               |                   1197.5                  |                   221.5                  
      MKN=(128x28672x8192)  |              101.2               |                   1023.2                  |                   142.9                  
      MKN=(256x8192x10240)  |               62.1               |                    574.5                  |                    89.4                  
      MKN=(256x8192x8192)   |               40.4               |                    297.2                  |                    51.3                  
      MKN=(256x8192x57344)  |              249.8               |                   2046.2                  |                   314.3                  
      MKN=(256x28672x8192)  |              122.7               |                   1024.6                  |                   163.6                  
      MKN=(512x8192x10240)  |               91.5               |                    856.1                  |                   133.3                  
      MKN=(512x8192x8192)   |               63.9               |                    580.1                  |                    93.5                  
      MKN=(512x8192x57344)  |              449.4               |                   4034.0                  |                   632.6                  
      MKN=(512x28672x8192)  |              220.0               |                   2008.9                  |                   316.7                  

@LucasWilkinson Just wondering which device you test for these results? Thank you!

@LucasWilkinson
Copy link
Collaborator Author

LucasWilkinson commented Feb 10, 2025

hi @LucasWilkinson , thanks for your awesome work, I am doing similar work but using latest cutlass version. But it can not work correctly. I find you make some other adaptions in sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp, and I wander have you face the same problem? Will be grateful for your help, thanks!

@yizhang2077 sorry haven't gotten around to playing with the latest CUTLASS yet but hoping to get to it soon, I do remember running into issues with the original kernels but I don't remember exactly what the issues were (was a little over a month ago since I last touched the kernel code). I think it was something along the lines of not all threads participating in the copy and not being predicated correctly. I think this commit: d963eb4 encompasses most of the changes with @soundOfDestiny 's original kernel, not sure about the the one that ultimately landed in upstream CUTLASS. Hope this helps!

@yizhang2077
Copy link

@LucasWilkinson Thanks for your reply, I also notice that d963eb4 has solved my problem, but I don't know why (I am not familiar with cutlass kernel). Anyway, thank you again!

@LucasWilkinson
Copy link
Collaborator Author

@Andy0422 This was an H100

@LucasWilkinson
Copy link
Collaborator Author

LucasWilkinson commented Feb 10, 2025

@LucasWilkinson Thanks for your reply, I also notice that d963eb4 has solved my problem, but I don't know why (I am not familiar with cutlass kernel). Anyway, thank you again!

@yizhang2077 No worries. After reviewing the commit it's coming back to me a bit. The original implementation was pretty heavily bugged so I ended-up heavily modifying it. The main issue was that the cp.async calls were only being issued by 1 thread (since it was bundled with the TMA code that intentionally only issues from a single thread). This meant that alot of scales were not getting copied into shared memory correctly and even when you did start calling from multiple threads there was issues with predication (to avoid reading illegal memory addresses) and barrier counts that had to be resolved.

Ill open a PR to fix this on upstream CUTLASS soon.

@yizhang2077
Copy link

yizhang2077 commented Feb 10, 2025

Ill open a PR to fix this on upstream CUTLASS soon.

@LucasWilkinson Nice work! I also raise an issue NVIDIA/cutlass#2087, maybe can help your PR

ShangmingCai pushed a commit to ShangmingCai/vllm that referenced this pull request Feb 10, 2025
ShangmingCai pushed a commit to ShangmingCai/vllm that referenced this pull request Feb 10, 2025
…DeepSeekV3 (vllm-project#12587)

Integrates the block-quantized kernels introduced in
vllm-project#11868 for use in linear
layers.

Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
GWS0428 pushed a commit to GWS0428/VARserve that referenced this pull request Feb 12, 2025
GWS0428 pushed a commit to GWS0428/VARserve that referenced this pull request Feb 12, 2025
…DeepSeekV3 (vllm-project#12587)

Integrates the block-quantized kernels introduced in
vllm-project#11868 for use in linear
layers.

Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
pathorn pushed a commit to deepinfra/vllm that referenced this pull request Feb 14, 2025
pathorn pushed a commit to deepinfra/vllm that referenced this pull request Feb 14, 2025
…DeepSeekV3 (vllm-project#12587)

Integrates the block-quantized kernels introduced in
vllm-project#11868 for use in linear
layers.

Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
panf2333 pushed a commit to yottalabsai/vllm that referenced this pull request Feb 18, 2025
panf2333 pushed a commit to yottalabsai/vllm that referenced this pull request Feb 18, 2025
…DeepSeekV3 (vllm-project#12587)

Integrates the block-quantized kernels introduced in
vllm-project#11868 for use in linear
layers.

Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
pathorn pushed a commit to deepinfra/vllm that referenced this pull request Feb 19, 2025
pathorn pushed a commit to deepinfra/vllm that referenced this pull request Feb 19, 2025
…DeepSeekV3 (vllm-project#12587)

Integrates the block-quantized kernels introduced in
vllm-project#11868 for use in linear
layers.

Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
kerthcet pushed a commit to kerthcet/vllm that referenced this pull request Feb 21, 2025
kerthcet pushed a commit to kerthcet/vllm that referenced this pull request Feb 21, 2025
…DeepSeekV3 (vllm-project#12587)

Integrates the block-quantized kernels introduced in
vllm-project#11868 for use in linear
layers.

Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
@ProphetPeng
Copy link

Hi, is there any plan to support smaller tensor core instructions? @LucasWilkinson

@LucasWilkinson
Copy link
Collaborator Author

Hi, is there any plan to support smaller tensor core instructions? @LucasWilkinson

Contemplating it with the lastest CUTLASS updates and NVIDIA/cutlass#2095 , is this something you can help with? If so, if you can handle the vLLM side then I can update NVIDIA/cutlass#2095 to support partial N tiles

@ProphetPeng
Copy link

Hi, is there any plan to support smaller tensor core instructions? @LucasWilkinson

Contemplating it with the lastest CUTLASS updates and NVIDIA/cutlass#2095 , is this something you can help with? If so, if you can handle the vLLM side then I can update NVIDIA/cutlass#2095 to support partial N tiles

Thanks for reply. We are going to use DeepGEMM and maybe support smaller tensor core instructions based on it

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants