Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Torch][WeightCompression] Add Scale Estimation data-aware support #3179

Open
wants to merge 45 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
09126af
add torch sample
kshpv Jan 3, 2025
67cef71
upd sample
kshpv Jan 7, 2025
94d2850
fix reducers
kshpv Jan 7, 2025
db42165
align SE with GPTQ
kshpv Jan 7, 2025
f96788a
add tests
kshpv Jan 8, 2025
b1d4c47
backend method - get_filter_fn_for_statistics
kshpv Jan 8, 2025
51ccdd6
fixes
kshpv Jan 8, 2025
e6a9191
Merge remote-tracking branch 'remote/develop' into scale_est_torch
kshpv Jan 8, 2025
e37ef52
sample
kshpv Jan 8, 2025
d1843ad
add tinyllama_data_aware, tinyllama_scale_estimation_per_channel for …
kshpv Jan 8, 2025
cd79e80
fix precommit
kshpv Jan 8, 2025
bc0731c
Merge remote-tracking branch 'remote/develop' into scale_est_torch
kshpv Jan 8, 2025
df6b43b
minor
kshpv Jan 8, 2025
368054a
refactor test
kshpv Jan 8, 2025
e2a6f46
add WA for dataset
kshpv Jan 8, 2025
dbf2b1d
fix
kshpv Jan 8, 2025
702f8b1
dtype
kshpv Jan 8, 2025
24e39c2
polishing
kshpv Jan 8, 2025
e97078b
updates for torch
kshpv Jan 15, 2025
035a668
add functions
kshpv Jan 15, 2025
58b9924
upd metrics
kshpv Jan 15, 2025
be3694b
rm ov flag
kshpv Jan 15, 2025
9345e2f
rm example
kshpv Jan 15, 2025
6e7d981
rm comments
kshpv Jan 15, 2025
683cfd4
fix tests
kshpv Jan 16, 2025
1a33369
reimplement compress/decompress
kshpv Jan 16, 2025
dcf88a5
rm fx
kshpv Jan 16, 2025
e48a44b
add wc template
kshpv Jan 17, 2025
63e8c0a
polishing
kshpv Jan 17, 2025
b2fef75
comment
kshpv Jan 17, 2025
32788a4
comments
kshpv Jan 17, 2025
9d0acdb
rollback no_grad
kshpv Jan 20, 2025
37a41ac
add torch.no_grad()
kshpv Jan 20, 2025
a305fac
start of cuda in conformance
kshpv Jan 20, 2025
ddee495
add scale estimation test
kshpv Jan 21, 2025
f89ae9d
upd year
kshpv Jan 21, 2025
026a0ed
add tinyllama_scale_estimation_group_size_64
kshpv Jan 21, 2025
e3f12c2
torch.no_grad -> torch.inference_mode
kshpv Jan 21, 2025
a347a25
upd reference
kshpv Jan 21, 2025
601f2e4
#test: upd int4 weight locator for torch
kshpv Jan 22, 2025
32bc0e5
upd licence year
kshpv Jan 22, 2025
f8d6451
Merge remote-tracking branch 'remote/develop' into scale_est_torch
kshpv Jan 23, 2025
7557fb5
merge
kshpv Jan 23, 2025
568809c
rebase
kshpv Jan 23, 2025
8c7efd6
add test on scale estimation
kshpv Jan 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@
import numpy as np
import pytest

import nncf.tensor.functions as fns
from nncf import CompressWeightsMode
from nncf import SensitivityMetric
from nncf.data.dataset import Dataset
from nncf.quantization import compress_weights
from nncf.quantization.algorithms.weight_compression.mixed_precision import MIXED_PRECISION_CRITERIA
from nncf.quantization.algorithms.weight_compression.scale_estimation import ScaleEstimation
from nncf.tensor import Tensor
from nncf.tensor import TensorDataType

