diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 863a3a4a7e939..fa6e92ebc7124 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -52,6 +52,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed preventing recursive symlink creation iwhen `save_last='link'` and `save_top_k=-1` ([#21186](https://github.com/Lightning-AI/pytorch-lightning/pull/21186)) +- Fixed `ModelPruning` sparsity logging bug that caused incorrect sparsity percentages ([#21223](https://github.com/Lightning-AI/pytorch-lightning/pull/21223)) + + --- ## [2.5.5] - 2025-09-05 diff --git a/src/lightning/pytorch/callbacks/pruning.py b/src/lightning/pytorch/callbacks/pruning.py index 1de693978acfa..f0e1bcbe49f99 100644 --- a/src/lightning/pytorch/callbacks/pruning.py +++ b/src/lightning/pytorch/callbacks/pruning.py @@ -349,7 +349,7 @@ def apply_pruning(self, amount: Union[int, float]) -> None: def _log_sparsity_stats( self, prev: list[tuple[int, int]], curr: list[tuple[int, int]], amount: Union[int, float] = 0 ) -> None: - total_params = sum(p.numel() for layer, _ in self._parameters_to_prune for p in layer.parameters()) + total_params = sum(total for _, total in curr) prev_total_zeros = sum(zeros for zeros, _ in prev) curr_total_zeros = sum(zeros for zeros, _ in curr) log.info( diff --git a/tests/tests_pytorch/callbacks/test_pruning.py b/tests/tests_pytorch/callbacks/test_pruning.py index 20bb03bfdd941..1a23efd919171 100644 --- a/tests/tests_pytorch/callbacks/test_pruning.py +++ b/tests/tests_pytorch/callbacks/test_pruning.py @@ -262,13 +262,13 @@ def test_multiple_pruning_callbacks(tmp_path, caplog, make_pruning_permanent: bo actual = [m for m in actual if m.startswith("Applied")] percentage = r"\(\d+(?:\.\d+)?%\)" expected = [ - rf"Applied `L1Unstructured`. Pruned: \d+\/1122 {percentage} -> \d+\/1122 {percentage}", + rf"Applied `L1Unstructured`. Pruned: \d+\/1088 {percentage} -> \d+\/1088 {percentage}", rf"Applied `L1Unstructured` to `Linear\(in_features=32, out_features=32, bias=True\).weight` with amount=0.5. Pruned: 0 \(0.00%\) -> \d+ {percentage}", # noqa: E501 rf"Applied `L1Unstructured` to `Linear\(in_features=32, out_features=2, bias=True\).weight` with amount=0.5. Pruned: 0 \(0.00%\) -> \d+ {percentage}", # noqa: E501 - rf"Applied `RandomUnstructured`. Pruned: \d+\/1122 {percentage} -> \d+\/1122 {percentage}", + rf"Applied `RandomUnstructured`. Pruned: \d+\/1088 {percentage} -> \d+\/1088 {percentage}", rf"Applied `RandomUnstructured` to `Linear\(in_features=32, out_features=32, bias=True\).weight` with amount=0.25. Pruned: \d+ {percentage} -> \d+ {percentage}", # noqa: E501 rf"Applied `RandomUnstructured` to `Linear\(in_features=32, out_features=2, bias=True\).weight` with amount=0.25. Pruned: \d+ {percentage} -> \d+ {percentage}", # noqa: E501 - rf"Applied `L1Unstructured`. Pruned: \d+\/1122 {percentage} -> \d+\/1122 {percentage}", + rf"Applied `L1Unstructured`. Pruned: \d+\/1088 {percentage} -> \d+\/1088 {percentage}", rf"Applied `L1Unstructured` to `Linear\(in_features=32, out_features=32, bias=True\).weight` with amount=0.5. Pruned: \d+ {percentage} -> \d+ {percentage}", # noqa: E501 rf"Applied `L1Unstructured` to `Linear\(in_features=32, out_features=2, bias=True\).weight` with amount=0.5. Pruned: \d+ {percentage} -> \d+ {percentage}", # noqa: E501 ] @@ -329,9 +329,9 @@ def on_save_checkpoint(self, trainer, pl_module, checkpoint): actual = [m for m in actual if m.startswith("Applied")] percentage = r"\(\d+(?:\.\d+)?%\)" expected = [ - rf"Applied `RandomUnstructured`. Pruned: \d+\/66 {percentage} -> \d+\/66 {percentage}", - rf"Applied `RandomUnstructured`. Pruned: \d+\/66 {percentage} -> \d+\/66 {percentage}", - rf"Applied `RandomUnstructured`. Pruned: \d+\/66 {percentage} -> \d+\/66 {percentage}", + rf"Applied `RandomUnstructured`. Pruned: \d+\/64 {percentage} -> \d+\/64 {percentage}", + rf"Applied `RandomUnstructured`. Pruned: \d+\/64 {percentage} -> \d+\/64 {percentage}", + rf"Applied `RandomUnstructured`. Pruned: \d+\/64 {percentage} -> \d+\/64 {percentage}", ] expected = [re.compile(s) for s in expected] assert all(regex.match(s) for s, regex in zip(actual, expected)) @@ -463,3 +463,91 @@ def __init__(self): f"Actual weight_orig: {weight_orig}\n" f"Max difference: {torch.max(torch.abs(weight_orig - original_weights))}" ) + + +@pytest.mark.parametrize("pruning_amount", [0.1, 0.2, 0.3, 0.5]) +@pytest.mark.parametrize("model_type", ["simple", "complex"]) +def test_sparsity_calculation(tmp_path, caplog, pruning_amount: float, model_type: str): + """Test that the sparsity calculation fix correctly reports percentages.""" + + class SimpleModel(BoringModel): + """Simple model with 66 parameters (64 weight + 2 bias).""" + + def __init__(self): + super().__init__() + self.layer = nn.Linear(32, 2) # 32*2 + 2 = 66 params + + class ComplexModel(BoringModel): + """Complex model with multiple layers.""" + + def __init__(self): + super().__init__() + self.layer1 = nn.Linear(32, 64) # 32*64 + 64 = 2112 params + self.layer2 = nn.Linear(64, 2) # 64*2 + 2 = 130 params + # Total: 2112 + 130 = 2242 params (but only layer1 will be pruned) + # layer1 params: 2112 + + def forward(self, x): + x = torch.relu(self.layer1(x)) + return self.layer2(x) + + if model_type == "simple": + model = SimpleModel() + expected_total_params = 66 + parameters_to_prune = None + else: + model = ComplexModel() + expected_total_params = 2112 + parameters_to_prune = [(model.layer1, "weight"), (model.layer1, "bias")] + + pruning = ModelPruning( + pruning_fn="l1_unstructured", + parameters_to_prune=parameters_to_prune, + amount=pruning_amount, + verbose=1, + use_global_unstructured=True, + ) + + trainer = Trainer( + default_root_dir=tmp_path, + enable_progress_bar=False, + enable_model_summary=False, + enable_checkpointing=False, + logger=False, + limit_train_batches=1, + max_epochs=1, + accelerator="cpu", + callbacks=[pruning], + ) + + with caplog.at_level(INFO): + trainer.fit(model) + + sparsity_logs = [msg for msg in caplog.messages if "Applied `L1Unstructured`. Pruned:" in msg] + assert len(sparsity_logs) == 1, f"Expected 1 sparsity log, got {len(sparsity_logs)}" + sparsity_log = sparsity_logs[0] + pattern = r"Applied `L1Unstructured`\. Pruned: \d+/(\d+) \(\d+\.\d+%\) -> (\d+)/(\d+) \((\d+\.\d+)%\)" + match = re.search(pattern, sparsity_log) + assert match, f"Could not parse sparsity log: {sparsity_log}" + + total_params_before = int(match.group(1)) + pruned_count = int(match.group(2)) + total_params_after = int(match.group(3)) + sparsity_percentage = float(match.group(4)) + assert total_params_before == expected_total_params, ( + f"Total parameter count mismatch for {model_type} model. " + f"Expected {expected_total_params}, got {total_params_before}" + ) + assert total_params_after == expected_total_params, ( + f"Total parameter count should be consistent. Before: {total_params_before}, After: {total_params_after}" + ) + + # Verify sparsity percentage is approximately correct + expected_sparsity = pruning_amount * 100 + tolerance = 5.0 + assert abs(sparsity_percentage - expected_sparsity) <= tolerance + + # Verify the number of pruned parameters is reasonable + expected_pruned_count = int(expected_total_params * pruning_amount) + pruned_tolerance = max(1, int(expected_total_params * 0.05)) + assert abs(pruned_count - expected_pruned_count) <= pruned_tolerance