From 75866838289bfadb9d1d184dd524e8a224b8d022 Mon Sep 17 00:00:00 2001 From: yangjianfeng01 Date: Sun, 4 Jan 2026 09:52:59 +0800 Subject: [PATCH] opt w4afp8 --- .../gpu_ops/w4afp8_gemm/kernel_traits.h | 77 ++++++-- custom_ops/gpu_ops/w4afp8_gemm/mainloop_fwd.h | 78 +++++--- .../w4afp8_gemm/w4afp8_gemm_kernel.hpp | 184 +++++++++++++----- .../utils/auto_gen_w4afp8_gemm_kernel.py | 5 +- 4 files changed, 242 insertions(+), 102 deletions(-) diff --git a/custom_ops/gpu_ops/w4afp8_gemm/kernel_traits.h b/custom_ops/gpu_ops/w4afp8_gemm/kernel_traits.h index 04d97eb141e..003f3fadb5b 100644 --- a/custom_ops/gpu_ops/w4afp8_gemm/kernel_traits.h +++ b/custom_ops/gpu_ops/w4afp8_gemm/kernel_traits.h @@ -46,7 +46,9 @@ struct SharedStorage { }; template , Int, Int>; - + using TileShape_MNK1 = Shape, Int, Int>; + using TileShape_MNK2 = Shape, Int, Int>; + using TileShape_MNK3 = Shape, Int, Int>; static constexpr int kClusterM = kClusterM_; using ClusterShape_MNK = Shape, _1, _1>; @@ -91,9 +96,17 @@ struct Kernel_traits { using AtomLayoutMNK = Layout, _1, _1>>; - using TiledMma = decltype(cute::make_tiled_mma( + using TiledMma1 = decltype(cute::make_tiled_mma( + cute::GMMA:: + rs_op_selector(), + AtomLayoutMNK{})); + using TiledMma2 = decltype(cute::make_tiled_mma( + cute::GMMA:: + rs_op_selector(), + AtomLayoutMNK{})); + using TiledMma3 = decltype(cute::make_tiled_mma( cute::GMMA:: - rs_op_selector(), + rs_op_selector(), AtomLayoutMNK{})); using SmemLayoutAtomA = @@ -107,27 +120,53 @@ struct Kernel_traits { SmemLayoutAtomA{}, make_shape(Int{}, Int{}, Int{}))); - using SmemLayoutAtomB = + using SmemLayoutAtomB1 = + decltype(cutlass::gemm::collective::detail::rs_smem_selector< + GMMA::Major::K, + Element, + decltype(cute::get<1>(TileShape_MNK1{})), + decltype(cute::get<2>(TileShape_MNK1{}))>()); + + using SmemLayoutB1 = + decltype(tile_to_shape(SmemLayoutAtomB1{}, + make_shape(shape<1>(TileShape_MNK1{}), + shape<2>(TileShape_MNK1{}), + Int{}))); + + using SmemLayoutAtomB2 = + decltype(cutlass::gemm::collective::detail::rs_smem_selector< + GMMA::Major::K, + Element, + decltype(cute::get<1>(TileShape_MNK2{})), + decltype(cute::get<2>(TileShape_MNK2{}))>()); + + using SmemLayoutB2 = + decltype(tile_to_shape(SmemLayoutAtomB2{}, + make_shape(shape<1>(TileShape_MNK2{}), + shape<2>(TileShape_MNK2{}), + Int{}))); + + using SmemLayoutAtomB3 = decltype(cutlass::gemm::collective::detail::rs_smem_selector< GMMA::Major::K, Element, - decltype(cute::get<1>(TileShape_MNK{})), - decltype(cute::get<2>(TileShape_MNK{}))>()); + decltype(cute::get<1>(TileShape_MNK3{})), + decltype(cute::get<2>(TileShape_MNK3{}))>()); - using SmemLayoutB = - decltype(tile_to_shape(SmemLayoutAtomB{}, - make_shape(shape<1>(TileShape_MNK{}), - shape<2>(TileShape_MNK{}), + using SmemLayoutB3 = + decltype(tile_to_shape(SmemLayoutAtomB3{}, + make_shape(shape<1>(TileShape_MNK3{}), + shape<2>(TileShape_MNK3{}), Int{}))); using SmemLayoutAtomC = decltype(cutlass::gemm::collective::detail::rs_smem_selector< GMMA::Major::K, ElementOutput, - decltype(cute::get<0>(TileShape_MNK{})), - decltype(cute::get<1>(TileShape_MNK{}))>()); + decltype(cute::get<0>(TileShape_MNK1{})), + decltype(cute::get<1>(TileShape_MNK1{}))>()); - using SmemLayoutC = - decltype(tile_to_shape(SmemLayoutAtomC{}, select<0, 1>(TileShape_MNK{}))); + using SmemLayoutC = decltype(tile_to_shape(SmemLayoutAtomC{}, + select<0, 1>(TileShape_MNK1{}))); using SmemCopyAtomAB = Copy_Atom; using SmemCopyAtomC = Copy_Atom; @@ -138,7 +177,7 @@ struct Kernel_traits { Element, ElementOutput, SmemLayoutA, - SmemLayoutB, + SmemLayoutB1, SmemLayoutC, SmemLayoutScale>; @@ -146,7 +185,7 @@ struct Kernel_traits { using PipelineState = typename cutlass::PipelineState; static constexpr int kNumVecElem = ceil_div(128, sizeof_bits_v); - static constexpr int kNumThreadsPerRow = kBlockN / kNumVecElem; + static constexpr int kNumThreadsPerRow = kBlockN1 / kNumVecElem; // static_assert(NumMmaThreads % kNumThreadsPerRow == 0); static constexpr int kNumRows = NumMmaThreads / kNumThreadsPerRow; using TiledCopyCAtom = diff --git a/custom_ops/gpu_ops/w4afp8_gemm/mainloop_fwd.h b/custom_ops/gpu_ops/w4afp8_gemm/mainloop_fwd.h index 0c3c71678aa..893bbd84c2b 100644 --- a/custom_ops/gpu_ops/w4afp8_gemm/mainloop_fwd.h +++ b/custom_ops/gpu_ops/w4afp8_gemm/mainloop_fwd.h @@ -32,13 +32,17 @@ template struct CollectiveMainloopFwd { using Element = typename Ktraits::Element; using ElementOutput = typename Ktraits::ElementOutput; - using TileShape_MNK = typename Ktraits::TileShape_MNK; + using TileShape_MNK1 = typename Ktraits::TileShape_MNK1; + using TileShape_MNK2 = typename Ktraits::TileShape_MNK2; + using TileShape_MNK3 = typename Ktraits::TileShape_MNK3; using ClusterShape = typename Ktraits::ClusterShape_MNK; using ElementAccum = typename Ktraits::ElementAccum; static constexpr int kStages = Ktraits::kStages; static constexpr int kBlockM = Ktraits::kBlockM; - static constexpr int kBlockN = Ktraits::kBlockN; + static constexpr int kBlockN1 = Ktraits::kBlockN1; + static constexpr int kBlockN2 = Ktraits::kBlockN2; + static constexpr int kBlockN3 = Ktraits::kBlockN3; static constexpr int kBlockK = Ktraits::kBlockK; static constexpr int NumCopyThreads = cutlass::NumThreadsPerWarpGroup; static constexpr int kTiles = Ktraits::kTiles; @@ -50,7 +54,10 @@ struct CollectiveMainloopFwd { using GmemTiledCopy = cute::SM90_TMA_LOAD; using SmemLayoutA = typename Ktraits::SmemLayoutA; - using SmemLayoutB = typename Ktraits::SmemLayoutB; + using SmemLayoutB1 = typename Ktraits::SmemLayoutB1; + using SmemLayoutB2 = typename Ktraits::SmemLayoutB2; + using SmemLayoutB3 = typename Ktraits::SmemLayoutB3; + using SmemLayoutC = typename Ktraits::SmemLayoutC; using SmemLayoutScale = typename Ktraits::SmemLayoutScale; @@ -76,8 +83,8 @@ struct CollectiveMainloopFwd { make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeT{}, StrideT{}), - take<0, 2>(SmemLayoutB{}), - select<1, 2>(TileShape_MNK{}), + take<0, 2>(SmemLayoutB1{}), + select<1, 2>(TileShape_MNK1{}), size<0>(ClusterShape{}))); using TMA_Scale = decltype(make_tma_copy( @@ -89,7 +96,7 @@ struct CollectiveMainloopFwd { select<0>(Shape>{}), size<0>(ClusterShape{}))); - static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma{}); + static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma1{}); using MainloopPipeline = typename Ktraits::MainloopPipeline; using PipelineParams = typename MainloopPipeline::Params; using PipelineState = typename MainloopPipeline::PipelineState; @@ -100,7 +107,7 @@ struct CollectiveMainloopFwd { static constexpr uint32_t TmaTransactionBytesA = static_cast( size(take<0, 2>(SmemLayoutA{})) * cutlass::sizeof_bits_v / 8); static constexpr uint32_t TmaTransactionBytesB = static_cast( - size(take<0, 2>(SmemLayoutB{})) * cutlass::sizeof_bits_v / 8); + size(take<0, 2>(SmemLayoutB1{})) * cutlass::sizeof_bits_v / 8); static constexpr uint32_t TmaTransactionBytesScale = static_cast( size(SmemLayoutScale{}(_, _0{})) * cutlass::sizeof_bits_v / 8); @@ -141,8 +148,8 @@ struct CollectiveMainloopFwd { Tensor mB = make_tensor(make_gmem_ptr(args.ptr_B), args.layout_B); TMA_B tma_load_B = make_tma_copy(GmemTiledCopy{}, mB, - SmemLayoutB{}(_, _, _0{}), - select<1, 2>(TileShape_MNK{}), + SmemLayoutB1{}(_, _, _0{}), + select<1, 2>(TileShape_MNK1{}), size<0>(ClusterShape{})); Tensor mScale = make_tensor(make_gmem_ptr(args.weight_scale), args.layout_Scale); @@ -176,7 +183,10 @@ struct CollectiveMainloopFwd { } } - template + template CUTLASS_DEVICE void store(Params const& mainloop_params, FrgTensorO& tOrO, SharedStorage& shared_storage, @@ -252,7 +262,7 @@ struct CollectiveMainloopFwd { cutlass::arch::NamedBarrier::sync(NumMmaThreads, 0); - constexpr int k_copy_times = kBlockN / 16; + constexpr int k_copy_times = CUR_N / 16; #pragma unroll for (int i = 0; i < k_copy_times; i++) { @@ -273,15 +283,15 @@ struct CollectiveMainloopFwd { const int expert_idx = TokenPackSize == 0 ? pre_fix_tokens * M : bidb * M * TokenPackSize; ElementOutput* store_c = mainloop_params.ptr_C + expert_idx + - bidn * (M * kBlockN) + bidm * kBlockM; + bidn * (M * kBlockN1) + bidm * kBlockM; - const int reamin_tokens = tokens - bidn * kBlockN; + const int reamin_tokens = tokens - bidn * kBlockN1; const int col = tidx % 2; constexpr int kPackSize = 16 / sizeof(ElementOutput); constexpr int kNumVecElem = kBlockM / kPackSize; - constexpr int copy_len = kBlockN * kNumVecElem; + constexpr int copy_len = CUR_N * kNumVecElem; #pragma unroll for (int idx = tidx; idx < copy_len; idx += NumMmaThreads) { const int idx_div2 = idx / 2; @@ -307,7 +317,7 @@ struct CollectiveMainloopFwd { auto g_tensor = domain_offset(make_coord(pre_fix_token, _0{}), mB(_, _, 0)); Tensor gB = local_tile( - g_tensor, select<1, 2>(TileShape_MNK{}), make_coord(bidn, _)); + g_tensor, select<1, 2>(TileShape_MNK1{}), make_coord(bidn, _)); return gB; } @@ -324,8 +334,8 @@ struct CollectiveMainloopFwd { const int tidx) { Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem_a.data()), SmemLayoutA{}); - Tensor sB = - make_tensor(make_smem_ptr(shared_storage.smem_b.data()), SmemLayoutB{}); + Tensor sB = make_tensor(make_smem_ptr(shared_storage.smem_b.data()), + SmemLayoutB1{}); Tensor sScale = make_tensor(make_smem_ptr(shared_storage.smem_scale.data()), SmemLayoutScale{}); @@ -387,7 +397,7 @@ struct CollectiveMainloopFwd { mB(_, _, bidb).data(), make_layout(cute::make_shape(tokens, size<1>(mB)), mB.stride())); Tensor gB = local_tile( - mB_this_expert, select<1, 2>(TileShape_MNK{}), make_coord(bidn, _)); + mB_this_expert, select<1, 2>(TileShape_MNK1{}), make_coord(bidn, _)); auto [tBgB, tBsB] = tma_partition(mainloop_params.tma_load_B, _0{}, Layout{}, @@ -421,7 +431,10 @@ struct CollectiveMainloopFwd { } } - template + template CUTLASS_DEVICE void mma(Params const& mainloop_params, TiledMma tiled_mma, MainloopPipeline pipeline, @@ -429,10 +442,15 @@ struct CollectiveMainloopFwd { SharedStorage& shared_storage, FrgTensorO& tSrS, const int tidx) { + using sMemBLayout = std::conditional_t< + CUR_N == kBlockN1, + SmemLayoutB1, + std::conditional_t>; + Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem_a.data()), SmemLayoutA{}); Tensor sB = - make_tensor(make_smem_ptr(shared_storage.smem_b.data()), SmemLayoutB{}); + make_tensor(make_smem_ptr(shared_storage.smem_b.data()), sMemBLayout{}); tiled_mma.accumulate_ = GMMA::ScaleOut::One; auto threadMma = tiled_mma.get_thread_slice(tidx); @@ -447,6 +465,7 @@ struct CollectiveMainloopFwd { auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); pipeline.consumer_wait(smem_pipe_read, barrier_token); }; + constexpr int B_STEPS = kBlockN1 / CUR_N; #pragma unroll for (int kiter = 0; kiter < kTiles; ++kiter) { Tensor tSsA = @@ -455,7 +474,7 @@ struct CollectiveMainloopFwd { gemm(tiled_mma, tSrA, tSsA, - tSrB(_, _, _, smem_pipe_read.index()), + tSrB(_, _, _, smem_pipe_read.index() * B_STEPS), tSrS, smem_tiled_copy_A, smem_thr_copy_A); @@ -464,7 +483,10 @@ struct CollectiveMainloopFwd { } } - template + template CUTLASS_DEVICE void mma_pipeline(Params const& mainloop_params, TiledMma tiled_mma, MainloopPipeline pipeline, @@ -472,10 +494,14 @@ struct CollectiveMainloopFwd { SharedStorage& shared_storage, FrgTensorO& tSrS, const int tidx) { + using sMemBLayout = std::conditional_t< + CUR_N == kBlockN1, + SmemLayoutB1, + std::conditional_t>; Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem_a.data()), SmemLayoutA{}); Tensor sB = - make_tensor(make_smem_ptr(shared_storage.smem_b.data()), SmemLayoutB{}); + make_tensor(make_smem_ptr(shared_storage.smem_b.data()), sMemBLayout{}); float2* weight_scale = reinterpret_cast(shared_storage.smem_scale.data()) + tidx / 4; @@ -501,7 +527,7 @@ struct CollectiveMainloopFwd { auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); pipeline.consumer_wait(smem_pipe_read, barrier_token); }; - + constexpr int B_STEPS = kBlockN1 / CUR_N; __half2 scale1, scale2, scale3, scale4; float2 scale_cur_k; #pragma unroll @@ -516,7 +542,7 @@ struct CollectiveMainloopFwd { gemm(tiled_mma, tSrA, tSsA1, - tSrB(_, _, _, smem_pipe_read.index()), + tSrB(_, _, _, smem_pipe_read.index() * B_STEPS), tSrS1, smem_tiled_copy_A, smem_thr_copy_A); @@ -545,7 +571,7 @@ struct CollectiveMainloopFwd { gemm(tiled_mma, tSrA, tSsA2, - tSrB(_, _, _, smem_pipe_read.index()), + tSrB(_, _, _, smem_pipe_read.index() * B_STEPS), tSrS2, smem_tiled_copy_A, smem_thr_copy_A); diff --git a/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm_kernel.hpp b/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm_kernel.hpp index 36f1ab0c9bc..3b7d0fbb54a 100644 --- a/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm_kernel.hpp +++ b/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm_kernel.hpp @@ -35,12 +35,16 @@ void __global__ __launch_bounds__(Ktraits::kNWarps *cutlass::NumThreadsPerWarp, using Element = typename Ktraits::Element; static_assert(cutlass::sizeof_bits_v == 8); - using TileShape_MNK = typename Ktraits::TileShape_MNK; + using TileShape_MNK1 = typename Ktraits::TileShape_MNK1; + using TileShape_MNK2 = typename Ktraits::TileShape_MNK2; + using TileShape_MNK3 = typename Ktraits::TileShape_MNK3; using ClusterShape = typename Ktraits::ClusterShape_MNK; - static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma{}); + static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma1{}); static constexpr int NumCopyThreads = cutlass::NumThreadsPerWarpGroup; - static constexpr int kBlockN = Ktraits::kBlockN; + static constexpr int kBlockN1 = Ktraits::kBlockN1; + static constexpr int kBlockN2 = Ktraits::kBlockN2; + static constexpr int kBlockN3 = Ktraits::kBlockN3; static constexpr int kBlockM = Ktraits::kBlockM; static constexpr int M = Ktraits::M; static constexpr int K = Ktraits::K; @@ -109,7 +113,9 @@ void __global__ __launch_bounds__(Ktraits::kNWarps *cutlass::NumThreadsPerWarp, ? mainloop_params.tokens[bidb] - pre_fix_tokens : mainloop_params.tokens[bidb]; - if (bidn * kBlockN >= tokens) { + const int block_compute_tokens = tokens - bidn * kBlockN1; + + if (block_compute_tokens <= 0) { return; } @@ -139,21 +145,23 @@ void __global__ __launch_bounds__(Ktraits::kNWarps *cutlass::NumThreadsPerWarp, cutlass::arch::warpgroup_reg_alloc(); PipelineState smem_pipe_read; - typename Ktraits::TiledMma tiled_mma; + typename Ktraits::TiledMma1 tiled_mma1; + typename Ktraits::TiledMma2 tiled_mma2; + typename Ktraits::TiledMma3 tiled_mma3; const int mma_tidx = tidx - NumCopyThreads; if (is_need_input_scale) { if constexpr (TokenPackSize == 0) { - const int input_scale_idx = pre_fix_tokens + bidn * kBlockN; + const int input_scale_idx = pre_fix_tokens + bidn * kBlockN1; if (mma_tidx < tokens) { reinterpret_cast(input_scale)[mma_tidx] = reinterpret_cast(mainloop_params.input_scale + input_scale_idx)[mma_tidx]; } } else { - const int input_scale_idx = bidb * TokenPackSize + bidn * kBlockN; - if (mma_tidx < kBlockN / 4) { + const int input_scale_idx = bidb * TokenPackSize + bidn * kBlockN1; + if (mma_tidx < kBlockN1 / 4) { reinterpret_cast(input_scale)[mma_tidx] = reinterpret_cast(mainloop_params.input_scale + input_scale_idx)[mma_tidx]; @@ -168,62 +176,130 @@ void __global__ __launch_bounds__(Ktraits::kNWarps *cutlass::NumThreadsPerWarp, mainloop_params.weight_scale + bidb * M + bidm * kBlockM)[mma_tidx / 4]; } - Tensor tSrS = - partition_fragment_C(tiled_mma, select<0, 1>(TileShape_MNK{})); - if constexpr (WeightScaleGroup == K) { - collective_mainloop.mma(mainloop_params, - tiled_mma, - pipeline, - smem_pipe_read, - shared_storage, - tSrS, - mma_tidx); + if (block_compute_tokens > kBlockN2) { + Tensor tSrS = + partition_fragment_C(tiled_mma1, select<0, 1>(TileShape_MNK1{})); + if constexpr (WeightScaleGroup == K) { + collective_mainloop.mma(mainloop_params, + tiled_mma1, + pipeline, + smem_pipe_read, + shared_storage, + tSrS, + mma_tidx); + } else { + collective_mainloop.mma_pipeline(mainloop_params, + tiled_mma1, + pipeline, + smem_pipe_read, + shared_storage, + tSrS, + mma_tidx); + } + collective_mainloop.store( + mainloop_params, + tSrS, + shared_storage, + tiled_mma1, + reinterpret_cast(&weight_scale), + input_scale, + tokens, + pre_fix_tokens, + bidm, + bidn, + bidb, + mma_tidx); + } else if (block_compute_tokens > kBlockN3) { + Tensor tSrS = + partition_fragment_C(tiled_mma2, select<0, 1>(TileShape_MNK2{})); + + if constexpr (WeightScaleGroup == K) { + collective_mainloop.mma(mainloop_params, + tiled_mma2, + pipeline, + smem_pipe_read, + shared_storage, + tSrS, + mma_tidx); + } else { + collective_mainloop.mma_pipeline(mainloop_params, + tiled_mma2, + pipeline, + smem_pipe_read, + shared_storage, + tSrS, + mma_tidx); + } + collective_mainloop.store( + mainloop_params, + tSrS, + shared_storage, + tiled_mma2, + reinterpret_cast(&weight_scale), + input_scale, + tokens, + pre_fix_tokens, + bidm, + bidn, + bidb, + mma_tidx); } else { - collective_mainloop.mma_pipeline(mainloop_params, - tiled_mma, - pipeline, - smem_pipe_read, - shared_storage, - tSrS, - mma_tidx); + Tensor tSrS = + partition_fragment_C(tiled_mma3, select<0, 1>(TileShape_MNK3{})); + + if constexpr (WeightScaleGroup == K) { + collective_mainloop.mma(mainloop_params, + tiled_mma3, + pipeline, + smem_pipe_read, + shared_storage, + tSrS, + mma_tidx); + } else { + collective_mainloop.mma_pipeline(mainloop_params, + tiled_mma3, + pipeline, + smem_pipe_read, + shared_storage, + tSrS, + mma_tidx); + } + collective_mainloop.store( + mainloop_params, + tSrS, + shared_storage, + tiled_mma3, + reinterpret_cast(&weight_scale), + input_scale, + tokens, + pre_fix_tokens, + bidm, + bidn, + bidb, + mma_tidx); } - - collective_mainloop.store(mainloop_params, - tSrS, - shared_storage, - tiled_mma, - reinterpret_cast(&weight_scale), - input_scale, - tokens, - pre_fix_tokens, - bidm, - bidn, - bidb, - mma_tidx); } } template auto get_gmem_layout(const int Rows, const int Cols) { - return make_layout( - make_shape(static_cast(Rows), - static_cast(Cols), - static_cast(Experts)), - make_stride(static_cast(Cols), - cute::_1{}, - static_cast(Rows) * static_cast(Cols))); + return make_layout(make_shape(static_cast(Rows), + static_cast(Cols), + static_cast(Experts)), + make_stride(static_cast(Cols), + cute::_1{}, + static_cast(Rows * Cols))); } template auto get_scale_layout(const int Rows, const int Cols) { - return make_layout( - make_shape(static_cast(Cols), - static_cast(Rows), - static_cast(Experts)), - make_stride(cute::_1{}, - static_cast(Cols), - static_cast(Rows) * static_cast(Cols))); + return make_layout(make_shape(static_cast(Cols), + static_cast(Rows), + static_cast(Experts)), + make_stride(cute::_1{}, + static_cast(Cols), + static_cast(Rows * Cols))); } template ; int smem_size = sizeof(typename Kernel_traits::SharedStorage) + - Kernel_traits::kBlockN * sizeof(float); + Kernel_traits::kBlockN1 * sizeof(float); if (smem_size >= 48 * 1024) { cudaFuncSetAttribute( diff --git a/custom_ops/utils/auto_gen_w4afp8_gemm_kernel.py b/custom_ops/utils/auto_gen_w4afp8_gemm_kernel.py index 194da2bdde6..bb47a65f9ce 100644 --- a/custom_ops/utils/auto_gen_w4afp8_gemm_kernel.py +++ b/custom_ops/utils/auto_gen_w4afp8_gemm_kernel.py @@ -14,8 +14,7 @@ import os import re -script_dir = os.path.dirname(os.path.abspath(__file__)) -file_dir = os.path.join(script_dir, "..", "gpu_ops", "w4afp8_gemm") + os.sep +file_dir = "./gpu_ops/w4afp8_gemm/" gemm_template_head = """ #pragma once @@ -76,7 +75,7 @@ constexpr int kTiles = K / kBlockK; using Kernel_traits = Kernel_traits< - kBlockM, kBlockN, kBlockK, kNWarps, kStages, kTiles, + kBlockM, kBlockN, 128, 64, kBlockK, kNWarps, kStages, kTiles, M, K, TokenPackSize, kGroupSize, kCluster, cutlass::float_e4m3_t, {cutlass_type}>; run_gemm