diff --git a/deepspeed/linear/config.py b/deepspeed/linear/config.py index 2632ce7de9c4..1459704a32c5 100644 --- a/deepspeed/linear/config.py +++ b/deepspeed/linear/config.py @@ -6,6 +6,8 @@ from dataclasses import dataclass, field from typing import List +import torch + @dataclass class LoRAConfig: @@ -13,7 +15,7 @@ class LoRAConfig: Configuration settings for LoRAOptimizedLinear. Attributes: - lora_r (int): LoRA attention dimension, also know as the rank. Defaults is 64. + lora_r (int): LoRA attention dimension, also known as the rank. Defaults is 64. lora_alpha (float): LoRA scaling factor, default is 16. base_weight_sharding (int): The degree to which the base weights are sharded, should typically be set to the data-parallel world size to maximize the memory @@ -42,8 +44,11 @@ class QuantizationConfig: Attributes: q_bits (int): The number of bits used for quantization. Default is 8. mantissa_bits (int): The number of bits reserved for the mantissa in fixed-point quantization. Default is 3. - group_size (int): The size of the group used for quantization. Default is 512. + group_size (int): The number of elements used for quantization. Default is 512. + q_dtype (torch.dtype): The data type to quantize to. Default is uint8. (in CUDA, buffers are allocated as + uint8, but inside the kernels the quantization is done to fp8) """ q_bits: int = 8 mantissa_bits: int = 3 group_size: int = 512 + q_dtype: torch.dtype = torch.uint8 diff --git a/deepspeed/linear/quantization.py b/deepspeed/linear/quantization.py index 70fabea845ba..2023601be281 100644 --- a/deepspeed/linear/quantization.py +++ b/deepspeed/linear/quantization.py @@ -51,24 +51,24 @@ def __new__( self.quantizer = quantizer else: # if FPQuantizerBuilder is not compatible in this env this init will fail - self.quantizer = FP_Quantize(group_size=self.quantization_config.group_size) + self.quantizer = FP_Quantize(quantization_config=self.quantization_config) self._ensure_quantized(self) return self def _ensure_quantized(self, tensor: torch.Tensor): # If the tensor is on the accelerator and is not quantized, then quantize it in-place. - if get_accelerator().on_accelerator(tensor) and tensor.dtype != torch.uint8: + if get_accelerator().on_accelerator(tensor) and tensor.dtype != self.quantization_config.q_dtype: with get_accelerator().stream(get_accelerator().current_stream(tensor.device)): tensor.data = self.quantizer.quantize(tensor.data, q_bits=self.quantization_config.q_bits, q_mantisa_bits=self.quantization_config.mantissa_bits) - assert tensor.dtype == torch.uint8 + assert tensor.dtype == self.quantization_config.q_dtype def dequantized(self) -> torch.Tensor: """ Return a tensor containing the dequantized weights of this parameter. """ - if get_accelerator().on_accelerator(self.data) and self.data.dtype == torch.uint8: + if get_accelerator().on_accelerator(self.data) and self.data.dtype == self.quantization_config.q_dtype: with get_accelerator().stream(get_accelerator().current_stream(self.data.device)): return self.quantizer.dequantize(self.data, q_bits=self.quantization_config.q_bits, diff --git a/deepspeed/ops/fp_quantizer/quantize.py b/deepspeed/ops/fp_quantizer/quantize.py index edd4ef57302c..1586f220907e 100644 --- a/deepspeed/ops/fp_quantizer/quantize.py +++ b/deepspeed/ops/fp_quantizer/quantize.py @@ -16,7 +16,7 @@ class Quantizer(ABC): """ - Abstract Quantizer class that implmenents quantize/dequantize methods. + Abstract Quantizer class that implements quantize/dequantize methods. Arguments: group_size (int, optional): number of values or elements that are grouped @@ -42,12 +42,18 @@ def dequantize(self, input_q, fp_out=None, q_bits=8, q_mantisa_bits=3, scale=Non class FP_Quantize(Quantizer): - def __init__(self, group_size=512) -> None: + def __init__(self, quantization_config) -> None: global fp_quant_module - super().__init__(group_size=group_size) + super().__init__(group_size=quantization_config.group_size) if fp_quant_module is None: fp_quant_module = FPQuantizerBuilder().load() + self.is_python_impl = getattr(fp_quant_module, "PYTHON_IMPL", False) + self.q_config = quantization_config + self.orig_dtype = None + self.num_groups = None + self.input_q = None + self.scale = None def quantize(self, input, @@ -73,15 +79,27 @@ def quantize(self, else: assert (0), \ f"Missing {q_bits}-quantization, please add the template arguments for the kernel to support this precision!" - self.num_groups = input.numel() // self.group_size - self.input_q = torch.ones(self.num_groups, - int(self.group_size * q_bits) // 8 + 4, - dtype=torch.uint8, - device=input.device) - out = fp_quant_module.quantize(self.input_q, input, self.group_size, stochastic_mode, q_bits, q_mantisa_bits) + + # Adding (group_size - 1) is for padding + self.num_groups = (input.numel() + self.q_config.group_size - 1) // self.q_config.group_size + # group_size should be the minimal number between the defined group size and number of elements in tensor. + group_size = int(min(self.q_config.group_size, input.numel()) * q_bits) // 8 + # CUDA quantization kernel saves the scale as (fp32) inside the quantized tensor for each group + if not self.is_python_impl: + group_size += 4 + # CUDA quantization kernel allocates tensors as uint8, but handles them as fp8 inside the kernel. + self.input_q = torch.ones(self.num_groups, group_size, dtype=self.q_config.q_dtype, device=input.device) + # CUDA quantization kernel attaches scales to quantized result, in python implementation it can't be done + # because they are of different types. + self.scale = torch.ones(self.num_groups, 1, device=input.device) + out = fp_quant_module.quantize(self.input_q, input, self.scale, group_size, stochastic_mode, q_bits, + q_mantisa_bits) if return_meta_tensor: - data, self.scale = out.split(self.group_size, dim=-1) - data = data.contiguous().reshape(input.shape) + if not self.is_python_impl: + data, self.scale = out.split(group_size, dim=-1) + data = data.contiguous().reshape(input.shape) + else: + data = out.contiguous().reshape(input.shape) self.scale = self.scale.contiguous() del self.input_q del out @@ -93,9 +111,9 @@ def quantize(self, def to(self, *args, **kwargs): # Intermediate tensors may need to be moved to different devices - if hasattr(self, 'input_q'): + if hasattr(self, 'input_q') and self.input_q is not None: self.input_q = self.input_q.to(*args, **kwargs) - if hasattr(self, 'scale'): + if hasattr(self, 'scale') and self.scale is not None: self.scale = self.scale.to(*args, **kwargs) def get_scales(self): @@ -118,11 +136,16 @@ def dequantize(self, input_q, fp_out=None, q_bits=8, q_mantisa_bits=3, scale=Non assert (0), \ f"Missing {q_bits}-dequantization, please add the template arguments for the kernel to support this precision!" - if scale is not None: + if scale is not None and not self.is_python_impl: assert input_q.numel() == fp_out.numel(), \ f'[De-quantization Error]: quantized data should have the same size as original tensor when scale is not None!' - input_q = torch.cat([input_q.reshape(-1, self.group_size), scale], dim=-1).contiguous() - fp_quant_module.dequantize(fp_out, input_q, self.group_size, q_mantisa_bits, q_bits - q_mantisa_bits - 1) + input_q = torch.cat([input_q.reshape(-1, self.q_config.group_size), scale], dim=-1).contiguous() + elif scale is not None and self.is_python_impl: + group_size = int(min(self.q_config.group_size, input_q.numel()) * q_bits) // 8 + input_q = input_q.reshape(-1, group_size) + + fp_quant_module.dequantize(fp_out, input_q, self.scale, self.q_config.group_size, q_mantisa_bits, + q_bits - q_mantisa_bits - 1) return fp_out def selective_dequantize(self, @@ -151,11 +174,11 @@ def selective_dequantize(self, assert (0), \ f"Missing {q_bits}-dequantization, please add the template arguments for the kernel to support this precision!" - if scale is not None: + if scale is not None and not self.is_python_impl: assert input_q.numel() == fp_out.numel(), \ f'[De-quantization Error]: quantized data should have the same size as original tensor when scale is not None!' - input_q = torch.cat([input_q.reshape(-1, self.group_size), scale], dim=-1).contiguous() + input_q = torch.cat([input_q.reshape(-1, self.q_config.group_size), scale], dim=-1).contiguous() - fp_quant_module.selective_dequantize(fp_out, input_q, indexes, self.group_size, q_mantisa_bits, + fp_quant_module.selective_dequantize(fp_out, input_q, indexes, self.q_config.group_size, q_mantisa_bits, q_bits - q_mantisa_bits - 1) return fp_out diff --git a/op_builder/fp_quantizer.py b/op_builder/fp_quantizer.py index 75ee54c09bf6..daa41a8148f5 100644 --- a/op_builder/fp_quantizer.py +++ b/op_builder/fp_quantizer.py @@ -98,3 +98,20 @@ def extra_ldflags(self): def include_paths(self): return ['csrc/fp_quantizer/includes', 'csrc/includes'] + + @staticmethod + def get_default_quant_dtype(): + import torch + return torch.uint8 + + @staticmethod + def get_quant_range(q_bits=None): + if q_bits == 8: + return 480 + elif q_bits == 6: + return 28. + elif q_bits == 12: + return 510. + else: + assert (0), \ + "Please specify the right quantization range for the selected precision!" diff --git a/op_builder/hpu/fp_quantizer.py b/op_builder/hpu/fp_quantizer.py new file mode 100644 index 000000000000..b00cb0cc43cd --- /dev/null +++ b/op_builder/hpu/fp_quantizer.py @@ -0,0 +1,86 @@ +# Copyright (c) 2024 Habana Labs, Ltd. an Intel Company +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +try: + # is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed + # if successful this also means we're doing a local install and not JIT compile path + from op_builder import __deepspeed__ # noqa: F401 # type: ignore + from op_builder.builder import OpBuilder +except ImportError: + from deepspeed.ops.op_builder.builder import OpBuilder + + +class FPQuantizerBuilder(OpBuilder): + BUILD_VAR = "DS_BUILD_FP_QUANTIZER" + NAME = "fp_quantizer" + + def __init__(self, name=None): + name = self.NAME if name is None else name + super().__init__(name=name) + + def absolute_name(self): + return f'deepspeed.ops.fp_quantizer.{self.NAME}_op' + + def sources(self): + return [] + + def load(self, verbose=True): + return FPQuantizer + + @staticmethod + def get_default_quant_dtype(): + return torch.float8_e4m3fn + + @staticmethod + def get_quant_range(q_bits=None): + import habana_frameworks.torch.utils.experimental as htexp + if htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi2: + dtype = torch.float8_e4m3fnuz + else: + dtype = torch.float8_e4m3fn + return torch.finfo(dtype).max + + +class FPQuantizer: + PYTHON_IMPL = True + + @classmethod + def selective_dequantize(cls, val_q, scales, indexes, group_size, q_mantisa_bits, q_exponent_bits): + assert False, "Selective dequantize isn't implemented for HPU!" + + @classmethod + def dequantize(cls, fp_out, input_q, scale, group_size, q_mantisa_bits, q_exponent_bits): + orig_shape = fp_out.shape + orig_dtype = fp_out.dtype + dequant_out = torch.ops.hpu.cast_from_fp8(input_q, (1.0 / scale), orig_dtype).view(orig_shape) + fp_out.copy_(dequant_out) + return fp_out + + @classmethod + def quantize(cls, out, val, scale, group_size, stochastic_rounding, q_bits, q_mantisa_bits): + assert q_bits == 8, "Quantize on HPU only supports quantization to FP8" + assert q_mantisa_bits == 3, "Quantize on HPU only supports q_mantissa_bits = 3" + assert out.dtype.is_floating_point, "Quantization on HPU is only to float dtypes" + + num_groups, group_size = out.shape + + # Reshape the tensor + val_reshaped = val.view(num_groups, group_size).float() + # Calculate the scale + max_vals = val_reshaped.abs().max(dim=1, keepdim=True)[0] + q_range = torch.finfo(out.dtype).max + tmp_scale = q_range / max_vals + scale.copy_(tmp_scale) + # Copy quantized + quant, _ = torch.ops.hpu.cast_to_fp8_v2(val_reshaped, scale, stochastic_rounding, dtype=out.dtype) + out.copy_(quant) + + return out + + @classmethod + def get_scales(cls, out, num_groups): + return out diff --git a/tests/unit/linear/test_linear.py b/tests/unit/linear/test_linear.py index ccd26b4cd726..2058791dba4a 100644 --- a/tests/unit/linear/test_linear.py +++ b/tests/unit/linear/test_linear.py @@ -46,7 +46,6 @@ class TestLoRALinear(DistributedTest): def test(self, base_weight_sharding): rank = dist.get_rank() - lora_config = None quantization_config = None input_features = 64 # Number of input features @@ -77,15 +76,13 @@ class TestQuantLinear(DistributedTest): world_size = 2 def test(self, q_bits): - rank = dist.get_rank() - lora_config = None - input_features = 64 # Number of input features output_features = 64 # Number of output features batch_size = 5 # Number of samples in a batch lora_config = None quantization_config = QuantizationConfig(q_bits=q_bits) + quantization_config.q_dtype = FPQuantizerBuilder.get_default_quant_dtype() linear_layer = OptimizedLinear(input_dim=input_features, output_dim=output_features, @@ -106,15 +103,13 @@ class TestOptimizedLinear(DistributedTest): world_size = 2 def test(self, base_weight_sharding, q_bits): - rank = dist.get_rank() - lora_config = None - input_features = 64 # Number of input features output_features = 64 # Number of output features batch_size = 5 # Number of samples in a batch lora_config = LoRAConfig(lora_r=16, lora_alpha=16, base_weight_sharding=base_weight_sharding) quantization_config = QuantizationConfig(q_bits=q_bits) + quantization_config.q_dtype = FPQuantizerBuilder.get_default_quant_dtype() linear_layer = OptimizedLinear(input_dim=input_features, output_dim=output_features, diff --git a/tests/unit/linear/test_quant_param.py b/tests/unit/linear/test_quant_param.py index 84a9f766ef74..283d81b4bf36 100644 --- a/tests/unit/linear/test_quant_param.py +++ b/tests/unit/linear/test_quant_param.py @@ -38,11 +38,13 @@ def test_requires_grad(self): def test_move_to_accelerator(self): device = get_accelerator().current_device() data = torch.rand(5, 5, device='cpu', dtype=torch.bfloat16) - qp = QuantizedParameter(data) + quantization_config = QuantizationConfig() + quantization_config.q_dtype = FPQuantizerBuilder.get_default_quant_dtype() + qp = QuantizedParameter(data, quantization_config=quantization_config) assert qp.device == torch.device('cpu') qp = qp.to(get_accelerator().current_device_name()) assert qp.device == torch.device(device) - assert qp.dtype == torch.uint8 + assert qp.dtype == quantization_config.q_dtype def test_hf_clone(self): device = get_accelerator().current_device_name() diff --git a/tests/unit/ops/fp_quantizer/test_fp8_gemm.py b/tests/unit/ops/fp_quantizer/test_fp8_gemm.py index a4cf579f5943..ee7c5bc2d7f1 100644 --- a/tests/unit/ops/fp_quantizer/test_fp8_gemm.py +++ b/tests/unit/ops/fp_quantizer/test_fp8_gemm.py @@ -15,6 +15,7 @@ from deepspeed.ops.fp_quantizer import FP_Quantize, matmul_fp8 from deepspeed import get_accelerator +from deepspeed.linear import QuantizationConfig @pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"]) @@ -25,7 +26,11 @@ def test_fp_quant(dtype, q_bits, M): device_name = get_accelerator().device_name() quantization_group_size = 128 - fpq = FP_Quantize(group_size=quantization_group_size) + + quant_config = QuantizationConfig() + quant_config.q_dtype = FPQuantizerBuilder.get_default_quant_dtype() + quant_config.group_size = quantization_group_size + fpq = FP_Quantize(quantization_config=quant_config) N = 8192 H = 4096 diff --git a/tests/unit/ops/fp_quantizer/test_fp_quant.py b/tests/unit/ops/fp_quantizer/test_fp_quant.py index bed8bd7e3bcc..1550b6ec7eb2 100644 --- a/tests/unit/ops/fp_quantizer/test_fp_quant.py +++ b/tests/unit/ops/fp_quantizer/test_fp_quant.py @@ -5,10 +5,13 @@ import pytest import torch +from deepspeed.linear import QuantizationConfig + import deepspeed from deepspeed.ops.fp_quantizer import FP_Quantize from deepspeed.ops.op_builder import FPQuantizerBuilder +from deepspeed.accelerator import get_accelerator if not deepspeed.ops.__compatible_ops__[FPQuantizerBuilder.NAME]: pytest.skip("FPQuantizer op is not available on this system", allow_module_level=True) @@ -24,37 +27,39 @@ def qtorch_quantize(input, exp_bits=4, man_bits=3, rounding="nearest", group_siz input = input.view(-1, last_dim) q_bits = exp_bits + man_bits + 1 + q_range = FPQuantizerBuilder.get_quant_range(q_bits) input_to_float = input.float() - if q_bits == 8: - q_range = 480. - elif q_bits == 6: - q_range = 28. - elif q_bits == 12: - q_range = 510. - else: - assert (0), \ - "Please specify the right quantization range for the selected precision!" input_max = input_to_float.abs().amax(dim=-1, keepdim=True) + return ((float_quantize(input_to_float / input_max * q_range, exp_bits, man_bits, rounding=rounding) * \ input_max / q_range).to(ori_dt)).reshape(ori_shape) @pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"]) def test_fp_quant_meta(dtype): + device_name = get_accelerator().device_name() group_size = 128 q_bits = 8 exp_bits = 4 man_bits = 3 - fpq = FP_Quantize(group_size=group_size) + quant_config = QuantizationConfig() + quant_config.q_dtype = FPQuantizerBuilder.get_default_quant_dtype() + quant_config.group_size = group_size + fpq = FP_Quantize(quantization_config=quant_config) + for i in range(10): - x = torch.rand(4, 1024, dtype=dtype, device='cuda') + x = torch.rand(4, 1024, dtype=dtype) - ds_x = x.clone() + ds_x = x.clone().to(device_name) x_quantized, meta_tensor = fpq.quantize(ds_x, q_bits=q_bits, return_meta_tensor=True) x_dequantized = fpq.dequantize(x_quantized, q_bits=q_bits, scale=meta_tensor) - qtorch_out = qtorch_quantize(x, exp_bits=exp_bits, man_bits=man_bits, group_size=group_size) + qtorch_out = qtorch_quantize(x, + exp_bits=exp_bits, + man_bits=man_bits, + group_size=group_size, + quant_config=quant_config) qtorch_error = (qtorch_out - x).abs().sum() / x.numel() ds_error = (x_dequantized - x).abs().sum() / x.numel() @@ -68,12 +73,18 @@ def test_fp_quant_selective(dtype): exp_bits = 4 man_bits = 3 - fpq = FP_Quantize(group_size=group_size) - indexes = torch.zeros(2, dtype=torch.int32, device='cuda') + device_name = get_accelerator().device_name() + + quant_config = QuantizationConfig() + quant_config.q_dtype = FPQuantizerBuilder.get_default_quant_dtype() + quant_config.group_size = group_size + fpq = FP_Quantize(quantization_config=quant_config) + + indexes = torch.zeros(2, dtype=torch.int32, device=device_name) indexes[0] = 1 indexes[1] = 3 for i in range(10): - x = torch.rand(4, 1024, dtype=dtype, device='cuda') + x = torch.rand(4, 1024, dtype=dtype, device=device_name) x = x.reshape(4, 1, x.shape[-1]) ds_x = x.clone() @@ -93,13 +104,17 @@ def test_fp_quant_selective(dtype): @pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"]) @pytest.mark.parametrize("q_bits", [8, 6, 12], ids=["qbits8", "qbits6", "qbits12"]) def test_fp_quant(dtype, q_bits): - group_size = 128 - fpq = FP_Quantize(group_size=group_size) + device_name = get_accelerator().device_name() + + quant_config = QuantizationConfig() + quant_config.q_dtype = FPQuantizerBuilder.get_default_quant_dtype() + quant_config.group_size = 128 + fpq = FP_Quantize(quantization_config=quant_config) for i in range(10): - x = torch.rand(4, 1024, dtype=dtype, device='cuda') + x = torch.rand(4, 1024, dtype=dtype) - ds_x = x.clone() + ds_x = x.clone().to(device_name) x_quantized = fpq.quantize(ds_x, q_bits=q_bits) x_dequantized = fpq.dequantize(x_quantized, q_bits=q_bits) @@ -115,7 +130,11 @@ def test_fp_quant(dtype, q_bits): else: raise ValueError(f"unknown {q_bits=}") - qtorch_out = qtorch_quantize(x, exp_bits=exp_bits, man_bits=man_bits, group_size=group_size) + qtorch_out = qtorch_quantize(x, + exp_bits=exp_bits, + man_bits=man_bits, + group_size=quant_config.group_size, + quant_config=quant_config) qtorch_error = (qtorch_out - x).abs().sum() / x.numel() ds_error = (x_dequantized - x).abs().sum() / x.numel()