From 28744fd8c256bb16b8c12ba735666412c9d86d94 Mon Sep 17 00:00:00 2001 From: George Ohashi Date: Wed, 26 Jun 2024 19:51:36 +0000 Subject: [PATCH] apply g_idx --- .../modifiers/quantization/gptq/pytorch.py | 2 +- .../quantization/gptq/utils/gptq_wrapper.py | 74 ++++++++++++------- .../modifiers/utils/layer_compressor.py | 6 +- 3 files changed, 50 insertions(+), 32 deletions(-) diff --git a/src/sparseml/modifiers/quantization/gptq/pytorch.py b/src/sparseml/modifiers/quantization/gptq/pytorch.py index 66898688f12..e9e3f715625 100644 --- a/src/sparseml/modifiers/quantization/gptq/pytorch.py +++ b/src/sparseml/modifiers/quantization/gptq/pytorch.py @@ -156,7 +156,7 @@ def apply_compression( layer_compressor.pre_compress() _LOGGER.info(f"Calibrating {layer_compressor.name}...") run_calibration_forward(self.model, dataloader, mask_padding=True) - layer_compressor.compress(self.actorder) + layer_compressor.compress() layer_compressor.post_compress() layer_compressor.revert_layer_wrappers() torch.cuda.empty_cache() diff --git a/src/sparseml/modifiers/quantization/gptq/utils/gptq_wrapper.py b/src/sparseml/modifiers/quantization/gptq/utils/gptq_wrapper.py index 1745b7c802b..c2b72ffa487 100644 --- a/src/sparseml/modifiers/quantization/gptq/utils/gptq_wrapper.py +++ b/src/sparseml/modifiers/quantization/gptq/utils/gptq_wrapper.py @@ -83,7 +83,6 @@ def add_batch(self, inp: torch.Tensor, out: torch.Tensor): def fasterprune( self, - actorder: bool = False, blocksize: int = 128, percdamp: float = 0.01, ): @@ -121,13 +120,6 @@ def fasterprune( self.H[dead, dead] = 1 W[:, dead] = 0 - # Or read from self.layer.quantization_scheme - if actorder: - perm = torch.argsort(torch.diag(self.H), descending=True) - W = W[:, perm] - self.H = self.H[perm][:, perm] - invperm = torch.argsort(perm) - Losses = torch.zeros(self.rows, device=self.dev) damp = percdamp * torch.mean(torch.diag(self.H)) @@ -138,6 +130,9 @@ def fasterprune( self.H = torch.linalg.cholesky(self.H, upper=True) Hinv = self.H + actorder = False + invperm = None + # See section 3.4 of https://arxiv.org/abs/2203.07259 for i1 in range(0, self.columns, blocksize): i2 = min(i1 + blocksize, self.columns) @@ -171,7 +166,15 @@ def fasterprune( elif hasattr(self.layer, "quantization_scheme"): quant_scheme = self.layer.quantization_scheme + actorder = quant_scheme.weights.actorder if quant_scheme.weights is not None: + + if actorder: + perm = torch.argsort(torch.diag(self.H), descending=True) + W = W[:, perm] + self.H = self.H[perm][:, perm] + invperm = torch.argsort(perm) + scale = self.layer.weight_scale zero_point = self.layer.weight_zero_point @@ -180,23 +183,15 @@ def fasterprune( group_size = self.layer.weight.shape[1] if actorder: - g_idx = torch.tensor( - [perm[j] // group_size for j in range(self.columns)], - dtype=torch.int32, - device=invperm.device - ) - + indices = torch.arange(self.columns, device=invperm.device) + g_idx = (perm[indices] // group_size).to(dtype=torch.int32) g_idx = g_idx[invperm] - self.layer.weight_g_idx = Parameter( - g_idx, - requires_grad=False, - ) + self.layer.weight_g_idx.data = g_idx else: - g_idx = torch.tensor( - [j // group_size for j in range(self.columns)], - dtype=torch.int32, - device=W.device + indices = torch.arange( + self.columns, device=W.device, dtype=torch.int32 ) + g_idx = indices // group_size from compressed_tensors.quantization import QuantizationStrategy from compressed_tensors.quantization.lifecycle.forward import ( @@ -217,7 +212,6 @@ def fasterprune( q, scale[:, 0], zero_point[:, 0], - # g_idx, quant_scheme.weights, ) else: # strategy == QuantizationStrategy.GROUP @@ -230,11 +224,16 @@ def fasterprune( # ends up being a channelwise application altered_qargs = copy(quant_scheme.weights) altered_qargs.strategy = QuantizationStrategy.CHANNEL - - # # apply g_idx - # if g_idx is not None: - # scale = scale[g_idx] - # zero_point = zero_point[g_idx] + + # apply g_idx + if g_idx is not None: + # scale and zp already transformed by group_size + # extract first index of group_idze + indices_to_extract = torch.arange( + 0, g_idx.shape[0], group_size + ) + scale = scale[:, g_idx[indices_to_extract]] + zero_point = zero_point[:, g_idx[indices_to_extract]] q = fake_quantize( q, @@ -284,3 +283,22 @@ def free(self): """ delattr(self, "H") super().free() + + +""" +(Pdb) scale.shape +torch.Size([4096, 32]) +(Pdb) self.layer.shape +*** AttributeError: 'Linear' object has no attribute 'shape' +(Pdb) self.layer.weight.shape +torch.Size([4096, 4096]) + + + +(Pdb) scale.shape +torch.Size([11008, 32]) +(Pdb) self.layer.weight.shape +torch.Size([11008, 4096]) + + +""" diff --git a/src/sparseml/modifiers/utils/layer_compressor.py b/src/sparseml/modifiers/utils/layer_compressor.py index 2d7fdf53e00..eb0b51cf269 100644 --- a/src/sparseml/modifiers/utils/layer_compressor.py +++ b/src/sparseml/modifiers/utils/layer_compressor.py @@ -131,10 +131,10 @@ def revert_layer_wrappers(self): module_wrapper.free() self.modules = None - def compress(self, actorder: bool = False): + def compress(self): """ Apply compression to each wrapped submodule in the layer - + :param: actorder: flag to apply activation reordering """ @@ -143,7 +143,7 @@ def prune(module): if isinstance(module, self.module_compressor_class): full_name = self._get_full_submodule_name(module.name) _LOGGER.info(f"Compressing {full_name}...") - module.fasterprune(actorder=actorder, **self.args) + module.fasterprune(**self.args) self.layer.apply(prune)