Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 535b30d

Browse files
authored
Update sparsity_config.py
1 parent ef0232e commit 535b30d

File tree

1 file changed

+27
-13
lines changed

1 file changed

+27
-13
lines changed

src/sparseml/transformers/compression/sparsity_config.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@
2121
from compressed_tensors import CompressionFormat, SparsityCompressionConfig
2222
from compressed_tensors.quantization.utils import is_model_quantized
2323
from sparseml.pytorch.utils import ModuleSparsificationInfo
24+
from sparseml.transformers.compression.helpers import (
25+
infer_sparsity_structure_from_model,
26+
infer_sparsity_structure_from_stage_modifiers,
27+
)
2428

2529

2630
class SparsityConfigMetadata:
@@ -47,26 +51,34 @@ def infer_global_sparsity(
4751
return global_sparsity
4852

4953
@staticmethod
50-
def infer_sparsity_structure() -> str:
54+
def infer_sparsity_structure(model: Optional[Module] = None) -> str:
5155
"""
52-
Determines what sparsity structure, if any, was applied in the currently active
53-
sparse session
56+
Determines what sparsity structure, if any, was applied.
57+
58+
First, there is an attempt to dedue the sparsity structure
59+
from the currently active sparse session.
60+
61+
If that fails, the sparsity structure is inferred from the
62+
model (if provided)
63+
64+
Finally, if both fail, the sparsity structure is set to
65+
"unstructured"
5466
5567
:return: sparsity structure as a string
5668
"""
69+
sparsity_structure = None
70+
5771
current_session = sparseml.active_session()
5872
stage_modifiers = current_session.lifecycle.modifiers
59-
sparsity_structure = "unstructured"
73+
if stage_modifiers:
74+
sparsity_structure = infer_sparsity_structure_from_stage_modifiers(
75+
stage_modifiers
76+
)
6077

61-
# check for applied pruning modifiers
62-
for stage in stage_modifiers:
63-
if stage.applied:
64-
for modifier in stage.modifiers:
65-
if hasattr(modifier, "mask_structure"):
66-
sparsity_structure = modifier.mask_structure
67-
break
78+
if model and sparsity_structure is None:
79+
sparsity_structure = infer_sparsity_structure_from_model(model)
6880

69-
return sparsity_structure
81+
return sparsity_structure or "unstructured"
7082

7183
@staticmethod
7284
def from_pretrained(
@@ -91,7 +103,9 @@ def from_pretrained(
91103
if global_sparsity < 0.05:
92104
return None
93105

94-
sparsity_structure = SparsityConfigMetadata.infer_sparsity_structure()
106+
sparsity_structure = SparsityConfigMetadata.infer_sparsity_structure(
107+
model=model
108+
)
95109
if is_model_quantized(model):
96110
# compressing a sparse quantized model is not supported yet
97111
format = CompressionFormat.dense.value

0 commit comments

Comments
 (0)