Skip to content

Commit

Permalink
apply g_idx
Browse files Browse the repository at this point in the history
  • Loading branch information
horheynm committed Jun 26, 2024
1 parent c6b5b28 commit 28744fd
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 32 deletions.
2 changes: 1 addition & 1 deletion src/sparseml/modifiers/quantization/gptq/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def apply_compression(
layer_compressor.pre_compress()
_LOGGER.info(f"Calibrating {layer_compressor.name}...")
run_calibration_forward(self.model, dataloader, mask_padding=True)
layer_compressor.compress(self.actorder)
layer_compressor.compress()
layer_compressor.post_compress()
layer_compressor.revert_layer_wrappers()
torch.cuda.empty_cache()
Expand Down
74 changes: 46 additions & 28 deletions src/sparseml/modifiers/quantization/gptq/utils/gptq_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ def add_batch(self, inp: torch.Tensor, out: torch.Tensor):

def fasterprune(
self,
actorder: bool = False,
blocksize: int = 128,
percdamp: float = 0.01,
):
Expand Down Expand Up @@ -121,13 +120,6 @@ def fasterprune(
self.H[dead, dead] = 1
W[:, dead] = 0

# Or read from self.layer.quantization_scheme
if actorder:
perm = torch.argsort(torch.diag(self.H), descending=True)
W = W[:, perm]
self.H = self.H[perm][:, perm]
invperm = torch.argsort(perm)

Losses = torch.zeros(self.rows, device=self.dev)

damp = percdamp * torch.mean(torch.diag(self.H))
Expand All @@ -138,6 +130,9 @@ def fasterprune(
self.H = torch.linalg.cholesky(self.H, upper=True)
Hinv = self.H

actorder = False
invperm = None

# See section 3.4 of https://arxiv.org/abs/2203.07259
for i1 in range(0, self.columns, blocksize):
i2 = min(i1 + blocksize, self.columns)
Expand Down Expand Up @@ -171,7 +166,15 @@ def fasterprune(

elif hasattr(self.layer, "quantization_scheme"):
quant_scheme = self.layer.quantization_scheme
actorder = quant_scheme.weights.actorder
if quant_scheme.weights is not None:

if actorder:
perm = torch.argsort(torch.diag(self.H), descending=True)
W = W[:, perm]
self.H = self.H[perm][:, perm]
invperm = torch.argsort(perm)

scale = self.layer.weight_scale
zero_point = self.layer.weight_zero_point

Expand All @@ -180,23 +183,15 @@ def fasterprune(
group_size = self.layer.weight.shape[1]

if actorder:
g_idx = torch.tensor(
[perm[j] // group_size for j in range(self.columns)],
dtype=torch.int32,
device=invperm.device
)

indices = torch.arange(self.columns, device=invperm.device)
g_idx = (perm[indices] // group_size).to(dtype=torch.int32)
g_idx = g_idx[invperm]
self.layer.weight_g_idx = Parameter(
g_idx,
requires_grad=False,
)
self.layer.weight_g_idx.data = g_idx
else:
g_idx = torch.tensor(
[j // group_size for j in range(self.columns)],
dtype=torch.int32,
device=W.device
indices = torch.arange(
self.columns, device=W.device, dtype=torch.int32
)
g_idx = indices // group_size

from compressed_tensors.quantization import QuantizationStrategy
from compressed_tensors.quantization.lifecycle.forward import (
Expand All @@ -217,7 +212,6 @@ def fasterprune(
q,
scale[:, 0],
zero_point[:, 0],
# g_idx,
quant_scheme.weights,
)
else: # strategy == QuantizationStrategy.GROUP
Expand All @@ -230,11 +224,16 @@ def fasterprune(
# ends up being a channelwise application
altered_qargs = copy(quant_scheme.weights)
altered_qargs.strategy = QuantizationStrategy.CHANNEL

# # apply g_idx
# if g_idx is not None:
# scale = scale[g_idx]
# zero_point = zero_point[g_idx]

# apply g_idx
if g_idx is not None:
# scale and zp already transformed by group_size
# extract first index of group_idze
indices_to_extract = torch.arange(
0, g_idx.shape[0], group_size
)
scale = scale[:, g_idx[indices_to_extract]]
zero_point = zero_point[:, g_idx[indices_to_extract]]

q = fake_quantize(
q,
Expand Down Expand Up @@ -284,3 +283,22 @@ def free(self):
"""
delattr(self, "H")
super().free()


"""
(Pdb) scale.shape
torch.Size([4096, 32])
(Pdb) self.layer.shape
*** AttributeError: 'Linear' object has no attribute 'shape'
(Pdb) self.layer.weight.shape
torch.Size([4096, 4096])
(Pdb) scale.shape
torch.Size([11008, 32])
(Pdb) self.layer.weight.shape
torch.Size([11008, 4096])
"""
6 changes: 3 additions & 3 deletions src/sparseml/modifiers/utils/layer_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,10 @@ def revert_layer_wrappers(self):
module_wrapper.free()
self.modules = None

def compress(self, actorder: bool = False):
def compress(self):
"""
Apply compression to each wrapped submodule in the layer
:param: actorder: flag to apply activation reordering
"""

Expand All @@ -143,7 +143,7 @@ def prune(module):
if isinstance(module, self.module_compressor_class):
full_name = self._get_full_submodule_name(module.name)
_LOGGER.info(f"Compressing {full_name}...")
module.fasterprune(actorder=actorder, **self.args)
module.fasterprune(**self.args)

self.layer.apply(prune)

Expand Down

0 comments on commit 28744fd

Please sign in to comment.