diff --git a/model_compression_toolkit/core/common/hessian/hessian_info_service.py b/model_compression_toolkit/core/common/hessian/hessian_info_service.py index 9f4f416d1..ba5e2c143 100644 --- a/model_compression_toolkit/core/common/hessian/hessian_info_service.py +++ b/model_compression_toolkit/core/common/hessian/hessian_info_service.py @@ -258,12 +258,10 @@ def fetch_hessian(self, f"{hessian_scores_request.target_nodes}.") # Replace node in reused target nodes with a representing node from the 'reuse group'. - for n in hessian_scores_request.target_nodes: - if n.reuse_group: - rep_node = self._get_representing_of_reuse_group(n) - hessian_scores_request.target_nodes.remove(n) - if rep_node not in hessian_scores_request.target_nodes: - hessian_scores_request.target_nodes.append(rep_node) + hessian_scores_request.target_nodes = [ + self._get_representing_of_reuse_group(node) if node.reuse else node + for node in hessian_scores_request.target_nodes + ] # Ensure the saved info has the required number of approximations self._populate_saved_info_to_size(hessian_scores_request, required_size, batch_size) diff --git a/model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py b/model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py index af3ed0a6f..a74ea13dc 100644 --- a/model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +++ b/model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py @@ -222,6 +222,7 @@ def __init__(self, self.return_float_outputs = return_float_outputs self.wrapper = wrapper self.get_activation_quantizer_holder = get_activation_quantizer_holder_fn + self.reuse_groups = {} self._add_modules() # todo: Move to parent class BaseModelBuilder @@ -279,7 +280,18 @@ def _add_modules(self): Build and add the modules and functional nodes from node_sort list as attributes to PytorchModel """ for node in self.node_sort: - node_op = self.wrap(node) + if node.reuse: + # If the node is reused, retrieve the original module + if node.reuse_group not in self.reuse_groups: + raise ValueError(f"Reuse group {node.reuse_group} not found for node {node.name}") + node_op = self.reuse_groups[node.reuse_group] + else: + # If it's not reused, create a new module + node_op = self.wrap(node) + if node.reuse_group: + # Store the module for future reuse + self.reuse_groups[node.reuse_group] = node_op + if isinstance(node, FunctionalNode): # for functional layers setattr(self, node.name, node_op) diff --git a/model_compression_toolkit/core/pytorch/reader/graph_builders.py b/model_compression_toolkit/core/pytorch/reader/graph_builders.py index 78b3b3400..915518c3e 100644 --- a/model_compression_toolkit/core/pytorch/reader/graph_builders.py +++ b/model_compression_toolkit/core/pytorch/reader/graph_builders.py @@ -178,6 +178,9 @@ def nodes_builder(model: GraphModule, consts_dict = {} used_consts = set() + # Dictionary to track seen targets and their corresponding nodes to mark reused nodes + seen_targets = {} + # Init parameters & buffers dictionary of the entire model. We later extract the constants values from this dictionary. model_parameters_and_buffers = _extract_parameters_and_buffers(model, to_numpy) @@ -237,6 +240,19 @@ def nodes_builder(model: GraphModule, # Extract input and output shapes of the node. input_shape, output_shape = _extract_input_and_output_shapes(node) + # Check if this node's target has been seen before + reuse = False + reuse_group = None + # We mark nodes as reused only if there are multiple nodes in the graph with same + # 'target' and it has some weights. + if node.target in seen_targets and len(weights) > 0: + reuse = True + reuse_group = str(node.target) + # Update the 'base/main' node with the reuse group as all other nodes in its group. + fx_node_2_graph_node[seen_targets[node.target]].reuse_group = reuse_group + else: + seen_targets[node.target] = node + # Initiate graph nodes. if node.op in [CALL_METHOD, CALL_FUNCTION]: graph_node_type = FunctionalNode @@ -291,6 +307,8 @@ def nodes_builder(model: GraphModule, weights=weights, layer_class=node_type, has_activation=node_has_activation, + reuse=reuse, + reuse_group=reuse_group, **kwargs) # Generate graph inputs list. diff --git a/tests/pytorch_tests/function_tests/test_activation_quantization_holder_gptq.py b/tests/pytorch_tests/function_tests/test_activation_quantization_holder_gptq.py index 53033e18d..f6cbd578f 100644 --- a/tests/pytorch_tests/function_tests/test_activation_quantization_holder_gptq.py +++ b/tests/pytorch_tests/function_tests/test_activation_quantization_holder_gptq.py @@ -3,6 +3,7 @@ import numpy as np import torch +from model_compression_toolkit.core.pytorch.utils import get_working_device from torch.nn import Conv2d import model_compression_toolkit as mct @@ -105,21 +106,17 @@ def test_adding_holders_after_reuse(self): last_module = list(gptq_model.named_modules())[-1][1] # the last module should be an activation quantization holder self.assertTrue(isinstance(last_module, PytorchActivationQuantizationHolder)) - # check that 4 activation quantization holders where generated + # check that 3 activation quantization holders where generated self.assertTrue(len(activation_quantization_holders_in_model) == 3) for a in activation_quantization_holders_in_model: self.assertTrue(isinstance(a.activation_holder_quantizer, STESymmetricActivationTrainableQuantizer)) for name, module in gptq_model.named_modules(): if isinstance(module, PytorchQuantizationWrapper): self.assertTrue(len(module.weights_quantizers) > 0) - # Test that two holders are getting inputs from reused conv2d (the layer that is wrapped) - # FIXME there is no reuse support and the test doesn't test what it says it tests. It doesn't even look - # at correct layers. After moving to trainable quantizer the test makes even less sense since now fx traces - # all quantization operations instead of fake_quant layer. - # fx_model = symbolic_trace(gptq_model) - # self.assertTrue(list(fx_model.graph.nodes)[3].all_input_nodes[0] == list(fx_model.graph.nodes)[2]) - # self.assertTrue(list(fx_model.graph.nodes)[6].all_input_nodes[0] == list(fx_model.graph.nodes)[5]) + self.assertEqual([p.data_ptr() for p in gptq_model.conv.parameters()], + [p.data_ptr() for p in gptq_model.conv_1.parameters()], + f"Shared parameters between reused layers should have identical memory addresses") def _get_gptq_model(self, input_shape, in_model): pytorch_impl = GPTQPytorchImplemantation() @@ -136,10 +133,10 @@ def _get_gptq_model(self, input_shape, in_model): graph = set_bit_widths(mixed_precision_enable=False, graph=graph) trainer = PytorchGPTQTrainer(graph, - graph, - mct.gptq.get_pytorch_gptq_config(1, use_hessian_based_weights=False), - pytorch_impl, - DEFAULT_PYTORCH_INFO, - representative_dataset) + graph, + mct.gptq.get_pytorch_gptq_config(1, use_hessian_based_weights=False), + pytorch_impl, + DEFAULT_PYTORCH_INFO, + representative_dataset) gptq_model, _ = trainer.build_gptq_model() return gptq_model \ No newline at end of file diff --git a/tests/pytorch_tests/function_tests/test_hessian_info_calculator.py b/tests/pytorch_tests/function_tests/test_hessian_info_calculator.py index 87130b1f9..e16279c65 100644 --- a/tests/pytorch_tests/function_tests/test_hessian_info_calculator.py +++ b/tests/pytorch_tests/function_tests/test_hessian_info_calculator.py @@ -113,8 +113,7 @@ def __init__(self): def forward(self, inp): x = self.conv1(inp) - x1 = self.bn1(x) - x1 = self.relu(x1) + x1 = self.relu(x) x_split = torch.split(x1, split_size_or_sections=4, dim=-1) x1 = self.conv1(x_split[0]) x2 = x_split[1] diff --git a/tests/pytorch_tests/model_tests/feature_models/reuse_layer_net_test.py b/tests/pytorch_tests/model_tests/feature_models/reuse_layer_net_test.py index c37317ebf..e2f6d3214 100644 --- a/tests/pytorch_tests/model_tests/feature_models/reuse_layer_net_test.py +++ b/tests/pytorch_tests/model_tests/feature_models/reuse_layer_net_test.py @@ -13,6 +13,8 @@ # limitations under the License. # ============================================================================== import torch + +from mct_quantizers import PytorchQuantizationWrapper from tests.pytorch_tests.model_tests.base_pytorch_test import BasePytorchTest """ @@ -25,31 +27,12 @@ def __init__(self): super(ReuseLayerNet, self).__init__() self.conv1 = torch.nn.Conv2d(3, 3, kernel_size=1, stride=1) self.conv2 = torch.nn.Conv2d(3, 3, kernel_size=1, stride=1) - self.identity = torch.nn.Identity() - def forward(self, x, y): - x = self.conv1(x) - x = self.identity(x) + def forward(self, x): x = self.conv1(x) - x = self.identity(x) - x = self.conv1(x) - x = self.identity(x) - x = self.conv2(x) - x = self.identity(x) x = self.conv2(x) - x = self.identity(x) - x = self.conv2(x) - x = self.identity(x) - y = self.conv2(y) - y = self.identity(y) - y = self.conv2(y) - y = self.identity(y) - y = self.conv1(y) - y = self.identity(y) - y = self.conv1(y) - y = self.identity(y) - y = self.conv1(y) - return x - y, y - x + x = self.conv1(x) + return x class ReuseLayerNetTest(BasePytorchTest): @@ -61,7 +44,43 @@ def __init__(self, unit_test): super().__init__(unit_test) def create_inputs_shape(self): - return [[self.val_batch_size, 3, 32, 32], [self.val_batch_size, 3, 32, 32]] + return [[self.val_batch_size, 3, 32, 32]] def create_feature_network(self, input_shape): - return ReuseLayerNet() \ No newline at end of file + model = ReuseLayerNet() + return model + + def compare(self, quantized_models, float_model, input_x=None, quantization_info=None): + + quant_model = quantized_models['all_4bit'] + + ######################################################################################### + + # Verify that the shared parameters have identical memory addresses + self.unit_test.assertEqual([p.data_ptr() for p in quant_model.conv1.parameters()], + [p.data_ptr() for p in quant_model.conv1_1.parameters()], + f"Shared parameters between reused layers should have identical memory addresses") + + ######################################################################################### + + # Verify that 'conv1' is called twice (thus reused) and 'conv2' is called once + layer_calls = {} + def hook_fn(module, input, output): + layer_name = [name for name, layer in quant_model.named_modules() if layer is module][0] + if layer_name not in layer_calls: + layer_calls[layer_name] = 0 + layer_calls[layer_name] += 1 + + # Register hooks + hooks = [] + for name, module in quant_model.named_modules(): + if isinstance(module, PytorchQuantizationWrapper): + hooks.append(module.register_forward_hook(hook_fn)) + _ = quant_model(input_x) + for hook in hooks: + hook.remove() + + self.unit_test.assertEqual(layer_calls['conv1'], 2, "conv1 should be called twice") + self.unit_test.assertEqual(layer_calls['conv2'], 1, "conv2 should be called once") + + #########################################################################################