|
| 1 | +/* |
| 2 | + * SPDX-FileCopyrightText: Copyright (c) 2011-2024 NVIDIA CORPORATION & AFFILIATES. All rights |
| 3 | + * reserved. SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement |
| 4 | + * |
| 5 | + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual |
| 6 | + * property and proprietary rights in and to this material, related |
| 7 | + * documentation and any modifications thereto. Any use, reproduction, |
| 8 | + * disclosure or distribution of this material and related documentation |
| 9 | + * without an express license agreement from NVIDIA CORPORATION or |
| 10 | + * its affiliates is strictly prohibited. |
| 11 | + */ |
| 12 | + |
| 13 | +#include <fmha/numeric_types.h> |
| 14 | +#include <fmha/utils.h> |
| 15 | +#include <stdint.h> |
| 16 | + |
| 17 | +//////////////////////////////////////////////////////////////////////////////////////////////////// |
| 18 | + |
| 19 | +__global__ void convert_int32_to_int8_kernel(void* dst, void const* src, size_t n, float scale) { |
| 20 | + // The step. |
| 21 | + size_t step = (size_t)gridDim.x * blockDim.x; |
| 22 | + |
| 23 | + // Iterate over the elements. |
| 24 | + for (size_t ii = blockIdx.x * blockDim.x + threadIdx.x; ii < n / 4; ii += step) { |
| 25 | + // Load 4 integers. |
| 26 | + int4 tmp = reinterpret_cast<int4 const*>(src)[ii]; |
| 27 | + |
| 28 | + // Convert to float and scale. |
| 29 | + float x = static_cast<float>(tmp.x) * scale; |
| 30 | + float y = static_cast<float>(tmp.y) * scale; |
| 31 | + float z = static_cast<float>(tmp.z) * scale; |
| 32 | + float w = static_cast<float>(tmp.w) * scale; |
| 33 | + |
| 34 | + // Convert to int8. |
| 35 | + uint32_t a; |
| 36 | + asm volatile("cvt.rni.sat.s8.f32 %0, %1;\n" : "=r"(a) : "f"(x)); |
| 37 | + uint32_t b; |
| 38 | + asm volatile("cvt.rni.sat.s8.f32 %0, %1;\n" : "=r"(b) : "f"(y)); |
| 39 | + uint32_t c; |
| 40 | + asm volatile("cvt.rni.sat.s8.f32 %0, %1;\n" : "=r"(c) : "f"(z)); |
| 41 | + uint32_t d; |
| 42 | + asm volatile("cvt.rni.sat.s8.f32 %0, %1;\n" : "=r"(d) : "f"(w)); |
| 43 | + |
| 44 | + // Compact. |
| 45 | + char4 out; |
| 46 | + out.x = reinterpret_cast<int8_t const&>(a); |
| 47 | + out.y = reinterpret_cast<int8_t const&>(b); |
| 48 | + out.z = reinterpret_cast<int8_t const&>(c); |
| 49 | + out.w = reinterpret_cast<int8_t const&>(d); |
| 50 | + |
| 51 | + // Store. |
| 52 | + reinterpret_cast<uint32_t*>(dst)[ii] = reinterpret_cast<uint32_t const&>(out); |
| 53 | + } |
| 54 | +} |
| 55 | + |
| 56 | +//////////////////////////////////////////////////////////////////////////////////////////////////// |
| 57 | + |
| 58 | +void run_conversion_int32_to_int8(void* dst, void const* src, int s, int b, int h, int d, |
| 59 | + float scale) { |
| 60 | + size_t n = (size_t)s * b * h * d; |
| 61 | + convert_int32_to_int8_kernel<<<512, 256>>>(dst, src, n, scale); |
| 62 | +} |
| 63 | + |
| 64 | +//////////////////////////////////////////////////////////////////////////////////////////////////// |
| 65 | + |
| 66 | +template <typename T> |
| 67 | +__device__ inline typename fmha::Uint_from_size_in_bytes<sizeof(T) * 4>::Type pack_float4( |
| 68 | + float4 const& f); |
| 69 | + |
| 70 | +//////////////////////////////////////////////////////////////////////////////////////////////////// |
| 71 | + |
| 72 | +template <> |
| 73 | +__device__ inline uint2 pack_float4<fmha::fp16_t>(float4 const& f) { |
| 74 | + return fmha::float4_to_half4(f.x, f.y, f.z, f.w); |
| 75 | +} |
| 76 | + |
| 77 | +//////////////////////////////////////////////////////////////////////////////////////////////////// |
| 78 | + |
| 79 | +template <> |
| 80 | +__device__ inline uint2 pack_float4<fmha::bf16_t>(float4 const& f) { |
| 81 | + return fmha::float4_to_16bit_x4<fmha::bf16_t>(f.x, f.y, f.z, f.w); |
| 82 | +} |
| 83 | + |
| 84 | +//////////////////////////////////////////////////////////////////////////////////////////////////// |
| 85 | + |
| 86 | +template <> |
| 87 | +__device__ inline uint32_t pack_float4<fmha::e4m3_t>(float4 const& f) { |
| 88 | + return fmha::float4_to_e4m3x4(f.x, f.y, f.z, f.w); |
| 89 | +} |
| 90 | + |
| 91 | +//////////////////////////////////////////////////////////////////////////////////////////////////// |
| 92 | +template <> |
| 93 | +__device__ inline uint32_t pack_float4<fmha::e5m2_t>(float4 const& f) { |
| 94 | + return fmha::float4_to_e5m2x4(f.x, f.y, f.z, f.w); |
| 95 | +} |
| 96 | + |
| 97 | +//////////////////////////////////////////////////////////////////////////////////////////////////// |
| 98 | + |
| 99 | +template <typename T> |
| 100 | +__global__ void convert_fp32_to_T_kernel(void* dst, void const* src, size_t n, float scale = 1.f) { |
| 101 | + using Dst = typename fmha::Uint_from_size_in_bytes<sizeof(T) * 4>::Type; |
| 102 | + |
| 103 | + // The step. |
| 104 | + size_t step = (size_t)gridDim.x * blockDim.x; |
| 105 | + |
| 106 | + // Iterate over the elements. |
| 107 | + for (size_t ii = blockIdx.x * blockDim.x + threadIdx.x; ii < n / 4; ii += step) { |
| 108 | + // Load 4 floats. |
| 109 | + float4 tmp = reinterpret_cast<float4 const*>(src)[ii]; |
| 110 | + // Scale. |
| 111 | + tmp.x *= scale; |
| 112 | + tmp.y *= scale; |
| 113 | + tmp.z *= scale; |
| 114 | + tmp.w *= scale; |
| 115 | + // Convert to 4 Ts. |
| 116 | + auto out = pack_float4<T>(tmp); |
| 117 | + |
| 118 | + // Store. |
| 119 | + reinterpret_cast<Dst*>(dst)[ii] = reinterpret_cast<Dst const&>(out); |
| 120 | + } |
| 121 | +} |
| 122 | + |
| 123 | +template <typename T> |
| 124 | +__global__ void convert_T_to_fp32_kernel(void* dst, void const* src, size_t n, float scale = 1.f) { |
| 125 | + using Src = typename fmha::Uint_from_size_in_bytes<sizeof(T) * 4>::Type; |
| 126 | + |
| 127 | + union { |
| 128 | + Src raw; |
| 129 | + T elt[4]; |
| 130 | + } data; |
| 131 | + |
| 132 | + // The step. |
| 133 | + size_t step = (size_t)gridDim.x * blockDim.x; |
| 134 | + |
| 135 | + // Iterate over the elements. |
| 136 | + for (size_t ii = blockIdx.x * blockDim.x + threadIdx.x; ii < n / 4; ii += step) { |
| 137 | + // Load 4 floats. |
| 138 | + data.raw = reinterpret_cast<Src const*>(src)[ii]; |
| 139 | + float4 out; |
| 140 | + // Scale. |
| 141 | + out.x = float(data.elt[0]) * scale; |
| 142 | + out.y = float(data.elt[1]) * scale; |
| 143 | + out.z = float(data.elt[2]) * scale; |
| 144 | + out.w = float(data.elt[3]) * scale; |
| 145 | + |
| 146 | + // Store. |
| 147 | + reinterpret_cast<float4*>(dst)[ii] = reinterpret_cast<float4 const&>(out); |
| 148 | + } |
| 149 | +} |
| 150 | + |
| 151 | +//////////////////////////////////////////////////////////////////////////////////////////////////// |
| 152 | + |
| 153 | +void run_conversion_fp32_to_fp16(void* dst, void const* src, int s, int b, int h, int d) { |
| 154 | + // No need to expose the scale factor for FP16/FP32. |
| 155 | + size_t n = (size_t)s * b * h * d; |
| 156 | + convert_fp32_to_T_kernel<fmha::fp16_t><<<512, 256>>>(dst, src, n, 1.f); |
| 157 | +} |
| 158 | + |
| 159 | +//////////////////////////////////////////////////////////////////////////////////////////////////// |
| 160 | + |
| 161 | +void run_conversion_fp32_to_bf16(void* dst, void const* src, int s, int b, int h, int d) { |
| 162 | + // No need to expose the scale factor for FP16/FP32. |
| 163 | + size_t n = (size_t)s * b * h * d; |
| 164 | + convert_fp32_to_T_kernel<fmha::bf16_t><<<512, 256>>>(dst, src, n, 1.f); |
| 165 | +} |
| 166 | + |
| 167 | +//////////////////////////////////////////////////////////////////////////////////////////////////// |
| 168 | + |
| 169 | +void run_conversion_fp32_to_e4m3(void* dst, void const* src, size_t n, float scale_o) { |
| 170 | + convert_fp32_to_T_kernel<fmha::e4m3_t><<<512, 256>>>(dst, src, n, scale_o); |
| 171 | +} |
| 172 | + |
| 173 | +//////////////////////////////////////////////////////////////////////////////////////////////////// |
| 174 | + |
| 175 | +void run_conversion_e4m3_to_fp32(void* dst, void const* src, size_t n, float scale_o) { |
| 176 | + convert_T_to_fp32_kernel<fmha::e4m3_t><<<512, 256>>>(dst, src, n, scale_o); |
| 177 | +} |
| 178 | + |
| 179 | +//////////////////////////////////////////////////////////////////////////////////////////////////// |
| 180 | + |
| 181 | +void run_conversion_fp32_to_e4m3(void* dst, void const* src, int s, int b, int h, int d, |
| 182 | + float scale_o) { |
| 183 | + run_conversion_fp32_to_e4m3(dst, src, s * b * h * d, scale_o); |
| 184 | +} |
| 185 | + |
| 186 | +//////////////////////////////////////////////////////////////////////////////////////////////////// |
| 187 | + |
| 188 | +void run_conversion_fp32_to_e5m2(void* dst, void const* src, size_t n, float scale_o) { |
| 189 | + convert_fp32_to_T_kernel<fmha::e5m2_t><<<512, 256>>>(dst, src, n, scale_o); |
| 190 | +} |
| 191 | + |
| 192 | +//////////////////////////////////////////////////////////////////////////////////////////////////// |
| 193 | + |
| 194 | +void run_conversion_e5m2_to_fp32(void* dst, void const* src, size_t n, float scale_o) { |
| 195 | + convert_T_to_fp32_kernel<fmha::e5m2_t><<<512, 256>>>(dst, src, n, scale_o); |
| 196 | +} |
0 commit comments