Skip to content

Commit a98187c

Browse files
authored
[Kernel] Make static FP8 scaling more robust (vllm-project#4570)
Previously FP8 static scaling works if the scales are overestimating the maxima of all activation tensors during computation. However this will not always be the case even if the scales were calibrated very carefully. For example, with the activations in my checkpoint https://huggingface.co/pcmoritz/Mixtral-8x7B-v0.1-fp8-act-scale (which was calibrated on https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k), I'm getting the following mostly random performance on MMLU: | Groups |Version|Filter|n-shot|Metric|Value | |Stderr| |------------------|-------|------|-----:|------|-----:|---|-----:| |mmlu |N/A |none | 0|acc |0.2295|± |0.0035| | - humanities |N/A |none | 5|acc |0.2421|± |0.0062| | - other |N/A |none | 5|acc |0.2398|± |0.0076| | - social_sciences|N/A |none | 5|acc |0.2171|± |0.0074| | - stem |N/A |none | 5|acc |0.2125|± |0.0073| With the fix in this PR where the scaled activations are clamped between [-std::numeric_limits<c10::Float8_e4m3fn>::max(), std::numeric_limits<c10::Float8_e4m3fn>::max()] to make sure there are no NaNs, the performance is | Groups |Version|Filter|n-shot|Metric|Value | |Stderr| |------------------|-------|------|-----:|------|-----:|---|-----:| |mmlu |N/A |none | 0|acc |0.7008|± |0.0036| | - humanities |N/A |none | 5|acc |0.6453|± |0.0065| | - other |N/A |none | 5|acc |0.7692|± |0.0072| | - social_sciences|N/A |none | 5|acc |0.8083|± |0.0070| | - stem |N/A |none | 5|acc |0.6115|± |0.0083| This is not perfect yet but is getting very close to the FP16 / dynamic activation scale performance.
1 parent bd99d22 commit a98187c

File tree

1 file changed

+10
-1
lines changed

1 file changed

+10
-1
lines changed

csrc/quantization/fp8/fp8_cuda_kernels.cu

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,15 @@ __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
1717
return old;
1818
}
1919

20+
#define FP8_E4M3_MAX std::numeric_limits<c10::Float8_e4m3fn>::max()
21+
22+
template<typename scalar_t>
23+
__device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion(const scalar_t val, const float scale) {
24+
float x = static_cast<float>(val) / scale;
25+
float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
26+
return static_cast<c10::Float8_e4m3fn>(r);
27+
}
28+
2029
// Compute the absolute maximum m of the input tensor and store
2130
// m / float8_e4m3::max() in *scale. Each thread block performs a
2231
// reduction tree and the memory in scale is atomically updated.
@@ -67,7 +76,7 @@ __global__ void scaled_fp8_quant_kernel(
6776
int64_t num_elems) {
6877
int i = blockDim.x * blockIdx.x + threadIdx.x;
6978
while (i < num_elems) {
70-
out[i] = static_cast<c10::Float8_e4m3fn>(input[i] / *scale);
79+
out[i] = scaled_fp8_conversion(input[i], *scale);
7180
i += blockDim.x * gridDim.x;
7281
}
7382
}

0 commit comments

Comments
 (0)