TModel = TypeVar("TModel")
Expand All @@ -39,13 +42,11 @@ class TemplateWeightCompression(ABC):
@staticmethod
@abstractmethod
def cast_to(x: TTensor, dtype: TensorDataType) -> TTensor:
pass
"""Casts a backend tensor to backend tensor with specified dtype."""

@abstractmethod
def get_matmul_model() -> TModel:
"""
Returns a backend model for test_data_based_criterion.
"""
"""Returns a backend model for test_data_based_criterion."""

@pytest.mark.parametrize(
("mode", "ref_act_score", "ref_score"),
Expand Down Expand Up @@ -80,13 +81,11 @@ def test_data_based_criterion(self, mode, ref_score, ref_act_score, mocker):

@abstractmethod
def get_sequential_matmul_model() -> TModel:
"""
Returns a backend model for test_mixed_precision.
"""
"""Returns a backend model for test_mixed_precision."""

@abstractmethod
def to_tensor(x: TTensor) -> TTensor:
pass
"""Returns a backend tensor."""

@abstractmethod
def check_weights(model: TModel, ref_ids: List[int]) -> None:
Expand Down Expand Up @@ -128,3 +127,39 @@ def test_mixed_precision(self, mode, all_layers, ratio, ref_ids):
dataset=dataset,
)
self.check_weights(compressed_model, ref_ids)

@staticmethod
@abstractmethod
def get_model_for_test_scale_estimation():
"""
Returns a backend model for test_scale_estimation.
"""

@staticmethod
@abstractmethod
def get_scale_estimation_ref():
"""
Returns the reference output of calculate_quantization_params of ScaleEstimation.
"""

def test_scale_estimation(self, mocker):
calc_q_params_spy = mocker.spy(ScaleEstimation, "calculate_quantization_params")
model = self.get_model_for_test_scale_estimation()

# prepare dataset with one input tensor
input = np.arange(0, 32 * 32, dtype=np.float32).reshape(1, 32, 32)
input[0, 15] *= 100 # make one channel relatively higher.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would expect this to be reflected in the references, but I don't see anything outstanding. Maybe the value should be higher.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that this test intends to check the difference between torch and OV backends. It does not aim to check the algorithm's correctness. However, I agree that your proposal is good. I can add a new test which will check the error after quantization and demonstrate that the outlier channel has the lowest error against others

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you going to add this new test in the follow-up PR?

input = self.to_tensor(input)
dataset = Dataset([input])

_ = compress_weights(
model,
mode=CompressWeightsMode.INT4_ASYM,
ratio=1.0,
group_size=32,
scale_estimation=True,
all_layers=True,
dataset=dataset,
)
reference = self.get_scale_estimation_ref()
assert fns.allclose(Tensor(reference), calc_q_params_spy.spy_return[0])
15 changes: 15 additions & 0 deletions tests/openvino/native/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1185,3 +1185,18 @@ def _create_ov_model(self):

model = ov.Model([sin_result, cos_result], [position_ids])
return model


class MLP(OVReferenceModel):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be honest it is not MLP. There is one layer.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

def _create_ov_model(self):
input_node = opset.parameter([1, 32, 32], name="Input")

weights_data = np.arange(0, 32 * 32, dtype=np.float32).reshape(32, 32)
weights_node = opset.constant(weights_data, dtype=np.float32, name="Weights")

matmul_node = opset.matmul(input_node, weights_node, transpose_a=False, transpose_b=True, name="MatMul")

result_node = opset.result(matmul_node, name="Result")

