From d29b20cd1664e769811f67d013b00ecf6f99519e Mon Sep 17 00:00:00 2001 From: Wangzheee <634486483@qq.com> Date: Thu, 19 Jun 2025 14:51:38 +0000 Subject: [PATCH 1/5] support group_gemm_offset, group_gemm_offset_swapAB --- deep_gemm/__init__.py | 1 + deep_gemm/include/deep_gemm/fp8_gemm.cuh | 868 +++++++++++++++++++++- deep_gemm/include/deep_gemm/mma_utils.cuh | 24 + deep_gemm/include/deep_gemm/scheduler.cuh | 275 ++++++- deep_gemm/jit_kernels/__init__.py | 3 +- deep_gemm/jit_kernels/gemm.py | 135 ++-- deep_gemm/jit_kernels/m_grouped_gemm.py | 168 ++++- deep_gemm/jit_kernels/runtime.py | 155 +++- deep_gemm/jit_kernels/utils.py | 3 + tests/test_core.py | 275 ++----- 10 files changed, 1652 insertions(+), 255 deletions(-) diff --git a/deep_gemm/__init__.py b/deep_gemm/__init__.py index 8e6b2996..f3af4c55 100644 --- a/deep_gemm/__init__.py +++ b/deep_gemm/__init__.py @@ -5,6 +5,7 @@ gemm_fp8_fp8_bf16_nt, m_grouped_gemm_fp8_fp8_bf16_nt_contiguous, m_grouped_gemm_fp8_fp8_bf16_nt_masked, + m_grouped_gemm_fp8_fp8_bf16_nt_offset, wgrad_gemm_fp8_fp8_fp32_nt, k_grouped_wgrad_gemm_fp8_fp8_fp32_nt, ceil_div, diff --git a/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_gemm.cuh index 5c11cd3d..8419f7d8 100644 --- a/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ b/deep_gemm/include/deep_gemm/fp8_gemm.cuh @@ -438,7 +438,873 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout, DG_DEVICE_ASSERT(false and "This kernel only support sm_90a"); #endif } +template +static __device__ __forceinline__ void write_result_to_gmem(__nv_bfloat16* gmem_d_this_block, + __nv_bfloat16 const* smem_d, uint32_t const m_offset, uint32_t const m_boundary, uint32_t const n_offset, + uint32_t const shape_n, uint32_t const ld_output) +{ + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + int lane_idx = threadIdx.x % 32; + constexpr int int4_per_tile_line = BLOCK_N * sizeof(__nv_bfloat16) / sizeof(int4); + int int4_per_global_line = shape_n * sizeof(__nv_bfloat16) / sizeof(int4); + constexpr auto num_lines = BLOCK_M; + constexpr auto num_warps = NUM_WARPS_PER_BLOCK; + int4 const* smem_d_int4 = reinterpret_cast(smem_d); + bool is_last_n_block = n_offset + BLOCK_N > shape_n; + int int4_per_line = is_last_n_block ? int4_per_global_line % int4_per_tile_line : int4_per_tile_line; + for (int line_idx = warp_idx; line_idx < num_lines; line_idx += num_warps) + { + if (m_offset + line_idx >= m_boundary) + { + break; + } + for (int elem_idx = lane_idx; elem_idx < int4_per_line; elem_idx += 32) + { + uint64_t idx = (uint64_t) line_idx * ld_output + n_offset; + int4* g_data_addr = reinterpret_cast(&gmem_d_this_block[idx]) + elem_idx; + int4 const* s_data_addr = &smem_d_int4[line_idx * (int4_per_tile_line) + elem_idx]; + *g_data_addr = *s_data_addr; + } + __syncwarp(); + } +} + +template +__global__ void __launch_bounds__(get_num_threads_per_sm(BLOCK_M), 1) + fp8_gemm_offset_kernel(__nv_bfloat16* gmem_d, float* scales_b, int64_t* offsets, + __grid_constant__ const CUtensorMap tensor_map_a, __grid_constant__ const CUtensorMap tensor_map_b, + __grid_constant__ const CUtensorMap tensor_map_scales_a, __grid_constant__ const CUtensorMap tensor_map_d) +{ +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ == 900)) + // Scaling checks + DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling"); + DG_STATIC_ASSERT(ceil_div(BLOCK_N, BLOCK_K) == 1, "Too much B scales in a single block"); + + InputType problem_input; + problem_input.problem_m_offsets = offsets; + + // Types + using WGMMA = typename FP8MMASelector::type; + using Barrier = cutlass::arch::ClusterTransactionBarrier; + + // Shared memory + static constexpr int kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0); + static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(__nv_bfloat16); + static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_SCALES_A_SIZE_PER_STAGE = BLOCK_M * sizeof(float); + static constexpr uint32_t SHAPE_K_SCALES = ceil_div(SHAPE_K, BLOCK_K); + static constexpr uint32_t SMEM_SCALES_B_SIZE + = ceil_div(SHAPE_K_SCALES * (kMustUseUniformedScaleB ? 1 : 2) * sizeof(float), sizeof(Barrier)) + * sizeof(Barrier); + + // Configs + constexpr uint32_t kFullKOfAllStages = kNumStages * BLOCK_K; + constexpr uint32_t kNumThreads = get_num_threads_per_sm(BLOCK_M); + constexpr uint32_t kNumMathThreads = kNumThreads - kNumTMAThreads; + constexpr uint32_t kNumIterations = ceil_div(SHAPE_K, kFullKOfAllStages); + uint32_t const warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + uint32_t const lane_idx = get_lane_id(); + + // Prefetch TMA descriptors at very beginning + if (threadIdx.x == kNumMathThreads) + { + cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_a)); + cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_b)); + cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_scales_a)); + cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_d)); + } + __syncwarp(); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes"); + + // Data on shared memory + auto smem_d = reinterpret_cast<__nv_bfloat16*>(smem_buffer); + __nv_fp8_e4m3* smem_a[kNumStages]; + __nv_fp8_e4m3* smem_b[kNumStages]; + float* smem_scales_a[kNumStages]; + float* smem_scales_b; + + // TMA Barrier for both divisible and non-divisible cases + Barrier* full_barriers[kNumStages]; + Barrier* empty_barriers[kNumStages]; + +// Fill shared memory pointers +#pragma unroll + for (int i = 0; i < kNumStages; ++i) + { + smem_a[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE); + smem_b[i] = reinterpret_cast<__nv_fp8_e4m3*>( + smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); + smem_scales_a[i] = reinterpret_cast(smem_buffer + SMEM_D_SIZE + + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE) + i * SMEM_SCALES_A_SIZE_PER_STAGE); + } + smem_scales_b = reinterpret_cast(smem_buffer + SMEM_D_SIZE + + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE)); + + // Fill barriers + auto barrier_start_ptr = reinterpret_cast(reinterpret_cast(smem_scales_b) + SMEM_SCALES_B_SIZE); +#pragma unroll + for (int i = 0; i < kNumStages; ++i) + { + full_barriers[i] = barrier_start_ptr + i; + empty_barriers[i] = barrier_start_ptr + kNumStages + i; + } + + // Initialize barriers + DG_STATIC_ASSERT(kNumTMAMulticast <= 32, "Too many TMA multicast"); + if (threadIdx.x == kNumMathThreads) + { +#pragma unroll + for (int i = 0; i < kNumStages; ++i) + { + full_barriers[i]->init(1); + empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32); + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_view_async_shared(); + (kNumTMAMulticast > 1) ? cutlass::arch::fence_barrier_init() : void(); + } + + // Synchronize all threads to make barrier visible in normal memory model + (kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads(); + + // For pipeline unrolling + struct DivisibleK + { + }; + + struct NotDivisibleK + { + }; + + auto launch_k_iterations = [](auto const& func) + { + if constexpr (SHAPE_K % kFullKOfAllStages == 0) + { + for (int k_iter = 0; k_iter < kNumIterations; ++k_iter) + func(k_iter, DivisibleK{}); + } + else + { + for (int k_iter = 0; k_iter < kNumIterations - 1; ++k_iter) + func(k_iter, DivisibleK{}); + func(kNumIterations - 1, NotDivisibleK{}); + } + }; + + // Register reconfigurations + constexpr int kNumTMARegisters = 40; + constexpr int kNumMathRegisters = 232; + + // Block scheduler + uint32_t m_block_idx, n_block_idx; + auto scheduler = SchedulerType(problem_input); + + if (threadIdx.x >= kNumMathThreads) + { + // TMA warp-group for loading data + cutlass::arch::warpgroup_reg_dealloc(); + + // NOTES: only one thread (or warp) will be used + if (threadIdx.x == kNumMathThreads) + { + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) + { + launch_k_iterations( + [&](int k_iter, auto type) + { + constexpr bool kHasDivisibleStages = std::is_same_v; + constexpr int kNumInnerStages + = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K; + DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); + +#pragma unroll + for (uint32_t s = 0; s < kNumInnerStages; ++s) + { + // Wait consumer release + empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1); + + // Issue TMA A with broadcasting + auto& full_barrier = *full_barriers[s]; + int k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K; + tma_copy(&tensor_map_a, reinterpret_cast(&full_barrier), + smem_a[s], k_idx, scheduler.get_global_m_idx(m_block_idx), kNumTMAMulticast); + + if constexpr (SchedulerType::gemm_type == GemmType::GroupedWithOffset) + { + tma_copy(&tensor_map_scales_a, + reinterpret_cast(&full_barrier), smem_scales_a[s], + scheduler.get_global_scales_a_idx(m_block_idx), k_idx / BLOCK_K, kNumTMAMulticast); + } + else + { + tma_copy(&tensor_map_scales_a, + reinterpret_cast(&full_barrier), smem_scales_a[s], m_block_idx * BLOCK_M, + scheduler.get_global_scales_a_idx(k_idx / BLOCK_K), kNumTMAMulticast); + } + + // Issue TMA B without broadcasting + tma_copy(&tensor_map_b, reinterpret_cast(&full_barrier), smem_b[s], k_idx, + scheduler.get_global_n_idx(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx), 1); + full_barrier.arrive_and_expect_tx( + SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE); + } + +// Wait unaligned cases +#pragma unroll + for (uint32_t s = kNumInnerStages; s < kNumStages; ++s) + { + empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1); + full_barriers[s]->arrive(); + } + }); + } + + // To safely deconstruct distributed shared barriers, we need another round of empty waits + if constexpr (kNumTMAMulticast > 1) + { +#pragma unroll + for (uint32_t s = 0; s < kNumStages; ++s) + empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + 1) & 1); + } + } + } + else + { + // Math warp-groups for WGMMA + cutlass::arch::warpgroup_reg_alloc(); + + // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers + auto const math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / kNumMathThreadsPerGroup, 0); + auto const r_0 = warp_idx * 16 + lane_idx / 4, r_1 = r_0 + 8; + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) + { + // Decide the number of scales B to load + DG_STATIC_ASSERT(SHAPE_N % 8 == 0, "Invalid shape N"); + uint32_t num_former_iters = BLOCK_N / 8, num_full_iters = num_former_iters; + if constexpr (not kMustUseUniformedScaleB) + { + num_former_iters = min(BLOCK_N, BLOCK_K - n_block_idx * BLOCK_N % BLOCK_K) / 8; + num_full_iters = min(SHAPE_N - n_block_idx * BLOCK_N, BLOCK_N) / 8; + } + uint32_t num_scales_b = SHAPE_K_SCALES * (num_former_iters >= num_full_iters ? 1 : 2); + + // Load B scales with math warp-groups + // NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks + if (threadIdx.x >= 32) + { + auto num_previous_lines + = scheduler.get_global_scales_b_idx(ceil_div(SHAPE_N, BLOCK_K), 0, 0, m_block_idx); + ; + auto local_scales_b + = scales_b + (num_previous_lines + ((n_block_idx * BLOCK_N) / BLOCK_K)) * SHAPE_K_SCALES; +#pragma unroll + for (uint32_t i = threadIdx.x - 32; i < num_scales_b; i += kNumMathThreads - 32) + st_shared(smem_scales_b + i, __ldg(local_scales_b + i)); + } + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + + // Accumulation for WGMMA or CUDA promotion + float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum] = {0}; + + // Empty barrier arrival + auto empty_barrier_arrive = [&](int s) + { + if constexpr (kNumTMAMulticast == 1) + { + lane_idx == 0 ? empty_barriers[s]->arrive() : void(); + } + else + { + lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(lane_idx) : void(); + } + }; + + // Launch MMAs + launch_k_iterations( + [&](int k_iter, auto type) + { + constexpr bool kHasDivisibleStages = std::is_same_v; + constexpr int kNumInnerStages + = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K; + DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); + +#pragma unroll + for (int s = 0; s < kNumInnerStages; ++s) + { + // Read B scales + float scale_b_0 = ld_shared(smem_scales_b + k_iter * kNumStages + s), scale_b_1 = 1.0f; + // NOTES: even some blocks do not need to read the second row, but we still load one to align + // with other blocks + if constexpr (not kMustUseUniformedScaleB) + scale_b_1 = ld_shared(smem_scales_b + k_iter * kNumStages + s + SHAPE_K_SCALES); + + // Wait TMA arrivals + full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); + + // Read A scales + // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled + // block polluting the results + auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0), + scale_a_1 = ld_shared(smem_scales_a[s] + r_1); + +// Commit WGMMA instructions +#pragma unroll + for (int i = 0; i < WGMMA::kNumAccum; ++i) + warpgroup_fence_operand(accum[i]); + warpgroup_arrive(); +#pragma unroll + for (int k = 0; k < BLOCK_K / WGMMA::K; ++k) + { + auto desc_a + = make_smem_desc(smem_a[s] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1); + auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1); + WGMMA::wgmma(desc_a, desc_b, accum, k); + } + warpgroup_commit_batch(); +#pragma unroll + for (int i = 0; i < WGMMA::kNumAccum; ++i) + warpgroup_fence_operand(accum[i]); + warpgroup_wait<0>(); + + // Notify barrier arrival + empty_barrier_arrive(s); + + // Promote with scales + float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0; + float scale_0_1, scale_1_1; + if constexpr (not kMustUseUniformedScaleB) + scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1; +#pragma unroll + for (int i = 0; i < WGMMA::kNumAccum / 4; ++i) + { + bool predicate = kMustUseUniformedScaleB or i < num_former_iters; + final_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0]; + final_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1]; + final_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2]; + final_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3]; + } + } + +// Wait unaligned cases +#pragma unroll + for (uint32_t s = kNumInnerStages; s < kNumStages; ++s) + { + full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); + empty_barrier_arrive(s); + } + }); + + // Write back to shared memory using STSM + DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization"); +#pragma unroll + for (auto i = 0; i < WGMMA::kNumAccum / 8; ++i) + { + SM90_U32x4_STSM_N::copy( + __float22bfloat162_rn({final_accum[i * 8 + 0], final_accum[i * 8 + 1]}), + __float22bfloat162_rn({final_accum[i * 8 + 2], final_accum[i * 8 + 3]}), + __float22bfloat162_rn({final_accum[i * 8 + 4], final_accum[i * 8 + 5]}), + __float22bfloat162_rn({final_accum[i * 8 + 6], final_accum[i * 8 + 7]}), + smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + i * 16 + 8 * (lane_idx / 16)); + } + if constexpr (WGMMA::kNumAccum % 8 != 0) + { + SM90_U32x2_STSM_N::copy(__float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 0], + final_accum[WGMMA::kNumAccum / 8 * 8 + 1]}), + __float22bfloat162_rn( + {final_accum[WGMMA::kNumAccum / 8 * 8 + 2], final_accum[WGMMA::kNumAccum / 8 * 8 + 3]}), + smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + WGMMA::kNumAccum / 8 * 16); + } + + if constexpr (SchedulerType::gemm_type == GemmType::GroupedWithOffset) + { + auto m_global_idx = scheduler.get_global_m_idx(m_block_idx); + bool cross_boundary = (m_global_idx + BLOCK_M) > scheduler.m_boundary; + cute::tma_store_fence(); + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + if (!cross_boundary) + { + // Use TMA store to write back to global memory + if (threadIdx.x == 0) + { + cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, n_block_idx * BLOCK_N, m_global_idx); + cute::tma_store_arrive(); + cute::tma_store_wait<0>(); + } + } + else + { + __nv_bfloat16* gmem_d_this_block = gmem_d + m_global_idx * SHAPE_N; + constexpr int NUM_WARPS + = (get_num_threads_per_sm(BLOCK_M) - 128) / 32; + write_result_to_gmem(gmem_d_this_block, smem_d, m_global_idx, + scheduler.m_boundary, n_block_idx * BLOCK_N, SHAPE_N, SHAPE_N); + } + } + else if constexpr (SchedulerType::gemm_type == GemmType::StridedBatched) + { + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + __nv_bfloat16* gmem_d_this_block; + auto m_global_idx = scheduler.get_global_m_idx(m_block_idx); + gmem_d_this_block = gmem_d + scheduler.curr_group_idx * problem_input.stride_d + + (m_block_idx * BLOCK_M) * problem_input.ld_d; + constexpr int NUM_WARPS + = (get_num_threads_per_sm(BLOCK_M) - 128) / 32; + write_result_to_gmem(gmem_d_this_block, smem_d, m_global_idx, + scheduler.m_boundary, n_block_idx * BLOCK_N, SHAPE_N, problem_input.ld_d); + } + else + { + cute::tma_store_fence(); + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + // Use TMA store to write back to global memory + if (threadIdx.x == 0) + { + cute::SM90_TMA_STORE_2D::copy( + &tensor_map_d, smem_d, n_block_idx * BLOCK_N, scheduler.get_global_m_idx(m_block_idx)); + cute::tma_store_arrive(); + cute::tma_store_wait<0>(); + } + } + + __syncwarp(); + } + } +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_90a"); +#endif +} + +template +__global__ void __launch_bounds__(get_num_threads_per_sm(BLOCK_M), 1) + fp8_gemm_offset_kernel_swapAB(__nv_bfloat16* gmem_d, float* scales_a, int64_t* offsets, + const __grid_constant__ CUtensorMap tensor_map_a, // weight (previously act) + const __grid_constant__ CUtensorMap tensor_map_b, // act (previously weight) + const __grid_constant__ CUtensorMap tensor_map_scales_b, // act scales (previously tensor_map_scales_a) + const __grid_constant__ CUtensorMap tensor_map_d) +{ +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__) + // Scaling checks + DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling"); + DG_STATIC_ASSERT(ceil_div(BLOCK_M, BLOCK_K) == 1, "Too much A scales in a single block"); + + InputType problem_input; + problem_input.problem_n_offsets = offsets; + + // Types + using WGMMA = typename FP8MMASelector::type; + using Barrier = cutlass::arch::ClusterTransactionBarrier; + + // Shared memory + DG_STATIC_ASSERT(BLOCK_K % BLOCK_M == 0, "BLOCK_M should be 64 or 128 and BLOCK_K should be 128"); + static constexpr uint32_t SMEM_D_SIZE = BLOCK_N * BLOCK_M * sizeof(__nv_bfloat16); + static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_SCALES_B_SIZE_PER_STAGE = BLOCK_N * sizeof(float); // B matrix (act) scales + static constexpr uint32_t SMEM_SCALES_B_SIZE_PER_STAGE_PADDED + = ceil_div(BLOCK_N * sizeof(float), 128) * 128; // B matrix (act) scales, 128B aligned + static constexpr uint32_t SHAPE_K_SCALES = ceil_div(SHAPE_K, BLOCK_K); + static constexpr uint32_t SMEM_SCALES_A_SIZE = ceil_div(SHAPE_K_SCALES * sizeof(float), sizeof(Barrier)) + * sizeof(Barrier); // renamed to A (weight) + + // Configs + constexpr uint32_t kFullKOfAllStages = kNumStages * BLOCK_K; + constexpr uint32_t kNumThreads = get_num_threads_per_sm(BLOCK_M); + constexpr uint32_t kNumMathThreads = kNumThreads - kNumTMAThreads; + constexpr uint32_t kNumIterations = ceil_div(SHAPE_K, kFullKOfAllStages); + const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + const uint32_t lane_idx = get_lane_id(); + + // Prefetch TMA descriptors at very beginning + if (threadIdx.x == kNumMathThreads) + { + cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_a)); + cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_b)); + cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_scales_b)); + cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_d)); + } + __syncwarp(); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes"); + + // Data on shared memory + auto smem_d = reinterpret_cast<__nv_bfloat16*>(smem_buffer); + __nv_fp8_e4m3* smem_a[kNumStages]; // weight + __nv_fp8_e4m3* smem_b[kNumStages]; // act + float* smem_scales_b[kNumStages]; // act scales + float* smem_scales_a; // weight scales + + // TMA Barrier for both divisible and non-divisible cases + Barrier* full_barriers[kNumStages]; + Barrier* empty_barriers[kNumStages]; + +// Fill shared memory pointers +#pragma unroll + for (int i = 0; i < kNumStages; ++i) + { + smem_a[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE); + smem_b[i] = reinterpret_cast<__nv_fp8_e4m3*>( + smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); + smem_scales_b[i] = reinterpret_cast(smem_buffer + SMEM_D_SIZE + + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE) + i * SMEM_SCALES_B_SIZE_PER_STAGE_PADDED); + } + smem_scales_a = reinterpret_cast(smem_buffer + SMEM_D_SIZE + + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_B_SIZE_PER_STAGE_PADDED)); + + // Fill barriers + auto barrier_start_ptr = reinterpret_cast(reinterpret_cast(smem_scales_a) + SMEM_SCALES_A_SIZE); +#pragma unroll + for (int i = 0; i < kNumStages; ++i) + { + full_barriers[i] = barrier_start_ptr + i; + empty_barriers[i] = barrier_start_ptr + kNumStages + i; + } + + // Initialize barriers + DG_STATIC_ASSERT(kNumTMAMulticast <= 32, "Too many TMA multicast"); + if (threadIdx.x == kNumMathThreads) + { +#pragma unroll + for (int i = 0; i < kNumStages; ++i) + { + full_barriers[i]->init(1); + empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32); + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_view_async_shared(); + (kNumTMAMulticast > 1) ? cutlass::arch::fence_barrier_init() : void(); + } + + // Synchronize all threads to make barrier visible in normal memory model + (kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads(); + + // For pipeline unrolling + struct DivisibleK + { + }; + + struct NotDivisibleK + { + }; + + auto launch_k_iterations = [](auto const& func) + { + if constexpr (SHAPE_K % kFullKOfAllStages == 0) + { + for (int k_iter = 0; k_iter < kNumIterations; ++k_iter) + func(k_iter, DivisibleK{}); + } + else + { + for (int k_iter = 0; k_iter < kNumIterations - 1; ++k_iter) + func(k_iter, DivisibleK{}); + func(kNumIterations - 1, NotDivisibleK{}); + } + }; + + // Register reconfigurations + constexpr int kNumTMARegisters = 40; + constexpr int kNumMathRegisters = 232; + + // Block scheduler + uint32_t m_block_idx, n_block_idx; + auto scheduler = SchedulerType(problem_input); + + if (threadIdx.x >= kNumMathThreads) + { + // TMA warp-group for loading data + cutlass::arch::warpgroup_reg_dealloc(); + + // NOTES: only one thread (or warp) will be used + if (threadIdx.x == kNumMathThreads) + { + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) + { + launch_k_iterations( + [&](int k_iter, auto type) + { + constexpr bool kHasDivisibleStages = std::is_same_v; + constexpr int kNumInnerStages + = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K; + DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); + +#pragma unroll + for (uint32_t s = 0; s < kNumInnerStages; ++s) + { + // Wait consumer release + empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1); + + // Issue TMA A (weight) now without broadcasting + auto& full_barrier = *full_barriers[s]; + int k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K; + tma_copy(&tensor_map_a, reinterpret_cast(&full_barrier), smem_a[s], k_idx, + scheduler.get_global_m_idx(SHAPE_M, BLOCK_M, m_block_idx, n_block_idx), 1); + + // Issue TMA B (act) with broadcasting + tma_copy(&tensor_map_b, reinterpret_cast(&full_barrier), + smem_b[s], k_idx, scheduler.get_global_n_idx(n_block_idx), kNumTMAMulticast); + + // Issue TMA scales_b (act scales) for B matrix + if constexpr (SchedulerType::gemm_type == GemmType::GroupedWithOffset) + { + tma_copy(&tensor_map_scales_b, + reinterpret_cast(&full_barrier), smem_scales_b[s], + scheduler.get_global_scales_b_idx(n_block_idx), k_idx / BLOCK_K, kNumTMAMulticast); + } + else + { + tma_copy(&tensor_map_scales_b, + reinterpret_cast(&full_barrier), smem_scales_b[s], n_block_idx * BLOCK_N, + scheduler.get_global_scales_b_idx(k_idx / BLOCK_K), kNumTMAMulticast); + } + + full_barrier.arrive_and_expect_tx( + SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_B_SIZE_PER_STAGE); + } + +// Wait unaligned cases +#pragma unroll + for (uint32_t s = kNumInnerStages; s < kNumStages; ++s) + { + empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1); + full_barriers[s]->arrive(); + } + }); + } + + // To safely deconstruct distributed shared barriers, we need another round of empty waits + if constexpr (kNumTMAMulticast > 1) + { +#pragma unroll + for (uint32_t s = 0; s < kNumStages; ++s) + empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + 1) & 1); + } + } + } + else + { + // Math warp-groups for WGMMA + cutlass::arch::warpgroup_reg_alloc(); + + // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers + auto const math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / kNumMathThreadsPerGroup, 0); + + // Each thread loads consecutive 2 scales + const uint32_t scale_offset = (lane_idx % 4) * 2; + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) + { + // Load weight scales (scales_a) - these are associated with tensor_map_a (weight) + // Decide the number of scales A to load + DG_STATIC_ASSERT(SHAPE_M % 8 == 0, "Invalid shape M"); + uint32_t num_scales_a = SHAPE_K_SCALES; + + // Load A scales with math warp-groups (weight scales) + if (threadIdx.x >= 32) + { + auto num_previous_lines + = scheduler.get_global_scales_a_idx(ceil_div(SHAPE_M, BLOCK_K), 0, 0, n_block_idx); + auto local_scales_a + = scales_a + (num_previous_lines + ((m_block_idx * BLOCK_M) / BLOCK_K)) * SHAPE_K_SCALES; +#pragma unroll + for (uint32_t i = threadIdx.x - 32; i < num_scales_a; i += kNumMathThreads - 32) + st_shared(smem_scales_a + i, __ldg(local_scales_a + i)); + } + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + + // Accumulation for WGMMA or CUDA promotion + float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum] = {0}; + + // Empty barrier arrival + auto empty_barrier_arrive = [&](int s) + { + if constexpr (kNumTMAMulticast == 1) + { + lane_idx == 0 ? empty_barriers[s]->arrive() : void(); + } + else + { + lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(lane_idx) : void(); + } + }; + + // Launch MMAs + launch_k_iterations( + [&](int k_iter, auto type) + { + constexpr bool kHasDivisibleStages = std::is_same_v; + constexpr int kNumInnerStages + = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K; + DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); + +#pragma unroll + for (int s = 0; s < kNumInnerStages; ++s) + { + // Read weight scales (A scales) + float scale_a_0 = ld_shared(smem_scales_a + k_iter * kNumStages + s); + + // Wait TMA arrivals + full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); + + // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled + // block polluting the results + // Each thread reads consecutive two b scales, each thread needs to read WGMMA::N / 4 * 2 b + // scales + float scale_0_0[WGMMA::kNumAccum / 4], scale_0_1[WGMMA::kNumAccum / 4]; +#pragma unroll + for (int i = 0; i < WGMMA::kNumAccum / 4; ++i) + { + float2 scale_b + = ld_shared(reinterpret_cast(smem_scales_b[s] + i * 8 + scale_offset)); + scale_0_0[i] = scale_a_0 * scale_b.x; + scale_0_1[i] = scale_a_0 * scale_b.y; + } + +// Commit WGMMA instructions +#pragma unroll + for (int i = 0; i < WGMMA::kNumAccum; ++i) + warpgroup_fence_operand(accum[i]); + warpgroup_arrive(); +#pragma unroll + for (int k = 0; k < BLOCK_K / WGMMA::K; ++k) + { + auto desc_a + = make_smem_desc(smem_a[s] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1); + auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1); + WGMMA::wgmma(desc_a, desc_b, accum, k); + } + warpgroup_commit_batch(); +#pragma unroll + for (int i = 0; i < WGMMA::kNumAccum; ++i) + warpgroup_fence_operand(accum[i]); + warpgroup_wait<0>(); + + // Notify barrier arrival + empty_barrier_arrive(s); + +// Promote with scales +#pragma unroll + for (int i = 0; i < WGMMA::kNumAccum / 4; ++i) + { + final_accum[i * 4 + 0] += scale_0_0[i] * accum[i * 4 + 0]; + final_accum[i * 4 + 1] += scale_0_1[i] * accum[i * 4 + 1]; + final_accum[i * 4 + 2] += scale_0_0[i] * accum[i * 4 + 2]; + final_accum[i * 4 + 3] += scale_0_1[i] * accum[i * 4 + 3]; + } + } + +// Wait unaligned cases +#pragma unroll + for (uint32_t s = kNumInnerStages; s < kNumStages; ++s) + { + full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); + empty_barrier_arrive(s); + } + }); + + // Write back to shared memory using STSM + DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization"); + int tid = 0; + if (lane_idx < 8) + { + tid = lane_idx * BLOCK_M; + } + else if (lane_idx < 16) + { + tid = (lane_idx - 8) * BLOCK_M + 8; + } + else if (lane_idx < 24) + { + tid = (lane_idx - 8) * BLOCK_M; + } + else + { + tid = (lane_idx - 16) * BLOCK_M + 8; + } +#pragma unroll + for (auto i = 0; i < WGMMA::kNumAccum / 8; ++i) + { + SM90_U32x4_STSM_T::copy( + __float22bfloat162_rn({final_accum[i * 8 + 0], final_accum[i * 8 + 1]}), + __float22bfloat162_rn({final_accum[i * 8 + 2], final_accum[i * 8 + 3]}), + __float22bfloat162_rn({final_accum[i * 8 + 4], final_accum[i * 8 + 5]}), + __float22bfloat162_rn({final_accum[i * 8 + 6], final_accum[i * 8 + 7]}), + smem_d + warp_idx * 16 + i * 16 * BLOCK_M + tid); + } + if constexpr (WGMMA::kNumAccum % 8 != 0) + { + SM90_U32x2_STSM_T::copy(__float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 0], + final_accum[WGMMA::kNumAccum / 8 * 8 + 1]}), + __float22bfloat162_rn( + {final_accum[WGMMA::kNumAccum / 8 * 8 + 2], final_accum[WGMMA::kNumAccum / 8 * 8 + 3]}), + smem_d + warp_idx * 16 + WGMMA::kNumAccum / 8 * 16 * BLOCK_M + tid); + } + + if constexpr (SchedulerType::gemm_type == GemmType::GroupedWithOffset) + { + auto n_global_idx = scheduler.get_global_n_idx(n_block_idx); + bool cross_boundary = (n_global_idx + BLOCK_N) > scheduler.n_boundary; + cute::tma_store_fence(); + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + if (!cross_boundary) + { + // Use TMA store to write back to global memory + if (threadIdx.x == 0) + { + cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, m_block_idx * BLOCK_M, n_global_idx); + cute::tma_store_arrive(); + cute::tma_store_wait<0>(); + } + } + else + { + __nv_bfloat16* gmem_d_this_block = gmem_d + n_global_idx * SHAPE_M; + constexpr int NUM_WARPS + = (get_num_threads_per_sm(BLOCK_M) - 128) / 32; + write_result_to_gmem(gmem_d_this_block, smem_d, n_global_idx, + scheduler.n_boundary, m_block_idx * BLOCK_M, SHAPE_M, SHAPE_M); + } + } + else + { + cute::tma_store_fence(); + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + // Use TMA store to write back to global memory + if (threadIdx.x == 0) + { + cute::SM90_TMA_STORE_2D::copy( + &tensor_map_d, smem_d, m_block_idx * BLOCK_M, scheduler.get_global_n_idx(n_block_idx)); + cute::tma_store_arrive(); + cute::tma_store_wait<0>(); + } + } + + __syncwarp(); + } + } +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_90a"); +#endif +} }; // namespace deep_gemm -#pragma clang diagnostic pop \ No newline at end of file +#pragma clang diagnostic pop diff --git a/deep_gemm/include/deep_gemm/mma_utils.cuh b/deep_gemm/include/deep_gemm/mma_utils.cuh index 85b2ccc0..4fc7f4fa 100644 --- a/deep_gemm/include/deep_gemm/mma_utils.cuh +++ b/deep_gemm/include/deep_gemm/mma_utils.cuh @@ -32,6 +32,30 @@ struct SM90_U32x4_STSM_N { } }; +template +struct SM90_U32x2_STSM_T +{ + __device__ __forceinline__ static void copy(dtype_t src_0, dtype_t src_1, void* smem_dst) + { + const uint32_t src[2] = {*reinterpret_cast(&src_0), *reinterpret_cast(&src_1)}; + asm volatile("stmatrix.sync.aligned.x2.m8n8.shared.b16.trans [%0], {%1, %2};\n" ::"l"(smem_dst), "r"(src[0]), + "r"(src[1])); + } +}; + +template +struct SM90_U32x4_STSM_T +{ + __device__ __forceinline__ static void copy( + dtype_t src_0, dtype_t src_1, dtype_t src_2, dtype_t src_3, void* smem_dst) + { + const uint32_t src[4] = {*reinterpret_cast(&src_0), *reinterpret_cast(&src_1), + *reinterpret_cast(&src_2), *reinterpret_cast(&src_3)}; + asm volatile("stmatrix.sync.aligned.x4.m8n8.shared.b16.trans [%0], {%1, %2, %3, %4};\n" ::"l"(smem_dst), + "r"(src[0]), "r"(src[1]), "r"(src[2]), "r"(src[3])); + } +}; + __forceinline__ __device__ void warpgroup_arrive() { asm volatile("wgmma.fence.sync.aligned;\n" ::: "memory"); } diff --git a/deep_gemm/include/deep_gemm/scheduler.cuh b/deep_gemm/include/deep_gemm/scheduler.cuh index 69ea2160..622be39b 100644 --- a/deep_gemm/include/deep_gemm/scheduler.cuh +++ b/deep_gemm/include/deep_gemm/scheduler.cuh @@ -7,7 +7,8 @@ namespace deep_gemm { enum class GemmType { Normal, GroupedContiguous, - GroupedMasked + GroupedMasked, + GroupedWithOffset }; #pragma clang diagnostic push @@ -158,6 +159,278 @@ struct Scheduler { } }; + +template +__device__ __forceinline__ void offset_get_swizzled_block_idx( + const uint32_t num_m_blocks, int block_idx, uint32_t& m_block_idx, uint32_t& n_block_idx) +{ + DG_STATIC_ASSERT(kNumNBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size"); + + // Swizzle for better L2 usages + auto num_blocks_per_group = num_m_blocks * kNumNBlocksPerGroup; + auto group_idx = block_idx / num_blocks_per_group; + auto first_n_block_idx = group_idx * kNumNBlocksPerGroup; + auto num_n_blocks_in_group = min(kNumNBlocksPerGroup, kNumNBlocks - first_n_block_idx); + auto in_group_idx = block_idx % num_blocks_per_group; + m_block_idx = in_group_idx / num_n_blocks_in_group; + n_block_idx = first_n_block_idx + in_group_idx % num_n_blocks_in_group; +} + + + +struct GroupedWithOffsetSchedulerInput +{ + uint32_t shape_m; + int64_t* problem_m_offsets; +}; + +struct GroupedWithOffsetSchedulerInputSwapAB +{ + uint32_t shape_m; + int64_t* problem_n_offsets; +}; + +struct StridedBatchedSchedulerInput +{ + uint32_t shape_m; + uint64_t ld_a; + uint64_t stride_a; + uint64_t ld_b; + uint64_t stride_b; + uint64_t ld_d; + uint64_t stride_d; +}; + +struct StridedBatchedSchedulerInputSwapAB +{ + uint32_t shape_n; + uint64_t ld_a; + uint64_t stride_a; + uint64_t ld_b; + uint64_t stride_b; + uint64_t ld_d; + uint64_t stride_d; +}; + + +// Need to keep the same as the one in tests/unittest/_torch/thop/deep_gemm_tests.py +template +__host__ __device__ __forceinline__ T_offset compute_padded_offset(T_offset offset, T_index problem_idx) +{ + // This formulation ensures that padded_offset[i + 1] - padded_offset[i] >= offset[i + 1] - offset[i]. + constexpr T_offset alignment = 32; + return (offset + problem_idx * (alignment - 1)) / alignment * alignment; +} + +template +struct GroupedWithOffsetScheduler +{ + static constexpr GemmType gemm_type = GemmType::GroupedWithOffset; + + int current_iter = -1; + uint32_t curr_group_idx; + uint32_t curr_cumsum; + int64_t m_offset; + int64_t m_padded_4_offset; + int64_t m_boundary; + int64_t* problem_m_offsets; + + using Input = GroupedWithOffsetSchedulerInput; + Input input; + + GroupedWithOffsetScheduler() {} + + __device__ __forceinline__ GroupedWithOffsetScheduler(Input& input) + { + this->problem_m_offsets = input.problem_m_offsets; + curr_group_idx = 0; + curr_cumsum = 0; + } + + __device__ __forceinline__ uint32_t get_global_m_idx(uint32_t const& block_idx) + { + return m_offset + block_idx * BLOCK_M; + } + + __device__ __forceinline__ uint32_t get_global_n_idx( + uint32_t const shape_dim, uint32_t const block_size, uint32_t const& block_idx, uint32_t const& m_block_idx = 0) + { + return curr_group_idx * shape_dim + block_idx * block_size; + } + + __device__ __forceinline__ uint32_t get_global_scales_a_idx(uint32_t const& block_idx) + { + return m_padded_4_offset + block_idx * BLOCK_M; + } + + __device__ __forceinline__ uint32_t get_global_scales_b_idx( + uint32_t const shape_dim, uint32_t const block_size, uint32_t const& block_idx, uint32_t const& m_block_idx = 0) + { + return curr_group_idx * shape_dim + block_idx * block_size; + } + + __device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx) + { + ++current_iter; + auto const next_block_idx = current_iter * gridDim.x + blockIdx.x; + uint32_t num_m_blocks; + while (true) + { + // End of the task + if (curr_group_idx == kNumGroups) + return false; + m_offset = __ldg(problem_m_offsets + curr_group_idx); + m_boundary = __ldg(problem_m_offsets + curr_group_idx + 1); + m_padded_4_offset = compute_padded_offset(m_offset, curr_group_idx); + auto m = m_boundary - m_offset; + // Within current group + num_m_blocks = ceil_div(m, static_cast(BLOCK_M)); + auto current_m_block_cumsum = curr_cumsum + num_m_blocks; + if (next_block_idx < current_m_block_cumsum * kNumNBlocks) + break; + + // Move to check the next group + curr_group_idx++; + curr_cumsum = current_m_block_cumsum; + } + + offset_get_swizzled_block_idx( + num_m_blocks, next_block_idx - curr_cumsum * kNumNBlocks, m_block_idx, n_block_idx); + return true; + } +}; + +template +struct GroupedWithOffsetSchedulerSwapAB +{ + static constexpr GemmType gemm_type = GemmType::GroupedWithOffset; + + int current_iter = -1; + uint32_t curr_group_idx; + uint32_t curr_cumsum; + int64_t n_offset; + int64_t n_padded_4_offset; + int64_t n_boundary; + int64_t* problem_n_offsets; + + using Input = GroupedWithOffsetSchedulerInputSwapAB; + Input input; + + GroupedWithOffsetSchedulerSwapAB() {} + + __device__ __forceinline__ GroupedWithOffsetSchedulerSwapAB(Input& input) + { + this->problem_n_offsets = input.problem_n_offsets; + curr_group_idx = 0; + curr_cumsum = 0; + } + + // weight + __device__ __forceinline__ uint32_t get_global_m_idx( + const uint32_t shape_dim, const uint32_t block_size, uint32_t const& block_idx, uint32_t const& n_block_idx = 0) + { + return curr_group_idx * shape_dim + block_idx * block_size; + } + + // act + __device__ __forceinline__ uint32_t get_global_n_idx(uint32_t const& block_idx) + { + return n_offset + block_idx * BLOCK_N; + } + + // act scales + __device__ __forceinline__ uint32_t get_global_scales_b_idx(uint32_t const& block_idx) + { + return n_padded_4_offset + block_idx * BLOCK_N; + } + + // weight scales + __device__ __forceinline__ uint32_t get_global_scales_a_idx( + const uint32_t shape_dim, const uint32_t block_size, uint32_t const& block_idx, uint32_t const& n_block_idx = 0) + { + return curr_group_idx * shape_dim + block_idx * block_size; + } + + __device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx) + { + ++current_iter; + auto const next_block_idx = current_iter * gridDim.x + blockIdx.x; + uint32_t num_n_blocks; + while (true) + { + // End of the task + if (curr_group_idx == kNumGroups) + return false; + n_offset = __ldg(problem_n_offsets + curr_group_idx); + n_boundary = __ldg(problem_n_offsets + curr_group_idx + 1); + n_padded_4_offset = compute_padded_offset(n_offset, curr_group_idx); + auto n = n_boundary - n_offset; + // Within current group + num_n_blocks = ceil_div(n, static_cast(BLOCK_N)); + auto current_n_block_cumsum = curr_cumsum + num_n_blocks; + if (next_block_idx < current_n_block_cumsum * kNumMBlocks) + break; + + // Move to check the next group + curr_group_idx++; + curr_cumsum = current_n_block_cumsum; + } + + offset_get_swizzled_block_idx( + num_n_blocks, next_block_idx - curr_cumsum * kNumMBlocks, n_block_idx, m_block_idx); + return true; + } +}; + +template +struct SchedulerSelector +{ + static constexpr auto select_type() + { + if constexpr (GT == GemmType::Normal) + return NormalScheduler(); + if constexpr (GT == GemmType::GroupedContiguous) + return GroupedContiguousScheduler(); + if constexpr (GT == GemmType::GroupedMasked) + return GroupedMaskedScheduler(); + if constexpr (GT == GemmType::GroupedWithOffset) + return GroupedWithOffsetScheduler(); + if constexpr (GT == GemmType::StridedBatched) + return StridedBatchedScheduler(); + } + + using type = decltype(select_type()); +}; + +template +struct SchedulerSelectorSwapAB +{ + static constexpr auto select_type() + { + static_assert(GT == GemmType::GroupedWithOffset || GT == GemmType::Normal, + "Only GroupedWithOffset and Normal are supported for SwapAB"); + if constexpr (GT == GemmType::Normal) + return NormalSchedulerSwapAB(); + if constexpr (GT == GemmType::GroupedWithOffset) + return GroupedWithOffsetSchedulerSwapAB(); + } + + using type = decltype(select_type()); +}; + #pragma clang diagnostic pop } // namespace deep_gemm diff --git a/deep_gemm/jit_kernels/__init__.py b/deep_gemm/jit_kernels/__init__.py index f1fa7bb2..839a3a19 100644 --- a/deep_gemm/jit_kernels/__init__.py +++ b/deep_gemm/jit_kernels/__init__.py @@ -1,7 +1,8 @@ from .gemm import gemm_fp8_fp8_bf16_nt from .m_grouped_gemm import ( m_grouped_gemm_fp8_fp8_bf16_nt_contiguous, - m_grouped_gemm_fp8_fp8_bf16_nt_masked + m_grouped_gemm_fp8_fp8_bf16_nt_masked, + m_grouped_gemm_fp8_fp8_bf16_nt_offset ) from .wgrad_gemm import ( wgrad_gemm_fp8_fp8_fp32_nt, diff --git a/deep_gemm/jit_kernels/gemm.py b/deep_gemm/jit_kernels/gemm.py index 574f821f..2a2cc316 100644 --- a/deep_gemm/jit_kernels/gemm.py +++ b/deep_gemm/jit_kernels/gemm.py @@ -34,42 +34,71 @@ def get_block_n_padding_for_smem_d(block_n: int) -> int: def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k: int = 128, - is_fp32_out: bool = False, is_wgrad: bool = False) -> Tuple[int, int, int]: + is_fp32_out: bool = False, is_wgrad: bool = False, is_swap_ab: bool = False) -> Tuple[int, int, int]: assert block_k == 128 - # Try swizzle first, as it does not waste shared memory - swizzle_mode = get_swizzle_mode(block_n) - block_n_padding = get_block_n_padding_for_smem_d( - block_n) if swizzle_mode == 0 else 0 - - # NOTES: `scales_b` in a total manner or per-stage manner - smem_d = block_m * (block_n + block_n_padding) * (4 if is_fp32_out else 2) - smem_a_per_stage = block_m * block_k - smem_scales_a_per_stage = block_m * 4 - smem_b_per_stage = block_n * block_k - smem_scales_b_per_stage = ceil_div(block_n * 4, block_k) * block_k if is_wgrad else 0 - smem_scales_b = ceil_div(k, block_k) * 4 if not is_wgrad else 0 - smem_barrier = num_stages * 8 * 2 - - smem_size = 0 - smem_size += smem_d - smem_size += num_stages * smem_a_per_stage - smem_size += num_stages * smem_scales_a_per_stage - smem_size += num_stages * smem_b_per_stage - smem_size += num_stages * smem_scales_b_per_stage - smem_size += ceil_div(smem_scales_b * (1 if block_k % block_n == 0 else 2), 8) * 8 - smem_size += smem_barrier - - # Swizzle and padding are not compatible - assert int(swizzle_mode > 0) + int(block_n_padding > 0) <= 1 - - return smem_size, swizzle_mode, block_n_padding + if not is_swap_ab: + # Try swizzle first, as it does not waste shared memory + swizzle_mode = get_swizzle_mode(block_n) + block_n_padding = get_block_n_padding_for_smem_d( + block_n) if swizzle_mode == 0 else 0 + + # NOTES: `scales_b` in a total manner or per-stage manner + smem_d = block_m * (block_n + block_n_padding) * (4 if is_fp32_out else 2) + smem_a_per_stage = block_m * block_k + smem_scales_a_per_stage = block_m * 4 + smem_b_per_stage = block_n * block_k + smem_scales_b_per_stage = ceil_div(block_n * 4, block_k) * block_k if is_wgrad else 0 + smem_scales_b = ceil_div(k, block_k) * 4 if not is_wgrad else 0 + smem_barrier = num_stages * 8 * 2 + + smem_size = 0 + smem_size += smem_d + smem_size += num_stages * smem_a_per_stage + smem_size += num_stages * smem_scales_a_per_stage + smem_size += num_stages * smem_b_per_stage + smem_size += num_stages * smem_scales_b_per_stage + smem_size += ceil_div(smem_scales_b * (1 if block_k % block_n == 0 else 2), 8) * 8 + smem_size += smem_barrier + + # Swizzle and padding are not compatible + assert int(swizzle_mode > 0) + int(block_n_padding > 0) <= 1 + + return smem_size, swizzle_mode, block_n_padding + else: + # Try swizzle first, as it does not waste shared memory + swizzle_mode = get_swizzle_mode(block_n) + block_n_padding = get_block_n_padding_for_smem_d( + block_n) if swizzle_mode == 0 else 0 + + # NOTES: `scales_b` in a total manner or per-stage manner + smem_d = block_m * (block_n + block_n_padding) * (4 if is_fp32_out else 2) + smem_a_per_stage = block_m * block_k + smem_scales_a_per_stage = ceil_div(k, block_k) * 4; # weight scales + smem_b_per_stage = block_n * block_k + smem_scales_b_per_stage = 0 # swap_ab not support wgrad + smem_scales_b = ceil_div(block_n * 4, 128) * 128 # swap_ab not support wgrad + smem_barrier = num_stages * 8 * 2 + + smem_size = 0 + smem_size += smem_d + smem_size += num_stages * smem_a_per_stage + smem_size += num_stages * smem_scales_b + smem_size += num_stages * smem_b_per_stage + smem_size += num_stages * smem_scales_b_per_stage + smem_size += ceil_div(smem_scales_a_per_stage * (1 if block_k % block_n == 0 else 2), 8) * 8 + smem_size += smem_barrier + + # Swizzle and padding are not compatible + assert int(swizzle_mode > 0) + int(block_n_padding > 0) <= 1 + + return smem_size, swizzle_mode, block_n_padding @lru_cache(maxsize=None) def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, is_grouped_contiguous: bool = False, is_grouped_masked: bool = False, - is_fp32_out: bool = False, is_wgrad: bool = False) -> \ + is_fp32_out: bool = False, is_wgrad: bool = False, is_swap_ab: bool = False) -> \ Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]]: if not is_grouped_contiguous: block_ms = (64, 128, ) + ((256, ) if not is_fp32_out else ()) @@ -119,7 +148,7 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, # Unrolling both stages and `num_former_iters` will cause large code size stage_candidates = tuple(filter(lambda s: s <= max(k // 128, 1), (4, 3, 2, 1))) for num_stages in stage_candidates: - best_smem_config = get_smem_config(num_stages, k, best_block_m, best_block_n, is_fp32_out=is_fp32_out, is_wgrad=is_wgrad) + best_smem_config = get_smem_config(num_stages, k, best_block_m, best_block_n, is_fp32_out=is_fp32_out, is_wgrad=is_wgrad, is_swap_ab = is_swap_ab) if best_smem_config[0] <= sm90_capacity: best_num_stages = num_stages break @@ -131,21 +160,39 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, # Try to multicast on the larger block side first # NOTES: currently, grouped masked GEMM only supports multicast on A and requires the number of blocks in the N-direction to be even - is_multicast_legal = { - 'A': is_tma_multicast_legal(n, best_block_n, 2, num_sms, is_grouped_masked), - 'B': is_tma_multicast_legal(m, best_block_m, 2, num_sms) and not is_grouped_masked, - } - for i in ('A', 'B') if best_block_m > best_block_n else ('B', 'A'): - if m >= 512 and is_multicast_legal[i]: - best_tma_multicast_config = (2, i == 'A') - break - # Recompute the minimal number of SMs required - # NOTES: less L2 cache usage and less GPU frequency drop - num_waves = get_num_waves(best_block_m, best_block_n) - num_min_sms = ceil_div(ceil_div(m, best_block_m) * ceil_div(n, best_block_n) * num_groups, num_waves) - num_min_sms = ceil_div(num_min_sms, best_tma_multicast_config[0]) * best_tma_multicast_config[0] - assert num_min_sms <= num_sms + if not is_swap_ab: + is_multicast_legal = { + 'A': is_tma_multicast_legal(n, best_block_n, 2, num_sms, is_grouped_masked), + 'B': is_tma_multicast_legal(m, best_block_m, 2, num_sms) and not is_grouped_masked, + } + for i in ('A', 'B') if best_block_m > best_block_n else ('B', 'A'): + if m >= 512 and is_multicast_legal[i]: + best_tma_multicast_config = (2, i == 'A') + break + + # Recompute the minimal number of SMs required + # NOTES: less L2 cache usage and less GPU frequency drop + num_waves = get_num_waves(best_block_m, best_block_n) + num_min_sms = ceil_div(ceil_div(m, best_block_m) * ceil_div(n, best_block_n) * num_groups, num_waves) + num_min_sms = ceil_div(num_min_sms, best_tma_multicast_config[0]) * best_tma_multicast_config[0] + assert num_min_sms <= num_sms + else: + is_multicast_legal = { + 'A': is_tma_multicast_legal(n, best_block_m, 2, num_sms), + 'B': is_tma_multicast_legal(m, best_block_n, 2, num_sms), + } + for i in ('A', 'B') if best_block_m > best_block_n else ('B', 'A'): + if n >= 512 and is_multicast_legal[i]: + best_tma_multicast_config = (2, i == 'B') + break + + # Recompute the minimal number of SMs required + # NOTES: less L2 cache usage and less GPU frequency drop + num_waves = get_num_waves(best_block_n, best_block_m) + num_min_sms = ceil_div(ceil_div(n, best_block_m) * ceil_div(m, best_block_n) * num_groups, num_waves) + num_min_sms = ceil_div(num_min_sms, best_tma_multicast_config[0]) * best_tma_multicast_config[0] + assert num_min_sms <= num_sms return num_min_sms, best_block_m, best_block_n, best_num_stages, best_tma_multicast_config, best_smem_config diff --git a/deep_gemm/jit_kernels/m_grouped_gemm.py b/deep_gemm/jit_kernels/m_grouped_gemm.py index ca2fc79a..92384688 100644 --- a/deep_gemm/jit_kernels/m_grouped_gemm.py +++ b/deep_gemm/jit_kernels/m_grouped_gemm.py @@ -4,10 +4,12 @@ from ..jit import build from .gemm import get_best_configs from .runtime import ( - FP8GemmRuntime, GemmType, + FP8GemmRuntime, FP8GemmOffsetRuntime, GemmType, make_2d_tma_a_desc, make_2d_tma_b_desc, - make_2d_tma_d_desc, make_2d_tma_scales_desc) -from .utils import ceil_div, get_col_major_tma_aligned_tensor, get_num_sms + make_2d_tma_d_desc, make_2d_tma_scales_desc, + make_2d_tma_scales_a_offset_desc, + make_2d_tma_a_offset_desc_swapAB, make_2d_tma_b_offset_desc_swapAB, make_2d_tma_d_offset_desc_swapAB, make_2d_tma_scales_b_offset_desc_swapAB) +from .utils import ceil_div, get_col_major_tma_aligned_tensor, get_num_sms, compute_padded_offset def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Tensor], @@ -203,3 +205,163 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor] code = FP8GemmRuntime.generate(kwargs) runtime = build('m_grouped_gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime, kwargs) runtime(**kwargs) + + +def m_grouped_gemm_fp8_fp8_bf16_nt_offset(lhs: Tuple[torch.Tensor, torch.Tensor], + rhs: Tuple[torch.Tensor, torch.Tensor], + offsets: torch.Tensor, + out: torch.Tensor, expected_m: int) -> None: + """ + GroupedWithOffset from TensorRT-LLM + """ + + lhs, lhs_scales = lhs + rhs, rhs_scales = rhs + m, k = lhs.shape + num_groups, n, k_ = rhs.shape + m_, n_ = out.shape + + + print("expected_m: ",expected_m) + print("A shape: ",lhs.shape) + print("A scale shape: ",lhs_scales.shape) + print("B shape: ",rhs.shape) + print("B scale shape: ",rhs_scales.shape) + print("out shape: ",out.shape) + + + # Type and shape checks + assert m == m_ and n == n_ and k == k_ + + max_shape_m_4_align = ceil_div(m, 4) * 4 # align 4 + max_shape_m_32_align_padded = compute_padded_offset(m, num_groups) + + assert expected_m > 0 and max_shape_m_4_align > 0 and n > 0 and k > 0 and num_groups > 0 + + + # if compute_padded_offset ? + #assert lhs_scales.shape == (num_groups, m, ceil_div(k, 128)) + assert rhs_scales.shape == (num_groups, ceil_div(n, 128), ceil_div(k, 128)) + assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32 + assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32 + assert out.dtype == torch.bfloat16 + assert lhs.is_contiguous() and rhs.is_contiguous() + assert out.is_contiguous() + + # LHS scales must be transposed for TMA load, but not for RHS scales + lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales) + assert rhs_scales.is_contiguous() + + # Auto-tuning with compilation + num_sms = get_num_sms() + + if num_sms==78: + m_per_expert_threshold = 64 # H20 + else: + m_per_expert_threshold = 32 # H100 + + if expected_m>= m_per_expert_threshold: + + num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs( + expected_m, n, k, num_groups, num_sms, is_grouped_contiguous = True, is_swap_ab=False) + + # Extra checks for TMA store + if num_groups > 1 and m > block_m: + assert m % block_m == 0, f'For GroupedWithOffset grouped GEMM, shape M should be multiple of the block M (current block M: {block_m})' + + block_k = 128 + num_tma_threads = 128 + num_math_threads_per_group = 128 + + tensor_map_a = make_2d_tma_a_desc(GemmType.GroupedWithOffset, lhs, m, k, k, block_m, block_k, num_groups) + tensor_map_b = make_2d_tma_b_desc(GemmType.GroupedWithOffset, rhs, n, k, k, block_n, block_k, num_groups) + tensor_map_d = make_2d_tma_d_desc(GemmType.GroupedWithOffset, out, m, n, n, block_m, block_n, num_groups, 0) # none swizzle + tensor_map_scales_a = make_2d_tma_scales_a_offset_desc(GemmType.GroupedWithOffset, lhs_scales, max_shape_m_32_align_padded, k, block_m, block_k) # none swizzle + + + kwargs = { + # Templated arguments + 'KERNEL_NAME': 'fp8_gemm_offset_kernel', + 'SCHEDULER_TYPE': 'SchedulerSelector', + 'INPUT_TYPE': 'GroupedWithOffsetSchedulerInput', + 'PROBLEM_OFFSETS': offsets, + 'NUM_TMA_THREADS': num_tma_threads, + 'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group, + 'M': m, 'N': n, 'K': k, + 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, + 'NUM_GROUPS': num_groups, + 'NUM_STAGES': num_stages, + 'NUM_TMA_MULTICAST': tma_multicast_config[0], + 'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1], + 'GEMM_TYPE': GemmType.GroupedWithOffset, + # Runtime arguments + 'SCALES': rhs_scales, + 'NUM_SMS': num_sms, + 'SMEM_SIZE': smem_config[0], + 'TENSOR_MAP_A': tensor_map_a, + 'TENSOR_MAP_B': tensor_map_b, + 'TENSOR_MAP_SCALES': tensor_map_scales_a, + 'TENSOR_MAP_D': tensor_map_d, + 'STREAM': torch.cuda.current_stream().cuda_stream, + 'DEVICE_INDEX': out.device.index, + 'OUT': out + } + + else: + num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs( + n, expected_m, k, num_groups, num_sms, is_grouped_contiguous = True, is_swap_ab=True) + + # Extra checks for TMA store + if num_groups > 1 and n > block_m: + assert n % block_m == 0, f'For GroupedWithOffset grouped GEMM, shape M should be multiple of the block M (current block M: {block_m})' + + print("is_swap_ab=True =========") + print("num_sms: ",num_sms) + print("block_m: ",block_m) + print("block_n: ",block_n) + print("num_stages: ",num_stages) + print("tma_multicast_config: ",tma_multicast_config) + print("smem_config: ",smem_config) + + block_k = 128 + num_tma_threads = 128 + num_math_threads_per_group = 128 + + tensor_map_a = make_2d_tma_a_offset_desc_swapAB(GemmType.GroupedWithOffset, rhs, n, k, k, block_m, block_k, num_groups) + tensor_map_b = make_2d_tma_b_offset_desc_swapAB(GemmType.GroupedWithOffset, lhs, m, k, k, block_n, block_k, num_groups) + tensor_map_d = make_2d_tma_d_offset_desc_swapAB(GemmType.GroupedWithOffset, out, n, m, m, block_m, block_n, num_groups, 0) # no swizzle + tensor_map_scales_b = make_2d_tma_scales_b_offset_desc_swapAB(GemmType.GroupedWithOffset, lhs_scales, max_shape_m_32_align_padded, k, block_n, block_k) # no swizzle + + kwargs = { + # Templated arguments + 'KERNEL_NAME': 'fp8_gemm_offset_kernel_swapAB', + 'SCHEDULER_TYPE': 'SchedulerSelectorSwapAB', + 'INPUT_TYPE': 'GroupedWithOffsetSchedulerInputSwapAB', + 'PROBLEM_OFFSETS': offsets, + 'NUM_TMA_THREADS': num_tma_threads, + 'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group, + 'M': m, 'N': n, 'K': k, + 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, + 'NUM_GROUPS': num_groups, + 'NUM_STAGES': num_stages, + 'NUM_TMA_MULTICAST': tma_multicast_config[0], + 'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1], + 'GEMM_TYPE': GemmType.GroupedWithOffset, + # Runtime arguments + 'SCALES': rhs_scales, + 'NUM_SMS': num_sms, + 'SMEM_SIZE': smem_config[0], + 'TENSOR_MAP_A': tensor_map_a, + 'TENSOR_MAP_B': tensor_map_b, + 'TENSOR_MAP_SCALES': tensor_map_scales_b, + 'TENSOR_MAP_D': tensor_map_d, + 'STREAM': torch.cuda.current_stream().cuda_stream, + 'DEVICE_INDEX': out.device.index, + 'OUT': out + } + + # Generate, build and run the kernel + code = FP8GemmOffsetRuntime.generate(kwargs) + runtime = build('m_grouped_gemm_fp8_fp8_bf16_nt_offset', code, FP8GemmOffsetRuntime, kwargs) + runtime(**kwargs) + diff --git a/deep_gemm/jit_kernels/runtime.py b/deep_gemm/jit_kernels/runtime.py index e65e85aa..584b3220 100644 --- a/deep_gemm/jit_kernels/runtime.py +++ b/deep_gemm/jit_kernels/runtime.py @@ -5,7 +5,7 @@ import cuda.bindings.driver as cbd from typing import Any, Dict, Tuple -from .utils import get_tma_aligned_size +from .utils import get_tma_aligned_size, ceil_div from ..jit.runtime import Runtime @@ -13,12 +13,15 @@ class GemmType(enum.Enum): Normal = 0 GroupedContiguous = 1 GroupedMasked = 2 + GroupedWithOffset = 3 + def __str__(self) -> str: return { 0: 'Normal', 1: 'GroupedContiguous', 2: 'GroupedMasked', + 3: 'GroupedWithOffset', }[self.value] @@ -133,6 +136,58 @@ def make_2d_tma_scales_desc(gemm_type: GemmType, t: torch.Tensor, cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE) +def make_2d_tma_scales_a_offset_desc(gemm_type: GemmType, t: torch.Tensor, + max_m_padded_total: int, shape_k: int, + block_m: int, block_k: int, + global_stride_in_bytes: int = 0) -> cbd.CUtensorMap: + return make_2d_tma_desc(t, + max_m_padded_total, ceil_div(shape_k, block_k), max_m_padded_total, + block_m, 1, + cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE) + + + +def make_2d_tma_a_offset_desc_swapAB(gemm_type: GemmType, t: torch.Tensor, + shape_m: int, shape_k: int, m_stride: int, + block_m: int, block_k: int, + num_groups: int) -> cbd.CUtensorMap: + return make_2d_tma_desc(t, + shape_k, shape_m * (num_groups if gemm_type != GemmType.Normal else 1), m_stride, + block_k, block_m) + + +def make_2d_tma_b_offset_desc_swapAB(gemm_type: GemmType, t: torch.Tensor, + shape_n: int, shape_k: int, n_stride: int, + block_n: int, block_k: int, + num_groups: int) -> cbd.CUtensorMap: + return make_2d_tma_desc(t, + shape_k, shape_n * (num_groups if gemm_type == GemmType.GroupedMasked else 1), n_stride, + block_k, block_n) + + +def make_2d_tma_d_offset_desc_swapAB(gemm_type: GemmType, t: torch.Tensor, + shape_m: int, shape_n: int, m_stride: int, + block_m: int, block_n: int, + num_groups: int, + swizzle_mode: int) -> cbd.CUtensorMap: + # Swizzling requires the inner box dim to be less or equal than `kSwizzleDMode` + # bytes, so `BLOCK_N * sizeof(T) / kSwizzleDMode` TMA stores are required + return make_2d_tma_desc(t, + shape_n, shape_m * (num_groups if gemm_type != GemmType.Normal else 1), m_stride, + min(block_n, shape_n), min(block_m, shape_m), + cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE) + + +def make_2d_tma_scales_b_offset_desc_swapAB(gemm_type: GemmType, t: torch.Tensor, + max_n_padded_total: int, shape_k: int, + block_n: int, block_k: int, + global_stride_in_bytes: int = 0) -> cbd.CUtensorMap: + return make_2d_tma_desc(t, + max_n_padded_total, ceil_div(shape_k, block_k), max_n_padded_total, + block_n, 1, + cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE) + + class FP8GemmRuntime(Runtime): def __init__(self, path: str) -> None: super().__init__(path) @@ -316,3 +371,101 @@ def launch(kernel: cbd.CUkernel, kwargs: Dict[str, Any]) -> cbd.CUresult: None, ) return cbd.cuLaunchKernelEx(config, kernel, (arg_values, arg_types), 0) + + +class FP8GemmOffsetRuntime(Runtime): + def __init__(self, path: str) -> None: + super().__init__(path) + + @staticmethod + def generate(kwargs: Dict[str, Any]) -> str: + code = f''' +#ifdef __CUDACC_RTC__ +#include +#else +#include +#include +#endif + +#include +#include + +#include + +using namespace deep_gemm; + +using SchedulerType = +typename {kwargs['SCHEDULER_TYPE']} ::type; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&{kwargs['KERNEL_NAME']}< + {kwargs['N']}, + {kwargs['K']}, + {kwargs['BLOCK_M']}, + {kwargs['BLOCK_N']}, + {kwargs['BLOCK_K']}, + {kwargs['NUM_GROUPS']}, + {kwargs['NUM_STAGES']}, + {kwargs['NUM_TMA_THREADS']}, + {kwargs['NUM_MATH_THREADS_PER_GROUP']}, + {kwargs['NUM_TMA_MULTICAST']}, + SchedulerType, + {kwargs['INPUT_TYPE']} + >); +}}; +''' + if int(os.getenv('DG_JIT_DEBUG', 0)): + print(f'Generated FP8 GEMM code:\n{code}') + return code + + # noinspection PyMethodOverriding + @staticmethod + def launch(kernel: cbd.CUkernel, kwargs: Dict[str, Any]) -> cbd.CUresult: + num_tma_threads = 128 + num_math_threads_per_group = 128 + + result = cbd.cuKernelSetAttribute(cbd.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, + kwargs['SMEM_SIZE'], kernel, cbd.CUdevice(kwargs['DEVICE_INDEX']))[0] + assert result == cbd.CUresult.CUDA_SUCCESS, f'Failed to set max dynamic shared memory size: {result}' + + attr_val = cbd.CUlaunchAttributeValue() + attr_val.clusterDim.x = kwargs['NUM_TMA_MULTICAST'] + attr_val.clusterDim.y = 1 + attr_val.clusterDim.z = 1 + attr = cbd.CUlaunchAttribute() + attr.id = cbd.CUlaunchAttributeID.CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION + attr.value = attr_val + + config = cbd.CUlaunchConfig() + config.numAttrs = 1 + config.attrs = [attr] + config.gridDimX = kwargs['NUM_SMS'] + config.gridDimY = 1 + config.gridDimZ = 1 + config.blockDimX = get_num_threads_per_sm(num_tma_threads, num_math_threads_per_group, kwargs['BLOCK_M']) + config.blockDimY = 1 + config.blockDimZ = 1 + config.sharedMemBytes = kwargs['SMEM_SIZE'] + config.hStream = kwargs['STREAM'] + + arg_values = ( + kwargs['OUT'].data_ptr(), + kwargs['SCALES'].data_ptr(), + kwargs['PROBLEM_OFFSETS'].data_ptr(), + kwargs['TENSOR_MAP_A'], + kwargs['TENSOR_MAP_B'], + kwargs['TENSOR_MAP_SCALES'], + kwargs['TENSOR_MAP_D'], + ) + arg_types = ( + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + None, + None, + None, + None, + ) + return cbd.cuLaunchKernelEx(config, kernel, (arg_values, arg_types), 0) diff --git a/deep_gemm/jit_kernels/utils.py b/deep_gemm/jit_kernels/utils.py index c6da56b0..11a42bdf 100644 --- a/deep_gemm/jit_kernels/utils.py +++ b/deep_gemm/jit_kernels/utils.py @@ -107,3 +107,6 @@ def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor: aligned_x[:, :m, :] = x aligned_x = aligned_x[:, :m, :] return aligned_x.squeeze(0) if remove_dim else aligned_x + +def compute_padded_offset(offset, idx_problem, alignment=32): + return (offset + idx_problem * (alignment - 1)) // alignment * alignment diff --git a/tests/test_core.py b/tests/test_core.py index 3b88539c..e152a9c6 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -34,49 +34,6 @@ def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(x_view.size(0), x_view.size(2)) -def construct(m: int, k: int, n: int) -> \ - Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]: - x = torch.randn((m, k), device='cuda', dtype=torch.bfloat16) - y = torch.randn((n, k), device='cuda', dtype=torch.bfloat16) - out = torch.empty((m, n), device='cuda', dtype=torch.bfloat16) - ref_out = x @ y.t() - - x_fp8, y_fp8 = per_token_cast_to_fp8(x), per_block_cast_to_fp8(y) - # Transpose earlier so that the testing will not trigger transposing kernels - x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1])) - return x_fp8, y_fp8, out, ref_out - - -def construct_contiguous_grouped(num_groups: int, expected_m_per_group: int, k: int, n: int) -> \ - Tuple[int, Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]: - alignment = get_m_alignment_for_contiguous_layout() - group_ms = [int(expected_m_per_group * random.uniform(0.7, 1.3)) for _ in range(num_groups)] - m = sum([ceil_div(x, alignment) * alignment for x in group_ms]) - - x = torch.randn((m, k), device='cuda', dtype=torch.bfloat16) - y = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16) - m_indices = torch.empty(m, device='cuda', dtype=torch.int32) - out = torch.empty((m, n), device='cuda', dtype=torch.bfloat16) - ref_out = torch.randn((m, n), device='cuda', dtype=torch.bfloat16) - - start = 0 - for i, group_m in enumerate(group_ms): - actual_end = start + group_m - aligned_end = start + ceil_div(group_m, alignment) * alignment - m_indices[start:actual_end] = i - m_indices[actual_end:aligned_end] = -1 - ref_out[start:aligned_end] = x[start:aligned_end] @ y[i].t() - start = aligned_end - ref_out = torch.where((m_indices == -1).unsqueeze(1), torch.zeros_like(ref_out), ref_out) - - assert m % 4 == 0, f'TMA alignment error: {m}' - x_fp8 = per_token_cast_to_fp8(x) - y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), torch.empty((num_groups, ceil_div(n, 128), k // 128), device='cuda', dtype=torch.float)) - for i in range(num_groups): - y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i]) - - return m, x_fp8, y_fp8, m_indices, out, ref_out - def construct_masked_grouped(num_groups: int, max_m: int, expected_m_per_group: int, k: int, n: int) -> \ Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]: @@ -98,120 +55,10 @@ def construct_masked_grouped(num_groups: int, max_m: int, expected_m_per_group: # Construct mask masked_m = torch.empty((num_groups, ), device='cuda', dtype=torch.int) for j in range(num_groups): - masked_m[j] = int(expected_m_per_group * random.uniform(0.7, 1.3)) + masked_m[j] = int(expected_m_per_group * random.uniform(1, 1)) assert masked_m.amax().item() <= max_m return x_fp8, y_fp8, masked_m, out, ref_out - -def construct_wgrad(m: int, k: int, n: int) -> \ - Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]: - x = torch.randn((m, k), device='cuda', dtype=torch.bfloat16) - y = torch.randn((n, k), device='cuda', dtype=torch.bfloat16) - residual = torch.randn((m, n), device='cuda', dtype=torch.float) * 10 - out = residual.clone() - ref_out = residual + (x.float() @ y.float().t()) - - x_fp8 = per_token_cast_to_fp8(x) - y_fp8 = per_token_cast_to_fp8(y) - - # NOTES: please do inplace add on the `out` later - return x_fp8, y_fp8, residual, out, ref_out - - -def construct_k_grouped_wgrad(m: int, n: int, k_sizes: List[int]) -> \ - Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor, List[int]]: - num_groups, total_k = len(k_sizes), sum(k_sizes) - - x_flat = torch.empty((m * total_k,), device='cuda', dtype=torch.bfloat16) - y_flat = torch.empty((n * total_k,), device='cuda', dtype=torch.bfloat16) - out = torch.zeros((num_groups, m, n), device='cuda', dtype=torch.float) - ref_out = torch.zeros((num_groups, m, n), device='cuda', dtype=torch.float) - - # Fill tensors with data and compute reference output - x_offset, y_offset = 0, 0 - for idx, k in enumerate(k_sizes): - x_chunk = torch.randn((m, k), device='cuda', dtype=torch.bfloat16) - y_chunk = torch.randn((n, k), device='cuda', dtype=torch.bfloat16) - - x_flat[x_offset:x_offset + m * k].copy_(x_chunk.flatten()) - y_flat[y_offset:y_offset + n * k].copy_(y_chunk.flatten()) - ref_out[idx] = x_chunk.float() @ y_chunk.float().t() - - x_offset += m * k - y_offset += n * k - - x_fp8_flat = torch.empty_like(x_flat, dtype=torch.float8_e4m3fn) - y_fp8_flat = torch.empty_like(y_flat, dtype=torch.float8_e4m3fn) - - total_scale_factors = sum(ceil_div(k, 128) for k in k_sizes) - x_scales = torch.empty((total_scale_factors, m), device='cuda', dtype=torch.float) - y_scales = torch.empty((total_scale_factors, n), device='cuda', dtype=torch.float) - - # Cast to FP8 and prepare scale factors - x_offset, y_offset, scale_offset = 0, 0, 0 - for k in k_sizes: - x_fp8_chunk, x_scale_chunk = per_token_cast_to_fp8(x_flat[x_offset:x_offset + m * k].view(m, k)) - y_fp8_chunk, y_scale_chunk = per_token_cast_to_fp8(y_flat[y_offset:y_offset + n * k].view(n, k)) - - x_fp8_flat[x_offset:x_offset + m * k].copy_(x_fp8_chunk.flatten()) - y_fp8_flat[y_offset:y_offset + n * k].copy_(y_fp8_chunk.flatten()) - - num_scales = ceil_div(k, 128) - x_scales[scale_offset:scale_offset + num_scales].copy_(x_scale_chunk.T) - y_scales[scale_offset:scale_offset + num_scales].copy_(y_scale_chunk.T) - - x_offset += m * k - y_offset += n * k - scale_offset += num_scales - - return (x_fp8_flat, x_scales), (y_fp8_flat, y_scales), out, ref_out, k_sizes - - -def test_gemm() -> None: - print('Testing GEMM:') - for m in (64, 128, 4096): - for k, n in [(576, 7168), (7168, 2112), (1536, 24576), (512, 32768), (16384, 7168), (7168, 4096), (2048, 7168)]: - x_fp8, y_fp8, out, ref_out = construct(m, k, n) - deep_gemm.gemm_fp8_fp8_bf16_nt(x_fp8, y_fp8, out) - diff = calc_diff(out, ref_out) - assert diff < 0.001, f'{m=}, {k=}, {n=}, {diff:.5f}' - - # noinspection PyShadowingNames - def test_func(): - deep_gemm.gemm_fp8_fp8_bf16_nt(x_fp8, y_fp8, out) - - t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) - print(f' > Perf (m={m:5}, n={n:5}, k={k:5}): {t * 1e6:4.0f} us | ' - f'throughput: {2 * m * n * k / t / 1e12:4.0f} TFLOPS, ' - f'{(m * k + k * n + m * n * 2) / 1e9 / t:4.0f} GB/s') - print() - - -def test_m_grouped_gemm_contiguous() -> None: - print('Testing grouped contiguous GEMM:') - - for num_groups, expected_m_per_group, k, n in ((4, 8192, 7168, 4096), (4, 8192, 2048, 7168), - (8, 4096, 7168, 4096), (8, 4096, 2048, 7168), - (32, 256, 7168, 4096), (32, 256, 2048, 7168)): - # NOTES: we should mask the unfilled part before calculating difference - m, x_fp8, y_fp8, m_indices, out, ref_out = construct_contiguous_grouped(num_groups, expected_m_per_group, k, n) - deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(x_fp8, y_fp8, out, m_indices) - out = torch.where((m_indices == -1).unsqueeze(1), torch.zeros_like(out), out) - diff = calc_diff(out, ref_out) - assert diff < 0.001, f'{m=}, {k=}, {n=}, {diff:.5f}' - - # noinspection PyShadowingNames - def test_func(): - deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(x_fp8, y_fp8, out, m_indices) - - t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) - valid_m = (m_indices != -1).sum().item() - print(f' > Perf ({num_groups=:2}, {expected_m_per_group=:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | ' - f'throughput: {2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS, ' - f'{(valid_m * k + num_groups * k * n + valid_m * n * 2) / 1e9 / t:4.0f} GB/s') - print() - - def test_m_grouped_gemm_masked() -> None: print('Testing grouped masked GEMM:') @@ -239,62 +86,87 @@ def test_func(): print() -def test_wgrad_gemm(): - print('Testing weight gradient GEMM:') +def construct_offset_grouped(num_groups: int, expected_m_per_group: int, k: int, n: int) -> \ + Tuple[int, Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]: + alignment = 32 + group_ms = [int(expected_m_per_group * random.uniform(1, 1)) for _ in range(num_groups)] + m = sum([ceil_div(x, alignment) * alignment for x in group_ms]) - for k in (4096, 8192): - for m, n in ((7168, 2112), (1536, 24576), (512, 32768), (16384, 7168), (7168, 4096), (2048, 7168)): - # Test correctness - x_fp8, y_fp8, residual, out, ref_out = construct_wgrad(m, k, n) - deep_gemm.wgrad_gemm_fp8_fp8_fp32_nt(x_fp8, y_fp8, out) - diff = calc_diff(out, ref_out) - assert diff < 0.001, f'{m=}, {k=}, {n=}, {diff:.5f}' + x = torch.randn((m, k), device='cuda', dtype=torch.bfloat16) + y = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16) + offsets = torch.empty(num_groups+1, device='cuda', dtype=torch.int32) + out = torch.empty((m, n), device='cuda', dtype=torch.bfloat16) + ref_out = torch.randn((m, n), device='cuda', dtype=torch.bfloat16) - # Construct new tensors only once to avoid L2 cache acceleration (creating them puts them in L2) - x_fp8, y_fp8, residual, out, ref_out = construct_wgrad(m, k, n) + start = 0 + offsets[0] = 0 + for i, group_m in enumerate(group_ms): + aligned_end = start + ceil_div(group_m, alignment) * alignment + offsets[i+1] = aligned_end + ref_out[start:aligned_end] = x[start:aligned_end] @ y[i].t() + start = aligned_end - # noinspection PyShadowingNames - def test_func(): - deep_gemm.wgrad_gemm_fp8_fp8_fp32_nt(x_fp8, y_fp8, out) + assert m % 4 == 0, f'TMA alignment error: {m}' + x_fp8 = per_token_cast_to_fp8(x) + y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), torch.empty((num_groups, ceil_div(n, 128), k // 128), device='cuda', dtype=torch.float)) + for i in range(num_groups): + y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i]) - t = bench_kineto(test_func, 'fp8_wgrad_gemm', suppress_kineto_output=True) - print(f' > Performance (m={m:5}, n={n:5}, k={k:5}): {t * 1e6:4.0f} us | ' - f'throughput: {2 * m * n * k / t / 1e12:4.0f} TFLOPS, ' - f'{(m * k + k * n + m * n * 2) / 1e9 / t:4.0f} GB/s') - print() + return m, x_fp8, y_fp8, offsets, out, ref_out -def test_k_grouped_wgrad_gemm(): - print('Testing grouped weight gradient GEMM:') - for num_groups, base_k in ((4, 4096), (4, 8192), (8, 4096)): - for m, n in ((7168, 4096), (2048, 7168)): - # Vary k sizes around base_k - k_sizes = [base_k + random.randint(-1, 1) * 128 for _ in range(num_groups - 1)] - k_sizes.append(base_k * num_groups - sum(k_sizes)) - - # Test correctness - x_fp8, y_fp8, out, ref_out, k_sizes = construct_k_grouped_wgrad(m, n, k_sizes) - deep_gemm.k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(x_fp8, y_fp8, out, k_sizes) +def test_m_grouped_gemm_offset() -> None: + print('Testing grouped contiguous GEMM:') - for idx in range(num_groups): - diff = calc_diff(out[idx], ref_out[idx]) - assert diff < 0.001, f'{num_groups=}, {m=}, {n=}, k={k_sizes[idx]}, batch={idx}, {diff:.5f}' + for num_groups, expected_m_per_group, k, n in ((8, 32, 7168, 4096),): + # NOTES: we should mask the unfilled part before calculating difference + + ''' + x_fp8_mask, y_fp8_mask, masked_m_mask, out_mask, ref_out_mask = construct_masked_grouped(num_groups, expected_m_per_group, expected_m_per_group, k, n) + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8_mask, y_fp8_mask, out_mask, masked_m_mask, expected_m_per_group) + - # Construct new tensors to avoid L2 cache acceleration - x_fp8, y_fp8, out, ref_out, k_sizes = construct_k_grouped_wgrad(m, n, k_sizes) - total_k = sum(k_sizes) - - def test_func(): - deep_gemm.k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(x_fp8, y_fp8, out, k_sizes) - - t = bench_kineto(test_func, 'fp8_wgrad_gemm', suppress_kineto_output=True, with_multiple_kernels=True) * num_groups - print(f' > Performance ({num_groups=}, m={m:5}, n={n:5}, avg_k={total_k//num_groups:5}): {t * 1e6:4.0f} us | ' - f'throughput: {2 * num_groups * m * n * (total_k/num_groups) / t / 1e12:4.0f} TFLOPS, ' - f'{(m * total_k + n * total_k + num_groups * m * n * 2) / 1e9 / t:4.0f} GB/s') + for j in range(num_groups): + diff = calc_diff(out_mask[j, :masked_m_mask[j].item()], ref_out_mask[j, :masked_m_mask[j].item()]) + #assert diff < 0.001, f'{expected_m_per_group=}, {k=}, {n=}, {j=}, masked_m={masked_m_mask[j]}, {num_groups=}, {diff:.5f}' + + # noinspection PyShadowingNames + def test_func(): + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8_mask, y_fp8_mask, out_mask, masked_m_mask, expected_m_per_group) + + # Test performance with fixed shapes + # noinspection PyUnboundLocalVariable + valid_m = masked_m_mask.sum().item() + t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) + print(f' > m_grouped_gemm_fp8_fp8_bf16_nt_masked: Perf ({num_groups=}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | ' + f'throughput: {2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS, ' + f'{(valid_m * k + num_groups * k * n + valid_m * n * 2) / 1e9 / t:4.0f} GB/s') + + ''' + + m_offset, x_fp8_offset, y_fp8_offset, offset, out_offset, ref_out_offset = construct_offset_grouped(num_groups, expected_m_per_group, k, n) + + #deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_offset(x_fp8_offset, y_fp8_offset, offset, out_offset, expected_m_per_group) + #diff = calc_diff(out_offset, ref_out_offset) + # assert diff < 0.001, f'{m=}, {k=}, {n=}, {diff:.5f}' + + # noinspection PyShadowingNames + def test_func(): + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_offset(x_fp8_offset, y_fp8_offset, offset, out_offset, expected_m_per_group) + + t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) + valid_m = m_offset + print(f' > m_grouped_gemm_fp8_fp8_bf16_nt_offset: Perf ({num_groups=:2}, {expected_m_per_group=:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | ' + f'throughput: {2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS, ' + f'{(valid_m * k + num_groups * k * n + valid_m * n * 2) / 1e9 / t:4.0f} GB/s') print() + + + + if __name__ == '__main__': torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True @@ -304,9 +176,4 @@ def test_func(): print('Library path:') print(f' > {deep_gemm.__path__}\n') - test_gemm() - test_m_grouped_gemm_contiguous() - test_m_grouped_gemm_masked() - - test_wgrad_gemm() - test_k_grouped_wgrad_gemm() + test_m_grouped_gemm_offset() From 26a603f51815cb6dd13bdba3ddd5d59c949843c2 Mon Sep 17 00:00:00 2001 From: Wangzheee <634486483@qq.com> Date: Fri, 20 Jun 2025 06:53:24 +0000 Subject: [PATCH 2/5] fix some bug --- deep_gemm/include/deep_gemm/fp8_gemm.cuh | 12 ------------ deep_gemm/include/deep_gemm/scheduler.cuh | 12 ------------ deep_gemm/jit_kernels/gemm.py | 4 ++-- deep_gemm/jit_kernels/m_grouped_gemm.py | 2 +- tests/test_core.py | 5 +++-- 5 files changed, 6 insertions(+), 29 deletions(-) diff --git a/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_gemm.cuh index 8419f7d8..fbe05c7c 100644 --- a/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ b/deep_gemm/include/deep_gemm/fp8_gemm.cuh @@ -851,18 +851,6 @@ __global__ void __launch_bounds__(get_num_threads_per_sm(BLOCK_M) - 128) / 32; - write_result_to_gmem(gmem_d_this_block, smem_d, m_global_idx, - scheduler.m_boundary, n_block_idx * BLOCK_N, SHAPE_N, problem_input.ld_d); - } else { cute::tma_store_fence(); diff --git a/deep_gemm/include/deep_gemm/scheduler.cuh b/deep_gemm/include/deep_gemm/scheduler.cuh index 622be39b..dacf5f1a 100644 --- a/deep_gemm/include/deep_gemm/scheduler.cuh +++ b/deep_gemm/include/deep_gemm/scheduler.cuh @@ -391,21 +391,9 @@ struct SchedulerSelector { static constexpr auto select_type() { - if constexpr (GT == GemmType::Normal) - return NormalScheduler(); - if constexpr (GT == GemmType::GroupedContiguous) - return GroupedContiguousScheduler(); - if constexpr (GT == GemmType::GroupedMasked) - return GroupedMaskedScheduler(); if constexpr (GT == GemmType::GroupedWithOffset) return GroupedWithOffsetScheduler(); - if constexpr (GT == GemmType::StridedBatched) - return StridedBatchedScheduler(); } using type = decltype(select_type()); diff --git a/deep_gemm/jit_kernels/gemm.py b/deep_gemm/jit_kernels/gemm.py index 2a2cc316..459d1c7e 100644 --- a/deep_gemm/jit_kernels/gemm.py +++ b/deep_gemm/jit_kernels/gemm.py @@ -104,8 +104,8 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, block_ms = (64, 128, ) + ((256, ) if not is_fp32_out else ()) else: block_ms = (get_m_alignment_for_contiguous_layout(), ) - block_ns = tuple(range(16, 129, 8)) + ((136, 152, ) if is_wgrad else (144, 160, )) - + #block_ns = tuple(range(16, 129, 8)) + ((136, 152, ) if is_wgrad else (144, 160, )) + block_ns = tuple(range(16, 129, 8)) # Avoid bank conflicts for FP32 output if is_fp32_out: block_ns = [x for x in block_ns if x % 16 == 8] diff --git a/deep_gemm/jit_kernels/m_grouped_gemm.py b/deep_gemm/jit_kernels/m_grouped_gemm.py index 92384688..2a607677 100644 --- a/deep_gemm/jit_kernels/m_grouped_gemm.py +++ b/deep_gemm/jit_kernels/m_grouped_gemm.py @@ -260,7 +260,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_offset(lhs: Tuple[torch.Tensor, torch.Tensor] else: m_per_expert_threshold = 32 # H100 - if expected_m>= m_per_expert_threshold: + if expected_m> m_per_expert_threshold: num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs( expected_m, n, k, num_groups, num_sms, is_grouped_contiguous = True, is_swap_ab=False) diff --git a/tests/test_core.py b/tests/test_core.py index e152a9c6..eb3e51e4 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -119,10 +119,9 @@ def construct_offset_grouped(num_groups: int, expected_m_per_group: int, k: int, def test_m_grouped_gemm_offset() -> None: print('Testing grouped contiguous GEMM:') - for num_groups, expected_m_per_group, k, n in ((8, 32, 7168, 4096),): + for num_groups, expected_m_per_group, k, n in ((9, 32, 7168, 4096),): # NOTES: we should mask the unfilled part before calculating difference - ''' x_fp8_mask, y_fp8_mask, masked_m_mask, out_mask, ref_out_mask = construct_masked_grouped(num_groups, expected_m_per_group, expected_m_per_group, k, n) deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8_mask, y_fp8_mask, out_mask, masked_m_mask, expected_m_per_group) @@ -160,6 +159,8 @@ def test_func(): print(f' > m_grouped_gemm_fp8_fp8_bf16_nt_offset: Perf ({num_groups=:2}, {expected_m_per_group=:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | ' f'throughput: {2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS, ' f'{(valid_m * k + num_groups * k * n + valid_m * n * 2) / 1e9 / t:4.0f} GB/s') + + ''' print() From ccd63bb234b63658c79c4c04d1c87bbd5213727f Mon Sep 17 00:00:00 2001 From: wangzhe_ant Date: Tue, 24 Jun 2025 17:52:28 +0800 Subject: [PATCH 3/5] fix tma_d_offset_desc_swapAB, update unitest --- deep_gemm/include/deep_gemm/fp8_gemm.cuh | 121 +++----- deep_gemm/jit_kernels/gemm.py | 19 +- deep_gemm/jit_kernels/m_grouped_gemm.py | 35 +-- deep_gemm/jit_kernels/runtime.py | 4 +- tests/test_core.py | 346 +++++++++++++++++++---- 5 files changed, 345 insertions(+), 180 deletions(-) diff --git a/deep_gemm/include/deep_gemm/fp8_gemm.cuh b/deep_gemm/include/deep_gemm/fp8_gemm.cuh index fbe05c7c..d53eaa0a 100644 --- a/deep_gemm/include/deep_gemm/fp8_gemm.cuh +++ b/deep_gemm/include/deep_gemm/fp8_gemm.cuh @@ -438,6 +438,7 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout, DG_DEVICE_ASSERT(false and "This kernel only support sm_90a"); #endif } + template static __device__ __forceinline__ void write_result_to_gmem(__nv_bfloat16* gmem_d_this_block, __nv_bfloat16 const* smem_d, uint32_t const m_offset, uint32_t const m_boundary, uint32_t const n_offset, @@ -638,18 +639,9 @@ __global__ void __launch_bounds__(get_num_threads_per_sm(&full_barrier), smem_a[s], k_idx, scheduler.get_global_m_idx(m_block_idx), kNumTMAMulticast); - if constexpr (SchedulerType::gemm_type == GemmType::GroupedWithOffset) - { - tma_copy(&tensor_map_scales_a, - reinterpret_cast(&full_barrier), smem_scales_a[s], - scheduler.get_global_scales_a_idx(m_block_idx), k_idx / BLOCK_K, kNumTMAMulticast); - } - else - { - tma_copy(&tensor_map_scales_a, - reinterpret_cast(&full_barrier), smem_scales_a[s], m_block_idx * BLOCK_M, - scheduler.get_global_scales_a_idx(k_idx / BLOCK_K), kNumTMAMulticast); - } + tma_copy(&tensor_map_scales_a, + reinterpret_cast(&full_barrier), smem_scales_a[s], + scheduler.get_global_scales_a_idx(m_block_idx), k_idx / BLOCK_K, kNumTMAMulticast); // Issue TMA B without broadcasting tma_copy(&tensor_map_b, reinterpret_cast(&full_barrier), smem_b[s], k_idx, @@ -826,45 +818,28 @@ __global__ void __launch_bounds__(get_num_threads_per_sm scheduler.m_boundary; - cute::tma_store_fence(); - cutlass::arch::NamedBarrier(kNumMathThreads).sync(); - if (!cross_boundary) - { - // Use TMA store to write back to global memory - if (threadIdx.x == 0) - { - cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, n_block_idx * BLOCK_N, m_global_idx); - cute::tma_store_arrive(); - cute::tma_store_wait<0>(); - } - } - else - { - __nv_bfloat16* gmem_d_this_block = gmem_d + m_global_idx * SHAPE_N; - constexpr int NUM_WARPS - = (get_num_threads_per_sm(BLOCK_M) - 128) / 32; - write_result_to_gmem(gmem_d_this_block, smem_d, m_global_idx, - scheduler.m_boundary, n_block_idx * BLOCK_N, SHAPE_N, SHAPE_N); - } - } - else + auto m_global_idx = scheduler.get_global_m_idx(m_block_idx); + bool cross_boundary = (m_global_idx + BLOCK_M) > scheduler.m_boundary; + cute::tma_store_fence(); + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + if (!cross_boundary) { - cute::tma_store_fence(); - cutlass::arch::NamedBarrier(kNumMathThreads).sync(); // Use TMA store to write back to global memory if (threadIdx.x == 0) { - cute::SM90_TMA_STORE_2D::copy( - &tensor_map_d, smem_d, n_block_idx * BLOCK_N, scheduler.get_global_m_idx(m_block_idx)); + cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, n_block_idx * BLOCK_N, m_global_idx); cute::tma_store_arrive(); cute::tma_store_wait<0>(); } } - + else + { + __nv_bfloat16* gmem_d_this_block = gmem_d + m_global_idx * SHAPE_N; + constexpr int NUM_WARPS + = (get_num_threads_per_sm(BLOCK_M) - 128) / 32; + write_result_to_gmem(gmem_d_this_block, smem_d, m_global_idx, + scheduler.m_boundary, n_block_idx * BLOCK_N, SHAPE_N, SHAPE_N); + } __syncwarp(); } } @@ -1050,18 +1025,9 @@ __global__ void __launch_bounds__(get_num_threads_per_sm(&full_barrier), smem_scales_b[s], - scheduler.get_global_scales_b_idx(n_block_idx), k_idx / BLOCK_K, kNumTMAMulticast); - } - else - { - tma_copy(&tensor_map_scales_b, - reinterpret_cast(&full_barrier), smem_scales_b[s], n_block_idx * BLOCK_N, - scheduler.get_global_scales_b_idx(k_idx / BLOCK_K), kNumTMAMulticast); - } + tma_copy(&tensor_map_scales_b, + reinterpret_cast(&full_barrier), smem_scales_b[s], + scheduler.get_global_scales_b_idx(n_block_idx), k_idx / BLOCK_K, kNumTMAMulticast); full_barrier.arrive_and_expect_tx( SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_B_SIZE_PER_STAGE); @@ -1246,45 +1212,28 @@ __global__ void __launch_bounds__(get_num_threads_per_sm scheduler.n_boundary; - cute::tma_store_fence(); - cutlass::arch::NamedBarrier(kNumMathThreads).sync(); - if (!cross_boundary) - { - // Use TMA store to write back to global memory - if (threadIdx.x == 0) - { - cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, m_block_idx * BLOCK_M, n_global_idx); - cute::tma_store_arrive(); - cute::tma_store_wait<0>(); - } - } - else - { - __nv_bfloat16* gmem_d_this_block = gmem_d + n_global_idx * SHAPE_M; - constexpr int NUM_WARPS - = (get_num_threads_per_sm(BLOCK_M) - 128) / 32; - write_result_to_gmem(gmem_d_this_block, smem_d, n_global_idx, - scheduler.n_boundary, m_block_idx * BLOCK_M, SHAPE_M, SHAPE_M); - } - } - else + auto n_global_idx = scheduler.get_global_n_idx(n_block_idx); + bool cross_boundary = (n_global_idx + BLOCK_N) > scheduler.n_boundary; + cute::tma_store_fence(); + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + if (!cross_boundary) { - cute::tma_store_fence(); - cutlass::arch::NamedBarrier(kNumMathThreads).sync(); // Use TMA store to write back to global memory if (threadIdx.x == 0) { - cute::SM90_TMA_STORE_2D::copy( - &tensor_map_d, smem_d, m_block_idx * BLOCK_M, scheduler.get_global_n_idx(n_block_idx)); + cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, m_block_idx * BLOCK_M, n_global_idx); cute::tma_store_arrive(); cute::tma_store_wait<0>(); } } - + else + { + __nv_bfloat16* gmem_d_this_block = gmem_d + n_global_idx * SHAPE_M; + constexpr int NUM_WARPS + = (get_num_threads_per_sm(BLOCK_M) - 128) / 32; + write_result_to_gmem(gmem_d_this_block, smem_d, n_global_idx, + scheduler.n_boundary, m_block_idx * BLOCK_M, SHAPE_M, SHAPE_M); + } __syncwarp(); } } diff --git a/deep_gemm/jit_kernels/gemm.py b/deep_gemm/jit_kernels/gemm.py index 459d1c7e..64bcc76a 100644 --- a/deep_gemm/jit_kernels/gemm.py +++ b/deep_gemm/jit_kernels/gemm.py @@ -67,13 +67,8 @@ def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k return smem_size, swizzle_mode, block_n_padding else: - # Try swizzle first, as it does not waste shared memory - swizzle_mode = get_swizzle_mode(block_n) - block_n_padding = get_block_n_padding_for_smem_d( - block_n) if swizzle_mode == 0 else 0 - # NOTES: `scales_b` in a total manner or per-stage manner - smem_d = block_m * (block_n + block_n_padding) * (4 if is_fp32_out else 2) + smem_d = block_m * block_n * (4 if is_fp32_out else 2) smem_a_per_stage = block_m * block_k smem_scales_a_per_stage = ceil_div(k, block_k) * 4; # weight scales smem_b_per_stage = block_n * block_k @@ -87,11 +82,12 @@ def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k smem_size += num_stages * smem_scales_b smem_size += num_stages * smem_b_per_stage smem_size += num_stages * smem_scales_b_per_stage - smem_size += ceil_div(smem_scales_a_per_stage * (1 if block_k % block_n == 0 else 2), 8) * 8 + smem_size += ceil_div(smem_scales_a_per_stage, 8) * 8 smem_size += smem_barrier - - # Swizzle and padding are not compatible - assert int(swizzle_mode > 0) + int(block_n_padding > 0) <= 1 + + # no swizzle, no block_n_padding + swizzle_mode = 0 + block_n_padding = 0 return smem_size, swizzle_mode, block_n_padding @@ -105,7 +101,8 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, else: block_ms = (get_m_alignment_for_contiguous_layout(), ) #block_ns = tuple(range(16, 129, 8)) + ((136, 152, ) if is_wgrad else (144, 160, )) - block_ns = tuple(range(16, 129, 8)) + block_ns = tuple(range(16, 129, 8)) + # Avoid bank conflicts for FP32 output if is_fp32_out: block_ns = [x for x in block_ns if x % 16 == 8] diff --git a/deep_gemm/jit_kernels/m_grouped_gemm.py b/deep_gemm/jit_kernels/m_grouped_gemm.py index 2a607677..94522db3 100644 --- a/deep_gemm/jit_kernels/m_grouped_gemm.py +++ b/deep_gemm/jit_kernels/m_grouped_gemm.py @@ -221,20 +221,12 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_offset(lhs: Tuple[torch.Tensor, torch.Tensor] num_groups, n, k_ = rhs.shape m_, n_ = out.shape - - print("expected_m: ",expected_m) - print("A shape: ",lhs.shape) - print("A scale shape: ",lhs_scales.shape) - print("B shape: ",rhs.shape) - print("B scale shape: ",rhs_scales.shape) - print("out shape: ",out.shape) - # Type and shape checks assert m == m_ and n == n_ and k == k_ max_shape_m_4_align = ceil_div(m, 4) * 4 # align 4 - max_shape_m_32_align_padded = compute_padded_offset(m, num_groups) + max_shape_m_32_align_padded = compute_padded_offset(max_shape_m_4_align, num_groups) assert expected_m > 0 and max_shape_m_4_align > 0 and n > 0 and k > 0 and num_groups > 0 @@ -244,12 +236,14 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_offset(lhs: Tuple[torch.Tensor, torch.Tensor] assert rhs_scales.shape == (num_groups, ceil_div(n, 128), ceil_div(k, 128)) assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32 assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32 + assert offsets.dtype == torch.int64 assert out.dtype == torch.bfloat16 assert lhs.is_contiguous() and rhs.is_contiguous() assert out.is_contiguous() # LHS scales must be transposed for TMA load, but not for RHS scales - lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales) + + #lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales) assert rhs_scales.is_contiguous() # Auto-tuning with compilation @@ -273,9 +267,9 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_offset(lhs: Tuple[torch.Tensor, torch.Tensor] num_tma_threads = 128 num_math_threads_per_group = 128 - tensor_map_a = make_2d_tma_a_desc(GemmType.GroupedWithOffset, lhs, m, k, k, block_m, block_k, num_groups) + tensor_map_a = make_2d_tma_a_desc(GemmType.GroupedWithOffset, lhs, max_shape_m_4_align, k, k, block_m, block_k, num_groups) tensor_map_b = make_2d_tma_b_desc(GemmType.GroupedWithOffset, rhs, n, k, k, block_n, block_k, num_groups) - tensor_map_d = make_2d_tma_d_desc(GemmType.GroupedWithOffset, out, m, n, n, block_m, block_n, num_groups, 0) # none swizzle + tensor_map_d = make_2d_tma_d_desc(GemmType.GroupedWithOffset, out, max_shape_m_4_align, n, n, block_m, block_n, num_groups, 0) # none swizzle tensor_map_scales_a = make_2d_tma_scales_a_offset_desc(GemmType.GroupedWithOffset, lhs_scales, max_shape_m_32_align_padded, k, block_m, block_k) # none swizzle @@ -287,7 +281,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_offset(lhs: Tuple[torch.Tensor, torch.Tensor] 'PROBLEM_OFFSETS': offsets, 'NUM_TMA_THREADS': num_tma_threads, 'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group, - 'M': m, 'N': n, 'K': k, + 'M': max_shape_m_4_align, 'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'NUM_GROUPS': num_groups, 'NUM_STAGES': num_stages, @@ -310,26 +304,17 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_offset(lhs: Tuple[torch.Tensor, torch.Tensor] else: num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs( n, expected_m, k, num_groups, num_sms, is_grouped_contiguous = True, is_swap_ab=True) - # Extra checks for TMA store if num_groups > 1 and n > block_m: assert n % block_m == 0, f'For GroupedWithOffset grouped GEMM, shape M should be multiple of the block M (current block M: {block_m})' - print("is_swap_ab=True =========") - print("num_sms: ",num_sms) - print("block_m: ",block_m) - print("block_n: ",block_n) - print("num_stages: ",num_stages) - print("tma_multicast_config: ",tma_multicast_config) - print("smem_config: ",smem_config) - block_k = 128 num_tma_threads = 128 num_math_threads_per_group = 128 tensor_map_a = make_2d_tma_a_offset_desc_swapAB(GemmType.GroupedWithOffset, rhs, n, k, k, block_m, block_k, num_groups) - tensor_map_b = make_2d_tma_b_offset_desc_swapAB(GemmType.GroupedWithOffset, lhs, m, k, k, block_n, block_k, num_groups) - tensor_map_d = make_2d_tma_d_offset_desc_swapAB(GemmType.GroupedWithOffset, out, n, m, m, block_m, block_n, num_groups, 0) # no swizzle + tensor_map_b = make_2d_tma_b_offset_desc_swapAB(GemmType.GroupedWithOffset, lhs, max_shape_m_4_align, k, k, block_n, block_k, num_groups) + tensor_map_d = make_2d_tma_d_offset_desc_swapAB(GemmType.GroupedWithOffset, out, max_shape_m_4_align, n, n, block_m, block_n, num_groups, 0) # no swizzle tensor_map_scales_b = make_2d_tma_scales_b_offset_desc_swapAB(GemmType.GroupedWithOffset, lhs_scales, max_shape_m_32_align_padded, k, block_n, block_k) # no swizzle kwargs = { @@ -340,7 +325,7 @@ def m_grouped_gemm_fp8_fp8_bf16_nt_offset(lhs: Tuple[torch.Tensor, torch.Tensor] 'PROBLEM_OFFSETS': offsets, 'NUM_TMA_THREADS': num_tma_threads, 'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group, - 'M': m, 'N': n, 'K': k, + 'M': max_shape_m_4_align, 'N': n, 'K': k, 'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'NUM_GROUPS': num_groups, 'NUM_STAGES': num_stages, diff --git a/deep_gemm/jit_kernels/runtime.py b/deep_gemm/jit_kernels/runtime.py index 584b3220..bc37de2e 100644 --- a/deep_gemm/jit_kernels/runtime.py +++ b/deep_gemm/jit_kernels/runtime.py @@ -173,8 +173,8 @@ def make_2d_tma_d_offset_desc_swapAB(gemm_type: GemmType, t: torch.Tensor, # Swizzling requires the inner box dim to be less or equal than `kSwizzleDMode` # bytes, so `BLOCK_N * sizeof(T) / kSwizzleDMode` TMA stores are required return make_2d_tma_desc(t, - shape_n, shape_m * (num_groups if gemm_type != GemmType.Normal else 1), m_stride, - min(block_n, shape_n), min(block_m, shape_m), + shape_n, shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), m_stride, + min(block_m, shape_n), min(block_n, shape_m), cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE) diff --git a/tests/test_core.py b/tests/test_core.py index eb3e51e4..336d5b40 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -6,6 +6,7 @@ import random import torch from typing import List, Tuple +import itertools import deep_gemm from deep_gemm import bench_kineto, calc_diff, ceil_div, get_col_major_tma_aligned_tensor @@ -34,6 +35,49 @@ def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(x_view.size(0), x_view.size(2)) +def construct(m: int, k: int, n: int) -> \ + Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]: + x = torch.randn((m, k), device='cuda', dtype=torch.bfloat16) + y = torch.randn((n, k), device='cuda', dtype=torch.bfloat16) + out = torch.empty((m, n), device='cuda', dtype=torch.bfloat16) + ref_out = x @ y.t() + + x_fp8, y_fp8 = per_token_cast_to_fp8(x), per_block_cast_to_fp8(y) + # Transpose earlier so that the testing will not trigger transposing kernels + x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1])) + return x_fp8, y_fp8, out, ref_out + + +def construct_contiguous_grouped(num_groups: int, expected_m_per_group: int, k: int, n: int) -> \ + Tuple[int, Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]: + alignment = get_m_alignment_for_contiguous_layout() + group_ms = [int(expected_m_per_group * random.uniform(0.7, 1.3)) for _ in range(num_groups)] + m = sum([ceil_div(x, alignment) * alignment for x in group_ms]) + + x = torch.randn((m, k), device='cuda', dtype=torch.bfloat16) + y = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16) + m_indices = torch.empty(m, device='cuda', dtype=torch.int32) + out = torch.empty((m, n), device='cuda', dtype=torch.bfloat16) + ref_out = torch.randn((m, n), device='cuda', dtype=torch.bfloat16) + + start = 0 + for i, group_m in enumerate(group_ms): + actual_end = start + group_m + aligned_end = start + ceil_div(group_m, alignment) * alignment + m_indices[start:actual_end] = i + m_indices[actual_end:aligned_end] = -1 + ref_out[start:aligned_end] = x[start:aligned_end] @ y[i].t() + start = aligned_end + ref_out = torch.where((m_indices == -1).unsqueeze(1), torch.zeros_like(ref_out), ref_out) + + assert m % 4 == 0, f'TMA alignment error: {m}' + x_fp8 = per_token_cast_to_fp8(x) + y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), torch.empty((num_groups, ceil_div(n, 128), k // 128), device='cuda', dtype=torch.float)) + for i in range(num_groups): + y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i]) + + return m, x_fp8, y_fp8, m_indices, out, ref_out + def construct_masked_grouped(num_groups: int, max_m: int, expected_m_per_group: int, k: int, n: int) -> \ Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]: @@ -55,46 +99,129 @@ def construct_masked_grouped(num_groups: int, max_m: int, expected_m_per_group: # Construct mask masked_m = torch.empty((num_groups, ), device='cuda', dtype=torch.int) for j in range(num_groups): - masked_m[j] = int(expected_m_per_group * random.uniform(1, 1)) + masked_m[j] = int(expected_m_per_group * random.uniform(0.7, 1.3)) assert masked_m.amax().item() <= max_m return x_fp8, y_fp8, masked_m, out, ref_out -def test_m_grouped_gemm_masked() -> None: - print('Testing grouped masked GEMM:') - for num_groups, expected_m_per_group in ((1, 1024), (2, 512), (4, 256)): - for k, n in ((7168, 4096), (2048, 7168), ): - # Test correctness - for i in range(10): - x_fp8, y_fp8, masked_m, out, ref_out = construct_masked_grouped(num_groups, 4096, expected_m_per_group, k, n) - deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8, y_fp8, out, masked_m, expected_m_per_group) - for j in range(num_groups): - diff = calc_diff(out[j, :masked_m[j].item()], ref_out[j, :masked_m[j].item()]) - assert diff < 0.001, f'{expected_m_per_group=}, {k=}, {n=}, {j=}, masked_m={masked_m[j]}, {num_groups=}, {diff:.5f}' +def construct_wgrad(m: int, k: int, n: int) -> \ + Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]: + x = torch.randn((m, k), device='cuda', dtype=torch.bfloat16) + y = torch.randn((n, k), device='cuda', dtype=torch.bfloat16) + residual = torch.randn((m, n), device='cuda', dtype=torch.float) * 10 + out = residual.clone() + ref_out = residual + (x.float() @ y.float().t()) - # noinspection PyShadowingNames - def test_func(): - deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8, y_fp8, out, masked_m, expected_m_per_group) + x_fp8 = per_token_cast_to_fp8(x) + y_fp8 = per_token_cast_to_fp8(y) - # Test performance with fixed shapes - # noinspection PyUnboundLocalVariable - valid_m = masked_m.sum().item() - t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) - print(f' > Perf ({num_groups=}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | ' - f'throughput: {2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS, ' - f'{(valid_m * k + num_groups * k * n + valid_m * n * 2) / 1e9 / t:4.0f} GB/s') - print() + # NOTES: please do inplace add on the `out` later + return x_fp8, y_fp8, residual, out, ref_out + + +def construct_k_grouped_wgrad(m: int, n: int, k_sizes: List[int]) -> \ + Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor, List[int]]: + num_groups, total_k = len(k_sizes), sum(k_sizes) + + x_flat = torch.empty((m * total_k,), device='cuda', dtype=torch.bfloat16) + y_flat = torch.empty((n * total_k,), device='cuda', dtype=torch.bfloat16) + out = torch.zeros((num_groups, m, n), device='cuda', dtype=torch.float) + ref_out = torch.zeros((num_groups, m, n), device='cuda', dtype=torch.float) + + # Fill tensors with data and compute reference output + x_offset, y_offset = 0, 0 + for idx, k in enumerate(k_sizes): + x_chunk = torch.randn((m, k), device='cuda', dtype=torch.bfloat16) + y_chunk = torch.randn((n, k), device='cuda', dtype=torch.bfloat16) + + x_flat[x_offset:x_offset + m * k].copy_(x_chunk.flatten()) + y_flat[y_offset:y_offset + n * k].copy_(y_chunk.flatten()) + ref_out[idx] = x_chunk.float() @ y_chunk.float().t() + + x_offset += m * k + y_offset += n * k + + x_fp8_flat = torch.empty_like(x_flat, dtype=torch.float8_e4m3fn) + y_fp8_flat = torch.empty_like(y_flat, dtype=torch.float8_e4m3fn) + + total_scale_factors = sum(ceil_div(k, 128) for k in k_sizes) + x_scales = torch.empty((total_scale_factors, m), device='cuda', dtype=torch.float) + y_scales = torch.empty((total_scale_factors, n), device='cuda', dtype=torch.float) + + # Cast to FP8 and prepare scale factors + x_offset, y_offset, scale_offset = 0, 0, 0 + for k in k_sizes: + x_fp8_chunk, x_scale_chunk = per_token_cast_to_fp8(x_flat[x_offset:x_offset + m * k].view(m, k)) + y_fp8_chunk, y_scale_chunk = per_token_cast_to_fp8(y_flat[y_offset:y_offset + n * k].view(n, k)) + + x_fp8_flat[x_offset:x_offset + m * k].copy_(x_fp8_chunk.flatten()) + y_fp8_flat[y_offset:y_offset + n * k].copy_(y_fp8_chunk.flatten()) + + num_scales = ceil_div(k, 128) + x_scales[scale_offset:scale_offset + num_scales].copy_(x_scale_chunk.T) + y_scales[scale_offset:scale_offset + num_scales].copy_(y_scale_chunk.T) + + x_offset += m * k + y_offset += n * k + scale_offset += num_scales + + return (x_fp8_flat, x_scales), (y_fp8_flat, y_scales), out, ref_out, k_sizes + + +def change_to_offset_layout( + ms: List[int], + x_fp8: torch.Tensor, + x_scale: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + x_list = [] + x_scale_list = [] + shape_m_total = 0 + num_problems = len(ms) + m_acc = [0] + list(itertools.accumulate(ms)) + + # Need to keep the same as the one in cpp/include/tensorrt_llm/deep_gemm/scheduler.cuh + def compute_padded_offset(offset, idx_problem, alignment=32): + return (offset + idx_problem * (alignment - 1)) // alignment * alignment + + offset = 0 + for i in range(num_problems): + ms[i] + x_list.append(x_fp8[m_acc[i]:m_acc[i + 1]]) + offset_next = compute_padded_offset(m_acc[i + 1], i + 1) + size_padded = (offset_next - offset) - (m_acc[i + 1] - m_acc[i]) + x_scale_padded = torch.cat([ + x_scale[m_acc[i]:m_acc[i + 1]], + torch.zeros( + [size_padded, *x_scale.shape[1:]], + dtype=x_scale.dtype, + device=x_scale.device, + ), + ]) + x_scale_list.append(x_scale_padded) + offset = offset_next + + shape_m_total = m_acc[-1] + ret_x = torch.cat(x_list) + ret_x_scale = torch.cat(x_scale_list) + ret_x_scale = ret_x_scale.t().contiguous() + pad_target = compute_padded_offset(shape_m_total, num_problems) + pad_target -= ret_x_scale.shape[1] + ret_x_scale = torch.nn.functional.pad(ret_x_scale, (0, pad_target), + mode='constant', + value=0) + return ret_x, ret_x_scale def construct_offset_grouped(num_groups: int, expected_m_per_group: int, k: int, n: int) -> \ Tuple[int, Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]: - alignment = 32 - group_ms = [int(expected_m_per_group * random.uniform(1, 1)) for _ in range(num_groups)] + alignment = 4 + group_ms = [int(expected_m_per_group * random.uniform(0.7, 1.3)) for _ in range(num_groups)] + m = sum([ceil_div(x, alignment) * alignment for x in group_ms]) x = torch.randn((m, k), device='cuda', dtype=torch.bfloat16) - y = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16) - offsets = torch.empty(num_groups+1, device='cuda', dtype=torch.int32) + y = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16) + offsets = torch.empty(num_groups+1, device='cuda', dtype=torch.int64) out = torch.empty((m, n), device='cuda', dtype=torch.bfloat16) ref_out = torch.randn((m, n), device='cuda', dtype=torch.bfloat16) @@ -105,6 +232,7 @@ def construct_offset_grouped(num_groups: int, expected_m_per_group: int, k: int, offsets[i+1] = aligned_end ref_out[start:aligned_end] = x[start:aligned_end] @ y[i].t() start = aligned_end + group_ms[i] = ceil_div(group_m, alignment) * alignment assert m % 4 == 0, f'TMA alignment error: {m}' x_fp8 = per_token_cast_to_fp8(x) @@ -112,62 +240,162 @@ def construct_offset_grouped(num_groups: int, expected_m_per_group: int, k: int, for i in range(num_groups): y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i]) - return m, x_fp8, y_fp8, offsets, out, ref_out + return group_ms, m, x_fp8, y_fp8, offsets.type(torch.int64), out, ref_out +def test_gemm() -> None: + print('Testing GEMM:') + for m in (64, 128, 4096): + for k, n in [(576, 7168), (7168, 2112), (1536, 24576), (512, 32768), (16384, 7168), (7168, 4096), (2048, 7168)]: + x_fp8, y_fp8, out, ref_out = construct(m, k, n) + deep_gemm.gemm_fp8_fp8_bf16_nt(x_fp8, y_fp8, out) + diff = calc_diff(out, ref_out) + assert diff < 0.001, f'{m=}, {k=}, {n=}, {diff:.5f}' -def test_m_grouped_gemm_offset() -> None: + # noinspection PyShadowingNames + def test_func(): + deep_gemm.gemm_fp8_fp8_bf16_nt(x_fp8, y_fp8, out) + + t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) + print(f' > Perf (m={m:5}, n={n:5}, k={k:5}): {t * 1e6:4.0f} us | ' + f'throughput: {2 * m * n * k / t / 1e12:4.0f} TFLOPS, ' + f'{(m * k + k * n + m * n * 2) / 1e9 / t:4.0f} GB/s') + print() + + +def test_m_grouped_gemm_contiguous() -> None: print('Testing grouped contiguous GEMM:') - for num_groups, expected_m_per_group, k, n in ((9, 32, 7168, 4096),): + for num_groups, expected_m_per_group, k, n in ((4, 8192, 7168, 4096), (4, 8192, 2048, 7168), + (8, 4096, 7168, 4096), (8, 4096, 2048, 7168), + (32, 256, 7168, 4096), (32, 256, 2048, 7168)): # NOTES: we should mask the unfilled part before calculating difference - - x_fp8_mask, y_fp8_mask, masked_m_mask, out_mask, ref_out_mask = construct_masked_grouped(num_groups, expected_m_per_group, expected_m_per_group, k, n) - deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8_mask, y_fp8_mask, out_mask, masked_m_mask, expected_m_per_group) - - - for j in range(num_groups): - diff = calc_diff(out_mask[j, :masked_m_mask[j].item()], ref_out_mask[j, :masked_m_mask[j].item()]) - #assert diff < 0.001, f'{expected_m_per_group=}, {k=}, {n=}, {j=}, masked_m={masked_m_mask[j]}, {num_groups=}, {diff:.5f}' + m, x_fp8, y_fp8, m_indices, out, ref_out = construct_contiguous_grouped(num_groups, expected_m_per_group, k, n) + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(x_fp8, y_fp8, out, m_indices) + out = torch.where((m_indices == -1).unsqueeze(1), torch.zeros_like(out), out) + diff = calc_diff(out, ref_out) + assert diff < 0.001, f'{m=}, {k=}, {n=}, {diff:.5f}' # noinspection PyShadowingNames def test_func(): - deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8_mask, y_fp8_mask, out_mask, masked_m_mask, expected_m_per_group) + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(x_fp8, y_fp8, out, m_indices) - # Test performance with fixed shapes - # noinspection PyUnboundLocalVariable - valid_m = masked_m_mask.sum().item() t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) - print(f' > m_grouped_gemm_fp8_fp8_bf16_nt_masked: Perf ({num_groups=}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | ' + valid_m = (m_indices != -1).sum().item() + print(f' > Perf ({num_groups=:2}, {expected_m_per_group=:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | ' f'throughput: {2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS, ' f'{(valid_m * k + num_groups * k * n + valid_m * n * 2) / 1e9 / t:4.0f} GB/s') + print() - ''' - m_offset, x_fp8_offset, y_fp8_offset, offset, out_offset, ref_out_offset = construct_offset_grouped(num_groups, expected_m_per_group, k, n) - - #deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_offset(x_fp8_offset, y_fp8_offset, offset, out_offset, expected_m_per_group) - #diff = calc_diff(out_offset, ref_out_offset) - # assert diff < 0.001, f'{m=}, {k=}, {n=}, {diff:.5f}' +def test_m_grouped_gemm_masked() -> None: + print('Testing grouped masked GEMM:') + + for num_groups, expected_m_per_group in ((1, 1024), (2, 512), (4, 256)): + for k, n in ((7168, 4096), (2048, 7168), ): + # Test correctness + for i in range(10): + x_fp8, y_fp8, masked_m, out, ref_out = construct_masked_grouped(num_groups, 4096, expected_m_per_group, k, n) + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8, y_fp8, out, masked_m, expected_m_per_group) + for j in range(num_groups): + diff = calc_diff(out[j, :masked_m[j].item()], ref_out[j, :masked_m[j].item()]) + assert diff < 0.001, f'{expected_m_per_group=}, {k=}, {n=}, {j=}, masked_m={masked_m[j]}, {num_groups=}, {diff:.5f}' + + # noinspection PyShadowingNames + def test_func(): + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8, y_fp8, out, masked_m, expected_m_per_group) + + # Test performance with fixed shapes + # noinspection PyUnboundLocalVariable + valid_m = masked_m.sum().item() + t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) + print(f' > Perf ({num_groups=}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | ' + f'throughput: {2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS, ' + f'{(valid_m * k + num_groups * k * n + valid_m * n * 2) / 1e9 / t:4.0f} GB/s') + print() + + +def test_wgrad_gemm(): + print('Testing weight gradient GEMM:') + + for k in (4096, 8192): + for m, n in ((7168, 2112), (1536, 24576), (512, 32768), (16384, 7168), (7168, 4096), (2048, 7168)): + # Test correctness + x_fp8, y_fp8, residual, out, ref_out = construct_wgrad(m, k, n) + deep_gemm.wgrad_gemm_fp8_fp8_fp32_nt(x_fp8, y_fp8, out) + diff = calc_diff(out, ref_out) + assert diff < 0.001, f'{m=}, {k=}, {n=}, {diff:.5f}' + + # Construct new tensors only once to avoid L2 cache acceleration (creating them puts them in L2) + x_fp8, y_fp8, residual, out, ref_out = construct_wgrad(m, k, n) + + # noinspection PyShadowingNames + def test_func(): + deep_gemm.wgrad_gemm_fp8_fp8_fp32_nt(x_fp8, y_fp8, out) + + t = bench_kineto(test_func, 'fp8_wgrad_gemm', suppress_kineto_output=True) + print(f' > Performance (m={m:5}, n={n:5}, k={k:5}): {t * 1e6:4.0f} us | ' + f'throughput: {2 * m * n * k / t / 1e12:4.0f} TFLOPS, ' + f'{(m * k + k * n + m * n * 2) / 1e9 / t:4.0f} GB/s') + print() + + +def test_k_grouped_wgrad_gemm(): + print('Testing grouped weight gradient GEMM:') + + for num_groups, base_k in ((4, 4096), (4, 8192), (8, 4096)): + for m, n in ((7168, 4096), (2048, 7168)): + # Vary k sizes around base_k + k_sizes = [base_k + random.randint(-1, 1) * 128 for _ in range(num_groups - 1)] + k_sizes.append(base_k * num_groups - sum(k_sizes)) + + # Test correctness + x_fp8, y_fp8, out, ref_out, k_sizes = construct_k_grouped_wgrad(m, n, k_sizes) + deep_gemm.k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(x_fp8, y_fp8, out, k_sizes) + + for idx in range(num_groups): + diff = calc_diff(out[idx], ref_out[idx]) + assert diff < 0.001, f'{num_groups=}, {m=}, {n=}, k={k_sizes[idx]}, batch={idx}, {diff:.5f}' + + # Construct new tensors to avoid L2 cache acceleration + x_fp8, y_fp8, out, ref_out, k_sizes = construct_k_grouped_wgrad(m, n, k_sizes) + total_k = sum(k_sizes) + + def test_func(): + deep_gemm.k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(x_fp8, y_fp8, out, k_sizes) + + t = bench_kineto(test_func, 'fp8_wgrad_gemm', suppress_kineto_output=True, with_multiple_kernels=True) * num_groups + print(f' > Performance ({num_groups=}, m={m:5}, n={n:5}, avg_k={total_k//num_groups:5}): {t * 1e6:4.0f} us | ' + f'throughput: {2 * num_groups * m * n * (total_k/num_groups) / t / 1e12:4.0f} TFLOPS, ' + f'{(m * total_k + n * total_k + num_groups * m * n * 2) / 1e9 / t:4.0f} GB/s') + print() + + +def test_m_grouped_gemm_offset() -> None: + print('Testing grouped contiguous GEMM:') + + for num_groups, expected_m_per_group, k, n in ((8, 32, 7168, 4096),(9, 64, 7168, 4096)): + # NOTES: we should mask the unfilled part before calculating difference + ms, m_offset, x_fp8_offset, y_fp8_offset, offset, out_offset, ref_out_offset = construct_offset_grouped(num_groups, expected_m_per_group, k, n) + pad_x_fp8 = change_to_offset_layout(ms, x_fp8_offset[0], x_fp8_offset[1]) + + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_offset(pad_x_fp8, y_fp8_offset, offset, out_offset, expected_m_per_group) + diff = calc_diff(out_offset, ref_out_offset) + assert diff < 0.1, f'{m_offset=}, {k=}, {n=}, {diff:.5f}' # noinspection PyShadowingNames def test_func(): - deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_offset(x_fp8_offset, y_fp8_offset, offset, out_offset, expected_m_per_group) + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_offset(pad_x_fp8, y_fp8_offset, offset, out_offset, expected_m_per_group) t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) valid_m = m_offset + print(f' > m_grouped_gemm_fp8_fp8_bf16_nt_offset: Perf ({num_groups=:2}, {expected_m_per_group=:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | ' f'throughput: {2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS, ' f'{(valid_m * k + num_groups * k * n + valid_m * n * 2) / 1e9 / t:4.0f} GB/s') - - ''' print() - - - - if __name__ == '__main__': torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True @@ -177,4 +405,10 @@ def test_func(): print('Library path:') print(f' > {deep_gemm.__path__}\n') + test_gemm() + test_m_grouped_gemm_contiguous() + test_m_grouped_gemm_masked() test_m_grouped_gemm_offset() + + test_wgrad_gemm() + test_k_grouped_wgrad_gemm() From 7db1b0ef63975773d353885e58c493044ad66f46 Mon Sep 17 00:00:00 2001 From: wangzhe_ant Date: Tue, 24 Jun 2025 18:24:08 +0800 Subject: [PATCH 4/5] update unitest --- tests/test_core.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_core.py b/tests/test_core.py index 336d5b40..3a1aca7f 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -372,16 +372,16 @@ def test_func(): def test_m_grouped_gemm_offset() -> None: - print('Testing grouped contiguous GEMM:') + print('Testing grouped offset GEMM:') - for num_groups, expected_m_per_group, k, n in ((8, 32, 7168, 4096),(9, 64, 7168, 4096)): + for num_groups, expected_m_per_group, k, n in ((8, 32, 7168, 4096),(9, 64, 7168, 4096),(32, 32, 7168, 4096)): # NOTES: we should mask the unfilled part before calculating difference ms, m_offset, x_fp8_offset, y_fp8_offset, offset, out_offset, ref_out_offset = construct_offset_grouped(num_groups, expected_m_per_group, k, n) pad_x_fp8 = change_to_offset_layout(ms, x_fp8_offset[0], x_fp8_offset[1]) deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_offset(pad_x_fp8, y_fp8_offset, offset, out_offset, expected_m_per_group) diff = calc_diff(out_offset, ref_out_offset) - assert diff < 0.1, f'{m_offset=}, {k=}, {n=}, {diff:.5f}' + assert diff < 0.001.1, f'{m_offset=}, {k=}, {n=}, {diff:.5f}' # noinspection PyShadowingNames def test_func(): @@ -390,7 +390,7 @@ def test_func(): t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) valid_m = m_offset - print(f' > m_grouped_gemm_fp8_fp8_bf16_nt_offset: Perf ({num_groups=:2}, {expected_m_per_group=:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | ' + print(f' > Perf ({num_groups=:2}, {expected_m_per_group=:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | ' f'throughput: {2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS, ' f'{(valid_m * k + num_groups * k * n + valid_m * n * 2) / 1e9 / t:4.0f} GB/s') print() From e29e996a42626810c5af15ad6be36c0dd2d48cdb Mon Sep 17 00:00:00 2001 From: wangzhe_ant Date: Tue, 24 Jun 2025 18:53:19 +0800 Subject: [PATCH 5/5] update unitest --- tests/test_core.py | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/tests/test_core.py b/tests/test_core.py index 3a1aca7f..1a3565b1 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -374,25 +374,26 @@ def test_func(): def test_m_grouped_gemm_offset() -> None: print('Testing grouped offset GEMM:') - for num_groups, expected_m_per_group, k, n in ((8, 32, 7168, 4096),(9, 64, 7168, 4096),(32, 32, 7168, 4096)): - # NOTES: we should mask the unfilled part before calculating difference - ms, m_offset, x_fp8_offset, y_fp8_offset, offset, out_offset, ref_out_offset = construct_offset_grouped(num_groups, expected_m_per_group, k, n) - pad_x_fp8 = change_to_offset_layout(ms, x_fp8_offset[0], x_fp8_offset[1]) - - deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_offset(pad_x_fp8, y_fp8_offset, offset, out_offset, expected_m_per_group) - diff = calc_diff(out_offset, ref_out_offset) - assert diff < 0.001.1, f'{m_offset=}, {k=}, {n=}, {diff:.5f}' + for num_groups, expected_m_per_group in ((2, 16), (4, 16), (2, 32), (9, 32), (2, 32), (4, 32), (32, 64)): + for k, n in ((7168, 4096),): + # NOTES: we should mask the unfilled part before calculating difference + ms, m_offset, x_fp8_offset, y_fp8_offset, offset, out_offset, ref_out_offset = construct_offset_grouped(num_groups, expected_m_per_group, k, n) + pad_x_fp8 = change_to_offset_layout(ms, x_fp8_offset[0], x_fp8_offset[1]) - # noinspection PyShadowingNames - def test_func(): deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_offset(pad_x_fp8, y_fp8_offset, offset, out_offset, expected_m_per_group) + diff = calc_diff(out_offset, ref_out_offset) + assert diff < 0.001, f'{m_offset=}, {k=}, {n=}, {diff:.5f}' - t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) - valid_m = m_offset + # noinspection PyShadowingNames + def test_func(): + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_offset(pad_x_fp8, y_fp8_offset, offset, out_offset, expected_m_per_group) - print(f' > Perf ({num_groups=:2}, {expected_m_per_group=:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | ' - f'throughput: {2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS, ' - f'{(valid_m * k + num_groups * k * n + valid_m * n * 2) / 1e9 / t:4.0f} GB/s') + t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) + valid_m = m_offset + + print(f' > Perf ({num_groups=:2}, {expected_m_per_group=:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | ' + f'throughput: {2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS, ' + f'{(valid_m * k + num_groups * k * n + valid_m * n * 2) / 1e9 / t:4.0f} GB/s') print()