Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issue with reuse layers in torch #1217

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -258,12 +258,10 @@ def fetch_hessian(self,
f"{hessian_scores_request.target_nodes}.")

# Replace node in reused target nodes with a representing node from the 'reuse group'.
for n in hessian_scores_request.target_nodes:
if n.reuse_group:
rep_node = self._get_representing_of_reuse_group(n)
hessian_scores_request.target_nodes.remove(n)
if rep_node not in hessian_scores_request.target_nodes:
hessian_scores_request.target_nodes.append(rep_node)
hessian_scores_request.target_nodes = [
self._get_representing_of_reuse_group(node) if node.reuse else node
for node in hessian_scores_request.target_nodes
]

# Ensure the saved info has the required number of approximations
self._populate_saved_info_to_size(hessian_scores_request, required_size, batch_size)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ def __init__(self,
self.return_float_outputs = return_float_outputs
self.wrapper = wrapper
self.get_activation_quantizer_holder = get_activation_quantizer_holder_fn
self.reuse_groups = {}
self._add_modules()

# todo: Move to parent class BaseModelBuilder
Expand Down Expand Up @@ -279,7 +280,18 @@ def _add_modules(self):
Build and add the modules and functional nodes from node_sort list as attributes to PytorchModel
"""
for node in self.node_sort:
node_op = self.wrap(node)
if node.reuse:
# If the node is reused, retrieve the original module
if node.reuse_group not in self.reuse_groups:
raise ValueError(f"Reuse group {node.reuse_group} not found for node {node.name}")
node_op = self.reuse_groups[node.reuse_group]
else:
# If it's not reused, create a new module
node_op = self.wrap(node)
if node.reuse_group:
# Store the module for future reuse
self.reuse_groups[node.reuse_group] = node_op

if isinstance(node, FunctionalNode):
# for functional layers
setattr(self, node.name, node_op)
Expand Down
18 changes: 18 additions & 0 deletions model_compression_toolkit/core/pytorch/reader/graph_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,9 @@ def nodes_builder(model: GraphModule,
consts_dict = {}
used_consts = set()

# Dictionary to track seen targets and their corresponding nodes to mark reused nodes
seen_targets = {}

# Init parameters & buffers dictionary of the entire model. We later extract the constants values from this dictionary.
model_parameters_and_buffers = _extract_parameters_and_buffers(model, to_numpy)

Expand Down Expand Up @@ -237,6 +240,19 @@ def nodes_builder(model: GraphModule,
# Extract input and output shapes of the node.
input_shape, output_shape = _extract_input_and_output_shapes(node)

# Check if this node's target has been seen before
reuse = False
reuse_group = None
# We mark nodes as reused only if there are multiple nodes in the graph with same
# 'target' and it has some weights.
if node.target in seen_targets and len(weights) > 0:
reuse = True
reuse_group = str(node.target)
# Update the 'base/main' node with the reuse group as all other nodes in its group.
fx_node_2_graph_node[seen_targets[node.target]].reuse_group = reuse_group
else:
seen_targets[node.target] = node

# Initiate graph nodes.
if node.op in [CALL_METHOD, CALL_FUNCTION]:
graph_node_type = FunctionalNode
Expand Down Expand Up @@ -291,6 +307,8 @@ def nodes_builder(model: GraphModule,
weights=weights,
layer_class=node_type,
has_activation=node_has_activation,
reuse=reuse,
reuse_group=reuse_group,
**kwargs)

# Generate graph inputs list.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np
import torch
from model_compression_toolkit.core.pytorch.utils import get_working_device
from torch.nn import Conv2d

import model_compression_toolkit as mct
Expand Down Expand Up @@ -105,21 +106,17 @@ def test_adding_holders_after_reuse(self):
last_module = list(gptq_model.named_modules())[-1][1]
# the last module should be an activation quantization holder
self.assertTrue(isinstance(last_module, PytorchActivationQuantizationHolder))
# check that 4 activation quantization holders where generated
# check that 3 activation quantization holders where generated
self.assertTrue(len(activation_quantization_holders_in_model) == 3)
for a in activation_quantization_holders_in_model:
self.assertTrue(isinstance(a.activation_holder_quantizer, STESymmetricActivationTrainableQuantizer))
for name, module in gptq_model.named_modules():
if isinstance(module, PytorchQuantizationWrapper):
self.assertTrue(len(module.weights_quantizers) > 0)
# Test that two holders are getting inputs from reused conv2d (the layer that is wrapped)

# FIXME there is no reuse support and the test doesn't test what it says it tests. It doesn't even look
# at correct layers. After moving to trainable quantizer the test makes even less sense since now fx traces
# all quantization operations instead of fake_quant layer.
# fx_model = symbolic_trace(gptq_model)
# self.assertTrue(list(fx_model.graph.nodes)[3].all_input_nodes[0] == list(fx_model.graph.nodes)[2])
# self.assertTrue(list(fx_model.graph.nodes)[6].all_input_nodes[0] == list(fx_model.graph.nodes)[5])
self.assertEqual([p.data_ptr() for p in gptq_model.conv.parameters()],
[p.data_ptr() for p in gptq_model.conv_1.parameters()],
f"Shared parameters between reused layers should have identical memory addresses")

def _get_gptq_model(self, input_shape, in_model):
pytorch_impl = GPTQPytorchImplemantation()
Expand All @@ -136,10 +133,10 @@ def _get_gptq_model(self, input_shape, in_model):
graph = set_bit_widths(mixed_precision_enable=False,
graph=graph)
trainer = PytorchGPTQTrainer(graph,
graph,
mct.gptq.get_pytorch_gptq_config(1, use_hessian_based_weights=False),
pytorch_impl,
DEFAULT_PYTORCH_INFO,
representative_dataset)
graph,
mct.gptq.get_pytorch_gptq_config(1, use_hessian_based_weights=False),
pytorch_impl,
DEFAULT_PYTORCH_INFO,
representative_dataset)
gptq_model, _ = trainer.build_gptq_model()
return gptq_model
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,7 @@ def __init__(self):

def forward(self, inp):
x = self.conv1(inp)
x1 = self.bn1(x)
x1 = self.relu(x1)
x1 = self.relu(x)
x_split = torch.split(x1, split_size_or_sections=4, dim=-1)
x1 = self.conv1(x_split[0])
x2 = x_split[1]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.
# ==============================================================================
import torch

from mct_quantizers import PytorchQuantizationWrapper
from tests.pytorch_tests.model_tests.base_pytorch_test import BasePytorchTest

"""
Expand All @@ -25,31 +27,12 @@ def __init__(self):
super(ReuseLayerNet, self).__init__()
self.conv1 = torch.nn.Conv2d(3, 3, kernel_size=1, stride=1)
self.conv2 = torch.nn.Conv2d(3, 3, kernel_size=1, stride=1)
self.identity = torch.nn.Identity()

def forward(self, x, y):
x = self.conv1(x)
x = self.identity(x)
def forward(self, x):
x = self.conv1(x)
x = self.identity(x)
x = self.conv1(x)
x = self.identity(x)
x = self.conv2(x)
x = self.identity(x)
x = self.conv2(x)
x = self.identity(x)
x = self.conv2(x)
x = self.identity(x)
y = self.conv2(y)
y = self.identity(y)
y = self.conv2(y)
y = self.identity(y)
y = self.conv1(y)
y = self.identity(y)
y = self.conv1(y)
y = self.identity(y)
y = self.conv1(y)
return x - y, y - x
x = self.conv1(x)
return x


class ReuseLayerNetTest(BasePytorchTest):
Expand All @@ -61,7 +44,43 @@ def __init__(self, unit_test):
super().__init__(unit_test)

def create_inputs_shape(self):
return [[self.val_batch_size, 3, 32, 32], [self.val_batch_size, 3, 32, 32]]
return [[self.val_batch_size, 3, 32, 32]]

def create_feature_network(self, input_shape):
return ReuseLayerNet()
model = ReuseLayerNet()
return model

def compare(self, quantized_models, float_model, input_x=None, quantization_info=None):

quant_model = quantized_models['all_4bit']

#########################################################################################

# Verify that the shared parameters have identical memory addresses
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider adding another shared conv and validate it is different from the first one

self.unit_test.assertEqual([p.data_ptr() for p in quant_model.conv1.parameters()],
[p.data_ptr() for p in quant_model.conv1_1.parameters()],
f"Shared parameters between reused layers should have identical memory addresses")

#########################################################################################

# Verify that 'conv1' is called twice (thus reused) and 'conv2' is called once
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can use defaultdict

layer_calls = {}
def hook_fn(module, input, output):
layer_name = [name for name, layer in quant_model.named_modules() if layer is module][0]
if layer_name not in layer_calls:
layer_calls[layer_name] = 0
layer_calls[layer_name] += 1

# Register hooks
hooks = []
for name, module in quant_model.named_modules():
if isinstance(module, PytorchQuantizationWrapper):
hooks.append(module.register_forward_hook(hook_fn))
_ = quant_model(input_x)
for hook in hooks:
hook.remove()

self.unit_test.assertEqual(layer_calls['conv1'], 2, "conv1 should be called twice")
self.unit_test.assertEqual(layer_calls['conv2'], 1, "conv2 should be called once")

#########################################################################################
Loading