From 6d0e16c15cab51fb8fc5a48c3d6e273d1f35430a Mon Sep 17 00:00:00 2001 From: VainF <2218880241@qq.com> Date: Sat, 22 Jul 2023 17:28:01 +0800 Subject: [PATCH 01/13] Fixed a bug in BNScalePruner --- .../pruner/algorithms/batchnorm_scale_pruner.py | 6 +++--- torch_pruning/pruner/algorithms/growing_reg_pruner.py | 6 ++---- torch_pruning/pruner/importance.py | 10 +--------- 3 files changed, 6 insertions(+), 16 deletions(-) diff --git a/torch_pruning/pruner/algorithms/batchnorm_scale_pruner.py b/torch_pruning/pruner/algorithms/batchnorm_scale_pruner.py index dd28a2d..9cd18dc 100644 --- a/torch_pruning/pruner/algorithms/batchnorm_scale_pruner.py +++ b/torch_pruning/pruner/algorithms/batchnorm_scale_pruner.py @@ -46,7 +46,7 @@ def __init__( ) self.reg = reg self._groups = list(self.DG.get_all_groups()) - self.group_lasso = True + self.group_lasso = group_lasso if self.group_lasso: self._l2_imp = MagnitudeImportance(p=2, group_reduction='mean', normalizer=None, target_types=[nn.modules.batchnorm._BatchNorm]) @@ -60,10 +60,10 @@ def regularize(self, model, reg=None): m.weight.grad.data.add_(reg*torch.sign(m.weight.data)) else: for group in self._groups: - group_l2norm_sq, group_size = self._l2_imp(group, return_group_size=True) + group_l2norm_sq = self._l2_imp(group) if group_l2norm_sq is None: continue for dep, _ in group: layer = dep.layer if isinstance(layer, nn.modules.batchnorm._BatchNorm) and layer.affine==True and layer not in self.ignored_layers: - layer.weight.grad.data.add_(reg * math.sqrt(group_size) * (1 / group_l2norm_sq.sqrt()) * layer.weight.data) # Group Lasso https://tibshirani.su.domains/ftp/sparse-grlasso.pdf \ No newline at end of file + layer.weight.grad.data.add_(reg * (1 / group_l2norm_sq.sqrt()) * layer.weight.data) # Group Lasso https://tibshirani.su.domains/ftp/sparse-grlasso.pdf \ No newline at end of file diff --git a/torch_pruning/pruner/algorithms/growing_reg_pruner.py b/torch_pruning/pruner/algorithms/growing_reg_pruner.py index 272b544..355af9e 100644 --- a/torch_pruning/pruner/algorithms/growing_reg_pruner.py +++ b/torch_pruning/pruner/algorithms/growing_reg_pruner.py @@ -27,7 +27,6 @@ def __init__( customized_pruners=None, unwrapped_parameters=None, output_transform=None, - target_types=[nn.modules.conv._ConvNd, nn.Linear, nn.modules.batchnorm._BatchNorm], ): super(GrowingRegPruner, self).__init__( model=model, @@ -48,7 +47,6 @@ def __init__( self.base_reg = reg self._groups = list(self.DG.get_all_groups()) self.group_lasso = True - self._l2_imp = GroupNormImportance() group_reg = {} for group in self._groups: @@ -58,7 +56,7 @@ def __init__( def update_reg(self): for group in self._groups: - group_l2norm_sq = self._l2_imp(group) + group_l2norm_sq = self.estimate_importance(group) if group_l2norm_sq is None: continue reg = self.group_reg[group] @@ -68,7 +66,7 @@ def update_reg(self): def regularize(self, model): for i, group in enumerate(self._groups): - group_l2norm_sq = self._l2_imp(group) + group_l2norm_sq = self.estimate_importance(group) if group_l2norm_sq is None: continue diff --git a/torch_pruning/pruner/importance.py b/torch_pruning/pruner/importance.py index 32dd832..24efee6 100644 --- a/torch_pruning/pruner/importance.py +++ b/torch_pruning/pruner/importance.py @@ -94,10 +94,9 @@ def _reduce(self, group_imp: typing.List[torch.Tensor], group_idxs: typing.List[ return reduced_imp @torch.no_grad() - def __call__(self, group: Group, ch_groups: int=1, return_group_size=False): + def __call__(self, group: Group, ch_groups: int=1): group_imp = [] group_idxs = [] - group_size = 0 # Iterate over all groups and estimate group importance for i, (dep, idxs) in enumerate(group): layer = dep.layer @@ -117,7 +116,6 @@ def __call__(self, group: Group, ch_groups: int=1, return_group_size=False): else: w = layer.weight.data[idxs].flatten(1) local_imp = w.abs().pow(self.p).sum(1) - group_size += w.shape[1] if ch_groups > 1: local_imp = local_imp.view(ch_groups, -1).sum(0) local_imp = local_imp.repeat(ch_groups) @@ -135,7 +133,6 @@ def __call__(self, group: Group, ch_groups: int=1, return_group_size=False): w = (layer.weight.data).flatten(1) else: w = (layer.weight.data).transpose(0, 1).flatten(1) - group_size += w.shape[1] if ch_groups > 1 and prune_fn == function.prune_conv_in_channels and layer.groups == 1: # non-grouped conv followed by a group conv w = w.view(w.shape[0] // group_imp[0].shape[0], group_imp[0].shape[0], w.shape[1]).transpose(0, 1).flatten(1) @@ -157,7 +154,6 @@ def __call__(self, group: Group, ch_groups: int=1, return_group_size=False): if layer.affine: w = layer.weight.data[idxs] local_imp = w.abs().pow(self.p) - group_size += 1 if ch_groups > 1: local_imp = local_imp.view(ch_groups, -1).sum(0) local_imp = local_imp.repeat(ch_groups) @@ -166,13 +162,9 @@ def __call__(self, group: Group, ch_groups: int=1, return_group_size=False): #elif prune_fn == function.prune_multihead_attention_out_channels: if len(group_imp) == 0: # skip groups without parameterized layers - if return_group_size: - return None, 0 return None group_imp = self._reduce(group_imp, group_idxs) group_imp = self._normalize(group_imp, self.normalizer) - if return_group_size: - return group_imp, group_size return group_imp From 973e8fa8c494e18791273532516f9853a93c396b Mon Sep 17 00:00:00 2001 From: VainF <2218880241@qq.com> Date: Sat, 22 Jul 2023 17:28:26 +0800 Subject: [PATCH 02/13] Growing Regularization for ImageNet --- benchmarks/main_imagenet.py | 36 +- .../cifar10-global-growing_reg-resnet56.txt | 450 ++++++++++++++++++ 2 files changed, 472 insertions(+), 14 deletions(-) diff --git a/benchmarks/main_imagenet.py b/benchmarks/main_imagenet.py index 760ad0d..55d7a42 100644 --- a/benchmarks/main_imagenet.py +++ b/benchmarks/main_imagenet.py @@ -89,6 +89,7 @@ def get_args_parser(add_help=True): parser.add_argument("--target-flops", type=float, default=2.0, help="GFLOPs of pruned model") parser.add_argument("--soft-keeping-ratio", type=float, default=0.0) parser.add_argument("--reg", type=float, default=1e-4) + parser.add_argument("--delta_reg", type=float, default=1e-4) parser.add_argument("--max-ch-sparsity", default=1.0, type=float, help="maximum channel sparsity") parser.add_argument("--sl-epochs", type=int, default=None) parser.add_argument("--sl-resume", type=str, default=None) @@ -131,6 +132,10 @@ def get_pruner(model, example_inputs, args): elif args.method == "group_norm": imp = tp.importance.GroupNormImportance(p=2) pruner_entry = partial(tp.pruner.GroupNormPruner, global_pruning=args.global_pruning) + elif args.method == "group_greg": + sparsity_learning = True + imp = tp.importance.GroupNormImportance(p=2) + pruner_entry = partial(tp.pruner.GrowingRegPruner, reg=args.reg, delta_reg=args.delta_reg, global_pruning=args.global_pruning) elif args.method == "group_sl": sparsity_learning = True imp = tp.importance.GroupNormImportance(p=2) @@ -163,7 +168,7 @@ def get_pruner(model, example_inputs, args): -def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None, regularizer=None, recover=None): +def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None, pruner=None, recover=None): model.train() metric_logger = utils.MetricLogger(delimiter=" ") metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}")) @@ -180,17 +185,17 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, arg optimizer.zero_grad() if scaler is not None: scaler.scale(loss).backward() - if regularizer: + if pruner: scaler.unscale_(optimizer) - regularizer(model) + pruner.regularize(model) #if recover: # recover(model.module) scaler.step(optimizer) scaler.update() else: loss.backward() - if regularizer: - regularizer(model) + if pruner is not None: + pruner.regularize(model) if recover: recover(model.module) if args.clip_grad_norm is not None: @@ -202,7 +207,10 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, arg if epoch < args.lr_warmup_epochs: # Reset ema buffer to keep copying weights during warmup period model_ema.n_averaged.fill_(0) - + + if pruner is not None and isinstance(pruner, tp.pruner.GroupNormPruner): + pruner.update_reg() + acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) batch_size = image.shape[0] metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"]) @@ -381,7 +389,7 @@ def collate_fn(batch): train(model, args.sl_epochs, lr=args.sl_lr, lr_step_size=args.sl_lr_step_size, lr_warmup_epochs=args.sl_lr_warmup_epochs, train_sampler=train_sampler, data_loader=data_loader, data_loader_test=data_loader_test, - device=device, args=args, regularizer=pruner.regularize, state_dict_only=True) + device=device, args=args, pruner=pruner, state_dict_only=True) #model.load_state_dict( torch.load('regularized_{:.4f}_best.pth'.format(args.reg), map_location='cpu')['model'] ) #utils.save_on_master( # model_without_ddp.state_dict(), @@ -403,14 +411,14 @@ def collate_fn(batch): train(model, args.epochs, lr=args.lr, lr_step_size=args.lr_step_size, lr_warmup_epochs=args.lr_warmup_epochs, train_sampler=train_sampler, data_loader=data_loader, data_loader_test=data_loader_test, - device=device, args=args, regularizer=None, state_dict_only=(not args.prune)) + device=device, args=args, pruner=None, state_dict_only=(not args.prune)) def train( model, epochs, lr, lr_step_size, lr_warmup_epochs, train_sampler, data_loader, data_loader_test, - device, args, regularizer=None, state_dict_only=True, recover=None): + device, args, pruner=None, state_dict_only=True, recover=None): model.to(device) if args.distributed and args.sync_bn: @@ -421,9 +429,9 @@ def train( else: criterion = nn.CrossEntropyLoss() - weight_decay = args.weight_decay if regularizer is None else 0 - bias_weight_decay = args.bias_weight_decay if regularizer is None else 0 - norm_weight_decay = args.norm_weight_decay if regularizer is None else 0 + weight_decay = args.weight_decay if pruner is None else 0 + bias_weight_decay = args.bias_weight_decay if pruner is None else 0 + norm_weight_decay = args.norm_weight_decay if pruner is None else 0 custom_keys_weight_decay = [] if bias_weight_decay is not None: @@ -534,11 +542,11 @@ def train( start_time = time.time() best_acc = 0 - prefix = '' if regularizer is None else 'regularized_{:e}_'.format(args.reg) + prefix = '' if pruner is None else 'regularized_{:e}_'.format(args.reg) for epoch in range(args.start_epoch, epochs): if args.distributed: train_sampler.set_epoch(epoch) - train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema, scaler, regularizer, recover=recover) + train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema, scaler, pruner, recover=recover) lr_scheduler.step() acc = evaluate(model, criterion, data_loader_test, device=device) if model_ema: diff --git a/benchmarks/run/cifar10/prune/cifar10-global-growing_reg-resnet56/cifar10-global-growing_reg-resnet56.txt b/benchmarks/run/cifar10/prune/cifar10-global-growing_reg-resnet56/cifar10-global-growing_reg-resnet56.txt index e7b8b33..b1083dd 100644 --- a/benchmarks/run/cifar10/prune/cifar10-global-growing_reg-resnet56/cifar10-global-growing_reg-resnet56.txt +++ b/benchmarks/run/cifar10/prune/cifar10-global-growing_reg-resnet56/cifar10-global-growing_reg-resnet56.txt @@ -478,3 +478,453 @@ [07/20 22:37:24] cifar10-global-growing_reg-resnet56 INFO: Epoch 98/100, Acc=0.9354, Val Loss=0.2627, lr=0.0001 [07/20 22:37:41] cifar10-global-growing_reg-resnet56 INFO: Epoch 99/100, Acc=0.9348, Val Loss=0.2656, lr=0.0001 [07/20 22:37:41] cifar10-global-growing_reg-resnet56 INFO: Best Acc=0.9355 +[07/22 01:25:58] cifar10-global-growing_reg-resnet56 INFO: mode: prune +[07/22 01:25:58] cifar10-global-growing_reg-resnet56 INFO: model: resnet56 +[07/22 01:25:58] cifar10-global-growing_reg-resnet56 INFO: verbose: False +[07/22 01:25:58] cifar10-global-growing_reg-resnet56 INFO: dataset: cifar10 +[07/22 01:25:58] cifar10-global-growing_reg-resnet56 INFO: batch_size: 128 +[07/22 01:25:58] cifar10-global-growing_reg-resnet56 INFO: total_epochs: 100 +[07/22 01:25:58] cifar10-global-growing_reg-resnet56 INFO: lr_decay_milestones: 60,80 +[07/22 01:25:58] cifar10-global-growing_reg-resnet56 INFO: lr_decay_gamma: 0.1 +[07/22 01:25:58] cifar10-global-growing_reg-resnet56 INFO: lr: 0.01 +[07/22 01:25:58] cifar10-global-growing_reg-resnet56 INFO: restore: cifar10_resnet56.pth +[07/22 01:25:58] cifar10-global-growing_reg-resnet56 INFO: output_dir: run/cifar10/prune/cifar10-global-growing_reg-resnet56 +[07/22 01:25:58] cifar10-global-growing_reg-resnet56 INFO: method: growing_reg +[07/22 01:25:58] cifar10-global-growing_reg-resnet56 INFO: speed_up: 2.11 +[07/22 01:25:58] cifar10-global-growing_reg-resnet56 INFO: max_sparsity: 1.0 +[07/22 01:25:58] cifar10-global-growing_reg-resnet56 INFO: soft_keeping_ratio: 0.0 +[07/22 01:25:58] cifar10-global-growing_reg-resnet56 INFO: reg: 0.0001 +[07/22 01:25:58] cifar10-global-growing_reg-resnet56 INFO: delta_reg: 1e-05 +[07/22 01:25:58] cifar10-global-growing_reg-resnet56 INFO: weight_decay: 0.0005 +[07/22 01:25:58] cifar10-global-growing_reg-resnet56 INFO: seed: None +[07/22 01:25:58] cifar10-global-growing_reg-resnet56 INFO: global_pruning: True +[07/22 01:25:58] cifar10-global-growing_reg-resnet56 INFO: sl_total_epochs: 100 +[07/22 01:25:58] cifar10-global-growing_reg-resnet56 INFO: sl_lr: 0.01 +[07/22 01:25:58] cifar10-global-growing_reg-resnet56 INFO: sl_lr_decay_milestones: 60,80 +[07/22 01:25:58] cifar10-global-growing_reg-resnet56 INFO: sl_reg_warmup: 0 +[07/22 01:25:58] cifar10-global-growing_reg-resnet56 INFO: sl_restore: None +[07/22 01:25:58] cifar10-global-growing_reg-resnet56 INFO: iterative_steps: 400 +[07/22 01:25:58] cifar10-global-growing_reg-resnet56 INFO: logger: +[07/22 01:25:58] cifar10-global-growing_reg-resnet56 INFO: device: cuda +[07/22 01:25:58] cifar10-global-growing_reg-resnet56 INFO: num_classes: 10 +[07/22 01:25:58] cifar10-global-growing_reg-resnet56 INFO: Loading model from cifar10_resnet56.pth +[07/22 01:26:01] cifar10-global-growing_reg-resnet56 INFO: Regularizing... +[07/22 01:26:36] cifar10-global-growing_reg-resnet56 INFO: Epoch 0/100, Acc=0.8918, Val Loss=0.4039, lr=0.0100 +[07/22 01:27:11] cifar10-global-growing_reg-resnet56 INFO: Epoch 1/100, Acc=0.9072, Val Loss=0.3326, lr=0.0100 +[07/22 01:27:46] cifar10-global-growing_reg-resnet56 INFO: Epoch 2/100, Acc=0.9120, Val Loss=0.3198, lr=0.0100 +[07/22 01:28:21] cifar10-global-growing_reg-resnet56 INFO: Epoch 3/100, Acc=0.9119, Val Loss=0.3161, lr=0.0100 +[07/22 01:28:55] cifar10-global-growing_reg-resnet56 INFO: Epoch 4/100, Acc=0.9131, Val Loss=0.3105, lr=0.0100 +[07/22 01:29:29] cifar10-global-growing_reg-resnet56 INFO: Epoch 5/100, Acc=0.9079, Val Loss=0.3481, lr=0.0100 +[07/22 01:30:04] cifar10-global-growing_reg-resnet56 INFO: Epoch 6/100, Acc=0.9107, Val Loss=0.3434, lr=0.0100 +[07/22 01:30:38] cifar10-global-growing_reg-resnet56 INFO: Epoch 7/100, Acc=0.8889, Val Loss=0.4145, lr=0.0100 +[07/22 01:31:13] cifar10-global-growing_reg-resnet56 INFO: Epoch 8/100, Acc=0.9129, Val Loss=0.3244, lr=0.0100 +[07/22 01:31:47] cifar10-global-growing_reg-resnet56 INFO: Epoch 9/100, Acc=0.9130, Val Loss=0.3284, lr=0.0100 +[07/22 01:32:21] cifar10-global-growing_reg-resnet56 INFO: Epoch 10/100, Acc=0.9096, Val Loss=0.3294, lr=0.0100 +[07/22 01:32:55] cifar10-global-growing_reg-resnet56 INFO: Epoch 11/100, Acc=0.9094, Val Loss=0.3398, lr=0.0100 +[07/22 01:33:29] cifar10-global-growing_reg-resnet56 INFO: Epoch 12/100, Acc=0.9073, Val Loss=0.3438, lr=0.0100 +[07/22 01:34:03] cifar10-global-growing_reg-resnet56 INFO: Epoch 13/100, Acc=0.9079, Val Loss=0.3467, lr=0.0100 +[07/22 01:34:37] cifar10-global-growing_reg-resnet56 INFO: Epoch 14/100, Acc=0.9202, Val Loss=0.3115, lr=0.0100 +[07/22 01:35:11] cifar10-global-growing_reg-resnet56 INFO: Epoch 15/100, Acc=0.9011, Val Loss=0.3489, lr=0.0100 +[07/22 01:35:46] cifar10-global-growing_reg-resnet56 INFO: Epoch 16/100, Acc=0.9158, Val Loss=0.3149, lr=0.0100 +[07/22 01:36:19] cifar10-global-growing_reg-resnet56 INFO: Epoch 17/100, Acc=0.9097, Val Loss=0.3410, lr=0.0100 +[07/22 01:36:54] cifar10-global-growing_reg-resnet56 INFO: Epoch 18/100, Acc=0.9119, Val Loss=0.3371, lr=0.0100 +[07/22 01:37:28] cifar10-global-growing_reg-resnet56 INFO: Epoch 19/100, Acc=0.9143, Val Loss=0.3335, lr=0.0100 +[07/22 01:38:01] cifar10-global-growing_reg-resnet56 INFO: Epoch 20/100, Acc=0.9135, Val Loss=0.3244, lr=0.0100 +[07/22 01:38:36] cifar10-global-growing_reg-resnet56 INFO: Epoch 21/100, Acc=0.9124, Val Loss=0.3542, lr=0.0100 +[07/22 01:39:10] cifar10-global-growing_reg-resnet56 INFO: Epoch 22/100, Acc=0.9147, Val Loss=0.3223, lr=0.0100 +[07/22 01:39:44] cifar10-global-growing_reg-resnet56 INFO: Epoch 23/100, Acc=0.9085, Val Loss=0.3612, lr=0.0100 +[07/22 01:40:18] cifar10-global-growing_reg-resnet56 INFO: Epoch 24/100, Acc=0.9107, Val Loss=0.3333, lr=0.0100 +[07/22 01:40:52] cifar10-global-growing_reg-resnet56 INFO: Epoch 25/100, Acc=0.9130, Val Loss=0.3433, lr=0.0100 +[07/22 01:41:27] cifar10-global-growing_reg-resnet56 INFO: Epoch 26/100, Acc=0.9061, Val Loss=0.3435, lr=0.0100 +[07/22 01:42:01] cifar10-global-growing_reg-resnet56 INFO: Epoch 27/100, Acc=0.9112, Val Loss=0.3367, lr=0.0100 +[07/22 01:42:35] cifar10-global-growing_reg-resnet56 INFO: Epoch 28/100, Acc=0.9041, Val Loss=0.3696, lr=0.0100 +[07/22 01:43:10] cifar10-global-growing_reg-resnet56 INFO: Epoch 29/100, Acc=0.9064, Val Loss=0.3655, lr=0.0100 +[07/22 01:43:44] cifar10-global-growing_reg-resnet56 INFO: Epoch 30/100, Acc=0.9114, Val Loss=0.3327, lr=0.0100 +[07/22 01:44:18] cifar10-global-growing_reg-resnet56 INFO: Epoch 31/100, Acc=0.9196, Val Loss=0.2976, lr=0.0100 +[07/22 01:44:52] cifar10-global-growing_reg-resnet56 INFO: Epoch 32/100, Acc=0.9131, Val Loss=0.3386, lr=0.0100 +[07/22 01:45:26] cifar10-global-growing_reg-resnet56 INFO: Epoch 33/100, Acc=0.9151, Val Loss=0.3207, lr=0.0100 +[07/22 01:46:01] cifar10-global-growing_reg-resnet56 INFO: Epoch 34/100, Acc=0.9142, Val Loss=0.3258, lr=0.0100 +[07/22 01:46:35] cifar10-global-growing_reg-resnet56 INFO: Epoch 35/100, Acc=0.9146, Val Loss=0.3139, lr=0.0100 +[07/22 01:47:09] cifar10-global-growing_reg-resnet56 INFO: Epoch 36/100, Acc=0.9083, Val Loss=0.3474, lr=0.0100 +[07/22 01:47:43] cifar10-global-growing_reg-resnet56 INFO: Epoch 37/100, Acc=0.9117, Val Loss=0.3453, lr=0.0100 +[07/22 01:48:17] cifar10-global-growing_reg-resnet56 INFO: Epoch 38/100, Acc=0.9028, Val Loss=0.3681, lr=0.0100 +[07/22 01:48:51] cifar10-global-growing_reg-resnet56 INFO: Epoch 39/100, Acc=0.9148, Val Loss=0.3275, lr=0.0100 +[07/22 01:49:25] cifar10-global-growing_reg-resnet56 INFO: Epoch 40/100, Acc=0.9130, Val Loss=0.3342, lr=0.0100 +[07/22 01:49:59] cifar10-global-growing_reg-resnet56 INFO: Epoch 41/100, Acc=0.9126, Val Loss=0.3329, lr=0.0100 +[07/22 01:50:34] cifar10-global-growing_reg-resnet56 INFO: Epoch 42/100, Acc=0.9143, Val Loss=0.3380, lr=0.0100 +[07/22 01:51:08] cifar10-global-growing_reg-resnet56 INFO: Epoch 43/100, Acc=0.9151, Val Loss=0.3400, lr=0.0100 +[07/22 01:51:43] cifar10-global-growing_reg-resnet56 INFO: Epoch 44/100, Acc=0.9157, Val Loss=0.3277, lr=0.0100 +[07/22 01:52:18] cifar10-global-growing_reg-resnet56 INFO: Epoch 45/100, Acc=0.9144, Val Loss=0.3160, lr=0.0100 +[07/22 01:52:52] cifar10-global-growing_reg-resnet56 INFO: Epoch 46/100, Acc=0.9145, Val Loss=0.3295, lr=0.0100 +[07/22 01:53:26] cifar10-global-growing_reg-resnet56 INFO: Epoch 47/100, Acc=0.9086, Val Loss=0.3452, lr=0.0100 +[07/22 01:54:00] cifar10-global-growing_reg-resnet56 INFO: Epoch 48/100, Acc=0.9045, Val Loss=0.4000, lr=0.0100 +[07/22 01:54:34] cifar10-global-growing_reg-resnet56 INFO: Epoch 49/100, Acc=0.9100, Val Loss=0.3247, lr=0.0100 +[07/22 01:55:09] cifar10-global-growing_reg-resnet56 INFO: Epoch 50/100, Acc=0.9142, Val Loss=0.3201, lr=0.0100 +[07/22 01:55:43] cifar10-global-growing_reg-resnet56 INFO: Epoch 51/100, Acc=0.9128, Val Loss=0.3145, lr=0.0100 +[07/22 01:56:17] cifar10-global-growing_reg-resnet56 INFO: Epoch 52/100, Acc=0.9052, Val Loss=0.3385, lr=0.0100 +[07/22 01:56:51] cifar10-global-growing_reg-resnet56 INFO: Epoch 53/100, Acc=0.9061, Val Loss=0.3435, lr=0.0100 +[07/22 01:57:25] cifar10-global-growing_reg-resnet56 INFO: Epoch 54/100, Acc=0.9024, Val Loss=0.3672, lr=0.0100 +[07/22 01:57:59] cifar10-global-growing_reg-resnet56 INFO: Epoch 55/100, Acc=0.8907, Val Loss=0.4212, lr=0.0100 +[07/22 01:58:33] cifar10-global-growing_reg-resnet56 INFO: Epoch 56/100, Acc=0.8909, Val Loss=0.4301, lr=0.0100 +[07/22 01:59:07] cifar10-global-growing_reg-resnet56 INFO: Epoch 57/100, Acc=0.8981, Val Loss=0.3696, lr=0.0100 +[07/22 01:59:41] cifar10-global-growing_reg-resnet56 INFO: Epoch 58/100, Acc=0.9016, Val Loss=0.3592, lr=0.0100 +[07/22 02:00:15] cifar10-global-growing_reg-resnet56 INFO: Epoch 59/100, Acc=0.9140, Val Loss=0.3138, lr=0.0100 +[07/22 02:00:49] cifar10-global-growing_reg-resnet56 INFO: Epoch 60/100, Acc=0.9316, Val Loss=0.2547, lr=0.0010 +[07/22 02:01:23] cifar10-global-growing_reg-resnet56 INFO: Epoch 61/100, Acc=0.9346, Val Loss=0.2533, lr=0.0010 +[07/22 02:01:57] cifar10-global-growing_reg-resnet56 INFO: Epoch 62/100, Acc=0.9328, Val Loss=0.2540, lr=0.0010 +[07/22 02:02:31] cifar10-global-growing_reg-resnet56 INFO: Epoch 63/100, Acc=0.9343, Val Loss=0.2547, lr=0.0010 +[07/22 02:03:05] cifar10-global-growing_reg-resnet56 INFO: Epoch 64/100, Acc=0.9343, Val Loss=0.2612, lr=0.0010 +[07/22 02:03:39] cifar10-global-growing_reg-resnet56 INFO: Epoch 65/100, Acc=0.9335, Val Loss=0.2614, lr=0.0010 +[07/22 02:04:14] cifar10-global-growing_reg-resnet56 INFO: Epoch 66/100, Acc=0.9337, Val Loss=0.2627, lr=0.0010 +[07/22 02:04:48] cifar10-global-growing_reg-resnet56 INFO: Epoch 67/100, Acc=0.9349, Val Loss=0.2614, lr=0.0010 +[07/22 02:05:22] cifar10-global-growing_reg-resnet56 INFO: Epoch 68/100, Acc=0.9339, Val Loss=0.2660, lr=0.0010 +[07/22 02:05:56] cifar10-global-growing_reg-resnet56 INFO: Epoch 69/100, Acc=0.9360, Val Loss=0.2641, lr=0.0010 +[07/22 02:06:30] cifar10-global-growing_reg-resnet56 INFO: Epoch 70/100, Acc=0.9341, Val Loss=0.2697, lr=0.0010 +[07/22 02:07:04] cifar10-global-growing_reg-resnet56 INFO: Epoch 71/100, Acc=0.9354, Val Loss=0.2683, lr=0.0010 +[07/22 02:07:38] cifar10-global-growing_reg-resnet56 INFO: Epoch 72/100, Acc=0.9348, Val Loss=0.2694, lr=0.0010 +[07/22 02:08:13] cifar10-global-growing_reg-resnet56 INFO: Epoch 73/100, Acc=0.9339, Val Loss=0.2704, lr=0.0010 +[07/22 02:08:47] cifar10-global-growing_reg-resnet56 INFO: Epoch 74/100, Acc=0.9353, Val Loss=0.2733, lr=0.0010 +[07/22 02:09:21] cifar10-global-growing_reg-resnet56 INFO: Epoch 75/100, Acc=0.9353, Val Loss=0.2732, lr=0.0010 +[07/22 02:09:56] cifar10-global-growing_reg-resnet56 INFO: Epoch 76/100, Acc=0.9350, Val Loss=0.2737, lr=0.0010 +[07/22 02:10:30] cifar10-global-growing_reg-resnet56 INFO: Epoch 77/100, Acc=0.9346, Val Loss=0.2778, lr=0.0010 +[07/22 02:11:04] cifar10-global-growing_reg-resnet56 INFO: Epoch 78/100, Acc=0.9362, Val Loss=0.2772, lr=0.0010 +[07/22 02:11:38] cifar10-global-growing_reg-resnet56 INFO: Epoch 79/100, Acc=0.9368, Val Loss=0.2765, lr=0.0010 +[07/22 02:12:12] cifar10-global-growing_reg-resnet56 INFO: Epoch 80/100, Acc=0.9358, Val Loss=0.2789, lr=0.0001 +[07/22 02:12:47] cifar10-global-growing_reg-resnet56 INFO: Epoch 81/100, Acc=0.9363, Val Loss=0.2753, lr=0.0001 +[07/22 02:13:21] cifar10-global-growing_reg-resnet56 INFO: Epoch 82/100, Acc=0.9369, Val Loss=0.2759, lr=0.0001 +[07/22 02:13:55] cifar10-global-growing_reg-resnet56 INFO: Epoch 83/100, Acc=0.9363, Val Loss=0.2766, lr=0.0001 +[07/22 02:14:29] cifar10-global-growing_reg-resnet56 INFO: Epoch 84/100, Acc=0.9366, Val Loss=0.2777, lr=0.0001 +[07/22 02:15:03] cifar10-global-growing_reg-resnet56 INFO: Epoch 85/100, Acc=0.9376, Val Loss=0.2757, lr=0.0001 +[07/22 02:15:38] cifar10-global-growing_reg-resnet56 INFO: Epoch 86/100, Acc=0.9370, Val Loss=0.2769, lr=0.0001 +[07/22 02:16:12] cifar10-global-growing_reg-resnet56 INFO: Epoch 87/100, Acc=0.9372, Val Loss=0.2755, lr=0.0001 +[07/22 02:16:46] cifar10-global-growing_reg-resnet56 INFO: Epoch 88/100, Acc=0.9367, Val Loss=0.2766, lr=0.0001 +[07/22 02:17:20] cifar10-global-growing_reg-resnet56 INFO: Epoch 89/100, Acc=0.9365, Val Loss=0.2775, lr=0.0001 +[07/22 02:17:54] cifar10-global-growing_reg-resnet56 INFO: Epoch 90/100, Acc=0.9363, Val Loss=0.2769, lr=0.0001 +[07/22 02:18:28] cifar10-global-growing_reg-resnet56 INFO: Epoch 91/100, Acc=0.9372, Val Loss=0.2778, lr=0.0001 +[07/22 02:19:02] cifar10-global-growing_reg-resnet56 INFO: Epoch 92/100, Acc=0.9373, Val Loss=0.2777, lr=0.0001 +[07/22 02:19:36] cifar10-global-growing_reg-resnet56 INFO: Epoch 93/100, Acc=0.9369, Val Loss=0.2764, lr=0.0001 +[07/22 02:20:10] cifar10-global-growing_reg-resnet56 INFO: Epoch 94/100, Acc=0.9370, Val Loss=0.2784, lr=0.0001 +[07/22 02:20:44] cifar10-global-growing_reg-resnet56 INFO: Epoch 95/100, Acc=0.9381, Val Loss=0.2756, lr=0.0001 +[07/22 02:21:19] cifar10-global-growing_reg-resnet56 INFO: Epoch 96/100, Acc=0.9356, Val Loss=0.2775, lr=0.0001 +[07/22 02:21:53] cifar10-global-growing_reg-resnet56 INFO: Epoch 97/100, Acc=0.9369, Val Loss=0.2773, lr=0.0001 +[07/22 02:22:27] cifar10-global-growing_reg-resnet56 INFO: Epoch 98/100, Acc=0.9375, Val Loss=0.2774, lr=0.0001 +[07/22 02:23:00] cifar10-global-growing_reg-resnet56 INFO: Epoch 99/100, Acc=0.9360, Val Loss=0.2790, lr=0.0001 +[07/22 02:23:01] cifar10-global-growing_reg-resnet56 INFO: Best Acc=0.9381 +[07/22 02:23:01] cifar10-global-growing_reg-resnet56 INFO: Loading the sparse model from run/cifar10/prune/cifar10-global-growing_reg-resnet56/reg_cifar10_resnet56_growing_reg_0.0001.pth... +[07/22 02:23:02] cifar10-global-growing_reg-resnet56 INFO: Pruning... +[07/22 02:23:09] cifar10-global-growing_reg-resnet56 INFO: ResNet( + (conv1): Conv2d(3, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (layer1): Sequential( + (0): BasicBlock( + (conv1): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (1): BasicBlock( + (conv1): Conv2d(8, 7, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(7, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(7, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (2): BasicBlock( + (conv1): Conv2d(8, 6, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(6, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (3): BasicBlock( + (conv1): Conv2d(8, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(10, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (4): BasicBlock( + (conv1): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (5): BasicBlock( + (conv1): Conv2d(8, 13, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(13, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(13, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (6): BasicBlock( + (conv1): Conv2d(8, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(3, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (7): BasicBlock( + (conv1): Conv2d(8, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(10, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (8): BasicBlock( + (conv1): Conv2d(8, 7, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(7, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(7, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + ) + (layer2): Sequential( + (0): BasicBlock( + (conv1): Conv2d(8, 24, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(24, 30, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (downsample): Sequential( + (0): Conv2d(8, 30, kernel_size=(1, 1), stride=(2, 2), bias=False) + (1): BatchNorm2d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + ) + (1): BasicBlock( + (conv1): Conv2d(30, 9, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(9, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(9, 30, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (2): BasicBlock( + (conv1): Conv2d(30, 29, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(29, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(29, 30, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (3): BasicBlock( + (conv1): Conv2d(30, 29, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(29, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(29, 30, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (4): BasicBlock( + (conv1): Conv2d(30, 29, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(29, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(29, 30, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (5): BasicBlock( + (conv1): Conv2d(30, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(8, 30, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (6): BasicBlock( + (conv1): Conv2d(30, 14, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(14, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(14, 30, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (7): BasicBlock( + (conv1): Conv2d(30, 11, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(11, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(11, 30, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (8): BasicBlock( + (conv1): Conv2d(30, 5, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(5, 30, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(30, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + ) + (layer3): Sequential( + (0): BasicBlock( + (conv1): Conv2d(30, 56, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(56, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(56, 47, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(47, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (downsample): Sequential( + (0): Conv2d(30, 47, kernel_size=(1, 1), stride=(2, 2), bias=False) + (1): BatchNorm2d(47, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + ) + (1): BasicBlock( + (conv1): Conv2d(47, 63, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(63, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(63, 47, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(47, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (2): BasicBlock( + (conv1): Conv2d(47, 61, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(61, 47, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(47, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (3): BasicBlock( + (conv1): Conv2d(47, 61, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(61, 47, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(47, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (4): BasicBlock( + (conv1): Conv2d(47, 62, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(62, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(62, 47, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(47, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (5): BasicBlock( + (conv1): Conv2d(47, 62, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(62, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(62, 47, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(47, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (6): BasicBlock( + (conv1): Conv2d(47, 57, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(57, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(57, 47, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(47, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (7): BasicBlock( + (conv1): Conv2d(47, 47, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(47, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(47, 47, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(47, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + (8): BasicBlock( + (conv1): Conv2d(47, 47, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn1): BatchNorm2d(47, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + (relu): ReLU(inplace=True) + (conv2): Conv2d(47, 47, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (bn2): BatchNorm2d(47, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) + ) + ) + (avgpool): AvgPool2d(kernel_size=8, stride=8, padding=0) + (fc): Linear(in_features=47, out_features=10, bias=True) +) +[07/22 02:23:10] cifar10-global-growing_reg-resnet56 INFO: Params: 0.86 M => 0.52 M (61.28%) +[07/22 02:23:10] cifar10-global-growing_reg-resnet56 INFO: FLOPs: 127.12 M => 59.84 M (47.07%, 2.12X ) +[07/22 02:23:10] cifar10-global-growing_reg-resnet56 INFO: Acc: 0.9381 => 0.1119 +[07/22 02:23:10] cifar10-global-growing_reg-resnet56 INFO: Val Loss: 0.2756 => 9.0222 +[07/22 02:23:10] cifar10-global-growing_reg-resnet56 INFO: Finetuning... +[07/22 02:23:28] cifar10-global-growing_reg-resnet56 INFO: Epoch 0/100, Acc=0.8742, Val Loss=0.3933, lr=0.0100 +[07/22 02:23:44] cifar10-global-growing_reg-resnet56 INFO: Epoch 1/100, Acc=0.8852, Val Loss=0.3753, lr=0.0100 +[07/22 02:24:01] cifar10-global-growing_reg-resnet56 INFO: Epoch 2/100, Acc=0.8661, Val Loss=0.4365, lr=0.0100 +[07/22 02:24:17] cifar10-global-growing_reg-resnet56 INFO: Epoch 3/100, Acc=0.8905, Val Loss=0.3537, lr=0.0100 +[07/22 02:24:34] cifar10-global-growing_reg-resnet56 INFO: Epoch 4/100, Acc=0.8797, Val Loss=0.3890, lr=0.0100 +[07/22 02:24:51] cifar10-global-growing_reg-resnet56 INFO: Epoch 5/100, Acc=0.8910, Val Loss=0.3505, lr=0.0100 +[07/22 02:25:10] cifar10-global-growing_reg-resnet56 INFO: Epoch 6/100, Acc=0.8989, Val Loss=0.3354, lr=0.0100 +[07/22 02:25:26] cifar10-global-growing_reg-resnet56 INFO: Epoch 7/100, Acc=0.8974, Val Loss=0.3435, lr=0.0100 +[07/22 02:25:43] cifar10-global-growing_reg-resnet56 INFO: Epoch 8/100, Acc=0.8927, Val Loss=0.3625, lr=0.0100 +[07/22 02:26:01] cifar10-global-growing_reg-resnet56 INFO: Epoch 9/100, Acc=0.8998, Val Loss=0.3302, lr=0.0100 +[07/22 02:26:19] cifar10-global-growing_reg-resnet56 INFO: Epoch 10/100, Acc=0.8933, Val Loss=0.3635, lr=0.0100 +[07/22 02:26:36] cifar10-global-growing_reg-resnet56 INFO: Epoch 11/100, Acc=0.8969, Val Loss=0.3341, lr=0.0100 +[07/22 02:26:55] cifar10-global-growing_reg-resnet56 INFO: Epoch 12/100, Acc=0.8904, Val Loss=0.3709, lr=0.0100 +[07/22 02:27:12] cifar10-global-growing_reg-resnet56 INFO: Epoch 13/100, Acc=0.9003, Val Loss=0.3277, lr=0.0100 +[07/22 02:27:29] cifar10-global-growing_reg-resnet56 INFO: Epoch 14/100, Acc=0.8918, Val Loss=0.3707, lr=0.0100 +[07/22 02:27:46] cifar10-global-growing_reg-resnet56 INFO: Epoch 15/100, Acc=0.9053, Val Loss=0.3068, lr=0.0100 +[07/22 02:28:03] cifar10-global-growing_reg-resnet56 INFO: Epoch 16/100, Acc=0.8918, Val Loss=0.3580, lr=0.0100 +[07/22 02:28:20] cifar10-global-growing_reg-resnet56 INFO: Epoch 17/100, Acc=0.8938, Val Loss=0.3611, lr=0.0100 +[07/22 02:28:37] cifar10-global-growing_reg-resnet56 INFO: Epoch 18/100, Acc=0.8902, Val Loss=0.3977, lr=0.0100 +[07/22 02:28:54] cifar10-global-growing_reg-resnet56 INFO: Epoch 19/100, Acc=0.8994, Val Loss=0.3405, lr=0.0100 +[07/22 02:29:11] cifar10-global-growing_reg-resnet56 INFO: Epoch 20/100, Acc=0.9094, Val Loss=0.3150, lr=0.0100 +[07/22 02:29:29] cifar10-global-growing_reg-resnet56 INFO: Epoch 21/100, Acc=0.8825, Val Loss=0.4182, lr=0.0100 +[07/22 02:29:46] cifar10-global-growing_reg-resnet56 INFO: Epoch 22/100, Acc=0.8907, Val Loss=0.3861, lr=0.0100 +[07/22 02:30:03] cifar10-global-growing_reg-resnet56 INFO: Epoch 23/100, Acc=0.8986, Val Loss=0.3507, lr=0.0100 +[07/22 02:30:21] cifar10-global-growing_reg-resnet56 INFO: Epoch 24/100, Acc=0.9082, Val Loss=0.3024, lr=0.0100 +[07/22 02:30:38] cifar10-global-growing_reg-resnet56 INFO: Epoch 25/100, Acc=0.8863, Val Loss=0.3911, lr=0.0100 +[07/22 02:30:55] cifar10-global-growing_reg-resnet56 INFO: Epoch 26/100, Acc=0.9041, Val Loss=0.3338, lr=0.0100 +[07/22 02:31:13] cifar10-global-growing_reg-resnet56 INFO: Epoch 27/100, Acc=0.8967, Val Loss=0.3447, lr=0.0100 +[07/22 02:31:30] cifar10-global-growing_reg-resnet56 INFO: Epoch 28/100, Acc=0.8908, Val Loss=0.3892, lr=0.0100 +[07/22 02:31:47] cifar10-global-growing_reg-resnet56 INFO: Epoch 29/100, Acc=0.8727, Val Loss=0.4386, lr=0.0100 +[07/22 02:32:05] cifar10-global-growing_reg-resnet56 INFO: Epoch 30/100, Acc=0.9017, Val Loss=0.3354, lr=0.0100 +[07/22 02:32:22] cifar10-global-growing_reg-resnet56 INFO: Epoch 31/100, Acc=0.8933, Val Loss=0.3861, lr=0.0100 +[07/22 02:32:40] cifar10-global-growing_reg-resnet56 INFO: Epoch 32/100, Acc=0.8716, Val Loss=0.4683, lr=0.0100 +[07/22 02:32:57] cifar10-global-growing_reg-resnet56 INFO: Epoch 33/100, Acc=0.9012, Val Loss=0.3370, lr=0.0100 +[07/22 02:33:15] cifar10-global-growing_reg-resnet56 INFO: Epoch 34/100, Acc=0.8968, Val Loss=0.3715, lr=0.0100 +[07/22 02:33:32] cifar10-global-growing_reg-resnet56 INFO: Epoch 35/100, Acc=0.9003, Val Loss=0.3457, lr=0.0100 +[07/22 02:33:50] cifar10-global-growing_reg-resnet56 INFO: Epoch 36/100, Acc=0.8886, Val Loss=0.4046, lr=0.0100 +[07/22 02:34:07] cifar10-global-growing_reg-resnet56 INFO: Epoch 37/100, Acc=0.8912, Val Loss=0.3739, lr=0.0100 +[07/22 02:34:25] cifar10-global-growing_reg-resnet56 INFO: Epoch 38/100, Acc=0.9001, Val Loss=0.3478, lr=0.0100 +[07/22 02:34:42] cifar10-global-growing_reg-resnet56 INFO: Epoch 39/100, Acc=0.9067, Val Loss=0.3225, lr=0.0100 +[07/22 02:35:00] cifar10-global-growing_reg-resnet56 INFO: Epoch 40/100, Acc=0.8995, Val Loss=0.3535, lr=0.0100 +[07/22 02:35:18] cifar10-global-growing_reg-resnet56 INFO: Epoch 41/100, Acc=0.8851, Val Loss=0.4250, lr=0.0100 +[07/22 02:35:36] cifar10-global-growing_reg-resnet56 INFO: Epoch 42/100, Acc=0.9077, Val Loss=0.3242, lr=0.0100 +[07/22 02:35:53] cifar10-global-growing_reg-resnet56 INFO: Epoch 43/100, Acc=0.9042, Val Loss=0.3331, lr=0.0100 +[07/22 02:36:11] cifar10-global-growing_reg-resnet56 INFO: Epoch 44/100, Acc=0.8865, Val Loss=0.3998, lr=0.0100 +[07/22 02:36:29] cifar10-global-growing_reg-resnet56 INFO: Epoch 45/100, Acc=0.8863, Val Loss=0.3913, lr=0.0100 +[07/22 02:36:46] cifar10-global-growing_reg-resnet56 INFO: Epoch 46/100, Acc=0.9015, Val Loss=0.3451, lr=0.0100 +[07/22 02:37:04] cifar10-global-growing_reg-resnet56 INFO: Epoch 47/100, Acc=0.8942, Val Loss=0.3699, lr=0.0100 +[07/22 02:37:22] cifar10-global-growing_reg-resnet56 INFO: Epoch 48/100, Acc=0.8992, Val Loss=0.3428, lr=0.0100 +[07/22 02:37:41] cifar10-global-growing_reg-resnet56 INFO: Epoch 49/100, Acc=0.8810, Val Loss=0.4253, lr=0.0100 +[07/22 02:38:00] cifar10-global-growing_reg-resnet56 INFO: Epoch 50/100, Acc=0.9068, Val Loss=0.3200, lr=0.0100 +[07/22 02:38:19] cifar10-global-growing_reg-resnet56 INFO: Epoch 51/100, Acc=0.9025, Val Loss=0.3292, lr=0.0100 +[07/22 02:38:38] cifar10-global-growing_reg-resnet56 INFO: Epoch 52/100, Acc=0.8983, Val Loss=0.3616, lr=0.0100 +[07/22 02:38:57] cifar10-global-growing_reg-resnet56 INFO: Epoch 53/100, Acc=0.9063, Val Loss=0.3364, lr=0.0100 +[07/22 02:39:16] cifar10-global-growing_reg-resnet56 INFO: Epoch 54/100, Acc=0.8970, Val Loss=0.3707, lr=0.0100 +[07/22 02:39:34] cifar10-global-growing_reg-resnet56 INFO: Epoch 55/100, Acc=0.9065, Val Loss=0.3400, lr=0.0100 +[07/22 02:39:53] cifar10-global-growing_reg-resnet56 INFO: Epoch 56/100, Acc=0.9070, Val Loss=0.3331, lr=0.0100 +[07/22 02:40:11] cifar10-global-growing_reg-resnet56 INFO: Epoch 57/100, Acc=0.9005, Val Loss=0.3485, lr=0.0100 +[07/22 02:40:30] cifar10-global-growing_reg-resnet56 INFO: Epoch 58/100, Acc=0.8969, Val Loss=0.3844, lr=0.0100 +[07/22 02:40:48] cifar10-global-growing_reg-resnet56 INFO: Epoch 59/100, Acc=0.9020, Val Loss=0.3527, lr=0.0100 +[07/22 02:41:05] cifar10-global-growing_reg-resnet56 INFO: Epoch 60/100, Acc=0.9281, Val Loss=0.2527, lr=0.0010 +[07/22 02:41:23] cifar10-global-growing_reg-resnet56 INFO: Epoch 61/100, Acc=0.9293, Val Loss=0.2529, lr=0.0010 +[07/22 02:41:40] cifar10-global-growing_reg-resnet56 INFO: Epoch 62/100, Acc=0.9310, Val Loss=0.2531, lr=0.0010 +[07/22 02:41:58] cifar10-global-growing_reg-resnet56 INFO: Epoch 63/100, Acc=0.9323, Val Loss=0.2484, lr=0.0010 +[07/22 02:42:16] cifar10-global-growing_reg-resnet56 INFO: Epoch 64/100, Acc=0.9322, Val Loss=0.2507, lr=0.0010 +[07/22 02:42:34] cifar10-global-growing_reg-resnet56 INFO: Epoch 65/100, Acc=0.9306, Val Loss=0.2518, lr=0.0010 +[07/22 02:42:52] cifar10-global-growing_reg-resnet56 INFO: Epoch 66/100, Acc=0.9323, Val Loss=0.2516, lr=0.0010 +[07/22 02:43:09] cifar10-global-growing_reg-resnet56 INFO: Epoch 67/100, Acc=0.9328, Val Loss=0.2553, lr=0.0010 +[07/22 02:43:26] cifar10-global-growing_reg-resnet56 INFO: Epoch 68/100, Acc=0.9316, Val Loss=0.2551, lr=0.0010 +[07/22 02:43:44] cifar10-global-growing_reg-resnet56 INFO: Epoch 69/100, Acc=0.9326, Val Loss=0.2542, lr=0.0010 +[07/22 02:44:03] cifar10-global-growing_reg-resnet56 INFO: Epoch 70/100, Acc=0.9344, Val Loss=0.2537, lr=0.0010 +[07/22 02:44:21] cifar10-global-growing_reg-resnet56 INFO: Epoch 71/100, Acc=0.9331, Val Loss=0.2572, lr=0.0010 +[07/22 02:44:38] cifar10-global-growing_reg-resnet56 INFO: Epoch 72/100, Acc=0.9321, Val Loss=0.2611, lr=0.0010 +[07/22 02:44:56] cifar10-global-growing_reg-resnet56 INFO: Epoch 73/100, Acc=0.9330, Val Loss=0.2601, lr=0.0010 +[07/22 02:45:14] cifar10-global-growing_reg-resnet56 INFO: Epoch 74/100, Acc=0.9324, Val Loss=0.2597, lr=0.0010 +[07/22 02:45:31] cifar10-global-growing_reg-resnet56 INFO: Epoch 75/100, Acc=0.9322, Val Loss=0.2671, lr=0.0010 +[07/22 02:45:49] cifar10-global-growing_reg-resnet56 INFO: Epoch 76/100, Acc=0.9326, Val Loss=0.2623, lr=0.0010 +[07/22 02:46:06] cifar10-global-growing_reg-resnet56 INFO: Epoch 77/100, Acc=0.9319, Val Loss=0.2644, lr=0.0010 +[07/22 02:46:23] cifar10-global-growing_reg-resnet56 INFO: Epoch 78/100, Acc=0.9328, Val Loss=0.2618, lr=0.0010 +[07/22 02:46:40] cifar10-global-growing_reg-resnet56 INFO: Epoch 79/100, Acc=0.9333, Val Loss=0.2638, lr=0.0010 +[07/22 02:46:58] cifar10-global-growing_reg-resnet56 INFO: Epoch 80/100, Acc=0.9337, Val Loss=0.2627, lr=0.0001 +[07/22 02:47:15] cifar10-global-growing_reg-resnet56 INFO: Epoch 81/100, Acc=0.9341, Val Loss=0.2631, lr=0.0001 +[07/22 02:47:32] cifar10-global-growing_reg-resnet56 INFO: Epoch 82/100, Acc=0.9343, Val Loss=0.2613, lr=0.0001 +[07/22 02:47:50] cifar10-global-growing_reg-resnet56 INFO: Epoch 83/100, Acc=0.9343, Val Loss=0.2619, lr=0.0001 +[07/22 02:48:06] cifar10-global-growing_reg-resnet56 INFO: Epoch 84/100, Acc=0.9345, Val Loss=0.2622, lr=0.0001 +[07/22 02:48:24] cifar10-global-growing_reg-resnet56 INFO: Epoch 85/100, Acc=0.9332, Val Loss=0.2629, lr=0.0001 +[07/22 02:48:41] cifar10-global-growing_reg-resnet56 INFO: Epoch 86/100, Acc=0.9337, Val Loss=0.2618, lr=0.0001 +[07/22 02:48:58] cifar10-global-growing_reg-resnet56 INFO: Epoch 87/100, Acc=0.9339, Val Loss=0.2632, lr=0.0001 +[07/22 02:49:15] cifar10-global-growing_reg-resnet56 INFO: Epoch 88/100, Acc=0.9327, Val Loss=0.2646, lr=0.0001 +[07/22 02:49:32] cifar10-global-growing_reg-resnet56 INFO: Epoch 89/100, Acc=0.9333, Val Loss=0.2599, lr=0.0001 +[07/22 02:49:49] cifar10-global-growing_reg-resnet56 INFO: Epoch 90/100, Acc=0.9346, Val Loss=0.2612, lr=0.0001 +[07/22 02:50:05] cifar10-global-growing_reg-resnet56 INFO: Epoch 91/100, Acc=0.9343, Val Loss=0.2594, lr=0.0001 +[07/22 02:50:23] cifar10-global-growing_reg-resnet56 INFO: Epoch 92/100, Acc=0.9335, Val Loss=0.2602, lr=0.0001 +[07/22 02:50:40] cifar10-global-growing_reg-resnet56 INFO: Epoch 93/100, Acc=0.9337, Val Loss=0.2629, lr=0.0001 +[07/22 02:50:58] cifar10-global-growing_reg-resnet56 INFO: Epoch 94/100, Acc=0.9333, Val Loss=0.2623, lr=0.0001 +[07/22 02:51:16] cifar10-global-growing_reg-resnet56 INFO: Epoch 95/100, Acc=0.9345, Val Loss=0.2592, lr=0.0001 +[07/22 02:51:33] cifar10-global-growing_reg-resnet56 INFO: Epoch 96/100, Acc=0.9335, Val Loss=0.2624, lr=0.0001 +[07/22 02:51:51] cifar10-global-growing_reg-resnet56 INFO: Epoch 97/100, Acc=0.9353, Val Loss=0.2602, lr=0.0001 +[07/22 02:52:08] cifar10-global-growing_reg-resnet56 INFO: Epoch 98/100, Acc=0.9336, Val Loss=0.2606, lr=0.0001 +[07/22 02:52:25] cifar10-global-growing_reg-resnet56 INFO: Epoch 99/100, Acc=0.9341, Val Loss=0.2620, lr=0.0001 +[07/22 02:52:25] cifar10-global-growing_reg-resnet56 INFO: Best Acc=0.9353 From 52180cd102bcf17c90dc5ffee57cdb743865ff36 Mon Sep 17 00:00:00 2001 From: VainF <2218880241@qq.com> Date: Sat, 22 Jul 2023 21:50:11 +0800 Subject: [PATCH 03/13] Clean up --- .../pruner/algorithms/growing_reg_pruner.py | 29 +++++++++---------- torch_pruning/pruner/importance.py | 5 ++++ 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/torch_pruning/pruner/algorithms/growing_reg_pruner.py b/torch_pruning/pruner/algorithms/growing_reg_pruner.py index 355af9e..cc993af 100644 --- a/torch_pruning/pruner/algorithms/growing_reg_pruner.py +++ b/torch_pruning/pruner/algorithms/growing_reg_pruner.py @@ -3,10 +3,8 @@ import typing import torch import torch.nn as nn - -from ..importance import MagnitudeImportance, GroupNormImportance from .. import function -import math + class GrowingRegPruner(MetaPruner): def __init__( @@ -15,7 +13,7 @@ def __init__( example_inputs, importance, reg=1e-5, - delta_reg = 1e-5, + delta_reg=1e-5, iterative_steps=1, iterative_sparsity_scheduler: typing.Callable = linear_scheduler, ch_sparsity=0.5, @@ -46,11 +44,10 @@ def __init__( ) self.base_reg = reg self._groups = list(self.DG.get_all_groups()) - self.group_lasso = True group_reg = {} for group in self._groups: - group_reg[group] = torch.ones( len(group[0].idxs) ) * self.base_reg + group_reg[group] = torch.ones(len(group[0].idxs)) * self.base_reg self.group_reg = group_reg self.delta_reg = delta_reg @@ -58,9 +55,10 @@ def update_reg(self): for group in self._groups: group_l2norm_sq = self.estimate_importance(group) if group_l2norm_sq is None: - continue + continue reg = self.group_reg[group] - standarized_imp = (group_l2norm_sq.max() - group_l2norm_sq) / (group_l2norm_sq.max() - group_l2norm_sq.min() + 1e-8) + standarized_imp = (group_l2norm_sq.max() - group_l2norm_sq) / \ + (group_l2norm_sq.max() - group_l2norm_sq.min() + 1e-8) # => [0, 1] reg = reg + self.delta_reg * standarized_imp.to(reg.device) self.group_reg[group] = reg @@ -69,19 +67,18 @@ def regularize(self, model): group_l2norm_sq = self.estimate_importance(group) if group_l2norm_sq is None: continue - reg = self.group_reg[group] for dep, idxs in group: layer = dep.layer pruning_fn = dep.pruning_fn - if isinstance(layer, nn.modules.batchnorm._BatchNorm) and layer.affine==True and layer not in self.ignored_layers: + if isinstance(layer, nn.modules.batchnorm._BatchNorm) and layer.affine == True and layer not in self.ignored_layers: layer.weight.grad.data.add_(reg.to(layer.weight.device) * layer.weight.data) - elif isinstance(layer, (nn.modules.conv._ConvNd, nn.Linear)) and layer not in self.ignored_layers: - if pruning_fn in [function.prune_conv_out_channels, function.prune_linear_out_channels]: + elif isinstance(layer, (nn.modules.conv._ConvNd, nn.Linear)): + if pruning_fn in [function.prune_conv_out_channels, function.prune_linear_out_channels] and layer not in self.ignored_layers: w = layer.weight.data[idxs] - g = w * reg.to(layer.weight.device).view( -1, *([1]*(len(w.shape)-1)) ) #/ group_norm.view( -1, *([1]*(len(w.shape)-1)) ) * group_size #group_size #* scale.view( -1, *([1]*(len(w.shape)-1)) ) - layer.weight.grad.data[idxs]+= g + g = w * reg.to(layer.weight.device).view(-1, *([1]*(len(w.shape)-1))) + layer.weight.grad.data[idxs] += g elif pruning_fn in [function.prune_conv_in_channels, function.prune_linear_in_channels]: w = layer.weight.data[:, idxs] - g = w * reg.to(layer.weight.device).view( 1, -1, *([1]*(len(w.shape)-2)) ) #/ gn.view( 1, -1, *([1]*(len(w.shape)-2)) ) * group_size #* scale.view( 1, -1, *([1]*(len(w.shape)-2)) ) - layer.weight.grad.data[:, idxs]+=g \ No newline at end of file + g = w * reg.to(layer.weight.device).view(1, -1, *([1]*(len(w.shape)-2))) + layer.weight.grad.data[:, idxs] += g diff --git a/torch_pruning/pruner/importance.py b/torch_pruning/pruner/importance.py index 24efee6..83e816d 100644 --- a/torch_pruning/pruner/importance.py +++ b/torch_pruning/pruner/importance.py @@ -229,6 +229,8 @@ def __call__(self, group, **kwargs): class GroupNormImportance(MagnitudeImportance): + """ A magnitude-based importance in the group level. Only for reproducing the results in the paper. It may not be ready for practical use. + """ def __init__(self, p=2, normalizer='max'): super().__init__(p=p, group_reduction=None, normalizer=normalizer) self.p = p @@ -357,6 +359,9 @@ def __call__(self, group, ch_groups=1): class TaylorImportance(MagnitudeImportance): + """First-order taylor expansion of the loss function. + https://openaccess.thecvf.com/content_CVPR_2019/papers/Molchanov_Importance_Estimation_for_Neural_Network_Pruning_CVPR_2019_paper.pdf + """ def __init__(self, group_reduction="mean", normalizer='mean', multivariable=False): self.group_reduction = group_reduction self.normalizer = normalizer From e3d1c112857aaf8204ed26417878a88f962e810d Mon Sep 17 00:00:00 2001 From: VainF <2218880241@qq.com> Date: Sun, 23 Jul 2023 23:28:02 +0800 Subject: [PATCH 04/13] Fixed a bug in GrowinReg --- benchmarks/main_imagenet.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/benchmarks/main_imagenet.py b/benchmarks/main_imagenet.py index 55d7a42..186d800 100644 --- a/benchmarks/main_imagenet.py +++ b/benchmarks/main_imagenet.py @@ -134,7 +134,7 @@ def get_pruner(model, example_inputs, args): pruner_entry = partial(tp.pruner.GroupNormPruner, global_pruning=args.global_pruning) elif args.method == "group_greg": sparsity_learning = True - imp = tp.importance.GroupNormImportance(p=2) + imp = tp.importance.MagnitudeImportance(p=2) pruner_entry = partial(tp.pruner.GrowingRegPruner, reg=args.reg, delta_reg=args.delta_reg, global_pruning=args.global_pruning) elif args.method == "group_sl": sparsity_learning = True @@ -207,9 +207,6 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, arg if epoch < args.lr_warmup_epochs: # Reset ema buffer to keep copying weights during warmup period model_ema.n_averaged.fill_(0) - - if pruner is not None and isinstance(pruner, tp.pruner.GroupNormPruner): - pruner.update_reg() acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) batch_size = image.shape[0] @@ -217,7 +214,9 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, arg metric_logger.meters["acc1"].update(acc1.item(), n=batch_size) metric_logger.meters["acc5"].update(acc5.item(), n=batch_size) metric_logger.meters["img/s"].update(batch_size / (time.time() - start_time)) - + + if pruner is not None and isinstance(pruner, tp.pruner.GrowingRegPruner): + pruner.update_reg() def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=""): model.eval() From 9ca2d951cb0a601a53ad9ba6a1a52b3dca6034e3 Mon Sep 17 00:00:00 2001 From: VainF <2218880241@qq.com> Date: Mon, 24 Jul 2023 20:30:16 +0800 Subject: [PATCH 05/13] Fixed a bug in torch.cat([x, x, y]) --- tests/test_concat.py | 4 ++-- tests/test_split.py | 35 ++++++++++++++++++----------------- torch_pruning/dependency.py | 36 +++++++++++++++++++----------------- 3 files changed, 39 insertions(+), 36 deletions(-) diff --git a/tests/test_concat.py b/tests/test_concat.py index 3104052..a43a609 100644 --- a/tests/test_concat.py +++ b/tests/test_concat.py @@ -24,14 +24,14 @@ def __init__(self, in_dim): nn.BatchNorm2d(in_dim//2) ) self.block2 = nn.Sequential( - nn.Conv2d(in_dim + in_dim//2, in_dim, 1), + nn.Conv2d(in_dim * 2 + in_dim//2, in_dim, 1), nn.BatchNorm2d(in_dim) ) def forward(self, x): x = self.block1(x) x2 = self.parallel_path(x) - x = torch.cat([x, x2], dim=1) + x = torch.cat([x, x, x2], dim=1) x = self.block2(x) return x diff --git a/tests/test_split.py b/tests/test_split.py index 4aae116..9aff270 100644 --- a/tests/test_split.py +++ b/tests/test_split.py @@ -9,13 +9,13 @@ class Net(nn.Module): def __init__(self, in_dim): super().__init__() - + self.block1 = nn.Sequential( nn.Conv2d(in_dim, in_dim, 1), nn.BatchNorm2d(in_dim), nn.GELU(), - nn.Conv2d(in_dim, in_dim*3, 1), - nn.BatchNorm2d(in_dim*3) + nn.Conv2d(in_dim, in_dim*4, 1), + nn.BatchNorm2d(in_dim*4) ) self.block2_1 = nn.Sequential( @@ -27,23 +27,24 @@ def __init__(self, in_dim): nn.Conv2d(2*in_dim, in_dim, 1), nn.BatchNorm2d(in_dim) ) - + def forward(self, x): x = self.block1(x) num_ch = x.shape[1] - + c1, c2 = self.block2_1[0].in_channels, self.block2_2[0].in_channels - x1, x3 = torch.split(x, [c1, c2], dim=1) + x1, x2, x3 = torch.split(x, [c1, c1, c2], dim=1) x1 = self.block2_1(x1) - #x2 = self.block2_1(x2) + x2 = self.block2_1(x2) x3 = self.block2_2(x3) - return x1, x3 - + return x1, x2, x3 + def test_pruner(): - model = Net(10) + dim = 128 + model = Net(dim) print(model) # Global metrics - example_inputs = torch.randn(1, 10, 7, 7) + example_inputs = torch.randn(1, dim, 7, 7) imp = tp.importance.RandomImportance() ignored_layers = [] @@ -66,19 +67,19 @@ def test_pruner(): base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs) for i in range(iterative_steps): for g in pruner.step(interactive=True): - print(g.details()) + #print(g.details()) g.prune() print(model) macs, nparams = tp.utils.count_ops_and_params(model, example_inputs) - + print([o.shape for o in model(example_inputs)]) print( - " Iter %d/%d, Params: %.2f => %.2f" - % (i+1, iterative_steps, base_nparams, nparams) + " Iter %d/%d, Params: %.2f M => %.2f M" + % (i+1, iterative_steps, base_nparams / 1e6, nparams / 1e6) ) print( - " Iter %d/%d, MACs: %.2f => %.2f" - % (i+1, iterative_steps, base_macs, macs ) + " Iter %d/%d, MACs: %.2f G => %.2f G" + % (i+1, iterative_steps, base_macs / 1e9, macs / 1e9) ) # finetune your model here # finetune(model) diff --git a/torch_pruning/dependency.py b/torch_pruning/dependency.py index b4fec44..0af41ce 100644 --- a/torch_pruning/dependency.py +++ b/torch_pruning/dependency.py @@ -238,7 +238,13 @@ def __len__(self): def add_and_merge(self, dep, idxs): for i, (_dep, _idxs) in enumerate(self._group): if _dep.target == dep.target and _dep.handler == dep.handler: - self._group[i] = GroupItem(dep=_dep, idxs=list(set(_idxs + idxs))) + visited_idxs = set() + merged_idxs = [] + for index in _idxs + idxs: + if index.idx not in visited_idxs: + merged_idxs.append(index) + visited_idxs.add(index.idx) + self._group[i] = GroupItem(dep=_dep, idxs=merged_idxs) return self.add_dep(dep, idxs) @@ -259,9 +265,9 @@ def details(self): fmt += "\n" + "-" * 32 + "\n" for i, (dep, idxs) in enumerate(self._group): if i==0: - fmt += "[{}] {}, idxs={} (Pruning Root)\n".format(i, dep, idxs) + fmt += "[{}] {}, idxs ({}) ={} (Pruning Root)\n".format(i, dep, len(idxs), idxs) else: - fmt += "[{}] {}, idxs={}\n".format(i, dep, idxs) + fmt += "[{}] {}, idxs ({}) ={} \n".format(i, dep, len(idxs), idxs) fmt += "-" * 32 + "\n" return fmt @@ -462,7 +468,6 @@ def _fix_dependency_graph_non_recursive(dep, idxs, *args): node, fn = dep.target, dep.handler visited_node.add(node) #print(dep) - #print(node.dependencies) for new_dep in node.dependencies: if new_dep.is_triggered_by(fn): new_indices = idxs @@ -485,7 +490,7 @@ def _fix_dependency_graph_non_recursive(dep, idxs, *args): ) _fix_dependency_graph_non_recursive(*group[0]) - + # merge pruning ops merged_group = Group() for dep, idxs in group.items: @@ -724,9 +729,11 @@ def _record_grad_fn(module, inputs, outputs): if output_transform is not None: out = output_transform(out) module2node = {} # create a mapping from nn.Module to tp.dependency.Node + + visited = set() for o in utils.flatten_as_list(out): self._trace_computational_graph( - module2node, o.grad_fn, gradfn2module, reused) + module2node, o.grad_fn, gradfn2module, reused, visited=visited) # TODO: Improving ViT pruning # This is a corner case for pruning ViT, @@ -749,13 +756,11 @@ def _record_grad_fn(module, inputs, outputs): stack.append(ni) return module2node - def _trace_computational_graph(self, module2node, grad_fn_root, gradfn2module, reused): + def _trace_computational_graph(self, module2node, grad_fn_root, gradfn2module, reused, visited=set()): def create_node_if_not_exists(grad_fn): module = gradfn2module.get(grad_fn, None) - if module is not None \ - and module in module2node \ - and module not in reused: + if module is not None and module in module2node and module not in reused: return module2node[module] # 1. link grad_fns and modules @@ -804,13 +809,10 @@ def create_node_if_not_exists(grad_fn): # non-recursive construction of computational graph processing_stack = [grad_fn_root] - visited = set() - visited_as_output_node = set() while len(processing_stack) > 0: grad_fn = processing_stack.pop(-1) if grad_fn in visited: continue - node = create_node_if_not_exists(grad_fn=grad_fn) if hasattr(grad_fn, "next_functions"): for f in grad_fn.next_functions: @@ -844,12 +846,11 @@ def create_node_if_not_exists(grad_fn): # node.add_input(input_node, allow_dumplicated=allow_dumplicated) # input_node.add_output(node, allow_dumplicated=allow_dumplicated) #else: - node.add_input(input_node, allow_dumplicated=False) - input_node.add_output(node, allow_dumplicated=False) - + node.add_input(input_node, allow_dumplicated=True) + input_node.add_output(node, allow_dumplicated=True) processing_stack.append(f[0]) visited.add(grad_fn) - visited_as_output_node.add(node) + for (param, dim) in self.unwrapped_parameters: module2node[param].pruning_dim = dim @@ -1000,6 +1001,7 @@ def _update_concat_index_mapping(self, cat_node: Node): for n in cat_node.inputs: chs.append(self.infer_channels_between(n, cat_node)) cat_node.module.concat_sizes = chs + offsets = [0] for ch in chs: From cfe06ff59f52d0c2ded20b45be594486abcc5b1c Mon Sep 17 00:00:00 2001 From: VainF <2218880241@qq.com> Date: Mon, 24 Jul 2023 21:48:35 +0800 Subject: [PATCH 06/13] update readme --- examples/yolov7/readme.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/yolov7/readme.md b/examples/yolov7/readme.md index f62428c..660b91f 100644 --- a/examples/yolov7/readme.md +++ b/examples/yolov7/readme.md @@ -27,7 +27,7 @@ python yolov7_detect_pruned.py --weights yolov7.pt --conf 0.25 --img-size 640 -- # Training with pruned yolov7 (The training part is not validated) # Please download the pretrained yolov7_training.pt from https://github.com/WongKinYiu/yolov7/releases/download/v0.1/yolov7.pt. -python yolov7_train_pruned.py --workers 8 --device 0 --batch-size 1 --data data/coco.yaml --img 640 640 --cfg cfg/training/yolov7.yaml --weights 'yolov7_training.pt' --name yolov7 --hyp data/hyp.scratch.p5.yaml +python yolov7_train_pruned.py --workers 8 --device 0 --batch-size 1 --data data/coco.yaml --img 640 640 --cfg cfg/training/yolov7.yaml --weights 'yolov7.pt' --name yolov7 --hyp data/hyp.scratch.p5.yaml ``` #### Screenshot for yolov7_train_pruned.py: From 212d9c712ddea28bc28050e9898079ec2d00c0d9 Mon Sep 17 00:00:00 2001 From: VainF <2218880241@qq.com> Date: Wed, 26 Jul 2023 18:35:21 +0800 Subject: [PATCH 07/13] support .get_channel_groups for auto grouping --- torch_pruning/dependency.py | 26 +++++-------------- torch_pruning/ops.py | 4 ++- torch_pruning/pruner/algorithms/metapruner.py | 14 +++++----- torch_pruning/pruner/function.py | 9 +++++++ 4 files changed, 25 insertions(+), 28 deletions(-) diff --git a/torch_pruning/dependency.py b/torch_pruning/dependency.py index 0af41ce..b6c9071 100644 --- a/torch_pruning/dependency.py +++ b/torch_pruning/dependency.py @@ -467,7 +467,7 @@ def _fix_dependency_graph_non_recursive(dep, idxs, *args): dep, idxs = processing_stack.pop(-1) node, fn = dep.target, dep.handler visited_node.add(node) - #print(dep) + for new_dep in node.dependencies: if new_dep.is_triggered_by(fn): new_indices = idxs @@ -475,8 +475,6 @@ def _fix_dependency_graph_non_recursive(dep, idxs, *args): if mapping is not None: new_indices = mapping(new_indices) - #print(len(new_indices)) - #print() if len(new_indices) == 0: continue if (new_dep.target in visited_node) and group.has_pruning_op( @@ -509,12 +507,13 @@ def get_all_groups(self, ignored_layers=[], root_module_types=(ops.TORCH_CONV, o ignored_layers = ignored_layers+self.IGNORED_LAYERS for m in list(self.module2node.keys()): + if m in ignored_layers: continue if not isinstance(m, tuple(root_module_types)): continue - + pruner = self.get_pruner_of_module(m) if pruner is None or pruner.get_out_channels(m) is None: continue @@ -783,6 +782,9 @@ def create_node_if_not_exists(grad_fn): module = ops._SplitOp(self._op_id) self._op_id+=1 elif "view" in grad_fn.name().lower() or 'reshape' in grad_fn.name().lower(): + #if 'reshape' in grad_fn.name().lower(): + #print(grad_fn.__dir__()) + #print(grad_fn._saved_self_sizes) module = ops._ReshapeOp(self._op_id) self._op_id+=1 else: @@ -830,22 +832,6 @@ def create_node_if_not_exists(grad_fn): if not is_unwrapped_param: continue input_node = create_node_if_not_exists(f[0]) - - #allow_dumplicated = False - - # TODO: support duplicated concat/split like torch.cat([x, x], dim=1) - # The following implementation is can achieve this but will introduce some bugs. - # will be fixed in the future version - #if node.type == ops.OPTYPE.CONCAT: - # allow_dumplicated = (node not in visited_as_output_node) - # node.add_input(input_node, allow_dumplicated=allow_dumplicated) - # input_node.add_output(node, allow_dumplicated=allow_dumplicated) - # print(node, node.inputs) - #elif input_node.type == ops.OPTYPE.SPLIT: - # allow_dumplicated = (node not in visited_as_output_node) - # node.add_input(input_node, allow_dumplicated=allow_dumplicated) - # input_node.add_output(node, allow_dumplicated=allow_dumplicated) - #else: node.add_input(input_node, allow_dumplicated=True) input_node.add_output(node, allow_dumplicated=True) processing_stack.append(f[0]) diff --git a/torch_pruning/ops.py b/torch_pruning/ops.py index acba6b4..13fca89 100644 --- a/torch_pruning/ops.py +++ b/torch_pruning/ops.py @@ -1,7 +1,6 @@ import torch.nn as nn from enum import IntEnum - class DummyMHA(nn.Module): def __init__(self): super(DummyMHA, self).__init__() @@ -70,6 +69,9 @@ def get_out_channels(self, layer): def get_in_channels(self, layer): return None + def get_channel_groups(self, layer): + return 1 + class ConcatPruner(DummyPruner): def prune_out_channels(self, layer, idxs): diff --git a/torch_pruning/pruner/algorithms/metapruner.py b/torch_pruning/pruner/algorithms/metapruner.py index c225b19..0334411 100644 --- a/torch_pruning/pruner/algorithms/metapruner.py +++ b/torch_pruning/pruner/algorithms/metapruner.py @@ -118,13 +118,13 @@ def __init__( # detect group convs & group norms for m in self.model.modules(): - if isinstance(m, ops.TORCH_CONV) \ - and m.groups > 1 \ - and m.groups != m.out_channels: - self.channel_groups[m] = m.groups - if isinstance(m, ops.TORCH_GROUPNORM): - self.channel_groups[m] = m.num_groups - + layer_pruner = self.DG.get_pruner_of_module(m) + channel_groups = layer_pruner.get_channel_groups(m) + if channel_groups > 1: + if isinstance(m, ops.TORCH_CONV) and m.groups == m.out_channels: + continue + self.channel_groups[m] = channel_groups + # count the number of total channels at initialization if self.global_pruning: initial_total_channels = 0 diff --git a/torch_pruning/pruner/function.py b/torch_pruning/pruner/function.py index c92d132..84dbca6 100644 --- a/torch_pruning/pruner/function.py +++ b/torch_pruning/pruner/function.py @@ -81,6 +81,9 @@ def __call__(self, layer: nn.Module, idxs: Sequence[int], to_output: bool = True layer = pruning_fn(layer, idxs) return layer + def get_channel_groups(self, layer): + return 1 + def _prune_parameter_and_grad(self, weight, keep_idxs, pruning_dim): pruned_weight = torch.nn.Parameter(torch.index_select(weight, pruning_dim, torch.LongTensor(keep_idxs).to(weight.device))) if weight.grad is not None: @@ -123,6 +126,9 @@ def get_out_channels(self, layer): def get_in_channels(self, layer): return layer.in_channels + def get_channel_groups(self, layer): + return layer.groups + class DepthwiseConvPruner(ConvPruner): TARGET_MODULE = ops.TORCH_CONV @@ -249,6 +255,9 @@ def get_out_channels(self, layer): def get_in_channels(self, layer): return layer.num_channels + def get_channel_groups(self, layer): + return layer.num_groups + class InstanceNormPruner(BasePruningFunc): def prune_out_channels(self, layer: nn.Module, idxs: Sequence[int]) -> nn.Module: keep_idxs = list(set(range(layer.num_features)) - set(idxs)) From a07659e8bd02c06b04bc073c2c3990d2d33c4300 Mon Sep 17 00:00:00 2001 From: VainF <2218880241@qq.com> Date: Wed, 26 Jul 2023 18:35:57 +0800 Subject: [PATCH 08/13] An example for Timm ViT --- examples/timm_models/timm_vit.py | 150 +++++++++++++++++++++++++++++++ 1 file changed, 150 insertions(+) create mode 100644 examples/timm_models/timm_vit.py diff --git a/examples/timm_models/timm_vit.py b/examples/timm_models/timm_vit.py new file mode 100644 index 0000000..fad2db5 --- /dev/null +++ b/examples/timm_models/timm_vit.py @@ -0,0 +1,150 @@ +import os, sys +sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))))) + +import torch +import torch.nn as nn +import timm +import torch_pruning as tp +from typing import Sequence + +from timm.models.vision_transformer import Attention +import torch.nn.functional as F + +def timm_attention_forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) + q, k = self.q_norm(q), self.k_norm(k) + + if self.fused_attn: + x = F.scaled_dot_product_attention( + q, k, v, + dropout_p=self.attn_drop.p, + ) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + + #x = x.transpose(1, 2).reshape(B, N, C) # this line forces the input and output channels to be identical. + x = x.transpose(0, 1).reshape(B, N, -1) + x = self.proj(x) + x = self.proj_drop(x) + return x + +class TimmAttentionPruner(tp.function.BasePruningFunc): + """ The implementation of timm Attention requires identical input and output channels. + So in this case, we prune all input channels and output channels at the same time. + """ + def prune_in_channels(self, layer: nn.Module, idxs: Sequence[int]): + tp.prune_linear_in_channels(layer.qkv, idxs) + return layer + + def prune_out_channels(self, layer: nn.Module, idxs: Sequence[int]): + tp.prune_linear_out_channels(layer.proj, idxs) + return layer + + def get_out_channels(self, layer: nn.Module): + return layer.proj.out_features + + def get_in_channels(self, layer: nn.Module): + return layer.qkv.in_features + + def get_channel_groups(self, layer): + return 1 + +# timm==0.9.2 +# torch==1.12.1 + +timm_models = timm.list_models() +print(timm_models) +example_inputs = torch.randn(1,3,224,224) +imp = tp.importance.MagnitudeImportance(p=2, group_reduction="mean") +prunable_list = [] +unprunable_list = [] +problem_with_input_shape = [] + +timm_atten_pruner = TimmAttentionPruner() + + +from transformers import ViTImageProcessor, ViTForImageClassification +from PIL import Image +import requests + +url = 'http://images.cocodataset.org/val2017/000000039769.jpg' +image = Image.open(requests.get(url, stream=True).raw) + +processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224') +model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224') +print(model) +inputs = processor(images=image, return_tensors="pt") +outputs = model(**inputs) +logits = outputs.logits +# model predicts one of the 1000 ImageNet classes +predicted_class_idx = logits.argmax(-1).item() +print("Predicted class:", model.config.id2label[predicted_class_idx]) + + + +for i, model_name in enumerate(timm_models): + if not model_name=='vit_base_patch8_224': + continue + + print("Pruning %s..."%model_name) + device = 'cuda' if torch.cuda.is_available() else 'cpu' + #if 'rexnet' in model_name or 'sequencer' in model_name or 'botnet' in model_name: # pruning process stuck with that architectures - skip them. + # unprunable_list.append(model_name) + # continue + try: + model = timm.create_model(model_name, pretrained=False, no_jit=True).eval().to(device) + except: # out of memory error + model = timm.create_model(model_name, pretrained=False, no_jit=True).eval() + device = 'cpu' + ch_groups = {} + for m in model.modules(): + if isinstance(m, timm.models.vision_transformer.Attention): + m.forward = timm_attention_forward.__get__(m, Attention) # https://stackoverflow.com/questions/50599045/python-replacing-a-function-within-a-class-of-a-module + + input_size = model.default_cfg['input_size'] + example_inputs = torch.randn(1, *input_size).to(device) + test_output = model(example_inputs) + + print(model) + prunable = True + #try: + if True: + base_macs, base_params = tp.utils.count_ops_and_params(model, example_inputs) + pruner = tp.pruner.MagnitudePruner( + model, + example_inputs, + global_pruning=True, # If False, a uniform sparsity will be assigned to different layers. + importance=imp, # importance criterion for parameter selection + iterative_steps=1, # the number of iterations to achieve target sparsity + ch_sparsity=0.5, + ignored_layers=[], + channel_groups=ch_groups, + customized_pruners={Attention: timm_atten_pruner}, + root_module_types=(Attention, nn.Linear, nn.Conv2d), + ) + for g in pruner.step(interactive=True): + #print(g) + g.prune() + print(model) + test_output = model(example_inputs) + pruned_macs, pruned_params = tp.utils.count_ops_and_params(model, example_inputs) + print("Base MACs: %d, Pruned MACs: %d"%(base_macs, pruned_macs)) + print("Base Params: %d, Pruned Params: %d"%(base_params, pruned_params)) + #except Exception as e: + # prunable = False + + + + if prunable: + prunable_list.append(model_name) + else: + unprunable_list.append(model_name) + + print("Prunable: %d models, \n %s\n"%(len(prunable_list), prunable_list)) + print("Unprunable: %d models, \n %s\n"%(len(unprunable_list), unprunable_list)) \ No newline at end of file From 71028c146a10b2757f0d7e5e60d0faf7a19f5445 Mon Sep 17 00:00:00 2001 From: VainF <2218880241@qq.com> Date: Wed, 26 Jul 2023 18:46:06 +0800 Subject: [PATCH 09/13] An example for Timm ViT --- examples/timm_models/timm_vit.py | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/examples/timm_models/timm_vit.py b/examples/timm_models/timm_vit.py index fad2db5..569489d 100644 --- a/examples/timm_models/timm_vit.py +++ b/examples/timm_models/timm_vit.py @@ -68,26 +68,6 @@ def get_channel_groups(self, layer): timm_atten_pruner = TimmAttentionPruner() - -from transformers import ViTImageProcessor, ViTForImageClassification -from PIL import Image -import requests - -url = 'http://images.cocodataset.org/val2017/000000039769.jpg' -image = Image.open(requests.get(url, stream=True).raw) - -processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224') -model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224') -print(model) -inputs = processor(images=image, return_tensors="pt") -outputs = model(**inputs) -logits = outputs.logits -# model predicts one of the 1000 ImageNet classes -predicted_class_idx = logits.argmax(-1).item() -print("Predicted class:", model.config.id2label[predicted_class_idx]) - - - for i, model_name in enumerate(timm_models): if not model_name=='vit_base_patch8_224': continue From a8027de1abcfc3a16190bf7f4d23291d4144a7d0 Mon Sep 17 00:00:00 2001 From: VainF <2218880241@qq.com> Date: Wed, 26 Jul 2023 19:28:55 +0800 Subject: [PATCH 10/13] An example for Timm ViT --- examples/hf_transformers/prune_vit.py | 57 +++++++++++++++++++++++++++ examples/hf_transformers/readme.md | 6 +++ torch_pruning/dependency.py | 5 +-- 3 files changed, 64 insertions(+), 4 deletions(-) create mode 100644 examples/hf_transformers/prune_vit.py create mode 100644 examples/hf_transformers/readme.md diff --git a/examples/hf_transformers/prune_vit.py b/examples/hf_transformers/prune_vit.py new file mode 100644 index 0000000..3852d58 --- /dev/null +++ b/examples/hf_transformers/prune_vit.py @@ -0,0 +1,57 @@ +from transformers import ViTImageProcessor, ViTForImageClassification +from transformers.models.vit.modeling_vit import ViTSelfAttention +import torch_pruning as tp +from PIL import Image +import requests + +url = 'http://images.cocodataset.org/val2017/000000039769.jpg' +image = Image.open(requests.get(url, stream=True).raw) + +processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224') +model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224') +example_inputs = processor(images=image, return_tensors="pt")["pixel_values"] +#outputs = model(example_inputs) +#logits = outputs.logits +# model predicts one of the 1000 ImageNet classes +#predicted_class_idx = logits.argmax(-1).item() +#print("Predicted class:", model.config.id2label[predicted_class_idx]) + +print(model) +imp = tp.importance.MagnitudeImportance(p=2, group_reduction="mean") +base_macs, base_params = tp.utils.count_ops_and_params(model, example_inputs) +channel_groups = {} + +# All heads should be pruned simultaneously, so we group channels by head. +for m in model.modules(): + if isinstance(m, ViTSelfAttention): + channel_groups[m.query] = m.num_attention_heads + channel_groups[m.key] = m.num_attention_heads + channel_groups[m.value] = m.num_attention_heads + +pruner = tp.pruner.MagnitudePruner( + model, + example_inputs, + global_pruning=False, # If False, a uniform sparsity will be assigned to different layers. + importance=imp, # importance criterion for parameter selection + iterative_steps=1, # the number of iterations to achieve target sparsity + ch_sparsity=0.5, + channel_groups=channel_groups, + output_transform=lambda out: out.logits.sum(), + ignored_layers=[model.classifier], + ) + +for g in pruner.step(interactive=True): + #print(g) + g.prune() + +# Modify the attention head size and all head size aftering pruning +for m in model.modules(): + if isinstance(m, ViTSelfAttention): + m.attention_head_size = m.query.out_features // m.num_attention_heads + m.all_head_size = m.query.out_features + +print(model) +test_output = model(example_inputs) +pruned_macs, pruned_params = tp.utils.count_ops_and_params(model, example_inputs) +print("Base MACs: %d G, Pruned MACs: %d G"%(base_macs/1e9, pruned_macs/1e9)) +print("Base Params: %d M, Pruned Params: %d M"%(base_params/1e6, pruned_params/1e6)) \ No newline at end of file diff --git a/examples/hf_transformers/readme.md b/examples/hf_transformers/readme.md new file mode 100644 index 0000000..e13b33d --- /dev/null +++ b/examples/hf_transformers/readme.md @@ -0,0 +1,6 @@ +# Example for HuggingFace ViT + +## Pruning +```bash +python prune_vit.py +``` \ No newline at end of file diff --git a/torch_pruning/dependency.py b/torch_pruning/dependency.py index b6c9071..dfed889 100644 --- a/torch_pruning/dependency.py +++ b/torch_pruning/dependency.py @@ -634,7 +634,7 @@ def _detect_unwrapped_parameters(self, unwrapped_parameters): unwrapped_parameters = [] unwrapped_detected = list( set(unwrapped_detected) - set([p for (p, _) in unwrapped_parameters]) ) if len(unwrapped_detected)>0 and self.verbose: - warning_str = "Unwrapped parameters detected: {}.\n Torch-Pruning will prune the last non-singleton dimension of a parameter. If you wish to customize this behavior, please provide an unwrapped_parameters argument.".format([_param_to_name[p] for p in unwrapped_detected]) + warning_str = "Unwrapped parameters detected: {}.\n Torch-Pruning will prune the last non-singleton dimension of these parameters. If you wish to change this behavior, please provide an unwrapped_parameters argument.".format([_param_to_name[p] for p in unwrapped_detected]) warnings.warn(warning_str) # set default pruning dim for unwrapped parameters @@ -782,9 +782,6 @@ def create_node_if_not_exists(grad_fn): module = ops._SplitOp(self._op_id) self._op_id+=1 elif "view" in grad_fn.name().lower() or 'reshape' in grad_fn.name().lower(): - #if 'reshape' in grad_fn.name().lower(): - #print(grad_fn.__dir__()) - #print(grad_fn._saved_self_sizes) module = ops._ReshapeOp(self._op_id) self._op_id+=1 else: From ca8f847b816d589ebb19b3a3175679efaf9b0266 Mon Sep 17 00:00:00 2001 From: VainF <2218880241@qq.com> Date: Wed, 26 Jul 2023 19:29:20 +0800 Subject: [PATCH 11/13] An example for Huggingface ViT --- examples/hf_transformers/readme.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/hf_transformers/readme.md b/examples/hf_transformers/readme.md index e13b33d..4179999 100644 --- a/examples/hf_transformers/readme.md +++ b/examples/hf_transformers/readme.md @@ -3,4 +3,4 @@ ## Pruning ```bash python prune_vit.py -``` \ No newline at end of file +``` From 4bfd8d2679686f90851bd42c659ff17d58aee209 Mon Sep 17 00:00:00 2001 From: VainF <2218880241@qq.com> Date: Wed, 26 Jul 2023 19:35:17 +0800 Subject: [PATCH 12/13] v1.2.1 --- README.md | 4 ++-- setup.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 0b2b6b3..4547951 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ Tested PyTorch Versions License Downloads - Latest Version + Latest Version Open In Colab @@ -21,7 +21,7 @@ Torch-Pruning (TP) is a library for structural pruning with the following features: -* **General-purpose Pruning Toolkit:** TP enables structural pruning for a wide range of deep neural networks, including *[Large Language Models (LLMs)](https://github.com/horseee/LLM-Pruner), [Diffusion Models](https://github.com/VainF/Diff-Pruning), [Yolov7](examples/yolov7/), [yolov8](examples/yolov8/), [ViT](examples/torchvision_models/), FasterRCNN, SSD, ResNe(X)t, ConvNext, DenseNet, ConvNext, RegNet, DeepLab, etc*. Different from [torch.nn.utils.prune](https://pytorch.org/tutorials/intermediate/pruning_tutorial.html) that zeroizes parameters through masking, Torch-Pruning deploys a (non-deep) graph algorithm called **DepGraph** to remove parameters physically. Currently, TP is able to prune approximately **81/85=95.3%** of the models from Torchvision 0.13.1. Try this [Colab Demo](https://colab.research.google.com/drive/1TRvELQDNj9PwM-EERWbF3IQOyxZeDepp?usp=sharing) for a quick start. +* **General-purpose Pruning Toolkit:** TP enables structural pruning for a wide range of deep neural networks, including *[Large Language Models (LLMs)](https://github.com/horseee/LLM-Pruner), [Diffusion Models](https://github.com/VainF/Diff-Pruning), [Yolov7](examples/yolov7/), [yolov8](examples/yolov8/), [ViT](examples/hf_transformers/), FasterRCNN, SSD, ResNe(X)t, ConvNext, DenseNet, ConvNext, RegNet, DeepLab, etc*. Different from [torch.nn.utils.prune](https://pytorch.org/tutorials/intermediate/pruning_tutorial.html) that zeroizes parameters through masking, Torch-Pruning deploys a (non-deep) graph algorithm called **DepGraph** to remove parameters physically. Currently, TP is able to prune approximately **81/85=95.3%** of the models from Torchvision 0.13.1. Try this [Colab Demo](https://colab.research.google.com/drive/1TRvELQDNj9PwM-EERWbF3IQOyxZeDepp?usp=sharing) for a quick start. * **[Performance Benchmark](benchmarks)**: Reproduce the our results in the DepGraph paper. * **[Tutorials and Documents](https://github.com/VainF/Torch-Pruning/wiki)** are available at the GitHub Wiki. diff --git a/setup.py b/setup.py index 35430d5..d8224e3 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setuptools.setup( name="torch-pruning", - version="v1.2.0", + version="v1.2.1", author="Gongfan Fang", author_email="gongfan@u.nus.edu", description="Towards Any Structural Pruning", From 407c77e69493f77ff30a78bc033e050f4f51be24 Mon Sep 17 00:00:00 2001 From: VainF <2218880241@qq.com> Date: Wed, 26 Jul 2023 19:36:33 +0800 Subject: [PATCH 13/13] Support GrowingReg --- benchmarks/main_imagenet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/main_imagenet.py b/benchmarks/main_imagenet.py index 186d800..c88172c 100644 --- a/benchmarks/main_imagenet.py +++ b/benchmarks/main_imagenet.py @@ -134,7 +134,7 @@ def get_pruner(model, example_inputs, args): pruner_entry = partial(tp.pruner.GroupNormPruner, global_pruning=args.global_pruning) elif args.method == "group_greg": sparsity_learning = True - imp = tp.importance.MagnitudeImportance(p=2) + imp = tp.importance.GroupNormImportance(p=2) pruner_entry = partial(tp.pruner.GrowingRegPruner, reg=args.reg, delta_reg=args.delta_reg, global_pruning=args.global_pruning) elif args.method == "group_sl": sparsity_learning = True