From c251681452e82d0bda08b2d9d32f2a762572ae1d Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Sun, 2 Apr 2023 20:52:42 +0800 Subject: [PATCH] update support for arbitrary sparse --- apex/contrib/sparsity/asp.py | 24 +- apex/contrib/sparsity/sparse_masklib.py | 300 +++++++++++++++++++++++- 2 files changed, 307 insertions(+), 17 deletions(-) diff --git a/apex/contrib/sparsity/asp.py b/apex/contrib/sparsity/asp.py index 42de945c5..1ffae8693 100644 --- a/apex/contrib/sparsity/asp.py +++ b/apex/contrib/sparsity/asp.py @@ -31,13 +31,14 @@ class ASP: __optimizer = None __sparse_parameters = [] __calculate_mask = None + __layer_mute_count = 0 __allow_permutation = True __all_parameters = [] __save_permutation_graph = False __permutation_output_dir = '' @classmethod - def init_model_for_pruning(cls, model, mask_calculator="m4n2_1d", + def init_model_for_pruning(cls, model, mask_calculator="m4n2_1d", density=0.5, verbosity=3, whitelist=[torch.nn.Linear, torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.MultiheadAttention], allowed_layer_names=None, disallowed_layer_names=[], @@ -75,6 +76,7 @@ def init_model_for_pruning(cls, model, mask_calculator="m4n2_1d", allow_recompute_mask If True, stores pruned values so that dense weights can be restored. Pruned weights are stored in CPU memory, hence this option does not increase GPU memory usage. custom_layer_dict Dictionary of additional layer paremeters to sparsify. e.g. {CustomLinear: ['weight']} + density set weight's density, if mask_calculator is unstructured pattern allow_permutation If True, allow the input channel permutation to ease the influence of weight pruning. [Future] Support for allow_recompute_mask can be removed, it is not part of sparse inference recipe. @@ -86,7 +88,7 @@ def init_model_for_pruning(cls, model, mask_calculator="m4n2_1d", if isinstance(mask_calculator, str): def create_mask_from_pattern(param): - return create_mask(param, mask_calculator).bool() + return create_mask(param, mask_calculator, density).bool() cls.__calculate_mask = create_mask_from_pattern else: cls.__calculate_mask = mask_calculator #user defined function @@ -210,7 +212,7 @@ def __step(opt_self, *args, **kwargs): cls.__optimizer.step = types.MethodType(__step, cls.__optimizer) @classmethod - def compute_sparse_masks(cls): + def compute_sparse_masks(cls, layer_wise_ratio): """Call this method to enable sparsity. If init(...) was called with allow_recompute_mask=False AND sparsity is disabled, pruned field can be None. """ @@ -237,12 +239,22 @@ def compute_sparse_masks(cls): print("[compute_sparse_masks] Take {:.4f} seconds to find and apply permutations.".format(duration_build_offline_permutation_graph)) + layer_wise_mute_count= int(layer_wise_ratio * len(cls.__sparse_parameters)) + _count = 0 for module_name, module, p_name, p, mask, pruned in cls.__sparse_parameters: if mask.sum() < mask.numel(): # when recalculating masks # restore dense parameter if allow_recompute_mask is enabled assert (pruned is not None), "Unable to restore dense parameter because allow_recompute_mask == False" p.add_(pruned.cuda()) + if _count < layer_wise_mute_count: + def create_mask_from_pattern(param): + return create_mask(param, "unstructured", 0.5).bool() + calculate_mask = create_mask_from_pattern + mask.set_(calculate_mask(p)) + else: + mask.set_(cls.__calculate_mask(p)) + _count += 1 mask.set_(cls.__calculate_mask(p)) if pruned is not None: # stow away pruned weights to cpu @@ -289,11 +301,11 @@ def is_sparsity_enabled(cls): return True @classmethod - def prune_trained_model(cls, model, optimizer): + def prune_trained_model(cls, model, optimizer, mask_calculator="m4n2_1d", sparsity=0.5, layer_wise_ratio=0.25): # add mask buffers to model (init_model_for_pruning), augment optimizer (init_optimizer_for_pruning) and compute masks (compute_sparse_masks) - cls.init_model_for_pruning(model, mask_calculator="m4n2_1d", verbosity=2, whitelist=[torch.nn.Linear, torch.nn.Conv2d, torch.nn.MultiheadAttention], allow_recompute_mask=False) + cls.init_model_for_pruning(model, mask_calculator, density=(1 - sparsity), verbosity=2, whitelist=[torch.nn.Linear, torch.nn.Conv2d, torch.nn.MultiheadAttention], allow_recompute_mask=False) cls.init_optimizer_for_pruning(optimizer) - cls.compute_sparse_masks() + cls.compute_sparse_masks(layer_wise_ratio) @classmethod def set_permutation_saving_params(cls, allow_permutation=True, save_permutation_graph=False, permutation_output_dir='.'): diff --git a/apex/contrib/sparsity/sparse_masklib.py b/apex/contrib/sparsity/sparse_masklib.py index fce77cce5..5601c5a51 100644 --- a/apex/contrib/sparsity/sparse_masklib.py +++ b/apex/contrib/sparsity/sparse_masklib.py @@ -1,8 +1,9 @@ +from math import floor import sys import torch import numpy as np import collections -from itertools import permutations +from itertools import permutations, combinations """ compute density (helper fn to compute % NNZs in a tensor) """ @@ -22,33 +23,127 @@ def reshape_1d(matrix, m): """ return all possible m:n patterns in a 1d vector """ valid_m4n2_1d_patterns = None +valid_m16n8_1d_patterns = None +valid_m32n8_1d_patterns = None +valid_m32n16_1d_patterns = None def compute_valid_1d_patterns(m,n): # Early exit if patterns was already created. global valid_m4n2_1d_patterns - + global valid_m16n8_1d_patterns + global valid_m32n8_1d_patterns + global valid_m32n16_1d_patterns if m==4 and n==2 and valid_m4n2_1d_patterns is not None: return valid_m4n2_1d_patterns - patterns = torch.zeros(m) - patterns[:n] = 1 - valid_patterns = torch.tensor(list(set(permutations(patterns.tolist())))) - if m == 4 and n == 2: valid_m4n2_1d_patterns = valid_patterns + if m==16 and n==8 and valid_m16n8_1d_patterns is not None: return valid_m16n8_1d_patterns + if m==32 and n==8 and valid_m32n8_1d_patterns is not None: return valid_m32n8_1d_patterns + if m==32 and n==16 and valid_m32n16_1d_patterns is not None: return valid_m32n16_1d_patterns + valid_patterns = [] + for i in list(combinations(range(0, m), n)): + cur_pattern = np.zeros(m, dtype=np.int32) + cur_pattern[list(i)] = 1 + valid_patterns.append(cur_pattern) + valid_patterns = torch.Tensor(np.array(valid_patterns)) + # patterns = torch.zeros(m) + # patterns[:n] = 1 + # valid_patterns = torch.Tensor(list(set(permutations(patterns.tolist())))) + if m == 4 and n == 2: valid_m4n2_1d_patterns = valid_patterns + if m == 16 and n == 8: valid_m16n8_1d_patterns = valid_patterns + if m == 32 and n == 8: valid_m32n8_1d_patterns = valid_patterns + if m == 32 and n == 16: valid_m32n16_1d_patterns = valid_patterns return valid_patterns """ m:n 1d structured best """ def mn_1d_best(matrix, m, n): # Find all possible patterns. patterns = compute_valid_1d_patterns(m,n).cuda() - # Find the best m:n pattern (sum of non-masked weights). mask = torch.cuda.IntTensor(matrix.shape).fill_(1).view(-1,m) mat,shape = reshape_1d(matrix,m) - pmax = torch.argmax(torch.matmul(mat.abs(),patterns.t()), dim=1) - mask[:] = patterns[pmax[:]] + + _dynamic, m = mat.shape + factor = 256 + for start in range(0, _dynamic, factor): + pmax = torch.argmax(torch.matmul(mat[start : start + factor].abs(), patterns.t()), dim=1) + mask[start: start + factor] = patterns[pmax[:]] mask = mask.view(matrix.shape) + return mask +""" m:n 1d structured pruning: greedy method to select mask """ +def mn_1d_greedy(matrix, m, n): + mat, shape = reshape_1d(matrix,m) + mask = torch.cuda.IntTensor(matrix.shape).fill_(0).view(-1,m) + + values, indices = torch.abs(mat).topk(n, dim=1, largest=True) + indexes = torch.arange(0, indices.shape[0], step=1, dtype=torch.long).view(-1, 1) + + mask[indexes, indices] = 1 + mask = mask.view(matrix.shape) + + return mask.cuda() + + +def m32n3_1d_best(mat, density): + return mn_1d_best(mat, 32, 3) + +def m32n4_1d_best(mat, density): + return mn_1d_best(mat, 32, 4) + +def m32n8_1d_best(mat, density): + return mn_1d_best(mat, 32, 8) + +def m32n16_1d_best(mat, density): + return mn_1d_best(mat, 32, 16) + +def m32n4_1d_greedy(mat, density): + return mn_1d_greedy(mat, 32, 4) + +def m32n16_1d_greedy(mat, density): + return mn_1d_greedy(mat, 32, 16) + +def m32n24_1d_best(mat, density): + return mn_1d_best(mat, 32, 24) + +def m16n8_1d_best(mat, density): + return mn_1d_best(mat, 16, 8) + +def m16n4_1d_best(mat, density): + return mn_1d_best(mat, 16, 4) + +def m8n4_1d_best(mat, density): + return mn_1d_best(mat, 8, 4) + def m4n2_1d(mat, density): return mn_1d_best(mat, 4, 2) +def m4n2_1d_greedy(mat, density): + return mn_1d_greedy(mat, 4, 2) + +def unstructured(mat, density): + mat_1d = mat.flatten() + (m,) = mat_1d.size() + n = int(m * density) + + mask = torch.cuda.IntTensor(mat_1d.shape).fill_(0) + + values, indices = torch.abs(mat_1d).topk(n, dim=0, largest=True) + + mask[indices] = 1; + mask = mask.view(mat.shape) + return mask + +def unstructured_element_wise(mat, density): + mat_1d = mat.flatten() + (m,) = mat_1d.size() + n = int(m * density) + + mask = torch.cuda.IntTensor(mat_1d.shape).fill_(0) + + values, indices = torch.abs(mat_1d).topk(n, dim=0, largest=True) + + mask[indices] = 1 + mask = mask.view(mat.shape) + return mask + """ Below 2d-masking related code is targeted more for training (from scratch). 2d-pruning of a weight tensor is done to accelerate DGRAD step during backprop @@ -109,10 +204,10 @@ def compute_valid_2d_patterns(m,n): patterns[:n] = 1 patterns = list(set(permutations(patterns.tolist()))) patterns = patterns + patterns - patterns = torch.empty(list(set(permutations(patterns,m)))) + patterns = torch.Tensor(list(set(permutations(patterns,m)))) valid = ((patterns.sum(dim=1) <= n).sum(dim=1) == m).nonzero().view(-1) - valid_patterns = torch.empty(valid.shape[0],m,m) + valid_patterns = torch.Tensor(valid.shape[0],m,m) valid_patterns[:] = patterns[valid[:]] if m == 4 and n == 2: valid_m4n2_2d_patterns = valid_patterns @@ -141,6 +236,189 @@ def m4n2_2d_best(mat, density): return mn_2d_best(mat, 4, 2) +def tuple_of_tensors_to_tensor(tuple_of_tensors): + return torch.stack(list(tuple_of_tensors), dim=0) + +def tensor_block_partition(matrix, m, n): + if matrix.shape[0] % m > 0 or matrix.shape[1] % n > 0: + print("matrix shape must be divisible by m and n, try to extend") + m_pad = 0 if matrix.shape[0] % m == 0 else m - matrix.shape[0] % m + n_pad = 0 if matrix.shape[1] % n == 0 else n - matrix.shape[1] % n + mat = torch.nn.functional.pad(matrix, (0, n_pad, 0, m_pad)) + shape = mat.shape + first_tile = tuple_of_tensors_to_tensor(torch.split(mat, m, 0)) + second_tile = tuple_of_tensors_to_tensor(torch.split(first_tile, n, 2)) + mat = second_tile + return mat, shape + else: + first_tile = tuple_of_tensors_to_tensor(torch.split(matrix, m, 0)) + second_tile = tuple_of_tensors_to_tensor(torch.split(first_tile, n, 2)) + mat = second_tile + return mat, matrix.shape + +def unstructured_vector_wise(matrix, density, v): + mat = matrix.view(-1, v) + (m, v) = mat.shape; + n = int(m * density) + + mask = torch.cuda.IntTensor(mat.shape).fill_(0) + mat_reduce = torch.sum(mat, dim=-1) + values, indices = torch.abs(mat_reduce).topk(n, dim=0, largest=True) + + mask[indices, :] = 1; + mask = mask.view(matrix.shape) + return mask + +def unstructured_v4(matrix, density): + return unstructured_vector_wise(matrix, density, 4) + +def unstructured_v32(matrix, density): + return unstructured_vector_wise(matrix, density, 32) + +def unstructured_v64(matrix, density): + return unstructured_vector_wise(matrix, density, 64) + +def mnv_vector_wise_greedy(matrix, m, n, v): + ''' + m -> length + n -> width + v -> valid vector + ''' + # split into tensor blocks + raw_shape = matrix.shape + # print("raw shape ", raw_shape) + mat, pad_shape = tensor_block_partition(matrix, v, m) + # print("extend shape ", pad_shape) + mask = torch.cuda.IntTensor(mat.shape).fill_(0) + mat_abs = torch.abs(mat) + mat_reduce = torch.sum(mat_abs, dim=2) + + values, indices = torch.topk(mat_reduce, n, dim=2, largest=True) + + # todo: this can be optimize, currently is slow. + for d0 in range(0, indices.shape[0]): + for d1 in range(0, indices.shape[1]): + mask[d0][d1][:, indices[d0][d1]] = 1 + # mask[0, 0, 0, indices] = 1; + mask = torch.cat(tuple(mask), 2).view(pad_shape) + mask = mask[0:raw_shape[0], 0:raw_shape[1]] + return mask.cuda() + +def m4n2v4_2d_greedy(mat, density): + return mnv_vector_wise_greedy(mat, 4, 2, 4) + +def m32n16v4_2d_greedy(mat, density): + return mnv_vector_wise_greedy(mat, 32, 16, 4) + +def m32n8v4_2d_greedy(mat, density): + return mnv_vector_wise_greedy(mat, 32, 8, 4) + +def m32n4v4_2d_greedy(mat, density): + return mnv_vector_wise_greedy(mat, 32, 4, 4) + +def m32n3v4_2d_greedy(mat, density): + return mnv_vector_wise_greedy(mat, 32, 3, 4) + +def m4n2v32_2d_greedy(mat, density): + return mnv_vector_wise_greedy(mat, 4, 2, 32) + +def m32n16v32_2d_greedy(mat, density): + return mnv_vector_wise_greedy(mat, 32, 16, 32) + +def m32n8v32_2d_greedy(mat, density): + return mnv_vector_wise_greedy(mat, 32, 8, 32) + +def m32n4v32_2d_greedy(mat, density): + return mnv_vector_wise_greedy(mat, 32, 4, 32) + +def m32n3v32_2d_greedy(mat, density): + return mnv_vector_wise_greedy(mat, 32, 3, 32) + +def m4n2v64_2d_greedy(mat, density): + return mnv_vector_wise_greedy(mat, 4, 2, 64) + +def m32n16v64_2d_greedy(mat, density): + return mnv_vector_wise_greedy(mat, 32, 16, 64) + +def m32n8v64_2d_greedy(mat, density): + return mnv_vector_wise_greedy(mat, 32, 8, 64) + +def m32n4v64_2d_greedy(mat, density): + return mnv_vector_wise_greedy(mat, 32, 4, 64) + +def m32n3v64_2d_greedy(mat, density): + return mnv_vector_wise_greedy(mat, 32, 3, 64) + +def unstructured_block_wise(matrix, density, bh, bw): + # split into tensor blocks + mat, shape = tensor_block_partition(matrix, bh, bw) + (bm, bn, bh, bw) = mat.shape + mask = torch.cuda.IntTensor(mat.shape).fill_(0) + mat_abs = torch.abs(mat) + mat_reduce = torch.sum(torch.sum(mat_abs, dim=-1), dim=-1) + mat_reduce_recover = torch.stack(tuple(mat_reduce), dim=-1).view(-1) + # n = int(bm * bn * density) + n = int(bm * bn * density) + values, indices = torch.topk(mat_reduce_recover, n, dim=-1, largest=True) + # todo: this can be optimize, currently is slow. + for d0 in indices: + mask[d0 // bn][d0 % bn][:][:] = 1 + # mask[0, 0, 0, indices] = 1; + mask = torch.cat(tuple(mask), 2).view(matrix.shape) + + return mask.cuda() + + +def unstructured_b4(matrix, density): + return unstructured_block_wise(matrix, density, 4, 4) + + +def mnb_block_wise_greedy(matrix, m, n, bh, bw): + ''' + m -> length + n -> width + v -> valid vector + ''' + # split into tensor blocks + raw_shape = matrix.shape + print("raw shape ", raw_shape) + mat, pad_shape = tensor_block_partition(matrix, bh, bw * m) + print("extend shape ", pad_shape) + mask = torch.cuda.IntTensor(mat.shape).fill_(0) + mat_abs = torch.abs(mat) + mat_reduce = torch.sum(mat_abs, dim=2) + mat_reduce_bw = tuple_of_tensors_to_tensor(torch.split(mat_reduce, bw, dim=-1)) + mat_reduce_bw_reduce = torch.sum(mat_reduce_bw, dim=-1) + mat_reduce_bw_reduce_recover = torch.stack(tuple(mat_reduce_bw_reduce), dim=-1).view(mask.shape[0], mask.shape[1], m) + # print(mat_reduce_bw_reduce) + # print(third_tile) + # print(mask.shape) + values, indices = torch.topk(mat_reduce_bw_reduce_recover, n, dim=2, largest=True) + # todo: this can be optimize, currently is slow. + for d0 in range(0, indices.shape[0]): + for d1 in range(0, indices.shape[1]): + for _bw in range(0, bw): + mask[d0][d1][:, indices[d0][d1]*bw+_bw] = 1 + # mask[0, 0, 0, indices] = 1; + mask = torch.cat(tuple(mask), 2).view(pad_shape) + mask = mask[0:raw_shape[0], 0:raw_shape[1]] + return mask.cuda() + +def m4n2b4_2d_greedy(mat, density): + return mnb_block_wise_greedy(mat, 4, 2, 4, 4) + +def m32n3b4_2d_greedy(mat, density): + return mnb_block_wise_greedy(mat, 32, 3, 4, 4) + +def m32n4b4_2d_greedy(mat, density): + return mnb_block_wise_greedy(mat, 32, 4, 4, 4) + +def m32n8b4_2d_greedy(mat, density): + return mnb_block_wise_greedy(mat, 32, 8, 4, 4) + +def m32n16b4_2d_greedy(mat, density): + return mnb_block_wise_greedy(mat, 32, 16, 4, 4) + """ returns a sparse mask """ def create_mask(tensor, pattern="m4n2_1d", density=0.5): # Reshape tensor and mask.