model = ov.Model([result_node], [input_node], name="MLP_Model")
return model
44 changes: 44 additions & 0 deletions tests/openvino/native/quantization/test_weights_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from tests.cross_fw.test_templates.template_test_weights_compression import ACTIVATION
from tests.cross_fw.test_templates.template_test_weights_compression import TemplateWeightCompression
from tests.openvino.native.common import get_actual_reference_for_current_openvino
from tests.openvino.native.models import MLP
from tests.openvino.native.models import AWQActMatmulModel
from tests.openvino.native.models import AWQMatmulModel
from tests.openvino.native.models import GatherAndMatmulShareData
Expand Down Expand Up @@ -1546,3 +1547,46 @@ def check_weights(model: ov.Model, ref_ids: List[int]) -> None:
names = {op.get_friendly_name() for op in model.get_ordered_ops() if op.get_element_type() == ov.Type.i4}
ref_nf4_nodes = {f"weights_{i}" for i in ref_ids}
assert ref_nf4_nodes == names

@staticmethod
def get_model_for_test_scale_estimation():
return MLP().ov_model

@staticmethod
def get_scale_estimation_ref():
return np.array(
[
[[2.0666666]],
[[3.7624273]],
[[5.884783]],
[[8.03606]],
[[10.136832]],
[[12.291862]],
[[14.34415]],
[[16.449669]],
[[18.608639]],
[[20.802698]],
[[22.9477]],
[[25.083504]],
[[27.152409]],
[[29.141987]],
[[31.171442]],
[[33.044716]],
[[35.178047]],
[[37.31138]],
[[39.444714]],
[[41.578045]],
[[43.71138]],
[[45.844715]],
[[47.978046]],
[[50.11138]],
[[52.244713]],
[[54.378044]],
[[56.511383]],
[[58.644714]],
[[60.77805]],
[[62.91138]],
[[65.044716]],
[[67.17805]],
]
)
38 changes: 38 additions & 0 deletions tests/torch/fx/test_weights_compression_backends.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from nncf.quantization.algorithms.weight_compression.mixed_precision import HAWQCriterion
from nncf.quantization.algorithms.weight_compression.mixed_precision import MaxVarianceCriterion
from nncf.quantization.algorithms.weight_compression.mixed_precision import MeanMaxCriterion
from nncf.quantization.algorithms.weight_compression.mixed_precision import MeanVarianceCriterion
from nncf.quantization.algorithms.weight_compression.torch_backend import PTMixedPrecisionAlgoBackend
from tests.cross_fw.test_templates.test_weights_compression_backends import TemplateTestMixedPrecisionAlgoBackend


class TestPTMixedPrecisionAlgoBackend(TemplateTestMixedPrecisionAlgoBackend):
def get_hawq_with_backend(self, subset_size):
hawq = HAWQCriterion(None, None, subset_size=subset_size)
hawq._backend_entity = PTMixedPrecisionAlgoBackend()
return hawq

def get_mean_variance_with_backend(self, subset_size: int):
mean_variance = MeanVarianceCriterion(None, None, subset_size=subset_size)
mean_variance._backend_entity = PTMixedPrecisionAlgoBackend()
return mean_variance

def get_max_variance_with_backend(self, subset_size: int):
max_variance = MaxVarianceCriterion(None, None, subset_size=subset_size)
max_variance._backend_entity = PTMixedPrecisionAlgoBackend()
return max_variance

def get_mean_max_with_backend(self, subset_size: int):
mean_max_variance = MeanMaxCriterion(None, None, subset_size=subset_size)
mean_max_variance._backend_entity = PTMixedPrecisionAlgoBackend()
return mean_max_variance
110 changes: 76 additions & 34 deletions tests/torch/ptq/test_weights_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,45 @@
UNSUPPORTED_MODES = (CompressWeightsMode.NF4, CompressWeightsMode.E2M1)


class MatMulModel(torch.nn.Module):
class SequentialMatmulModel(nn.Module):
def __init__(self):
super(SequentialMatmulModel, self).__init__()
self.main_values = [10000, 1000, 1, 10, 10000]
self.layers = nn.ModuleList()

