Skip to content

Commit

Permalink
[Kernel] Update cutlass_scaled_mm to support 2d group (blockwise) s…
Browse files Browse the repository at this point in the history
…caling (vllm-project#11868)
  • Loading branch information
LucasWilkinson authored and pathorn committed Feb 14, 2025
1 parent 8cb38ed commit 99dc8ec
Show file tree
Hide file tree
Showing 25 changed files with 1,924 additions and 346 deletions.
9 changes: 7 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
FetchContent_Declare(
cutlass
GIT_REPOSITORY https://github.com/nvidia/cutlass.git
GIT_TAG v3.6.0
GIT_TAG v3.7.0
GIT_PROGRESS TRUE

# Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history.
Expand Down Expand Up @@ -299,7 +299,12 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# CUDA 12.0 or later (and only work on Hopper, 9.0a for now).
cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS)
set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu")
set(SRCS
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu"
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu"
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu"
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_azp_sm90_int8.cu"
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_3X_ARCHS}")
Expand Down
Loading

0 comments on commit 99dc8ec

Please sign in to comment.