From 9749362da094e4f96b9996127deab3053f5d1f63 Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Thu, 9 May 2024 21:23:51 +0000 Subject: [PATCH 1/4] test --- .../quantization/gptq/utils/gptq_wrapper.py | 20 +++++++++++++++++++ .../obcq/test_consecutive_runs.py | 2 +- 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/src/sparseml/modifiers/quantization/gptq/utils/gptq_wrapper.py b/src/sparseml/modifiers/quantization/gptq/utils/gptq_wrapper.py index 215560b230b..f3ccc1aa8cf 100644 --- a/src/sparseml/modifiers/quantization/gptq/utils/gptq_wrapper.py +++ b/src/sparseml/modifiers/quantization/gptq/utils/gptq_wrapper.py @@ -92,6 +92,7 @@ def fasterprune( final_shape = self.layer.weight.shape final_dtype = self.layer.weight.dtype W = self.layer.weight.data.clone() + from sparseml.pytorch.utils.helpers import tensor_sparsity if isinstance(self.layer, nn.Conv2d): W = W.flatten(1) @@ -115,6 +116,14 @@ def fasterprune( self.H = torch.linalg.cholesky(self.H, upper=True) Hinv = self.H + mask = torch.where( + W == 0, + torch.tensor(1, dtype=torch.bool), + torch.tensor(0, dtype=torch.bool), + ) + sparsity = tensor_sparsity(W) + + # See section 3.4 of https://arxiv.org/abs/2203.07259 for i1 in range(0, self.columns, blocksize): i2 = min(i1 + blocksize, self.columns) @@ -126,11 +135,22 @@ def fasterprune( Losses1 = torch.zeros_like(W1) Hinv1 = Hinv[i1:i2, i1:i2] + tmp = ( + (~mask[:, i1:i2]) + * W1**2 + / (torch.diag(Hinv1).reshape((1, -1))) ** 2 + ) + thresh = torch.sort(tmp.flatten())[0][ + int(tmp.numel() * sparsity) + ] + mask1 = tmp <= thresh + for i in range(count): w = W1[:, i] d = Hinv1[i, i] q = w.clone() + q[mask1[:, i]] = 0 if hasattr(self.layer, "weight_fake_quant"): scale = self.layer.weight_fake_quant.scale diff --git a/tests/sparseml/transformers/obcq/test_consecutive_runs.py b/tests/sparseml/transformers/obcq/test_consecutive_runs.py index 04b78ec82b8..7bcfc8b7efe 100644 --- a/tests/sparseml/transformers/obcq/test_consecutive_runs.py +++ b/tests/sparseml/transformers/obcq/test_consecutive_runs.py @@ -114,7 +114,7 @@ def setUp(self): self.output_second = Path(self.output) / "test_2" def test_consecutive_runs_small(self): - self._test_consecutive_runs(tolerance=1e-1) + self._test_consecutive_runs(tolerance=1e-3) @requires_gpu From 77ad1a27f8cd045ff763ab7b453a8a7882c25b06 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Mon, 13 May 2024 13:40:48 +0000 Subject: [PATCH 2/4] Preserve weight sparsity if greater than threshold --- .../quantization/gptq/utils/gptq_wrapper.py | 36 ++++++++++--------- src/sparseml/modifiers/utils/__init__.py | 4 +++ src/sparseml/modifiers/utils/constants.py | 18 ++++++++++ 3 files changed, 42 insertions(+), 16 deletions(-) create mode 100644 src/sparseml/modifiers/utils/constants.py diff --git a/src/sparseml/modifiers/quantization/gptq/utils/gptq_wrapper.py b/src/sparseml/modifiers/quantization/gptq/utils/gptq_wrapper.py index f3ccc1aa8cf..3dce40cecc5 100644 --- a/src/sparseml/modifiers/quantization/gptq/utils/gptq_wrapper.py +++ b/src/sparseml/modifiers/quantization/gptq/utils/gptq_wrapper.py @@ -14,6 +14,7 @@ import time +from sparseml.modifiers.utils import SPARSITY_THRESHOLD from sparseml.modifiers.utils.compression_wrapper import ModuleCompressionWrapper @@ -116,13 +117,16 @@ def fasterprune( self.H = torch.linalg.cholesky(self.H, upper=True) Hinv = self.H - mask = torch.where( - W == 0, - torch.tensor(1, dtype=torch.bool), - torch.tensor(0, dtype=torch.bool), - ) sparsity = tensor_sparsity(W) - + mask = ( + torch.where( + W == 0, + torch.tensor(1, dtype=torch.bool), + torch.tensor(0, dtype=torch.bool), + ) + if sparsity >= SPARSITY_THRESHOLD + else None + ) # See section 3.4 of https://arxiv.org/abs/2203.07259 for i1 in range(0, self.columns, blocksize): @@ -135,22 +139,22 @@ def fasterprune( Losses1 = torch.zeros_like(W1) Hinv1 = Hinv[i1:i2, i1:i2] - tmp = ( - (~mask[:, i1:i2]) - * W1**2 - / (torch.diag(Hinv1).reshape((1, -1))) ** 2 - ) - thresh = torch.sort(tmp.flatten())[0][ - int(tmp.numel() * sparsity) - ] - mask1 = tmp <= thresh + if sparsity >= SPARSITY_THRESHOLD: + tmp = ( + (~mask[:, i1:i2]) + * W1**2 + / (torch.diag(Hinv1).reshape((1, -1))) ** 2 + ) + thresh = torch.sort(tmp.flatten())[0][int(tmp.numel() * sparsity)] + mask1 = tmp <= thresh for i in range(count): w = W1[:, i] d = Hinv1[i, i] q = w.clone() - q[mask1[:, i]] = 0 + if sparsity >= SPARSITY_THRESHOLD: + q[mask1[:, i]] = 0 if hasattr(self.layer, "weight_fake_quant"): scale = self.layer.weight_fake_quant.scale diff --git a/src/sparseml/modifiers/utils/__init__.py b/src/sparseml/modifiers/utils/__init__.py index 0c44f887a47..39d1132f697 100644 --- a/src/sparseml/modifiers/utils/__init__.py +++ b/src/sparseml/modifiers/utils/__init__.py @@ -11,3 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +# flake8: noqa + +from .constants import * diff --git a/src/sparseml/modifiers/utils/constants.py b/src/sparseml/modifiers/utils/constants.py new file mode 100644 index 00000000000..3801c2e9ea9 --- /dev/null +++ b/src/sparseml/modifiers/utils/constants.py @@ -0,0 +1,18 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +__all__ = ["SPARSITY_THRESHOLD"] + +SPARSITY_THRESHOLD: float = 0.05 From 40facd9b77602b5854c678874f6694bd03ef0f8f Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Mon, 13 May 2024 13:56:37 +0000 Subject: [PATCH 3/4] Add argument to preserve sparsity mask in SPARSEGPT --- src/sparseml/modifiers/obcq/base.py | 4 ++ src/sparseml/modifiers/obcq/pytorch.py | 1 + .../modifiers/obcq/utils/sgpt_wrapper.py | 41 +++++++++++++++++-- 3 files changed, 43 insertions(+), 3 deletions(-) diff --git a/src/sparseml/modifiers/obcq/base.py b/src/sparseml/modifiers/obcq/base.py index c1534618302..74920d0d697 100644 --- a/src/sparseml/modifiers/obcq/base.py +++ b/src/sparseml/modifiers/obcq/base.py @@ -54,6 +54,9 @@ class SparseGPTModifier(Modifier): :param block_size: Used to determine number of columns to compress in one pass :param dampening_frac: Amount of dampening to apply to H, as a fraction of the diagonal norm + :param preserve_sparsity_mask: Whether or not to preserve the sparsity mask + during when applying sparsegpt, this becomes useful when starting from a + previously pruned model, defaults to False. """ sparsity: Union[float, List[float]] = 0.0 @@ -68,6 +71,7 @@ class SparseGPTModifier(Modifier): prunem_: Optional[int] = None block_size: int = 128 dampening_frac: Optional[float] = 0.01 + preserve_sparsity_mask: bool = False def on_initialize_structure(self, state: State, **kwargs): """ diff --git a/src/sparseml/modifiers/obcq/pytorch.py b/src/sparseml/modifiers/obcq/pytorch.py index 4825eed1a92..ec9dfd90d23 100644 --- a/src/sparseml/modifiers/obcq/pytorch.py +++ b/src/sparseml/modifiers/obcq/pytorch.py @@ -203,6 +203,7 @@ def _compression_arguments(self, sparsity): "prunem": self.prunem_, "blocksize": self.block_size, "percdamp": self.dampening_frac, + "preserve_sparsity_mask": self.preserve_sparsity_mask, } def _compression_class(self): diff --git a/src/sparseml/modifiers/obcq/utils/sgpt_wrapper.py b/src/sparseml/modifiers/obcq/utils/sgpt_wrapper.py index d8a95f18853..0842dd8aab8 100644 --- a/src/sparseml/modifiers/obcq/utils/sgpt_wrapper.py +++ b/src/sparseml/modifiers/obcq/utils/sgpt_wrapper.py @@ -84,6 +84,7 @@ def fasterprune( prunem: int = 0, blocksize: int = 128, percdamp: float = 0.01, + preserve_sparsity_mask: bool = False, ): """ Run pruning and quantization(if applicable) on the layer up to the target @@ -95,6 +96,7 @@ def fasterprune( :param blocksize: Number of columns to compress in one pass :param percdamp: Amount of dampening to apply to H, as a fraction of the diagonal norm + :param preserve_sparsity_mask: Extend or ignore the base sparsity mask """ final_shape = self.layer.weight.shape final_dtype = self.layer.weight.dtype @@ -123,6 +125,13 @@ def fasterprune( Hinv = self.H mask = None + if preserve_sparsity_mask: + # compute existing sparsity mask + mask = torch.where( + W == 0, + torch.tensor(1, dtype=torch.bool), + torch.tensor(0, dtype=torch.bool), + ) # See section 3.4 of https://arxiv.org/abs/2203.07259 for i1 in range(0, self.columns, blocksize): @@ -138,12 +147,32 @@ def fasterprune( if prunen == 0: if mask is not None: mask1 = mask[:, i1:i2] + if int(W1.numel() * sparsity) > mask1.sum(): + # target sparsity is higher than base sparsity, extend mask1 + tmp = ( + (~mask[:, i1:i2]) + * W1**2 + / (torch.diag(Hinv1).reshape((1, -1))) ** 2 + ) + thresh = torch.sort(tmp.flatten())[0][ + int(tmp.numel() * sparsity) + ] + mask1 = tmp <= thresh + else: + raise ValueError( + "The target sparsity is lower than the sparsity " + "of the base model. Please retry " + "after turning preserve_sparsity_mask=False" + ) else: tmp = W1**2 / (torch.diag(Hinv1).reshape((1, -1))) ** 2 thresh = torch.sort(tmp.flatten())[0][int(tmp.numel() * sparsity)] mask1 = tmp <= thresh else: - mask1 = torch.zeros_like(W1) == 1 + if mask is not None: + mask1 = mask[:, i1:i2] + else: + mask1 = torch.zeros_like(W1) == 1 for i in range(count): w = W1[:, i] @@ -151,7 +180,8 @@ def fasterprune( if prunen != 0 and i % prunem == 0: tmp = ( - W1[:, i : (i + prunem)] ** 2 + (~mask[:, i : (i + prunem)]) + * W1[:, i : (i + prunem)] ** 2 / (torch.diag(Hinv1)[i : (i + prunem)].reshape((1, -1))) ** 2 ) mask1.scatter_( @@ -174,7 +204,12 @@ def fasterprune( W[:, i1:i2] = Q1 Losses += torch.sum(Losses1, 1) / 2 - W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) + if preserve_sparsity_mask: + # respect the sparsity of other groups + # really not needed, but kept for explicitness + W[:, i2:] -= (~mask[:, i2:]) * Err1.matmul(Hinv[i1:i2, i2:]) + else: + W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) _LOGGER.info("time %.2f" % (time.time() - tick)) _LOGGER.info("error %.2f" % torch.sum(Losses).item()) From 2a77468590f750af656b01a0e2d097cc2dc85be9 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Mon, 13 May 2024 21:12:10 +0000 Subject: [PATCH 4/4] fix case when mask is none --- src/sparseml/modifiers/obcq/utils/sgpt_wrapper.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/sparseml/modifiers/obcq/utils/sgpt_wrapper.py b/src/sparseml/modifiers/obcq/utils/sgpt_wrapper.py index 0842dd8aab8..0079071bd0e 100644 --- a/src/sparseml/modifiers/obcq/utils/sgpt_wrapper.py +++ b/src/sparseml/modifiers/obcq/utils/sgpt_wrapper.py @@ -180,10 +180,13 @@ def fasterprune( if prunen != 0 and i % prunem == 0: tmp = ( - (~mask[:, i : (i + prunem)]) - * W1[:, i : (i + prunem)] ** 2 + W1[:, i : (i + prunem)] ** 2 / (torch.diag(Hinv1)[i : (i + prunem)].reshape((1, -1))) ** 2 ) + + if mask is not None: + tmp = tmp * (~mask[:, i : (i + prunem)]) + mask1.scatter_( 1, i + torch.topk(tmp, prunen, dim=1, largest=False)[1], True )