for _, main_value in enumerate(self.main_values):
weights_data = torch.arange(0, 16, dtype=torch.float32).reshape(4, 4)
weights_data[-1, -1] = main_value
weight_tensor = torch.tensor(weights_data)
layer = nn.Linear(4, 4, bias=False)
layer.weight = nn.Parameter(weight_tensor.t())
self.layers.append(layer)

def forward(self, x):
for layer in self.layers:
x = layer(x)
return x


class MatMulModel(torch.nn.Module):
def __init__(self, weight: torch.Tensor = torch.ones(size=(256, 256), dtype=torch.float32)):
super().__init__()
self.w = torch.nn.Parameter(torch.ones(size=(256, 256), dtype=torch.float32))
self.w = torch.nn.Parameter(weight)

def forward(self, input):
return input @ self.w


class LinearModel(torch.nn.Module):
def __init__(self, weight: torch.Tensor = torch.ones(size=(256, 256), dtype=torch.float32)):
super().__init__()
self.linear = torch.nn.Linear(weight.shape[0], weight.shape[1], False)
self.linear.weight = torch.nn.Parameter(weight)

def forward(self, input):
return self.linear(input)


class FunctionalModel(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -326,41 +356,10 @@ def test_pack_int4():
assert torch.all(unpacked_w == w_int8)


class IdentityMatmul(torch.nn.Module):
def __init__(self):
super().__init__()
self.w = torch.nn.Parameter(
torch.eye(3, dtype=torch.float32) * 255,
)

def forward(self, input):
return input @ self.w


class SequentialMatmulModel(nn.Module):
def __init__(self):
super(SequentialMatmulModel, self).__init__()
self.main_values = [10000, 1000, 1, 10, 10000]
self.layers = nn.ModuleList()

for _, main_value in enumerate(self.main_values):
weights_data = torch.arange(0, 16, dtype=torch.float32).reshape(4, 4)
weights_data[-1, -1] = main_value
weight_tensor = torch.tensor(weights_data)
layer = nn.Linear(4, 4, bias=False)
layer.weight = nn.Parameter(weight_tensor.t())
self.layers.append(layer)

def forward(self, x):
for layer in self.layers:
x = layer(x)
return x


class TestPTTemplateWeightCompression(TemplateWeightCompression):
@staticmethod
def get_matmul_model() -> torch.nn.Module:
return IdentityMatmul()
return MatMulModel(255 * torch.eye(3, dtype=torch.float32))

@staticmethod
def get_sequential_matmul_model() -> torch.nn.Module:
Expand All @@ -381,3 +380,46 @@ def check_weights(model: torch.nn.Module, ref_ids: List[int]) -> None:
assert torch.numel(op.weight) == 8 # workaround to detect uint4 weights
else:
assert torch.numel(op.weight) == 16

@staticmethod
def get_model_for_test_scale_estimation():
return LinearModel(torch.arange(0, 32 * 32, dtype=torch.float32).reshape(32, 32))

@staticmethod
def get_scale_estimation_ref():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible to move reference to json file like int this test:

ref_stats_path = get_actual_reference_for_current_openvino(
?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reduce the scales number. Should look much better

return torch.tensor(
[
[[2.0666666]],
[[3.7624271]],
[[5.8847833]],
[[8.0360603]],
[[10.1368332]],
[[12.2918606]],
[[14.3441496]],
[[16.4496689]],
[[18.6086369]],
[[20.8027000]],
[[22.9477024]],
[[25.0835018]],
[[27.1524105]],
[[29.1419849]],
[[31.1714401]],
[[33.0447121]],
[[35.1780472]],
[[37.3113823]],
[[39.4447136]],
[[41.5780487]],
[[43.7113838]],
[[45.8447189]],
[[47.9780464]],
[[50.1113815]],
[[52.2447128]],
[[54.3780441]],
[[56.5113831]],
[[58.6447144]],
[[60.7780533]],
[[62.9113808]],
[[65.0447083]],
[[67.1780548]],
]
)
Loading