21
21
from compressed_tensors import CompressionFormat , SparsityCompressionConfig
22
22
from compressed_tensors .quantization .utils import is_model_quantized
23
23
from sparseml .pytorch .utils import ModuleSparsificationInfo
24
+ from sparseml .transformers .compression .helpers import (
25
+ infer_sparsity_structure_from_model ,
26
+ infer_sparsity_structure_from_stage_modifiers ,
27
+ )
24
28
25
29
26
30
class SparsityConfigMetadata :
@@ -47,26 +51,34 @@ def infer_global_sparsity(
47
51
return global_sparsity
48
52
49
53
@staticmethod
50
- def infer_sparsity_structure () -> str :
54
+ def infer_sparsity_structure (model : Optional [ Module ] = None ) -> str :
51
55
"""
52
- Determines what sparsity structure, if any, was applied in the currently active
53
- sparse session
56
+ Determines what sparsity structure, if any, was applied.
57
+
58
+ First, there is an attempt to dedue the sparsity structure
59
+ from the currently active sparse session.
60
+
61
+ If that fails, the sparsity structure is inferred from the
62
+ model (if provided)
63
+
64
+ Finally, if both fail, the sparsity structure is set to
65
+ "unstructured"
54
66
55
67
:return: sparsity structure as a string
56
68
"""
69
+ sparsity_structure = None
70
+
57
71
current_session = sparseml .active_session ()
58
72
stage_modifiers = current_session .lifecycle .modifiers
59
- sparsity_structure = "unstructured"
73
+ if stage_modifiers :
74
+ sparsity_structure = infer_sparsity_structure_from_stage_modifiers (
75
+ stage_modifiers
76
+ )
60
77
61
- # check for applied pruning modifiers
62
- for stage in stage_modifiers :
63
- if stage .applied :
64
- for modifier in stage .modifiers :
65
- if hasattr (modifier , "mask_structure" ):
66
- sparsity_structure = modifier .mask_structure
67
- break
78
+ if model and sparsity_structure is None :
79
+ sparsity_structure = infer_sparsity_structure_from_model (model )
68
80
69
- return sparsity_structure
81
+ return sparsity_structure or "unstructured"
70
82
71
83
@staticmethod
72
84
def from_pretrained (
@@ -91,7 +103,9 @@ def from_pretrained(
91
103
if global_sparsity < 0.05 :
92
104
return None
93
105
94
- sparsity_structure = SparsityConfigMetadata .infer_sparsity_structure ()
106
+ sparsity_structure = SparsityConfigMetadata .infer_sparsity_structure (
107
+ model = model
108
+ )
95
109
if is_model_quantized (model ):
96
110
# compressing a sparse quantized model is not supported yet
97
111
format = CompressionFormat .dense .value
0 commit comments