From c131985d0c2515e20de3fa0ad743be4b16923124 Mon Sep 17 00:00:00 2001 From: Artem Kuzmitckii Date: Tue, 15 Jul 2025 04:27:20 -0400 Subject: [PATCH 1/7] [AMD][ROCm] Improve support of AMD The patch delivers several fixes for building issues for CUDA part of DeepSpeed library. Percentage of passed unit tests improved(tested on RDNA hardware, gfx110x and gfx12x) Before: collected 5298 items / 15 skipped 2773 failed, 862 passed, 1665 skipped, 13 errors After: collected 5851 items / 11 skipped 4187 failed, 1373 passed, 292 skipped, 10 errors Signed-off-by: Artem Kuzmitckii --- .../evoformer_attn/gemm_kernel_utils.h | 2 +- csrc/fp_quantizer/fp_quantize.cpp | 1 + csrc/fp_quantizer/fp_quantize.cu | 3 +- .../includes/{context.h => fp_context.h} | 0 csrc/includes/reduction_utils.h | 36 +++++++++++++------ .../include/utils_paralleldequant.cuh | 2 +- .../cuda_linear/linear_kernels_cuda.cu | 3 +- .../cutlass_ops/mixed_gemm/mixed_gemm.cu | 1 + .../kernels/cutlass_ops/moe_gemm/moe_gemm.cu | 1 + 9 files changed, 35 insertions(+), 14 deletions(-) rename csrc/fp_quantizer/includes/{context.h => fp_context.h} (100%) diff --git a/csrc/deepspeed4science/evoformer_attn/gemm_kernel_utils.h b/csrc/deepspeed4science/evoformer_attn/gemm_kernel_utils.h index c102234a4dfb..24e5d1e997cc 100644 --- a/csrc/deepspeed4science/evoformer_attn/gemm_kernel_utils.h +++ b/csrc/deepspeed4science/evoformer_attn/gemm_kernel_utils.h @@ -233,7 +233,7 @@ struct call_conditional { CUTLASS_DEVICE int32_t warp_uniform(int32_t value) { - return (int32_t)__shfl_sync(0xffffffff, (unsigned)value, 0); + return (int32_t)__shfl_sync(static_cast(0xffffffff), (unsigned)value, 0); } template diff --git a/csrc/fp_quantizer/fp_quantize.cpp b/csrc/fp_quantizer/fp_quantize.cpp index 1a887b50e1a3..e273f442d34a 100644 --- a/csrc/fp_quantizer/fp_quantize.cpp +++ b/csrc/fp_quantizer/fp_quantize.cpp @@ -6,6 +6,7 @@ #include "fp_quantize.h" #include +#include #include #include diff --git a/csrc/fp_quantizer/fp_quantize.cu b/csrc/fp_quantizer/fp_quantize.cu index 66ea7392e011..1ec80f543c5f 100644 --- a/csrc/fp_quantizer/fp_quantize.cu +++ b/csrc/fp_quantizer/fp_quantize.cu @@ -4,7 +4,7 @@ // DeepSpeed Team #include -#include "context.h" +#include "fp_context.h" #include "fp_quantize.h" #include "memory_access_utils.h" #include "reduction_utils.h" @@ -14,6 +14,7 @@ #include #include +#include #ifdef BF16_AVAILABLE #include diff --git a/csrc/fp_quantizer/includes/context.h b/csrc/fp_quantizer/includes/fp_context.h similarity index 100% rename from csrc/fp_quantizer/includes/context.h rename to csrc/fp_quantizer/includes/fp_context.h diff --git a/csrc/includes/reduction_utils.h b/csrc/includes/reduction_utils.h index eb9afb66a894..b736e511cb85 100644 --- a/csrc/includes/reduction_utils.h +++ b/csrc/includes/reduction_utils.h @@ -526,12 +526,28 @@ here (fold is C++17 only and I don't think helps and recursion feels like huge overkill that harms readability) that would be wonderful. */ +template +DS_D_INLINE T shfl_xor_helper(cg::thread_block_tile& warp, const T& value, int i) +{ + return warp.shfl_xor(value, i); +} + +#if defined(__HIP_PLATFORM_AMD__) +template <> +DS_D_INLINE __half shfl_xor_helper<__half>(cg::thread_block_tile& warp, + const __half& value, + int i) +{ + return __half(warp.shfl_xor(float(value), i)); +} +#endif + template DS_D_INLINE void _warp(cg::thread_block_tile& warp, T* data) { #pragma unroll for (int i = 1; i < reduce_width; i *= 2) { - data[0] = element(data[0], warp.shfl_xor(data[0], i)); + data[0] = element(data[0], shfl_xor_helper(warp, data[0], i)); } } @@ -540,8 +556,8 @@ DS_D_INLINE void _warp(cg::thread_block_tile& warp, T* data) { #pragma unroll for (int i = 1; i < reduce_width; i *= 2) { - data[0] = element(data[0], warp.shfl_xor(data[0], i)); - data[1] = element(data[1], warp.shfl_xor(data[1], i)); + data[0] = element(data[0], shfl_xor_helper(warp, data[0], i)); + data[1] = element(data[1], shfl_xor_helper(warp, data[0], i)); } } @@ -550,9 +566,9 @@ DS_D_INLINE void _warp(cg::thread_block_tile& warp, T* data) { #pragma unroll for (int i = 1; i < reduce_width; i *= 2) { - data[0] = element(data[0], warp.shfl_xor(data[0], i)); - data[1] = element(data[1], warp.shfl_xor(data[1], i)); - data[2] = element(data[2], warp.shfl_xor(data[2], i)); + data[0] = element(data[0], shfl_xor_helper(warp, data[0], i)); + data[1] = element(data[1], shfl_xor_helper(warp, data[0], i)); + data[2] = element(data[2], shfl_xor_helper(warp, data[0], i)); } } @@ -566,10 +582,10 @@ DS_D_INLINE void _warp(cg::thread_block_tile& warp, T* data) { #pragma unroll for (int i = 1; i < reduce_width; i *= 2) { - data[0] = element(data[0], warp.shfl_xor(data[0], i)); - data[1] = element(data[1], warp.shfl_xor(data[1], i)); - data[2] = element(data[2], warp.shfl_xor(data[2], i)); - data[3] = element(data[3], warp.shfl_xor(data[3], i)); + data[0] = element(data[0], shfl_xor_helper(warp, data[0], i)); + data[1] = element(data[1], shfl_xor_helper(warp, data[0], i)); + data[2] = element(data[2], shfl_xor_helper(warp, data[0], i)); + data[3] = element(data[3], shfl_xor_helper(warp, data[0], i)); } } diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/utils_paralleldequant.cuh b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/utils_paralleldequant.cuh index 11603fcc576c..6178cc116b7f 100644 --- a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/utils_paralleldequant.cuh +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/utils_paralleldequant.cuh @@ -120,7 +120,7 @@ __device__ __forceinline__ void ExtractFromSharedToReg_Scales(uint32_t* Scales, #pragma unroll for (int i = 0; i < 4; i++) { // T __shfl_sync(unsigned mask, T var, int srcLane, int width=warpSize); - Scales[i] = __shfl_sync(0xffffffff, tmpReg, i, 4); + Scales[i] = __shfl_sync(static_cast(0xffffffff), tmpReg, i, 4); } } diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/linear_kernels_cuda.cu b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/linear_kernels_cuda.cu index ea0203c42f84..2b7feb588373 100644 --- a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/linear_kernels_cuda.cu +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/linear_kernels_cuda.cu @@ -45,7 +45,8 @@ static void Kernel_Ex(cudaStream_t stream, static size_t SHMEM_SZ = max(TilingConfig::SMEM_SIZE_B_TILE + SMEM_SIZE_A1_TILE + SMEM_SIZE_A2_TILE, TilingConfig::SMEM_SIZE_C_TILE); - cudaFuncSetAttribute(QUANT_GEMM_Kernel, + auto kernel = QUANT_GEMM_Kernel; + cudaFuncSetAttribute(reinterpret_cast(kernel), cudaFuncAttributeMaxDynamicSharedMemorySize, SHMEM_SZ); size_t dimN = (N_Global - 1) / TilingConfig::TILE_N + 1; diff --git a/deepspeed/inference/v2/kernels/cutlass_ops/mixed_gemm/mixed_gemm.cu b/deepspeed/inference/v2/kernels/cutlass_ops/mixed_gemm/mixed_gemm.cu index 7c522203bb48..25682ca91bd3 100644 --- a/deepspeed/inference/v2/kernels/cutlass_ops/mixed_gemm/mixed_gemm.cu +++ b/deepspeed/inference/v2/kernels/cutlass_ops/mixed_gemm/mixed_gemm.cu @@ -4,6 +4,7 @@ // DeepSpeed Team #include +#include #include "mixed_gemm.h" #include "mixed_gemm_api.h" #include "weight_variant.h" diff --git a/deepspeed/inference/v2/kernels/cutlass_ops/moe_gemm/moe_gemm.cu b/deepspeed/inference/v2/kernels/cutlass_ops/moe_gemm/moe_gemm.cu index d1cafc9fff4c..ac875d7a73f1 100644 --- a/deepspeed/inference/v2/kernels/cutlass_ops/moe_gemm/moe_gemm.cu +++ b/deepspeed/inference/v2/kernels/cutlass_ops/moe_gemm/moe_gemm.cu @@ -4,6 +4,7 @@ // DeepSpeed Team #include +#include #include "moe_gemm.h" #include "moe_gemm_api.h" #include "weight_variant.h" From 4490ea5fb0da2854c14d2d96a1f3bb136cc05a09 Mon Sep 17 00:00:00 2001 From: "artem.kuzmitckii@amd.com" Date: Fri, 25 Jul 2025 15:12:45 +0000 Subject: [PATCH 2/7] [AMD][ROCm] Fixes review comments Signed-off-by: Artem Kuzmitckii --- csrc/deepspeed4science/evoformer_attn/gemm_kernel_utils.h | 4 ++++ csrc/fp_quantizer/fp_quantize.cpp | 5 ++++- csrc/fp_quantizer/fp_quantize.cu | 2 ++ csrc/includes/reduction_utils.h | 3 ++- .../core_ops/cuda_linear/include/utils_paralleldequant.cuh | 4 ++++ .../v2/kernels/cutlass_ops/mixed_gemm/mixed_gemm.cu | 1 - .../inference/v2/kernels/cutlass_ops/moe_gemm/moe_gemm.cu | 1 - 7 files changed, 16 insertions(+), 4 deletions(-) diff --git a/csrc/deepspeed4science/evoformer_attn/gemm_kernel_utils.h b/csrc/deepspeed4science/evoformer_attn/gemm_kernel_utils.h index 24e5d1e997cc..b1b51d22a133 100644 --- a/csrc/deepspeed4science/evoformer_attn/gemm_kernel_utils.h +++ b/csrc/deepspeed4science/evoformer_attn/gemm_kernel_utils.h @@ -233,7 +233,11 @@ struct call_conditional { CUTLASS_DEVICE int32_t warp_uniform(int32_t value) { +#if defined(__HIP_PLATFORM_AMD__) return (int32_t)__shfl_sync(static_cast(0xffffffff), (unsigned)value, 0); +#else + return (int32_t)__shfl_sync(0xffffffff, (unsigned)value, 0); +#endif } template diff --git a/csrc/fp_quantizer/fp_quantize.cpp b/csrc/fp_quantizer/fp_quantize.cpp index e273f442d34a..bf13ceb5743c 100644 --- a/csrc/fp_quantizer/fp_quantize.cpp +++ b/csrc/fp_quantizer/fp_quantize.cpp @@ -6,10 +6,13 @@ #include "fp_quantize.h" #include -#include #include #include +#if defined(__HIP_PLATFORM_AMD__) +#include +#endif + #define DISPATCH_QUANTIZE(T_TYPE, C_TYPE, mantisa, exponent) \ if (val.options().dtype() == torch::T_TYPE) { \ launch_quantization((C_TYPE*)val.data_ptr(), \ diff --git a/csrc/fp_quantizer/fp_quantize.cu b/csrc/fp_quantizer/fp_quantize.cu index 1ec80f543c5f..bc720a0e47f7 100644 --- a/csrc/fp_quantizer/fp_quantize.cu +++ b/csrc/fp_quantizer/fp_quantize.cu @@ -14,7 +14,9 @@ #include #include +#if defined(__HIP_PLATFORM_AMD__) #include +#endif #ifdef BF16_AVAILABLE #include diff --git a/csrc/includes/reduction_utils.h b/csrc/includes/reduction_utils.h index b736e511cb85..1aa8a408f88d 100644 --- a/csrc/includes/reduction_utils.h +++ b/csrc/includes/reduction_utils.h @@ -538,7 +538,8 @@ DS_D_INLINE __half shfl_xor_helper<__half>(cg::thread_block_tile& const __half& value, int i) { - return __half(warp.shfl_xor(float(value), i)); + float fvalue = __half2float(value); + return __half(warp.shfl_xor(fvalue, i)); } #endif diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/utils_paralleldequant.cuh b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/utils_paralleldequant.cuh index 6178cc116b7f..610480297e50 100644 --- a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/utils_paralleldequant.cuh +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/utils_paralleldequant.cuh @@ -120,7 +120,11 @@ __device__ __forceinline__ void ExtractFromSharedToReg_Scales(uint32_t* Scales, #pragma unroll for (int i = 0; i < 4; i++) { // T __shfl_sync(unsigned mask, T var, int srcLane, int width=warpSize); +#if defined(__HIP_PLATFORM_AMD__) Scales[i] = __shfl_sync(static_cast(0xffffffff), tmpReg, i, 4); +#else + Scales[i] = __shfl_sync(0xffffffff, tmpReg, i, 4); +#endif } } diff --git a/deepspeed/inference/v2/kernels/cutlass_ops/mixed_gemm/mixed_gemm.cu b/deepspeed/inference/v2/kernels/cutlass_ops/mixed_gemm/mixed_gemm.cu index 25682ca91bd3..7c522203bb48 100644 --- a/deepspeed/inference/v2/kernels/cutlass_ops/mixed_gemm/mixed_gemm.cu +++ b/deepspeed/inference/v2/kernels/cutlass_ops/mixed_gemm/mixed_gemm.cu @@ -4,7 +4,6 @@ // DeepSpeed Team #include -#include #include "mixed_gemm.h" #include "mixed_gemm_api.h" #include "weight_variant.h" diff --git a/deepspeed/inference/v2/kernels/cutlass_ops/moe_gemm/moe_gemm.cu b/deepspeed/inference/v2/kernels/cutlass_ops/moe_gemm/moe_gemm.cu index ac875d7a73f1..d1cafc9fff4c 100644 --- a/deepspeed/inference/v2/kernels/cutlass_ops/moe_gemm/moe_gemm.cu +++ b/deepspeed/inference/v2/kernels/cutlass_ops/moe_gemm/moe_gemm.cu @@ -4,7 +4,6 @@ // DeepSpeed Team #include -#include #include "moe_gemm.h" #include "moe_gemm_api.h" #include "weight_variant.h" From 77a7e069991e2f04455fa9db4a15ff67beb63c52 Mon Sep 17 00:00:00 2001 From: "artem.kuzmitckii@amd.com" Date: Sun, 3 Aug 2025 13:43:15 +0000 Subject: [PATCH 3/7] [AMD][ROCm] Fixes review comments part 2 Signed-off-by: Artem Kuzmitckii --- .../evoformer_attn/gemm_kernel_utils.h | 4 ---- csrc/includes/reduction_utils.h | 12 ++++++------ .../cuda_linear/include/utils_paralleldequant.cuh | 4 ---- 3 files changed, 6 insertions(+), 14 deletions(-) diff --git a/csrc/deepspeed4science/evoformer_attn/gemm_kernel_utils.h b/csrc/deepspeed4science/evoformer_attn/gemm_kernel_utils.h index b1b51d22a133..c102234a4dfb 100644 --- a/csrc/deepspeed4science/evoformer_attn/gemm_kernel_utils.h +++ b/csrc/deepspeed4science/evoformer_attn/gemm_kernel_utils.h @@ -233,11 +233,7 @@ struct call_conditional { CUTLASS_DEVICE int32_t warp_uniform(int32_t value) { -#if defined(__HIP_PLATFORM_AMD__) - return (int32_t)__shfl_sync(static_cast(0xffffffff), (unsigned)value, 0); -#else return (int32_t)__shfl_sync(0xffffffff, (unsigned)value, 0); -#endif } template diff --git a/csrc/includes/reduction_utils.h b/csrc/includes/reduction_utils.h index 1aa8a408f88d..1e60ad4dec39 100644 --- a/csrc/includes/reduction_utils.h +++ b/csrc/includes/reduction_utils.h @@ -558,7 +558,7 @@ DS_D_INLINE void _warp(cg::thread_block_tile& warp, T* data) #pragma unroll for (int i = 1; i < reduce_width; i *= 2) { data[0] = element(data[0], shfl_xor_helper(warp, data[0], i)); - data[1] = element(data[1], shfl_xor_helper(warp, data[0], i)); + data[1] = element(data[1], shfl_xor_helper(warp, data[1], i)); } } @@ -568,8 +568,8 @@ DS_D_INLINE void _warp(cg::thread_block_tile& warp, T* data) #pragma unroll for (int i = 1; i < reduce_width; i *= 2) { data[0] = element(data[0], shfl_xor_helper(warp, data[0], i)); - data[1] = element(data[1], shfl_xor_helper(warp, data[0], i)); - data[2] = element(data[2], shfl_xor_helper(warp, data[0], i)); + data[1] = element(data[1], shfl_xor_helper(warp, data[1], i)); + data[2] = element(data[2], shfl_xor_helper(warp, data[2], i)); } } @@ -584,9 +584,9 @@ DS_D_INLINE void _warp(cg::thread_block_tile& warp, T* data) #pragma unroll for (int i = 1; i < reduce_width; i *= 2) { data[0] = element(data[0], shfl_xor_helper(warp, data[0], i)); - data[1] = element(data[1], shfl_xor_helper(warp, data[0], i)); - data[2] = element(data[2], shfl_xor_helper(warp, data[0], i)); - data[3] = element(data[3], shfl_xor_helper(warp, data[0], i)); + data[1] = element(data[1], shfl_xor_helper(warp, data[1], i)); + data[2] = element(data[2], shfl_xor_helper(warp, data[2], i)); + data[3] = element(data[3], shfl_xor_helper(warp, data[3], i)); } } diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/utils_paralleldequant.cuh b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/utils_paralleldequant.cuh index 610480297e50..11603fcc576c 100644 --- a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/utils_paralleldequant.cuh +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/utils_paralleldequant.cuh @@ -120,11 +120,7 @@ __device__ __forceinline__ void ExtractFromSharedToReg_Scales(uint32_t* Scales, #pragma unroll for (int i = 0; i < 4; i++) { // T __shfl_sync(unsigned mask, T var, int srcLane, int width=warpSize); -#if defined(__HIP_PLATFORM_AMD__) - Scales[i] = __shfl_sync(static_cast(0xffffffff), tmpReg, i, 4); -#else Scales[i] = __shfl_sync(0xffffffff, tmpReg, i, 4); -#endif } } From 0946828b9a767ec9cc21a3381f8c8438bedd6e30 Mon Sep 17 00:00:00 2001 From: Artem Kuzmitckii Date: Mon, 18 Aug 2025 13:19:36 -0400 Subject: [PATCH 4/7] [AMD][ROCm] Enable BF16 and fixes review's comment Signed-off-by: Artem Kuzmitckii --- .../{fp_quantize.cpp => fp_quantize_api.cu} | 3 + csrc/includes/conversion_utils.h | 69 ++++++++++++++++ csrc/includes/reduction_utils.h | 81 ++++++++++++------- .../csrc/{pt_binding.cpp => pt_binding.cu} | 0 op_builder/fp_quantizer.py | 4 +- op_builder/transformer_inference.py | 4 +- 6 files changed, 131 insertions(+), 30 deletions(-) rename csrc/fp_quantizer/{fp_quantize.cpp => fp_quantize_api.cu} (99%) rename csrc/transformer/inference/csrc/{pt_binding.cpp => pt_binding.cu} (100%) diff --git a/csrc/fp_quantizer/fp_quantize.cpp b/csrc/fp_quantizer/fp_quantize_api.cu similarity index 99% rename from csrc/fp_quantizer/fp_quantize.cpp rename to csrc/fp_quantizer/fp_quantize_api.cu index bf13ceb5743c..9e373b1b7f5d 100644 --- a/csrc/fp_quantizer/fp_quantize.cpp +++ b/csrc/fp_quantizer/fp_quantize_api.cu @@ -11,6 +11,9 @@ #if defined(__HIP_PLATFORM_AMD__) #include +#if BF16_AVAILABLE +#include +#endif #endif #define DISPATCH_QUANTIZE(T_TYPE, C_TYPE, mantisa, exponent) \ diff --git a/csrc/includes/conversion_utils.h b/csrc/includes/conversion_utils.h index 3a90a3e91ddf..99b0363bc27b 100644 --- a/csrc/includes/conversion_utils.h +++ b/csrc/includes/conversion_utils.h @@ -59,6 +59,7 @@ DS_D_INLINE __half to(__half val) { return val; } + #ifdef BF16_AVAILABLE template <> DS_D_INLINE __nv_bfloat16 to(__nv_bfloat16 val) @@ -363,42 +364,74 @@ DS_D_INLINE __nv_bfloat16 to(float val) template <> DS_D_INLINE __nv_bfloat16 to(int64_t val) { +#ifdef __HIP_PLATFORM_AMD__ + return __double2bfloat16(__ll2double_rn(val)); +#else return __ll2bfloat16_rn(val); +#endif } template <> DS_D_INLINE __nv_bfloat16 to(int32_t val) { +#ifdef __HIP_PLATFORM_AMD__ + return __float2bfloat16(__int2float_rn(val)); +#else return __int2bfloat16_rn(val); +#endif } template <> DS_D_INLINE __nv_bfloat16 to(int16_t val) { +#ifdef __HIP_PLATFORM_AMD__ + return __float2bfloat16(__int2float_rn(val)); +#else return __short2bfloat16_rn(val); +#endif } template <> DS_D_INLINE __nv_bfloat16 to(int8_t val) { +#ifdef __HIP_PLATFORM_AMD__ + return __float2bfloat16(__int2float_rn(val)); +#else return __int2bfloat16_rn(val); +#endif } template <> DS_D_INLINE __nv_bfloat16 to(uint64_t val) { +#ifdef __HIP_PLATFORM_AMD__ + return __double2bfloat16(__ull2double_rn(val)); +#else return __ull2bfloat16_rn(val); +#endif } template <> DS_D_INLINE __nv_bfloat16 to(uint32_t val) { +#ifdef __HIP_PLATFORM_AMD__ + return __float2bfloat16(__uint2float_rn(val)); +#else return __uint2bfloat16_rn(val); +#endif } template <> DS_D_INLINE __nv_bfloat16 to(uint16_t val) { +#ifdef __HIP_PLATFORM_AMD__ + return __float2bfloat16(__uint2float_rn(val)); +#else return __ushort2bfloat16_rn(val); +#endif } template <> DS_D_INLINE __nv_bfloat16 to(uint8_t val) { +#ifdef __HIP_PLATFORM_AMD__ + return __float2bfloat16(__uint2float_rn(val)); +#else return __uint2bfloat16_rn(val); +#endif } #endif @@ -412,7 +445,11 @@ DS_D_INLINE __nv_bfloat162 to(float2 val) template <> DS_D_INLINE __nv_bfloat162 to(float val) { +#ifdef __HIP_PLATFORM_AMD__ + return __bfloat162bfloat162(__float2bfloat16(val)); +#else return __float2bfloat162_rn(val); +#endif } template <> DS_D_INLINE __nv_bfloat162 to(__half2 val) @@ -444,7 +481,11 @@ DS_D_INLINE int64_t to(__half val) template <> DS_D_INLINE int64_t to(__nv_bfloat16 val) { +#ifdef __HIP_PLATFORM_AMD__ + return __float2ll_rn(__bfloat162float(val)); +#else return __bfloat162ll_rn(val); +#endif } #endif @@ -471,7 +512,11 @@ DS_D_INLINE int32_t to(__half val) template <> DS_D_INLINE int32_t to(__nv_bfloat16 val) { +#ifdef __HIP_PLATFORM_AMD__ + return __float2int_rn(__bfloat162float(val)); +#else return __bfloat162int_rn(val); +#endif } #endif @@ -498,7 +543,11 @@ DS_D_INLINE int16_t to(__half val) template <> DS_D_INLINE int16_t to(__nv_bfloat16 val) { +#ifdef __HIP_PLATFORM_AMD__ + return __float2int_rn(__bfloat162float(val)); +#else return __bfloat162int_rn(val); +#endif } #endif @@ -525,7 +574,11 @@ DS_D_INLINE int8_t to(__half val) template <> DS_D_INLINE int8_t to(__nv_bfloat16 val) { +#ifdef __HIP_PLATFORM_AMD__ + return __float2int_rn(__bfloat162float(val)); +#else return __bfloat162int_rn(val); +#endif } #endif @@ -552,7 +605,11 @@ DS_D_INLINE uint64_t to(__half val) template <> DS_D_INLINE uint64_t to(__nv_bfloat16 val) { +#ifdef __HIP_PLATFORM_AMD__ + return __float2ull_rn(__bfloat162float(val)); +#else return __bfloat162ull_rn(val); +#endif } #endif @@ -579,7 +636,11 @@ DS_D_INLINE uint32_t to(__half val) template <> DS_D_INLINE uint32_t to(__nv_bfloat16 val) { +#ifdef __HIP_PLATFORM_AMD__ + return __float2uint_rn(__bfloat162float(val)); +#else return __bfloat162uint_rn(val); +#endif } #endif @@ -606,7 +667,11 @@ DS_D_INLINE uint16_t to(__half val) template <> DS_D_INLINE uint16_t to(__nv_bfloat16 val) { +#ifdef __HIP_PLATFORM_AMD__ + return __float2uint_rn(__bfloat162float(val)); +#else return __bfloat162uint_rn(val); +#endif } #endif @@ -633,7 +698,11 @@ DS_D_INLINE uint8_t to(__half val) template <> DS_D_INLINE uint8_t to(__nv_bfloat16 val) { +#ifdef __HIP_PLATFORM_AMD__ + return __float2uint_rn(__bfloat162float(val)); +#else return __bfloat162uint_rn(val); +#endif } #endif diff --git a/csrc/includes/reduction_utils.h b/csrc/includes/reduction_utils.h index 1e60ad4dec39..834f7d2c4b5b 100644 --- a/csrc/includes/reduction_utils.h +++ b/csrc/includes/reduction_utils.h @@ -9,6 +9,10 @@ #include "ds_kernel_utils.h" #include "memory_access_utils.h" +#if defined(BF16_AVAILABLE) && defined(__HIP_PLATFORM_AMD__) +#include +#endif + namespace cg = cooperative_groups; namespace reduce { @@ -374,7 +378,11 @@ DS_D_INLINE __half init() template <> DS_D_INLINE __nv_bfloat16 init() { +#ifdef __HIP_PLATFORM_AMD__ + constexpr __hip_bfloat16_raw neg_inf = {0xFF80}; +#else constexpr __nv_bfloat16_raw neg_inf = {0xFF80}; +#endif return __nv_bfloat16(neg_inf); } #endif @@ -526,29 +534,12 @@ here (fold is C++17 only and I don't think helps and recursion feels like huge overkill that harms readability) that would be wonderful. */ -template -DS_D_INLINE T shfl_xor_helper(cg::thread_block_tile& warp, const T& value, int i) -{ - return warp.shfl_xor(value, i); -} - -#if defined(__HIP_PLATFORM_AMD__) -template <> -DS_D_INLINE __half shfl_xor_helper<__half>(cg::thread_block_tile& warp, - const __half& value, - int i) -{ - float fvalue = __half2float(value); - return __half(warp.shfl_xor(fvalue, i)); -} -#endif - template DS_D_INLINE void _warp(cg::thread_block_tile& warp, T* data) { #pragma unroll for (int i = 1; i < reduce_width; i *= 2) { - data[0] = element(data[0], shfl_xor_helper(warp, data[0], i)); + data[0] = element(data[0], warp.shfl_xor(data[0], i)); } } @@ -557,8 +548,8 @@ DS_D_INLINE void _warp(cg::thread_block_tile& warp, T* data) { #pragma unroll for (int i = 1; i < reduce_width; i *= 2) { - data[0] = element(data[0], shfl_xor_helper(warp, data[0], i)); - data[1] = element(data[1], shfl_xor_helper(warp, data[1], i)); + data[0] = element(data[0], warp.shfl_xor(data[0], i)); + data[1] = element(data[1], warp.shfl_xor(data[1], i)); } } @@ -567,9 +558,9 @@ DS_D_INLINE void _warp(cg::thread_block_tile& warp, T* data) { #pragma unroll for (int i = 1; i < reduce_width; i *= 2) { - data[0] = element(data[0], shfl_xor_helper(warp, data[0], i)); - data[1] = element(data[1], shfl_xor_helper(warp, data[1], i)); - data[2] = element(data[2], shfl_xor_helper(warp, data[2], i)); + data[0] = element(data[0], warp.shfl_xor(data[0], i)); + data[1] = element(data[1], warp.shfl_xor(data[1], i)); + data[2] = element(data[2], warp.shfl_xor(data[2], i)); } } @@ -583,13 +574,39 @@ DS_D_INLINE void _warp(cg::thread_block_tile& warp, T* data) { #pragma unroll for (int i = 1; i < reduce_width; i *= 2) { - data[0] = element(data[0], shfl_xor_helper(warp, data[0], i)); - data[1] = element(data[1], shfl_xor_helper(warp, data[1], i)); - data[2] = element(data[2], shfl_xor_helper(warp, data[2], i)); - data[3] = element(data[3], shfl_xor_helper(warp, data[3], i)); + data[0] = element(data[0], warp.shfl_xor(data[0], i)); + data[1] = element(data[1], warp.shfl_xor(data[1], i)); + data[2] = element(data[2], warp.shfl_xor(data[2], i)); + data[3] = element(data[3], warp.shfl_xor(data[3], i)); } } +#if defined(__HIP_PLATFORM_AMD__) +template +DS_D_INLINE void _warp_with_type_conversion( + cg::thread_block_tile& warp_arg, + T* data) +{ + constexpr int elems = sizeof...(Ops); + if constexpr ( + !(std::is_integral::value || std::is_floating_point::value) + ) { + float temp_data[elems]; +#pragma unroll + for (int i = 0; i < elems; i++) { + temp_data[i] = conversion::to(data[i]); + } + _warp(warp_arg, temp_data); +#pragma unroll + for (int i = 0; i < elems; i++) { + data[i] = conversion::to(temp_data[i]); + } + } else { + _warp(warp_arg, data); + } +} +#endif // defined(__HIP_PLATFORM_AMD__) + /* Implementation for primary block reduction that serves both `block` and `partitioned_block`. @@ -617,7 +634,11 @@ DS_D_INLINE void _block(cg::thread_block& tb, #endif // Always perform warp-scope reduction +#ifdef __HIP_PLATFORM_AMD__ + _warp_with_type_conversion(warp_arg, data); +#else _warp(warp_arg, data); +#endif // If max_warps == 1 let's skip the runtime check if (total_warps != 1) { @@ -641,8 +662,12 @@ DS_D_INLINE void _block(cg::thread_block& tb, } else { init(data); } - +#ifdef __HIP_PLATFORM_AMD__ + _warp_with_type_conversion(warp_arg, data); +#else _warp(warp_arg, data); +#endif + #pragma unroll for (int i = 0; i < elems; i++) { diff --git a/csrc/transformer/inference/csrc/pt_binding.cpp b/csrc/transformer/inference/csrc/pt_binding.cu similarity index 100% rename from csrc/transformer/inference/csrc/pt_binding.cpp rename to csrc/transformer/inference/csrc/pt_binding.cu diff --git a/op_builder/fp_quantizer.py b/op_builder/fp_quantizer.py index 2b962ac2c1fe..b9202eeac177 100644 --- a/op_builder/fp_quantizer.py +++ b/op_builder/fp_quantizer.py @@ -18,6 +18,8 @@ class FPQuantizerBuilder(CUDAOpBuilder): def __init__(self, name=None): name = self.NAME if name is None else name super().__init__(name=name) + if self.is_rocm_pytorch(): + self.enable_bf16 = True def absolute_name(self): return f'deepspeed.ops.fp_quantizer.{self.NAME}_op' @@ -90,7 +92,7 @@ def filter_ccs(self, ccs): def sources(self): return [ "csrc/fp_quantizer/fp_quantize.cu", - "csrc/fp_quantizer/fp_quantize.cpp", + "csrc/fp_quantizer/fp_quantize_api.cu", ] def extra_ldflags(self): diff --git a/op_builder/transformer_inference.py b/op_builder/transformer_inference.py index 642aed56a192..1ea4da92b5f2 100755 --- a/op_builder/transformer_inference.py +++ b/op_builder/transformer_inference.py @@ -13,6 +13,8 @@ class InferenceBuilder(CUDAOpBuilder): def __init__(self, name=None): name = self.NAME if name is None else name super().__init__(name=name) + if self.is_rocm_pytorch(): + self.enable_bf16 = True def absolute_name(self): return f'deepspeed.ops.transformer.inference.{self.NAME}_op' @@ -55,7 +57,7 @@ def filter_ccs(self, ccs): def sources(self): return [ - 'csrc/transformer/inference/csrc/pt_binding.cpp', + 'csrc/transformer/inference/csrc/pt_binding.cu', 'csrc/transformer/inference/csrc/gelu.cu', 'csrc/transformer/inference/csrc/relu.cu', 'csrc/transformer/inference/csrc/layer_norm.cu', From a23815a6c0fbe02535747654fc78b8354c4e2076 Mon Sep 17 00:00:00 2001 From: Artem Kuzmitckii Date: Thu, 21 Aug 2025 10:30:58 +0000 Subject: [PATCH 5/7] [AMD][ROCm] Fix format Signed-off-by: Artem Kuzmitckii --- csrc/includes/reduction_utils.h | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/csrc/includes/reduction_utils.h b/csrc/includes/reduction_utils.h index 834f7d2c4b5b..68ec106975b6 100644 --- a/csrc/includes/reduction_utils.h +++ b/csrc/includes/reduction_utils.h @@ -583,29 +583,21 @@ DS_D_INLINE void _warp(cg::thread_block_tile& warp, T* data) #if defined(__HIP_PLATFORM_AMD__) template -DS_D_INLINE void _warp_with_type_conversion( - cg::thread_block_tile& warp_arg, - T* data) +DS_D_INLINE void _warp_with_type_conversion(cg::thread_block_tile& warp_arg, T* data) { constexpr int elems = sizeof...(Ops); - if constexpr ( - !(std::is_integral::value || std::is_floating_point::value) - ) { + if constexpr (!(std::is_integral::value || std::is_floating_point::value)) { float temp_data[elems]; #pragma unroll - for (int i = 0; i < elems; i++) { - temp_data[i] = conversion::to(data[i]); - } + for (int i = 0; i < elems; i++) { temp_data[i] = conversion::to(data[i]); } _warp(warp_arg, temp_data); #pragma unroll - for (int i = 0; i < elems; i++) { - data[i] = conversion::to(temp_data[i]); - } + for (int i = 0; i < elems; i++) { data[i] = conversion::to(temp_data[i]); } } else { _warp(warp_arg, data); } } -#endif // defined(__HIP_PLATFORM_AMD__) +#endif // defined(__HIP_PLATFORM_AMD__) /* Implementation for primary block reduction that serves both `block` and @@ -668,7 +660,6 @@ DS_D_INLINE void _block(cg::thread_block& tb, _warp(warp_arg, data); #endif - #pragma unroll for (int i = 0; i < elems; i++) { mem_access::store_shared(reduce_buffer + elems * warp_arg.thread_rank() + i, From 4904d948f15c35d02602a7beef5cd1392b7ce1e0 Mon Sep 17 00:00:00 2001 From: Artem Kuzmitckii Date: Mon, 13 Oct 2025 14:24:02 +0000 Subject: [PATCH 6/7] Fix BF16 support for AMD Signed-off-by: Artem Kuzmitckii --- .../{fp_quantize_api.cu => fp_quantize.cpp} | 7 ------- csrc/fp_quantizer/includes/fp_quantize.h | 12 ++++++++++-- csrc/includes/ds_kernel_utils.h | 10 ++++++++-- .../csrc/{pt_binding.cu => pt_binding.cpp} | 0 .../core_ops/cuda_linear/linear_kernels_cuda.cu | 4 ++-- op_builder/builder.py | 1 + op_builder/fp_quantizer.py | 4 +--- op_builder/transformer_inference.py | 15 ++++++++++++--- tests/unit/ops/fp_quantizer/test_fp_quant.py | 4 ++-- 9 files changed, 36 insertions(+), 21 deletions(-) rename csrc/fp_quantizer/{fp_quantize_api.cu => fp_quantize.cpp} (97%) rename csrc/transformer/inference/csrc/{pt_binding.cu => pt_binding.cpp} (100%) diff --git a/csrc/fp_quantizer/fp_quantize_api.cu b/csrc/fp_quantizer/fp_quantize.cpp similarity index 97% rename from csrc/fp_quantizer/fp_quantize_api.cu rename to csrc/fp_quantizer/fp_quantize.cpp index 9e373b1b7f5d..1a887b50e1a3 100644 --- a/csrc/fp_quantizer/fp_quantize_api.cu +++ b/csrc/fp_quantizer/fp_quantize.cpp @@ -9,13 +9,6 @@ #include #include -#if defined(__HIP_PLATFORM_AMD__) -#include -#if BF16_AVAILABLE -#include -#endif -#endif - #define DISPATCH_QUANTIZE(T_TYPE, C_TYPE, mantisa, exponent) \ if (val.options().dtype() == torch::T_TYPE) { \ launch_quantization((C_TYPE*)val.data_ptr(), \ diff --git a/csrc/fp_quantizer/includes/fp_quantize.h b/csrc/fp_quantizer/includes/fp_quantize.h index 60c75541f603..a15b8ddf5a22 100644 --- a/csrc/fp_quantizer/includes/fp_quantize.h +++ b/csrc/fp_quantizer/includes/fp_quantize.h @@ -9,10 +9,18 @@ #include #include - -#ifdef BF16_AVAILABLE +// Note: BF16 support on AMD but we have to exclude here cuda_bf16.h (which turn to +// after hipifying), because this header is pulled into .cpp translation units +// that are compiled by a host-only compiler, which triggers build errors. Added forward declaration +// instead, see code block below +#if defined(BF16_AVAILABLE) +#if !defined(__HIP_PLATFORM_AMD__) #include +#else +struct __hip_bfloat16; +#endif #endif + #include #include diff --git a/csrc/includes/ds_kernel_utils.h b/csrc/includes/ds_kernel_utils.h index f8b16ee6a315..cb8b0b28484e 100644 --- a/csrc/includes/ds_kernel_utils.h +++ b/csrc/includes/ds_kernel_utils.h @@ -13,7 +13,11 @@ used throughout the codebase. #include #include -#ifdef BF16_AVAILABLE +// Note: BF16 support on AMD but we have to exclude here cuda_bf16.h (which turn to +// after hipifying), because this header is pulled into .cpp translation units +// that are compiled by a host-only compiler, which triggers build errors. Added forward declaration +// instead, see code block below +#if defined(BF16_AVAILABLE) && !defined(__HIP_PLATFORM_AMD__) #include #endif @@ -21,7 +25,9 @@ used throughout the codebase. #define DS_D_INLINE __device__ __forceinline__ #ifdef __HIP_PLATFORM_AMD__ - +#if BF16_AVAILABLE +struct __hip_bfloat16; +#endif // constexpr variant of warpSize for templating constexpr int hw_warp_size = ROCM_WAVEFRONT_SIZE; #define HALF_PRECISION_AVAILABLE = 1 diff --git a/csrc/transformer/inference/csrc/pt_binding.cu b/csrc/transformer/inference/csrc/pt_binding.cpp similarity index 100% rename from csrc/transformer/inference/csrc/pt_binding.cu rename to csrc/transformer/inference/csrc/pt_binding.cpp diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/linear_kernels_cuda.cu b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/linear_kernels_cuda.cu index 2b7feb588373..74112236b6df 100644 --- a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/linear_kernels_cuda.cu +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/linear_kernels_cuda.cu @@ -45,10 +45,10 @@ static void Kernel_Ex(cudaStream_t stream, static size_t SHMEM_SZ = max(TilingConfig::SMEM_SIZE_B_TILE + SMEM_SIZE_A1_TILE + SMEM_SIZE_A2_TILE, TilingConfig::SMEM_SIZE_C_TILE); - auto kernel = QUANT_GEMM_Kernel; - cudaFuncSetAttribute(reinterpret_cast(kernel), + cudaFuncSetAttribute(QUANT_GEMM_Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, SHMEM_SZ); + size_t dimN = (N_Global - 1) / TilingConfig::TILE_N + 1; size_t dimM = M_Global * Split_K / TilingConfig::TILE_M; dim3 GridDim(dimN, dimM, 1); diff --git a/op_builder/builder.py b/op_builder/builder.py index 13853df92a8b..926dd2fd9cc9 100644 --- a/op_builder/builder.py +++ b/op_builder/builder.py @@ -778,6 +778,7 @@ def nvcc_args(self): '-DROCM_VERSION_MAJOR=%s' % ROCM_MAJOR, '-DROCM_VERSION_MINOR=%s' % ROCM_MINOR ] + self.enable_bf16 = True else: try: nvcc_threads = int(os.getenv("DS_NVCC_THREADS", "")) diff --git a/op_builder/fp_quantizer.py b/op_builder/fp_quantizer.py index 5afc49bfcd6e..5ccc35ac2b1f 100644 --- a/op_builder/fp_quantizer.py +++ b/op_builder/fp_quantizer.py @@ -18,8 +18,6 @@ class FPQuantizerBuilder(CUDAOpBuilder): def __init__(self, name=None): name = self.NAME if name is None else name super().__init__(name=name) - if self.is_rocm_pytorch(): - self.enable_bf16 = True def absolute_name(self): return f'deepspeed.ops.fp_quantizer.{self.NAME}_op' @@ -92,7 +90,7 @@ def filter_ccs(self, ccs): def sources(self): return [ "csrc/fp_quantizer/fp_quantize.cu", - "csrc/fp_quantizer/fp_quantize_api.cu", + "csrc/fp_quantizer/fp_quantize.cpp", ] def extra_ldflags(self): diff --git a/op_builder/transformer_inference.py b/op_builder/transformer_inference.py index 1ea4da92b5f2..3afa74dc31c2 100755 --- a/op_builder/transformer_inference.py +++ b/op_builder/transformer_inference.py @@ -13,8 +13,6 @@ class InferenceBuilder(CUDAOpBuilder): def __init__(self, name=None): name = self.NAME if name is None else name super().__init__(name=name) - if self.is_rocm_pytorch(): - self.enable_bf16 = True def absolute_name(self): return f'deepspeed.ops.transformer.inference.{self.NAME}_op' @@ -57,7 +55,7 @@ def filter_ccs(self, ccs): def sources(self): return [ - 'csrc/transformer/inference/csrc/pt_binding.cu', + 'csrc/transformer/inference/csrc/pt_binding.cpp', 'csrc/transformer/inference/csrc/gelu.cu', 'csrc/transformer/inference/csrc/relu.cu', 'csrc/transformer/inference/csrc/layer_norm.cu', @@ -77,3 +75,14 @@ def extra_ldflags(self): def include_paths(self): return ['csrc/transformer/inference/includes', 'csrc/includes'] + + def nvcc_args(self): + args = super().nvcc_args() + """BF16 is supported on AMD, but including `cuda_bf16.h` (`` after hipification) + in host-only translation units (*.cpp files) fails because GPU-specific builtins are pulled in with the BF16 type. + This cannot be avoided via forward declarations for this transformer_inference extension, + since `pt_binding.cpp` code explicitly requires the BF16 header, so disable it for now. + """ + if self.is_rocm_pytorch(): + self.enable_bf16 = False + return args diff --git a/tests/unit/ops/fp_quantizer/test_fp_quant.py b/tests/unit/ops/fp_quantizer/test_fp_quant.py index e9baf016310e..0655b0ce26a3 100644 --- a/tests/unit/ops/fp_quantizer/test_fp_quant.py +++ b/tests/unit/ops/fp_quantizer/test_fp_quant.py @@ -57,7 +57,7 @@ def test_fp_quant_meta(dtype): qtorch_out = qtorch_quantize(x, exp_bits=exp_bits, man_bits=man_bits, group_size=group_size) qtorch_error = (qtorch_out - x).abs().sum() / x.numel() - ds_error = (x_dequantized - x).abs().sum() / x.numel() + ds_error = (x_dequantized - ds_x).abs().sum() / x.numel() assert 0.0004 > abs(qtorch_error.item() - ds_error.item()), f"failed on iteration {i}" @@ -129,6 +129,6 @@ def test_fp_quant(dtype, q_bits): qtorch_out = qtorch_quantize(x, exp_bits=exp_bits, man_bits=man_bits, group_size=quant_config.group_size) qtorch_error = (qtorch_out - x).abs().sum() / x.numel() - ds_error = (x_dequantized - x).abs().sum() / x.numel() + ds_error = (x_dequantized - ds_x).abs().sum() / x.numel() assert 0.0004 > abs(qtorch_error.item() - ds_error.item()), f"failed on iteration {i}" From 4a1d7b76e3ff6ffd1d091cc43f47cb8f98600fe9 Mon Sep 17 00:00:00 2001 From: Artem Kuzmitckii Date: Mon, 13 Oct 2025 14:57:19 +0000 Subject: [PATCH 7/7] Remove unnecessary changes Signed-off-by: Artem Kuzmitckii --- csrc/fp_quantizer/fp_quantize.cu | 3 --- csrc/includes/conversion_utils.h | 1 - .../v2/kernels/core_ops/cuda_linear/linear_kernels_cuda.cu | 1 - 3 files changed, 5 deletions(-) diff --git a/csrc/fp_quantizer/fp_quantize.cu b/csrc/fp_quantizer/fp_quantize.cu index bc720a0e47f7..42a1b63e424b 100644 --- a/csrc/fp_quantizer/fp_quantize.cu +++ b/csrc/fp_quantizer/fp_quantize.cu @@ -14,9 +14,6 @@ #include #include -#if defined(__HIP_PLATFORM_AMD__) -#include -#endif #ifdef BF16_AVAILABLE #include diff --git a/csrc/includes/conversion_utils.h b/csrc/includes/conversion_utils.h index 99b0363bc27b..d6d8f11e0854 100644 --- a/csrc/includes/conversion_utils.h +++ b/csrc/includes/conversion_utils.h @@ -59,7 +59,6 @@ DS_D_INLINE __half to(__half val) { return val; } - #ifdef BF16_AVAILABLE template <> DS_D_INLINE __nv_bfloat16 to(__nv_bfloat16 val) diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/linear_kernels_cuda.cu b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/linear_kernels_cuda.cu index 74112236b6df..ea0203c42f84 100644 --- a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/linear_kernels_cuda.cu +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/linear_kernels_cuda.cu @@ -48,7 +48,6 @@ static void Kernel_Ex(cudaStream_t stream, cudaFuncSetAttribute(QUANT_GEMM_Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, SHMEM_SZ); - size_t dimN = (N_Global - 1) / TilingConfig::TILE_N + 1; size_t dimM = M_Global * Split_K / TilingConfig::TILE_M; dim3 GridDim(dimN, dimM, 1);