From cbbc513bd3f634f3be8d5cee83c8afb94cc9f985 Mon Sep 17 00:00:00 2001 From: VainF <2218880241@qq.com> Date: Wed, 23 Aug 2023 00:31:51 +0800 Subject: [PATCH] Bugfix in Global Pruning --- examples/hf_transformers/prune_vit.py | 2 +- examples/timm_models/timm_beit.py | 1 - examples/timm_models/timm_vit.py | 1 - .../torchvision_global_pruning.py | 8 +- .../torchvision_models/torchvision_pruning.py | 8 +- .../torchvision_pruning_test.py | 323 ++++++++++++++++++ examples/yolov8/readme.md | 2 +- tests/test_pruner.py | 5 +- torch_pruning/dependency.py | 14 +- torch_pruning/ops.py | 5 +- torch_pruning/pruner/algorithms/metapruner.py | 174 ++++++---- torch_pruning/pruner/function.py | 15 +- 12 files changed, 469 insertions(+), 89 deletions(-) create mode 100644 examples/torchvision_models/torchvision_pruning_test.py diff --git a/examples/hf_transformers/prune_vit.py b/examples/hf_transformers/prune_vit.py index 04888a0..0afeef9 100644 --- a/examples/hf_transformers/prune_vit.py +++ b/examples/hf_transformers/prune_vit.py @@ -34,7 +34,7 @@ global_pruning=False, # If False, a uniform sparsity will be assigned to different layers. importance=imp, # importance criterion for parameter selection iterative_steps=1, # the number of iterations to achieve target sparsity - ch_sparsity=0.2, + ch_sparsity=0.5, channel_groups=channel_groups, output_transform=lambda out: out.logits.sum(), ignored_layers=[model.classifier], diff --git a/examples/timm_models/timm_beit.py b/examples/timm_models/timm_beit.py index de86784..b6d063c 100644 --- a/examples/timm_models/timm_beit.py +++ b/examples/timm_models/timm_beit.py @@ -54,7 +54,6 @@ def forward(self, x, shared_rel_pos_bias = None): # torch==1.12.1 timm_models = timm.list_models() -print(timm_models) example_inputs = torch.randn(1,3,224,224) imp = tp.importance.MagnitudeImportance(p=2, group_reduction="mean") prunable_list = [] diff --git a/examples/timm_models/timm_vit.py b/examples/timm_models/timm_vit.py index 9ed6bb1..3f0ff4a 100644 --- a/examples/timm_models/timm_vit.py +++ b/examples/timm_models/timm_vit.py @@ -38,7 +38,6 @@ def timm_attention_forward(self, x): # torch==1.12.1 timm_models = timm.list_models() -print(timm_models) example_inputs = torch.randn(1,3,224,224) imp = tp.importance.MagnitudeImportance(p=2, group_reduction="mean") prunable_list = [] diff --git a/examples/torchvision_models/torchvision_global_pruning.py b/examples/torchvision_models/torchvision_global_pruning.py index bd90d0c..b799726 100644 --- a/examples/torchvision_models/torchvision_global_pruning.py +++ b/examples/torchvision_models/torchvision_global_pruning.py @@ -163,8 +163,11 @@ def my_prune(model, example_inputs, output_transform, model_name): ignored_layers.extend([model.head.classification_head.cls_logits, model.head.regression_head.bbox_reg]) # For ViT: Rounding the number of channels to the nearest multiple of num_heads round_to = None - if isinstance( model, VisionTransformer): round_to = model.encoder.layers[0].num_heads - + channel_groups = {} + if isinstance( model, VisionTransformer): + for m in model.modules(): + if isinstance(m, nn.MultiheadAttention): + channel_groups[m] = m.num_heads ######################################### # (Optional) Register unwrapped nn.Parameters # TP will automatically detect unwrapped parameters and prune the last dim for you by default. @@ -195,6 +198,7 @@ def my_prune(model, example_inputs, output_transform, model_name): round_to=round_to, unwrapped_parameters=unwrapped_parameters, ignored_layers=ignored_layers, + channel_groups=channel_groups, ) diff --git a/examples/torchvision_models/torchvision_pruning.py b/examples/torchvision_models/torchvision_pruning.py index e36197d..2da20b9 100644 --- a/examples/torchvision_models/torchvision_pruning.py +++ b/examples/torchvision_models/torchvision_pruning.py @@ -163,7 +163,12 @@ def my_prune(model, example_inputs, output_transform, model_name): ignored_layers.extend([model.head.classification_head.cls_logits, model.head.regression_head.bbox_reg]) # For ViT: Rounding the number of channels to the nearest multiple of num_heads round_to = None - if isinstance( model, VisionTransformer): round_to = model.encoder.layers[0].num_heads + channel_groups = {} + if isinstance( model, VisionTransformer): + for m in model.modules(): + if isinstance(m, nn.MultiheadAttention): + channel_groups[m] = m.num_heads + #round_to = model.encoder.layers[0].num_heads ######################################### # (Optional) Register unwrapped nn.Parameters @@ -195,6 +200,7 @@ def my_prune(model, example_inputs, output_transform, model_name): round_to=round_to, unwrapped_parameters=unwrapped_parameters, ignored_layers=ignored_layers, + channel_groups=channel_groups, ) diff --git a/examples/torchvision_models/torchvision_pruning_test.py b/examples/torchvision_models/torchvision_pruning_test.py new file mode 100644 index 0000000..5ea2dc0 --- /dev/null +++ b/examples/torchvision_models/torchvision_pruning_test.py @@ -0,0 +1,323 @@ +import os, sys +sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))))) + +from torchvision.models.resnet import ( + resnext50_32x4d, + resnext101_32x8d, +) + +# torchvision==0.13.1 +from torchvision.models.vision_transformer import ( + vit_b_16, + vit_b_32, + vit_l_16, + vit_l_32, + vit_h_14, +) +########################################### +# Prunable Models +############################################ +from torchvision.models.detection.ssdlite import ssdlite320_mobilenet_v3_large +from torchvision.models.detection.ssd import ssd300_vgg16 +from torchvision.models.detection.faster_rcnn import ( + fasterrcnn_resnet50_fpn, + fasterrcnn_resnet50_fpn_v2, + fasterrcnn_mobilenet_v3_large_320_fpn, + fasterrcnn_mobilenet_v3_large_fpn +) +from torchvision.models.detection.fcos import fcos_resnet50_fpn +from torchvision.models.detection.keypoint_rcnn import keypointrcnn_resnet50_fpn +from torchvision.models.detection.mask_rcnn import maskrcnn_resnet50_fpn_v2 +from torchvision.models.detection.retinanet import retinanet_resnet50_fpn_v2 +from torchvision.models.alexnet import alexnet + +from torchvision.models.vision_transformer import ( + vit_b_16, + vit_b_32, + vit_l_16, + vit_l_32, + vit_h_14, +) + +from torchvision.models.convnext import ( + convnext_tiny, + convnext_small, + convnext_base, + convnext_large, +) + +from torchvision.models.densenet import ( + densenet121, + densenet169, + densenet201, + densenet161, +) +from torchvision.models.efficientnet import ( + efficientnet_b0, + efficientnet_b1, + efficientnet_b2, + efficientnet_b3, + efficientnet_b4, + efficientnet_b5, + efficientnet_b6, + efficientnet_b7, + efficientnet_v2_s, + efficientnet_v2_m, + efficientnet_v2_l, +) +from torchvision.models.googlenet import googlenet +from torchvision.models.inception import inception_v3 +from torchvision.models.mnasnet import mnasnet0_5, mnasnet0_75, mnasnet1_0, mnasnet1_3 +from torchvision.models.mobilenetv2 import mobilenet_v2 +from torchvision.models.mobilenetv3 import mobilenet_v3_large, mobilenet_v3_small +from torchvision.models.regnet import ( + regnet_y_400mf, + regnet_y_800mf, + regnet_y_1_6gf, + regnet_y_3_2gf, + regnet_y_8gf, + regnet_y_16gf, + regnet_y_32gf, + regnet_y_128gf, +) +from torchvision.models.resnet import ( + resnet18, + resnet34, + resnet50, + resnet101, + resnet152, + resnext50_32x4d, + resnext101_32x8d, + wide_resnet50_2, + wide_resnet101_2, +) +from torchvision.models.segmentation import ( + fcn_resnet50, + fcn_resnet101, + deeplabv3_resnet50, + deeplabv3_resnet101, + deeplabv3_mobilenet_v3_large, + lraspp_mobilenet_v3_large, +) +from torchvision.models.squeezenet import squeezenet1_0, squeezenet1_1 +from torchvision.models.vgg import ( + vgg11, + vgg13, + vgg16, + vgg19, + vgg11_bn, + vgg13_bn, + vgg16_bn, + vgg19_bn, +) + + + +########################################### +# Failue cases in this script +############################################ +from torchvision.models.optical_flow import raft_large +from torchvision.models.swin_transformer import swin_t, swin_s, swin_b # TODO: support Swin ops +from torchvision.models.shufflenetv2 import ( # TODO: support channel shuffling + shufflenet_v2_x0_5, + shufflenet_v2_x1_0, + shufflenet_v2_x1_5, + shufflenet_v2_x2_0, +) + + +if __name__ == "__main__": + + entries = globals().copy() + + import torch + import torch.nn as nn + import torch_pruning as tp + import random + + def my_prune(model, example_inputs, output_transform, model_name): + + from torchvision.models.vision_transformer import VisionTransformer + from torchvision.models.convnext import CNBlock, ConvNeXt + + device = 'cuda' if torch.cuda.is_available() else 'cpu' + ori_size = tp.utils.count_params(model) + model.cpu().eval() + ignored_layers = [] + for p in model.parameters(): + p.requires_grad_(True) + ######################################### + # Ignore unprunable modules + ######################################### + for m in model.modules(): + if isinstance(m, nn.Linear) and m.out_features == 1000: + ignored_layers.append(m) + #elif isinstance(m, nn.modules.linear.NonDynamicallyQuantizableLinear): + # ignored_layers.append(m) # this module is used in Self-Attention + if 'ssd' in model_name: + ignored_layers.append(model.head) + if model_name=='raft_large': + ignored_layers.extend( + [model.corr_block, model.update_block, model.mask_predictor] + ) + if 'fasterrcnn' in model_name: + ignored_layers.extend([ + model.rpn.head.cls_logits, model.rpn.head.bbox_pred, model.backbone.fpn, model.roi_heads + ]) + if model_name=='fcos_resnet50_fpn': + ignored_layers.extend([model.head.classification_head.cls_logits, model.head.regression_head.bbox_reg, model.head.regression_head.bbox_ctrness]) + if model_name=='keypointrcnn_resnet50_fpn': + ignored_layers.extend([model.rpn.head.cls_logits, model.backbone.fpn.layer_blocks, model.rpn.head.bbox_pred, model.roi_heads.box_head, model.roi_heads.box_predictor, model.roi_heads.keypoint_predictor]) + if model_name=='maskrcnn_resnet50_fpn_v2': + ignored_layers.extend([model.rpn.head.cls_logits, model.rpn.head.bbox_pred, model.roi_heads.box_predictor, model.roi_heads.mask_predictor]) + if model_name=='retinanet_resnet50_fpn_v2': + ignored_layers.extend([model.head.classification_head.cls_logits, model.head.regression_head.bbox_reg]) + # For ViT: Rounding the number of channels to the nearest multiple of num_heads + round_to = None + #if isinstance( model, VisionTransformer): round_to = model.encoder.layers[0].num_heads + channel_groups = {} + if isinstance( model, VisionTransformer): + for m in model.modules(): + if isinstance(m, nn.MultiheadAttention): + channel_groups[m] = m.num_heads + + ######################################### + # (Optional) Register unwrapped nn.Parameters + # TP will automatically detect unwrapped parameters and prune the last dim for you by default. + # If you want to prune other dims, you can register them here. + ######################################### + unwrapped_parameters = None + #if model_name=='ssd300_vgg16': + # unwrapped_parameters=[ (model.backbone.scale_weight, 0) ] # pruning the 0-th dim of scale_weight + #if isinstance( model, VisionTransformer): + # unwrapped_parameters = [ (model.class_token, 0), (model.encoder.pos_embedding, 0)] + #elif isinstance(model, ConvNeXt): + # unwrapped_parameters = [] + # for m in model.modules(): + # if isinstance(m, CNBlock): + # unwrapped_parameters.append( (m.layer_scale, 0) ) + + ######################################### + # Build network pruners + ######################################### + importance = tp.importance.MagnitudeImportance(p=1) + ch_sparsity = 0.2 + pruner = tp.pruner.MagnitudePruner( + model, + example_inputs=example_inputs, + importance=importance, + iterative_steps=1, + ch_sparsity=ch_sparsity, + global_pruning=False, + round_to=round_to, + unwrapped_parameters=unwrapped_parameters, + ignored_layers=ignored_layers, + channel_groups=channel_groups, + ) + + + ######################################### + # Pruning + ######################################### + print("==============Before pruning=================") + print("Model Name: {}".format(model_name)) + print(model) + + layer_channel_cfg = {} + for module in model.modules(): + if module not in pruner.ignored_layers: + #print(module) + if isinstance(module, nn.Conv2d): + layer_channel_cfg[module] = module.out_channels + elif isinstance(module, nn.Linear): + layer_channel_cfg[module] = module.out_features + + pruner.step() + if isinstance( + model, VisionTransformer + ): # Torchvision relies on the hidden_dim variable for forwarding, so we have to modify this varaible after pruning + model.hidden_dim = model.conv_proj.out_channels + print(model.class_token.shape, model.encoder.pos_embedding.shape) + print("==============After pruning=================") + print(model) + + ######################################### + # Testing + ######################################### + with torch.no_grad(): + if isinstance(example_inputs, dict): + out = model(**example_inputs) + else: + out = model(example_inputs) + if output_transform: + out = output_transform(out) + print("{} Pruning: ".format(model_name)) + params_after_prune = tp.utils.count_params(model) + print(" Params: %s => %s" % (ori_size, params_after_prune)) + + + if isinstance(out, (dict,list,tuple)): + print(" Output:") + for o in tp.utils.flatten_as_list(out): + print(o.shape) + else: + print(" Output:", out.shape) + print("------------------------------------------------------\n") + + successful = [] + unsuccessful = [] + for model_name, entry in entries.items(): + if 'swin' in model_name.lower() or 'raft' in model_name.lower() or 'shufflenet' in model_name.lower(): # stuck + unsuccessful.append(model_name) + continue + + if not callable(entry): + continue + if "inception" in model_name: + example_inputs = torch.randn(1, 3, 299, 299) + elif "raft" in model_name: + example_inputs = { + "image1": torch.randn(1, 3, 224, 224), + "image2": torch.randn(1, 3, 224, 224), + } + elif 'fasterrcnn' in model_name: + example_inputs = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] + else: + example_inputs = torch.randn(1, 3, 224, 224) + + if "googlenet" in model_name or "inception" in model_name: + model = entry(aux_logits=False) + elif "fcn" in model_name or "deeplabv3" in model_name: + model = entry(aux_loss=None) + elif 'fasterrcnn' in model_name: + model = entry(weights_backbone=None, trainable_backbone_layers=5) # TP does not support FrozenBN. + elif 'fcos' in model_name: + model = entry(weights_backbone=None, trainable_backbone_layers=5) # TP does not support FrozenBN. + elif 'rcnn' in model_name: + model = entry(weights=None, weights_backbone=None, trainable_backbone_layers=5) # TP does not support FrozenBN. + else: + model = entry() + + if "fcn" in model_name or "deeplabv3" in model_name: + output_transform = lambda x: x["out"] + else: + output_transform = None + + #try: + my_prune( + model, example_inputs=example_inputs, output_transform=output_transform, model_name=model_name + ) + successful.append(model_name) + #except Exception as e: + # print(e) + # unsuccessful.append(model_name) + print("Successful Pruning: %d Models\n"%(len(successful)), successful) + print("") + print("Unsuccessful Pruning: %d Models\n"%(len(unsuccessful)), unsuccessful) + sys.stdout.flush() + +print("Finished!") + +print("Successful Pruning: %d Models\n"%(len(successful)), successful) +print("") +print("Unsuccessful Pruning: %d Models\n"%(len(unsuccessful)), unsuccessful) \ No newline at end of file diff --git a/examples/yolov8/readme.md b/examples/yolov8/readme.md index 76cefcd..1241b06 100644 --- a/examples/yolov8/readme.md +++ b/examples/yolov8/readme.md @@ -37,7 +37,7 @@ This is modified to save the model with full precision because changing model to YOLO v8 replaces saved checkpoint file to half precision after training is done using ```strip_optimizer```. Half precision saving is changed with same reason above. #### Training -``` +```bash # This example will craft yolov8-half and fine-tune it on the coco128 toy set. python yolov8_pruning.py ``` diff --git a/tests/test_pruner.py b/tests/test_pruner.py index 9e839ae..4ca4a73 100644 --- a/tests/test_pruner.py +++ b/tests/test_pruner.py @@ -3,7 +3,7 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) import torch -from torchvision.models import resnet18 as entry +from torchvision.models import resnet50 as entry import torch_pruning as tp from torch import nn import torch.nn.functional as F @@ -21,11 +21,12 @@ def test_pruner(): if isinstance(m, torch.nn.Linear) and m.out_features == 1000: ignored_layers.append(m) - iterative_steps = 5 + iterative_steps = 1 pruner = tp.pruner.MagnitudePruner( model, example_inputs, importance=imp, + global_pruning=True, iterative_steps=iterative_steps, ch_sparsity=0.5, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256} ignored_layers=ignored_layers, diff --git a/torch_pruning/dependency.py b/torch_pruning/dependency.py index ef94e0c..3f1cc8d 100644 --- a/torch_pruning/dependency.py +++ b/torch_pruning/dependency.py @@ -84,7 +84,7 @@ def __init__( source: Node, target: Node, ): - """Layer dependency (Edge of DepGraph) in structral neural network pruning. + """Layer dependency (Edge of DepGraph). Args: trigger (Callable): a pruning function that triggers this dependency handler (Callable): a pruning function that can fix the broken dependency @@ -143,19 +143,19 @@ def __hash__(self): class Group(object): """A group that contains dependencies and pruning indices. - Each element is defined as a namedtuple('_helpers.GroupItem', ['dep', 'idxs']). - A group is a iterable List just like - [ [Dep1, Indices1], [Dep2, Indices2], ..., [DepK, IndicesK] ] + Each element is defined as as namedtuple('_helpers.GroupItem'. + + group := [ (Dep1, Indices1), (Dep2, Indices2), ..., (DepK, IndicesK) ] """ def __init__(self): self._group = list() - self._DG = None # link to the DependencyGraph that produces this group. Will be filled by DependencyGraph.get_pruning_group. + self._DG = None # the dependency graph that this group belongs to def prune(self, idxs=None, record_history=True): """Prune all coupled layers in the group """ - if idxs is not None: # prune the group with the specified indices + if idxs is not None: # prune the group with user-specified indices module = self._group[0].dep.target.module pruning_fn = self._group[0].dep.handler new_group = self._DG.get_pruning_group(module, pruning_fn, idxs) # create a new group with the specified indices @@ -281,7 +281,7 @@ def __init__(self): # cache pruning functions for fast lookup self._in_channel_pruning_fn = set([p.prune_in_channels for p in self.REGISTERED_PRUNERS.values() if p is not None] + [p.prune_in_channels for p in self.CUSTOMIZED_PRUNERS.values() if p is not None]) self._out_channel_pruning_fn = set([p.prune_out_channels for p in self.REGISTERED_PRUNERS.values() if p is not None] + [p.prune_out_channels for p in self.CUSTOMIZED_PRUNERS.values() if p is not None]) - self._op_id = 0 # operatior id + self._op_id = 0 # operatior id, will be increased by 1 for each new operator # Pruning History self._pruning_history = [] diff --git a/torch_pruning/ops.py b/torch_pruning/ops.py index c592dd2..3befdca 100644 --- a/torch_pruning/ops.py +++ b/torch_pruning/ops.py @@ -80,7 +80,10 @@ def get_out_channels(self, layer): def get_in_channels(self, layer): return None - def get_channel_groups(self, layer): + def get_in_channel_groups(self, layer): + return 1 + + def get_out_channel_groups(self, layer): return 1 diff --git a/torch_pruning/pruner/algorithms/metapruner.py b/torch_pruning/pruner/algorithms/metapruner.py index e0a7042..31ea047 100644 --- a/torch_pruning/pruner/algorithms/metapruner.py +++ b/torch_pruning/pruner/algorithms/metapruner.py @@ -1,6 +1,6 @@ import torch import torch.nn as nn -import typing +import typing, warnings from .scheduler import linear_scheduler from ..import function @@ -50,12 +50,16 @@ def __init__( round_to: int = None, # round channels to the nearest multiple of round_to # Advanced - channel_groups: typing.Dict[nn.Module, int] = dict(), # channel groups for layers like group convs & group norms + in_channel_groups: typing.Dict[nn.Module, int] = dict(), # channel groups for layers inputs + out_channel_groups: typing.Dict[nn.Module, int] = dict(), # channel groups for layer outputs customized_pruners: typing.Dict[typing.Any, function.BasePruningFunc] = None, # pruners for customized layers. E.g., {nn.Linear: my_linear_pruner} unwrapped_parameters: typing.Dict[nn.Parameter, int] = None, # unwrapped nn.Parameters & pruning_dims. For example, {ViT.pos_emb: 0} root_module_types: typing.List = [ops.TORCH_CONV, ops.TORCH_LINEAR, ops.TORCH_LSTM], # root module for each group forward_fn: typing.Callable = None, # a function to execute model.forward output_transform: typing.Callable = None, # a function to transform network outputs + + # deprecated + channel_groups: typing.Dict[nn.Module, int] = dict(), # channel groups for layers ): self.model = model self.importance = importance @@ -63,10 +67,18 @@ def __init__( self.ch_sparsity_dict = ch_sparsity_dict if ch_sparsity_dict is not None else {} self.max_ch_sparsity = max_ch_sparsity self.global_pruning = global_pruning - self.channel_groups = channel_groups + + if len(channel_groups) > 0: + warnings.warn("channel_groups is deprecated. Please use in_channel_groups and out_channel_groups instead.") + out_channel_groups.update(channel_groups) + + self.in_channel_groups = in_channel_groups + self.out_channel_groups = out_channel_groups + self.root_module_types = root_module_types self.round_to = round_to + ############################################### # Build dependency graph self.DG = dependency.DependencyGraph().build_dependency( model, @@ -77,33 +89,27 @@ def __init__( customized_pruners=customized_pruners, ) - # Ignored layers + ############################################### + # Ignored layers and submodules self.ignored_layers = [] - if ignored_layers: + if ignored_layers is not None: for layer in ignored_layers: self.ignored_layers.extend(list(layer.modules())) + ############################################### # Iterative pruning # The pruner will prune the model iteratively for several steps to achieve the target sparsity # E.g., if iterative_steps=5, ch_sparsity=0.5, the sparsity of each step will be [0.1, 0.2, 0.3, 0.4, 0.5] self.iterative_steps = iterative_steps self.iterative_sparsity_scheduler = iterative_sparsity_scheduler self.current_step = 0 - - # initial channels/dims for each layer - self.layer_init_out_ch = {} - self.layer_init_in_ch = {} - for m in self.DG.module2node.keys(): - if ops.module2type(m) in self.DG.REGISTERED_PRUNERS: - self.layer_init_out_ch[m] = self.DG.get_out_channels(m) - self.layer_init_in_ch[m] = self.DG.get_in_channels(m) - # channel sparsity for each iterative step self.per_step_ch_sparsity = self.iterative_sparsity_scheduler( self.ch_sparsity, self.iterative_steps ) - # The layer-specific sparsity will cover the global sparsity if specified + ############################################### + # Layer-specific sparsity. Will cover the global sparsity if specified self.ch_sparsity_dict = {} if ch_sparsity_dict is not None: for module in ch_sparsity_dict: @@ -115,26 +121,51 @@ def __init__( self.ch_sparsity_dict[submodule] = self.iterative_sparsity_scheduler( sparsity, self.iterative_steps ) - - # detect group convs & group norms + + ############################################### + # Detect group convs & group norms for m in self.model.modules(): layer_pruner = self.DG.get_pruner_of_module(m) - channel_groups = layer_pruner.get_channel_groups(m) - if channel_groups > 1: - if isinstance(m, ops.TORCH_CONV) and m.groups == m.out_channels: + in_ch_group = layer_pruner.get_in_channel_groups(m) + out_ch_group = layer_pruner.get_out_channel_groups(m) + if isinstance(m, ops.TORCH_CONV) and m.groups == m.out_channels: continue - self.channel_groups[m] = channel_groups + if in_ch_group > 1: + self.in_channel_groups[m] = in_ch_group + if out_ch_group > 1: + self.out_channel_groups[m] = out_ch_group + + ############################################### + # Initial channels/dims of each layer + self.layer_init_out_ch = {} + self.layer_init_in_ch = {} + for m in self.DG.module2node.keys(): + if ops.module2type(m) in self.DG.REGISTERED_PRUNERS: + self.layer_init_out_ch[m] = self.DG.get_out_channels(m) + self.layer_init_in_ch[m] = self.DG.get_in_channels(m) - # count the number of total channels at initialization + ############################################### + # Count the number of total channels at initialization if self.global_pruning: initial_total_channels = 0 for group in self.DG.get_all_groups(ignored_layers=self.ignored_layers, root_module_types=self.root_module_types): - #ch_groups = self.get_channel_groups(group) group = self._downstream_node_as_root_if_unbind(group) - # utils.count_prunable_out_channels( group[0][0].target.module ) - initial_total_channels += (self.DG.get_out_channels(group[0][0].target.module) ) + initial_total_channels += ( (self.DG.get_out_channels(group[0][0].target.module) ) // self._get_channel_groups(group) ) self.initial_total_channels = initial_total_channels + def step(self, interactive=False)-> typing.Union[typing.Generator, None]: + self.current_step += 1 + pruning_method = self.prune_global if self.global_pruning else self.prune_local + + if interactive: # yield groups for interactive pruning + return pruning_method() + else: + for group in pruning_method(): + group.prune() + + def estimate_importance(self, group, ch_groups=1) -> torch.Tensor: + return self.importance(group, ch_groups=ch_groups) + def pruning_history(self) -> typing.List[typing.Tuple[str, bool, typing.Union[list, tuple]]]: return self.DG.pruning_history() @@ -149,23 +180,10 @@ def reset(self) -> None: self.current_step = 0 def regularize(self, model, loss) -> typing.Any: - """ Model regularizor + """ Model regularizor for sparse training """ pass - def step(self, interactive=False)-> typing.Union[typing.Generator, None]: - self.current_step += 1 - pruning_method = self.prune_global if self.global_pruning else self.prune_local - - if interactive: # yield groups for interactive pruning - return pruning_method() - else: - for group in pruning_method(): - group.prune() - - def estimate_importance(self, group, ch_groups=1) -> torch.Tensor: - return self.importance(group, ch_groups=ch_groups) - def _check_sparsity(self, group) -> bool: for dep, _ in group: module = dep.target.module @@ -189,21 +207,23 @@ def _check_sparsity(self, group) -> bool: return False return True - def get_channel_groups(self, group) -> int: - if isinstance(self.channel_groups, int): - return self.channel_groups + def _get_channel_groups(self, group) -> int: ch_groups = 1 has_unbind = False unbind_node = None + for dep, _ in group: module = dep.target.module - if module in self.channel_groups: - if self.DG.is_in_channel_pruning_fn(dep.handler) and not isinstance(module, (ops.TORCH_CONV, ops.TORCH_GROUPNORM)): - continue - ch_groups = self.channel_groups[module] + pruning_fn = dep.handler + channel_groups = self.out_channel_groups if self.DG.is_out_channel_pruning_fn(pruning_fn) else self.in_channel_groups + + if module in channel_groups: + ch_groups = channel_groups[module] + if dep.source.type==ops.OPTYPE.UNBIND: has_unbind = True unbind_node = dep.source + if has_unbind and ch_groups>1: ch_groups = ch_groups // len(unbind_node.outputs) return ch_groups # no channel grouping @@ -221,17 +241,24 @@ def _downstream_node_as_root_if_unbind(self, group): group = self.DG.get_pruning_group(downstream_dep.target.module, downstream_dep.handler, _idxs) return group + def _round_to(self, n_pruned, current_channels, round_to): + rounded_channels = current_channels - n_pruned + rounded_channels = rounded_channels + (round_to - rounded_channels % round_to) + n_pruned = current_channels - rounded_channels + return n_pruned + def prune_local(self) -> typing.Generator: if self.current_step > self.iterative_steps: + warnings.warn("Pruning exceed the maximum iterative steps, no pruning will be performed.") return for group in self.DG.get_all_groups(ignored_layers=self.ignored_layers, root_module_types=self.root_module_types): if self._check_sparsity(group): # check pruning ratio - + group = self._downstream_node_as_root_if_unbind(group) module = group[0][0].target.module pruning_fn = group[0][0].handler - ch_groups = self.get_channel_groups(group) + ch_groups = self._get_channel_groups(group) imp = self.estimate_importance(group, ch_groups=ch_groups) if imp is None: continue @@ -251,16 +278,14 @@ def prune_local(self) -> typing.Generator: (1 - target_sparsity) ) + # round to the nearest multiple of round_to if self.round_to: - rounded_channels = current_channels - n_pruned - # round to the nearest multiple of round_to - rounded_channels = rounded_channels - \ - (rounded_channels % self.round_to) - n_pruned = current_channels - rounded_channels - + n_pruned = self._round_to(n_pruned, current_channels, self.round_to) + if n_pruned <= 0: continue - if ch_groups > 1: # independent pruning for each channel group + + if ch_groups > 1: # independent pruning for each group group_size = current_channels // ch_groups pruning_idxs = [] n_pruned_per_group = n_pruned // ch_groups # max(1, n_pruned // ch_groups) @@ -275,7 +300,6 @@ def prune_local(self) -> typing.Generator: imp_argsort = torch.argsort(imp) pruning_idxs = imp_argsort[:n_pruned] - group = self.DG.get_pruning_group( module, pruning_fn, pruning_idxs.tolist()) @@ -284,42 +308,53 @@ def prune_local(self) -> typing.Generator: def prune_global(self) -> typing.Generator: if self.current_step > self.iterative_steps: + warnings.warn("Pruning exceed the maximum iterative steps, no pruning will be performed.") return + + # Pre-compute importance for each group global_importance = [] for group in self.DG.get_all_groups(ignored_layers=self.ignored_layers, root_module_types=self.root_module_types): if self._check_sparsity(group): group = self._downstream_node_as_root_if_unbind(group) - ch_groups = self.get_channel_groups(group) + ch_groups = self._get_channel_groups(group) imp = self.estimate_importance(group, ch_groups=ch_groups) if imp is None: continue if ch_groups > 1: - imp = imp.view(ch_groups, -1).mean(dim=0) # average importance across groups + imp = imp.view(ch_groups, -1).mean(dim=0) # average importance across groups. TODO: find a better way to handle grouped channels global_importance.append((group, ch_groups, imp)) - if len(global_importance) == 0: return - - imp = torch.cat([local_imp[-1] - for local_imp in global_importance], dim=0) + + # Find the threshold for global pruning + imp = torch.cat([local_imp[-1] for local_imp in global_importance], dim=0) target_sparsity = self.per_step_ch_sparsity[self.current_step] n_pruned = len(imp) - int( self.initial_total_channels * (1 - target_sparsity) ) + if n_pruned <= 0: return - topk_imp, _ = torch.topk(imp, k=n_pruned, largest=False) - thres = topk_imp[-1] # global pruning through thresholding + thres = topk_imp[-1] + # Group-by-group pruning for group, ch_groups, imp in global_importance: module = group[0][0].target.module pruning_fn = group[0][0].handler + get_channel_fn = self.DG.get_out_channels if self.DG.is_out_channel_pruning_fn(pruning_fn) else self.DG.get_in_channels pruning_indices = (imp <= thres).nonzero().view(-1) if ch_groups > 1: # re-compute importance for each channel group if channel grouping is enabled n_pruned_per_group = len(pruning_indices) if n_pruned_per_group == 0: continue # skip + + if self.round_to: + n_pruned = n_pruned_per_group * ch_groups + current_channels = get_channel_fn(module) + n_pruned = self._round_to(n_pruned, current_channels, self.round_to) + n_pruned_per_group = n_pruned // ch_groups + imp = self.estimate_importance(group, ch_groups=ch_groups) # re-compute importance group_size = len(imp) // ch_groups pruning_indices = [] @@ -329,11 +364,12 @@ def prune_global(self) -> typing.Generator: sub_pruning_idxs = sub_imp_argsort[:n_pruned_per_group]+chg*group_size pruning_indices.append(sub_pruning_idxs) pruning_indices = torch.cat(pruning_indices, 0) - - if self.round_to: # round to the nearest multiple of round_to - n_pruned = len(pruning_indices) - n_pruned = n_pruned - (n_pruned % self.round_to) - pruning_indices = pruning_indices[:n_pruned] + else: + if self.round_to: + n_pruned = len(pruning_indices) + current_channels = get_channel_fn(module) + n_pruned = self._round_to(n_pruned, current_channels, self.round_to) + pruning_indices = pruning_indices[:n_pruned] group = self.DG.get_pruning_group( module, pruning_fn, pruning_indices.tolist()) diff --git a/torch_pruning/pruner/function.py b/torch_pruning/pruner/function.py index 10b3965..394791c 100644 --- a/torch_pruning/pruner/function.py +++ b/torch_pruning/pruner/function.py @@ -101,7 +101,10 @@ def __call__(self, layer: nn.Module, idxs: Sequence[int], to_output: bool = True layer = pruning_fn(layer, idxs) return layer - def get_channel_groups(self, layer): + def get_in_channel_groups(self, layer): + return 1 + + def get_out_channel_groups(self, layer): return 1 def _prune_parameter_and_grad(self, weight, keep_idxs, pruning_dim): @@ -146,7 +149,10 @@ def get_out_channels(self, layer): def get_in_channels(self, layer): return layer.in_channels - def get_channel_groups(self, layer): + def get_in_channel_groups(self, layer): + return layer.groups + + def get_out_channel_groups(self, layer): return layer.groups @@ -275,7 +281,10 @@ def get_out_channels(self, layer): def get_in_channels(self, layer): return layer.num_channels - def get_channel_groups(self, layer): + def get_in_channel_groups(self, layer): + return layer.num_groups + + def get_out_channel_groups(self, layer): return layer.num_groups class InstanceNormPruner(BasePruningFunc):