Skip to content

Commit

Permalink
generalize deepspeed linear and implement it for non cuda systems
Browse files Browse the repository at this point in the history
  • Loading branch information
oelayan7 committed Jan 8, 2025
1 parent c41b0c2 commit 485492d
Show file tree
Hide file tree
Showing 9 changed files with 208 additions and 56 deletions.
9 changes: 7 additions & 2 deletions deepspeed/linear/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@
from dataclasses import dataclass, field
from typing import List

import torch


@dataclass
class LoRAConfig:
"""
Configuration settings for LoRAOptimizedLinear.
Attributes:
lora_r (int): LoRA attention dimension, also know as the rank. Defaults is 64.
lora_r (int): LoRA attention dimension, also known as the rank. Defaults is 64.
lora_alpha (float): LoRA scaling factor, default is 16.
base_weight_sharding (int): The degree to which the base weights are sharded,
should typically be set to the data-parallel world size to maximize the memory
Expand Down Expand Up @@ -42,8 +44,11 @@ class QuantizationConfig:
Attributes:
q_bits (int): The number of bits used for quantization. Default is 8.
mantissa_bits (int): The number of bits reserved for the mantissa in fixed-point quantization. Default is 3.
group_size (int): The size of the group used for quantization. Default is 512.
group_size (int): The number of elements used for quantization. Default is 512.
q_dtype (torch.dtype): The data type to quantize to. Default is uint8. (in CUDA, buffers are allocated as
uint8, but inside the kernels the quantization is done to fp8)
"""
q_bits: int = 8
mantissa_bits: int = 3
group_size: int = 512
q_dtype: torch.dtype = torch.uint8
8 changes: 4 additions & 4 deletions deepspeed/linear/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,24 +51,24 @@ def __new__(
self.quantizer = quantizer
else:
# if FPQuantizerBuilder is not compatible in this env this init will fail
self.quantizer = FP_Quantize(group_size=self.quantization_config.group_size)
self.quantizer = FP_Quantize(quantization_config=self.quantization_config)
self._ensure_quantized(self)
return self

def _ensure_quantized(self, tensor: torch.Tensor):
# If the tensor is on the accelerator and is not quantized, then quantize it in-place.
if get_accelerator().on_accelerator(tensor) and tensor.dtype != torch.uint8:
if get_accelerator().on_accelerator(tensor) and tensor.dtype != self.quantization_config.q_dtype:
with get_accelerator().stream(get_accelerator().current_stream(tensor.device)):
tensor.data = self.quantizer.quantize(tensor.data,
q_bits=self.quantization_config.q_bits,
q_mantisa_bits=self.quantization_config.mantissa_bits)
assert tensor.dtype == torch.uint8
assert tensor.dtype == self.quantization_config.q_dtype

def dequantized(self) -> torch.Tensor:
"""
Return a tensor containing the dequantized weights of this parameter.
"""
if get_accelerator().on_accelerator(self.data) and self.data.dtype == torch.uint8:
if get_accelerator().on_accelerator(self.data) and self.data.dtype == self.quantization_config.q_dtype:
with get_accelerator().stream(get_accelerator().current_stream(self.data.device)):
return self.quantizer.dequantize(self.data,
q_bits=self.quantization_config.q_bits,
Expand Down
61 changes: 42 additions & 19 deletions deepspeed/ops/fp_quantizer/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

class Quantizer(ABC):
"""
Abstract Quantizer class that implmenents quantize/dequantize methods.
Abstract Quantizer class that implements quantize/dequantize methods.
Arguments:
group_size (int, optional): number of values or elements that are grouped
Expand All @@ -42,12 +42,18 @@ def dequantize(self, input_q, fp_out=None, q_bits=8, q_mantisa_bits=3, scale=Non

class FP_Quantize(Quantizer):

def __init__(self, group_size=512) -> None:
def __init__(self, quantization_config) -> None:
global fp_quant_module
super().__init__(group_size=group_size)
super().__init__(group_size=quantization_config.group_size)
if fp_quant_module is None:
fp_quant_module = FPQuantizerBuilder().load()
self.is_python_impl = getattr(fp_quant_module, "PYTHON_IMPL", False)
self.q_config = quantization_config

self.orig_dtype = None
self.num_groups = None
self.input_q = None
self.scale = None

def quantize(self,
input,
Expand All @@ -73,15 +79,27 @@ def quantize(self,
else:
assert (0), \
f"Missing {q_bits}-quantization, please add the template arguments for the kernel to support this precision!"
self.num_groups = input.numel() // self.group_size
self.input_q = torch.ones(self.num_groups,
int(self.group_size * q_bits) // 8 + 4,
dtype=torch.uint8,
device=input.device)
out = fp_quant_module.quantize(self.input_q, input, self.group_size, stochastic_mode, q_bits, q_mantisa_bits)

# Adding (group_size - 1) is for padding
self.num_groups = (input.numel() + self.q_config.group_size - 1) // self.q_config.group_size
# group_size should be the minimal number between the defined group size and number of elements in tensor.
group_size = int(min(self.q_config.group_size, input.numel()) * q_bits) // 8
# CUDA quantization kernel saves the scale as (fp32) inside the quantized tensor for each group
if not self.is_python_impl:
group_size += 4
# CUDA quantization kernel allocates tensors as uint8, but handles them as fp8 inside the kernel.
self.input_q = torch.ones(self.num_groups, group_size, dtype=self.q_config.q_dtype, device=input.device)
# CUDA quantization kernel attaches scales to quantized result, in python implementation it can't be done
# because they are of different types.
self.scale = torch.ones(self.num_groups, 1, device=input.device)
out = fp_quant_module.quantize(self.input_q, input, self.scale, group_size, stochastic_mode, q_bits,
q_mantisa_bits)
if return_meta_tensor:
data, self.scale = out.split(self.group_size, dim=-1)
data = data.contiguous().reshape(input.shape)
if not self.is_python_impl:
data, self.scale = out.split(group_size, dim=-1)
data = data.contiguous().reshape(input.shape)
else:
data = out.contiguous().reshape(input.shape)
self.scale = self.scale.contiguous()
del self.input_q
del out
Expand All @@ -93,9 +111,9 @@ def quantize(self,

def to(self, *args, **kwargs):
# Intermediate tensors may need to be moved to different devices
if hasattr(self, 'input_q'):
if hasattr(self, 'input_q') and self.input_q is not None:
self.input_q = self.input_q.to(*args, **kwargs)
if hasattr(self, 'scale'):
if hasattr(self, 'scale') and self.scale is not None:
self.scale = self.scale.to(*args, **kwargs)

def get_scales(self):
Expand All @@ -118,11 +136,16 @@ def dequantize(self, input_q, fp_out=None, q_bits=8, q_mantisa_bits=3, scale=Non
assert (0), \
f"Missing {q_bits}-dequantization, please add the template arguments for the kernel to support this precision!"

if scale is not None:
if scale is not None and not self.is_python_impl:
assert input_q.numel() == fp_out.numel(), \
f'[De-quantization Error]: quantized data should have the same size as original tensor when scale is not None!'
input_q = torch.cat([input_q.reshape(-1, self.group_size), scale], dim=-1).contiguous()
fp_quant_module.dequantize(fp_out, input_q, self.group_size, q_mantisa_bits, q_bits - q_mantisa_bits - 1)
input_q = torch.cat([input_q.reshape(-1, self.q_config.group_size), scale], dim=-1).contiguous()
elif scale is not None and self.is_python_impl:
group_size = int(min(self.q_config.group_size, input_q.numel()) * q_bits) // 8
input_q = input_q.reshape(-1, group_size)

fp_quant_module.dequantize(fp_out, input_q, self.scale, self.q_config.group_size, q_mantisa_bits,
q_bits - q_mantisa_bits - 1)
return fp_out

def selective_dequantize(self,
Expand Down Expand Up @@ -151,11 +174,11 @@ def selective_dequantize(self,
assert (0), \
f"Missing {q_bits}-dequantization, please add the template arguments for the kernel to support this precision!"

if scale is not None:
if scale is not None and not self.is_python_impl:
assert input_q.numel() == fp_out.numel(), \
f'[De-quantization Error]: quantized data should have the same size as original tensor when scale is not None!'
input_q = torch.cat([input_q.reshape(-1, self.group_size), scale], dim=-1).contiguous()
input_q = torch.cat([input_q.reshape(-1, self.q_config.group_size), scale], dim=-1).contiguous()

fp_quant_module.selective_dequantize(fp_out, input_q, indexes, self.group_size, q_mantisa_bits,
fp_quant_module.selective_dequantize(fp_out, input_q, indexes, self.q_config.group_size, q_mantisa_bits,
q_bits - q_mantisa_bits - 1)
return fp_out
17 changes: 17 additions & 0 deletions op_builder/fp_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
pkg_version = None

from .builder import CUDAOpBuilder, installed_cuda_version
import torch


class FPQuantizerBuilder(CUDAOpBuilder):
Expand Down Expand Up @@ -98,3 +99,19 @@ def extra_ldflags(self):

def include_paths(self):
return ['csrc/fp_quantizer/includes', 'csrc/includes']

@staticmethod
def get_default_quant_dtype():
return torch.uint8

@staticmethod
def get_quant_range(q_bits=None):
if q_bits == 8:
return 480
elif q_bits == 6:
return 28.
elif q_bits == 12:
return 510.
else:
assert (0), \
"Please specify the right quantization range for the selected precision!"
86 changes: 86 additions & 0 deletions op_builder/hpu/fp_quantizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright (c) 2024 Habana Labs, Ltd. an Intel Company
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import torch
try:
# is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed
# if successful this also means we're doing a local install and not JIT compile path
from op_builder import __deepspeed__ # noqa: F401 # type: ignore
from op_builder.builder import OpBuilder
except ImportError:
from deepspeed.ops.op_builder.builder import OpBuilder


class FPQuantizerBuilder(OpBuilder):
BUILD_VAR = "DS_BUILD_FP_QUANTIZER"
NAME = "fp_quantizer"

def __init__(self, name=None):
name = self.NAME if name is None else name
super().__init__(name=name)

def absolute_name(self):
return f'deepspeed.ops.fp_quantizer.{self.NAME}_op'

def sources(self):
return []

def load(self, verbose=True):
return FPQuantizer

@staticmethod
def get_default_quant_dtype():
return torch.float8_e4m3fn

@staticmethod
def get_quant_range(q_bits=None):
import habana_frameworks.torch.utils.experimental as htexp
if htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi2:
dtype = torch.float8_e4m3fnuz
else:
dtype = torch.float8_e4m3fn
return torch.finfo(dtype).max


class FPQuantizer:
PYTHON_IMPL = True

@classmethod
def selective_dequantize(cls, val_q, scales, indexes, group_size, q_mantisa_bits, q_exponent_bits):
assert False, "Selective dequantize isn't implemented for HPU!"

@classmethod
def dequantize(cls, fp_out, input_q, scale, group_size, q_mantisa_bits, q_exponent_bits):
orig_shape = fp_out.shape
orig_dtype = fp_out.dtype
dequant_out = torch.ops.hpu.cast_from_fp8(input_q, (1.0 / scale), orig_dtype).view(orig_shape)
fp_out.copy_(dequant_out)
return fp_out

@classmethod
def quantize(cls, out, val, scale, group_size, stochastic_rounding, q_bits, q_mantisa_bits):
assert q_bits == 8, "Quantize on HPU only supports quantization to FP8"
assert q_mantisa_bits == 3, "Quantize on HPU only supports q_mantissa_bits = 3"
assert out.dtype.is_floating_point, "Quantization on HPU is only to float dtypes"

num_groups, group_size = out.shape

# Reshape the tensor
val_reshaped = val.view(num_groups, group_size).float()
# Calculate the scale
max_vals = val_reshaped.abs().max(dim=1, keepdim=True)[0]
q_range = torch.finfo(out.dtype).max
tmp_scale = q_range / max_vals
scale.copy_(tmp_scale)
# Copy quantized
quant, _ = torch.ops.hpu.cast_to_fp8_v2(val_reshaped, scale, stochastic_rounding, dtype=out.dtype)
out.copy_(quant)

return out

@classmethod
def get_scales(cls, out, num_groups):
return out
9 changes: 2 additions & 7 deletions tests/unit/linear/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ class TestLoRALinear(DistributedTest):

def test(self, base_weight_sharding):
rank = dist.get_rank()
lora_config = None
quantization_config = None

input_features = 64 # Number of input features
Expand Down Expand Up @@ -77,15 +76,13 @@ class TestQuantLinear(DistributedTest):
world_size = 2

def test(self, q_bits):
rank = dist.get_rank()
lora_config = None

input_features = 64 # Number of input features
output_features = 64 # Number of output features
batch_size = 5 # Number of samples in a batch

lora_config = None
quantization_config = QuantizationConfig(q_bits=q_bits)
quantization_config.q_dtype = FPQuantizerBuilder.get_default_quant_dtype()

linear_layer = OptimizedLinear(input_dim=input_features,
output_dim=output_features,
Expand All @@ -106,15 +103,13 @@ class TestOptimizedLinear(DistributedTest):
world_size = 2

def test(self, base_weight_sharding, q_bits):
rank = dist.get_rank()
lora_config = None

input_features = 64 # Number of input features
output_features = 64 # Number of output features
batch_size = 5 # Number of samples in a batch

lora_config = LoRAConfig(lora_r=16, lora_alpha=16, base_weight_sharding=base_weight_sharding)
quantization_config = QuantizationConfig(q_bits=q_bits)
quantization_config.q_dtype = FPQuantizerBuilder.get_default_quant_dtype()

linear_layer = OptimizedLinear(input_dim=input_features,
output_dim=output_features,
Expand Down
6 changes: 4 additions & 2 deletions tests/unit/linear/test_quant_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,13 @@ def test_requires_grad(self):
def test_move_to_accelerator(self):
device = get_accelerator().current_device()
data = torch.rand(5, 5, device='cpu', dtype=torch.bfloat16)
qp = QuantizedParameter(data)
quantization_config = QuantizationConfig()
quantization_config.q_dtype = FPQuantizerBuilder.get_default_quant_dtype()
qp = QuantizedParameter(data, quantization_config=quantization_config)
assert qp.device == torch.device('cpu')
qp = qp.to(get_accelerator().current_device_name())
assert qp.device == torch.device(device)
assert qp.dtype == torch.uint8
assert qp.dtype == quantization_config.q_dtype

def test_hf_clone(self):
device = get_accelerator().current_device_name()
Expand Down
7 changes: 6 additions & 1 deletion tests/unit/ops/fp_quantizer/test_fp8_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from deepspeed.ops.fp_quantizer import FP_Quantize, matmul_fp8

from deepspeed import get_accelerator
from deepspeed.linear import QuantizationConfig


@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"])
Expand All @@ -25,7 +26,11 @@
def test_fp_quant(dtype, q_bits, M):
device_name = get_accelerator().device_name()
quantization_group_size = 128
fpq = FP_Quantize(group_size=quantization_group_size)

quant_config = QuantizationConfig()
quant_config.q_dtype = FPQuantizerBuilder.get_default_quant_dtype()
quant_config.group_size = quantization_group_size
fpq = FP_Quantize(quantization_config=quant_config)

N = 8192
H = 4096
Expand Down
Loading

0 comments on commit 485492d

Please sign in to comment.