diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 863a3a4a7e939..e58f6c669ad5a 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -52,6 +52,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed preventing recursive symlink creation iwhen `save_last='link'` and `save_top_k=-1` ([#21186](https://github.com/Lightning-AI/pytorch-lightning/pull/21186)) +- Fixed `last.ckpt` being created and not linked to another checkpoint ([#21244](https://github.com/Lightning-AI/pytorch-lightning/pull/21244)) + --- ## [2.5.5] - 2025-09-05 diff --git a/src/lightning/pytorch/callbacks/model_checkpoint.py b/src/lightning/pytorch/callbacks/model_checkpoint.py index fc83c0a4513a2..8a5d9dcdf786f 100644 --- a/src/lightning/pytorch/callbacks/model_checkpoint.py +++ b/src/lightning/pytorch/callbacks/model_checkpoint.py @@ -380,7 +380,11 @@ def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModu monitor_candidates = self._monitor_candidates(trainer) if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0: self._save_topk_checkpoint(trainer, monitor_candidates) - self._save_last_checkpoint(trainer, monitor_candidates) + # Only save last checkpoint if a checkpoint was actually saved in this step or if save_last="link" + if self._last_global_step_saved == trainer.global_step or ( + self.save_last == "link" and self._last_checkpoint_saved + ): + self._save_last_checkpoint(trainer, monitor_candidates) @override def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: @@ -397,7 +401,11 @@ def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModul if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0: self._save_topk_checkpoint(trainer, monitor_candidates) - self._save_last_checkpoint(trainer, monitor_candidates) + # Only save last checkpoint if a checkpoint was actually saved in this step or if save_last="link" + if self._last_global_step_saved == trainer.global_step or ( + self.save_last == "link" and self._last_checkpoint_saved + ): + self._save_last_checkpoint(trainer, monitor_candidates) @override def on_exception(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", exception: BaseException) -> None: diff --git a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py index d2bbea7ecdafe..feeb9a55b91d9 100644 --- a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py @@ -2124,3 +2124,59 @@ def test_save_last_without_save_on_train_epoch_and_without_val(tmp_path): # save_last=True should always save last.ckpt assert (tmp_path / "last.ckpt").exists() + + +def test_save_last_only_when_checkpoint_saved(tmp_path): + """Test that save_last only creates last.ckpt when another checkpoint is actually saved.""" + + class SelectiveModel(BoringModel): + def __init__(self): + super().__init__() + self.validation_step_outputs = [] + + def validation_step(self, batch, batch_idx): + outputs = super().validation_step(batch, batch_idx) + epoch = self.trainer.current_epoch + loss = torch.tensor(1.0 - epoch * 0.1) if epoch % 2 == 0 else torch.tensor(1.0 + epoch * 0.1) + outputs["val_loss"] = loss + self.validation_step_outputs.append(outputs) + return outputs + + def on_validation_epoch_end(self): + if self.validation_step_outputs: + avg_loss = torch.stack([x["val_loss"] for x in self.validation_step_outputs]).mean() + self.log("val_loss", avg_loss) + self.validation_step_outputs.clear() + + model = SelectiveModel() + + checkpoint_callback = ModelCheckpoint( + dirpath=tmp_path, + filename="best-{epoch}-{val_loss:.2f}", + monitor="val_loss", + save_last=True, + save_top_k=1, + mode="min", + every_n_epochs=1, + save_on_train_epoch_end=False, + ) + + trainer = Trainer( + max_epochs=4, + callbacks=[checkpoint_callback], + logger=False, + enable_progress_bar=False, + limit_train_batches=2, + limit_val_batches=2, + enable_checkpointing=True, + ) + + trainer.fit(model) + + checkpoint_files = list(tmp_path.glob("*.ckpt")) + checkpoint_names = [f.name for f in checkpoint_files] + assert "last.ckpt" in checkpoint_names, "last.ckpt should exist since checkpoints were saved" + expected_files = 2 # best checkpoint + last.ckpt + assert len(checkpoint_files) == expected_files, ( + f"Expected {expected_files} files, got {len(checkpoint_files)}: {checkpoint_names}" + )