Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed preventing recursive symlink creation iwhen `save_last='link'` and `save_top_k=-1` ([#21186](https://github.com/Lightning-AI/pytorch-lightning/pull/21186))


- Fixed `ModelPruning` sparsity logging bug that caused incorrect sparsity percentages ([#21223](https://github.com/Lightning-AI/pytorch-lightning/pull/21223))


---

## [2.5.5] - 2025-09-05
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/pytorch/callbacks/pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def apply_pruning(self, amount: Union[int, float]) -> None:
def _log_sparsity_stats(
self, prev: list[tuple[int, int]], curr: list[tuple[int, int]], amount: Union[int, float] = 0
) -> None:
total_params = sum(p.numel() for layer, _ in self._parameters_to_prune for p in layer.parameters())
total_params = sum(total for _, total in curr)
prev_total_zeros = sum(zeros for zeros, _ in prev)
curr_total_zeros = sum(zeros for zeros, _ in curr)
log.info(
Expand Down
100 changes: 94 additions & 6 deletions tests/tests_pytorch/callbacks/test_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,13 +262,13 @@ def test_multiple_pruning_callbacks(tmp_path, caplog, make_pruning_permanent: bo
actual = [m for m in actual if m.startswith("Applied")]
percentage = r"\(\d+(?:\.\d+)?%\)"
expected = [
rf"Applied `L1Unstructured`. Pruned: \d+\/1122 {percentage} -> \d+\/1122 {percentage}",
rf"Applied `L1Unstructured`. Pruned: \d+\/1088 {percentage} -> \d+\/1088 {percentage}",
rf"Applied `L1Unstructured` to `Linear\(in_features=32, out_features=32, bias=True\).weight` with amount=0.5. Pruned: 0 \(0.00%\) -> \d+ {percentage}", # noqa: E501
rf"Applied `L1Unstructured` to `Linear\(in_features=32, out_features=2, bias=True\).weight` with amount=0.5. Pruned: 0 \(0.00%\) -> \d+ {percentage}", # noqa: E501
rf"Applied `RandomUnstructured`. Pruned: \d+\/1122 {percentage} -> \d+\/1122 {percentage}",
rf"Applied `RandomUnstructured`. Pruned: \d+\/1088 {percentage} -> \d+\/1088 {percentage}",
rf"Applied `RandomUnstructured` to `Linear\(in_features=32, out_features=32, bias=True\).weight` with amount=0.25. Pruned: \d+ {percentage} -> \d+ {percentage}", # noqa: E501
rf"Applied `RandomUnstructured` to `Linear\(in_features=32, out_features=2, bias=True\).weight` with amount=0.25. Pruned: \d+ {percentage} -> \d+ {percentage}", # noqa: E501
rf"Applied `L1Unstructured`. Pruned: \d+\/1122 {percentage} -> \d+\/1122 {percentage}",
rf"Applied `L1Unstructured`. Pruned: \d+\/1088 {percentage} -> \d+\/1088 {percentage}",
rf"Applied `L1Unstructured` to `Linear\(in_features=32, out_features=32, bias=True\).weight` with amount=0.5. Pruned: \d+ {percentage} -> \d+ {percentage}", # noqa: E501
rf"Applied `L1Unstructured` to `Linear\(in_features=32, out_features=2, bias=True\).weight` with amount=0.5. Pruned: \d+ {percentage} -> \d+ {percentage}", # noqa: E501
]
Expand Down Expand Up @@ -329,9 +329,9 @@ def on_save_checkpoint(self, trainer, pl_module, checkpoint):
actual = [m for m in actual if m.startswith("Applied")]
percentage = r"\(\d+(?:\.\d+)?%\)"
expected = [
rf"Applied `RandomUnstructured`. Pruned: \d+\/66 {percentage} -> \d+\/66 {percentage}",
rf"Applied `RandomUnstructured`. Pruned: \d+\/66 {percentage} -> \d+\/66 {percentage}",
rf"Applied `RandomUnstructured`. Pruned: \d+\/66 {percentage} -> \d+\/66 {percentage}",
rf"Applied `RandomUnstructured`. Pruned: \d+\/64 {percentage} -> \d+\/64 {percentage}",
rf"Applied `RandomUnstructured`. Pruned: \d+\/64 {percentage} -> \d+\/64 {percentage}",
rf"Applied `RandomUnstructured`. Pruned: \d+\/64 {percentage} -> \d+\/64 {percentage}",
]
expected = [re.compile(s) for s in expected]
assert all(regex.match(s) for s, regex in zip(actual, expected))
Expand Down Expand Up @@ -463,3 +463,91 @@ def __init__(self):
f"Actual weight_orig: {weight_orig}\n"
f"Max difference: {torch.max(torch.abs(weight_orig - original_weights))}"
)


@pytest.mark.parametrize("pruning_amount", [0.1, 0.2, 0.3, 0.5])
@pytest.mark.parametrize("model_type", ["simple", "complex"])
def test_sparsity_calculation(tmp_path, caplog, pruning_amount: float, model_type: str):
"""Test that the sparsity calculation fix correctly reports percentages."""

class SimpleModel(BoringModel):
"""Simple model with 66 parameters (64 weight + 2 bias)."""

def __init__(self):
super().__init__()
self.layer = nn.Linear(32, 2) # 32*2 + 2 = 66 params

class ComplexModel(BoringModel):
"""Complex model with multiple layers."""

def __init__(self):
super().__init__()
self.layer1 = nn.Linear(32, 64) # 32*64 + 64 = 2112 params
self.layer2 = nn.Linear(64, 2) # 64*2 + 2 = 130 params
# Total: 2112 + 130 = 2242 params (but only layer1 will be pruned)
# layer1 params: 2112

def forward(self, x):
x = torch.relu(self.layer1(x))
return self.layer2(x)

if model_type == "simple":
model = SimpleModel()
expected_total_params = 66
parameters_to_prune = None
else:
model = ComplexModel()
expected_total_params = 2112
parameters_to_prune = [(model.layer1, "weight"), (model.layer1, "bias")]

pruning = ModelPruning(
pruning_fn="l1_unstructured",
parameters_to_prune=parameters_to_prune,
amount=pruning_amount,
verbose=1,
use_global_unstructured=True,
)

trainer = Trainer(
default_root_dir=tmp_path,
enable_progress_bar=False,
enable_model_summary=False,
enable_checkpointing=False,
logger=False,
limit_train_batches=1,
max_epochs=1,
accelerator="cpu",
callbacks=[pruning],
)

with caplog.at_level(INFO):
trainer.fit(model)

sparsity_logs = [msg for msg in caplog.messages if "Applied `L1Unstructured`. Pruned:" in msg]
assert len(sparsity_logs) == 1, f"Expected 1 sparsity log, got {len(sparsity_logs)}"
sparsity_log = sparsity_logs[0]
pattern = r"Applied `L1Unstructured`\. Pruned: \d+/(\d+) \(\d+\.\d+%\) -> (\d+)/(\d+) \((\d+\.\d+)%\)"
match = re.search(pattern, sparsity_log)
assert match, f"Could not parse sparsity log: {sparsity_log}"

total_params_before = int(match.group(1))
pruned_count = int(match.group(2))
total_params_after = int(match.group(3))
sparsity_percentage = float(match.group(4))
assert total_params_before == expected_total_params, (
f"Total parameter count mismatch for {model_type} model. "
f"Expected {expected_total_params}, got {total_params_before}"
)
assert total_params_after == expected_total_params, (
f"Total parameter count should be consistent. Before: {total_params_before}, After: {total_params_after}"
)

# Verify sparsity percentage is approximately correct
expected_sparsity = pruning_amount * 100
tolerance = 5.0
assert abs(sparsity_percentage - expected_sparsity) <= tolerance

# Verify the number of pruned parameters is reasonable
expected_pruned_count = int(expected_total_params * pruning_amount)
pruned_tolerance = max(1, int(expected_total_params * 0.05))
assert abs(pruned_count - expected_pruned_count) <= pruned_tolerance
Loading