Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[sparse]update support for arbitrary N:M settings sparse #1631

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions apex/contrib/sparsity/asp.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,14 @@ class ASP:
__optimizer = None
__sparse_parameters = []
__calculate_mask = None
__layer_mute_count = 0
__allow_permutation = True
__all_parameters = []
__save_permutation_graph = False
__permutation_output_dir = ''

@classmethod
def init_model_for_pruning(cls, model, mask_calculator="m4n2_1d",
def init_model_for_pruning(cls, model, mask_calculator="m4n2_1d", density=0.5,
verbosity=3,
whitelist=[torch.nn.Linear, torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.MultiheadAttention],
allowed_layer_names=None, disallowed_layer_names=[],
Expand Down Expand Up @@ -75,6 +76,7 @@ def init_model_for_pruning(cls, model, mask_calculator="m4n2_1d",
allow_recompute_mask If True, stores pruned values so that dense weights can be restored.
Pruned weights are stored in CPU memory, hence this option does not increase GPU memory usage.
custom_layer_dict Dictionary of additional layer paremeters to sparsify. e.g. {CustomLinear: ['weight']}
density set weight's density, if mask_calculator is unstructured pattern
allow_permutation If True, allow the input channel permutation to ease the influence of weight pruning.

[Future] Support for allow_recompute_mask can be removed, it is not part of sparse inference recipe.
Expand All @@ -86,7 +88,7 @@ def init_model_for_pruning(cls, model, mask_calculator="m4n2_1d",

if isinstance(mask_calculator, str):
def create_mask_from_pattern(param):
return create_mask(param, mask_calculator).bool()
return create_mask(param, mask_calculator, density).bool()
cls.__calculate_mask = create_mask_from_pattern
else:
cls.__calculate_mask = mask_calculator #user defined function
Expand Down Expand Up @@ -210,7 +212,7 @@ def __step(opt_self, *args, **kwargs):
cls.__optimizer.step = types.MethodType(__step, cls.__optimizer)

@classmethod
def compute_sparse_masks(cls):
def compute_sparse_masks(cls, layer_wise_ratio):
"""Call this method to enable sparsity.
If init(...) was called with allow_recompute_mask=False AND sparsity is disabled, pruned field can be None.
"""
Expand All @@ -237,12 +239,22 @@ def compute_sparse_masks(cls):
print("[compute_sparse_masks] Take {:.4f} seconds to find and apply permutations.".format(duration_build_offline_permutation_graph))


layer_wise_mute_count= int(layer_wise_ratio * len(cls.__sparse_parameters))
_count = 0
for module_name, module, p_name, p, mask, pruned in cls.__sparse_parameters:
if mask.sum() < mask.numel(): # when recalculating masks
# restore dense parameter if allow_recompute_mask is enabled
assert (pruned is not None), "Unable to restore dense parameter because allow_recompute_mask == False"
p.add_(pruned.cuda())

if _count < layer_wise_mute_count:
def create_mask_from_pattern(param):
return create_mask(param, "unstructured", 0.5).bool()
calculate_mask = create_mask_from_pattern
mask.set_(calculate_mask(p))
else:
mask.set_(cls.__calculate_mask(p))
_count += 1
mask.set_(cls.__calculate_mask(p))

if pruned is not None: # stow away pruned weights to cpu
Expand Down Expand Up @@ -289,11 +301,11 @@ def is_sparsity_enabled(cls):
return True

@classmethod
def prune_trained_model(cls, model, optimizer):
def prune_trained_model(cls, model, optimizer, mask_calculator="m4n2_1d", sparsity=0.5, layer_wise_ratio=0.25):
# add mask buffers to model (init_model_for_pruning), augment optimizer (init_optimizer_for_pruning) and compute masks (compute_sparse_masks)
cls.init_model_for_pruning(model, mask_calculator="m4n2_1d", verbosity=2, whitelist=[torch.nn.Linear, torch.nn.Conv2d, torch.nn.MultiheadAttention], allow_recompute_mask=False)
cls.init_model_for_pruning(model, mask_calculator, density=(1 - sparsity), verbosity=2, whitelist=[torch.nn.Linear, torch.nn.Conv2d, torch.nn.MultiheadAttention], allow_recompute_mask=False)
cls.init_optimizer_for_pruning(optimizer)
cls.compute_sparse_masks()
cls.compute_sparse_masks(layer_wise_ratio)

@classmethod
def set_permutation_saving_params(cls, allow_permutation=True, save_permutation_graph=False, permutation_output_dir='.'):
Expand Down
Loading