From b2b9efe906a9fef9ac984ddd1b3688009e51a033 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Mon, 22 Sep 2025 14:25:07 +0200 Subject: [PATCH 1/6] fix implementation --- src/lightning/pytorch/callbacks/model_checkpoint.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/lightning/pytorch/callbacks/model_checkpoint.py b/src/lightning/pytorch/callbacks/model_checkpoint.py index 415e1dcac309b..6eae15cc2acf9 100644 --- a/src/lightning/pytorch/callbacks/model_checkpoint.py +++ b/src/lightning/pytorch/callbacks/model_checkpoint.py @@ -380,7 +380,10 @@ def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModu monitor_candidates = self._monitor_candidates(trainer) if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0: self._save_topk_checkpoint(trainer, monitor_candidates) - self._save_last_checkpoint(trainer, monitor_candidates) + # Only save last checkpoint if a checkpoint was actually saved in this step or if save_last="link" + if (self._last_global_step_saved == trainer.global_step or + (self.save_last == "link" and self._last_checkpoint_saved)): + self._save_last_checkpoint(trainer, monitor_candidates) @override def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: @@ -397,7 +400,10 @@ def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModul if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0: self._save_topk_checkpoint(trainer, monitor_candidates) - self._save_last_checkpoint(trainer, monitor_candidates) + # Only save last checkpoint if a checkpoint was actually saved in this step or if save_last="link" + if (self._last_global_step_saved == trainer.global_step or + (self.save_last == "link" and self._last_checkpoint_saved)): + self._save_last_checkpoint(trainer, monitor_candidates) @override def on_exception(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", exception: BaseException) -> None: @@ -902,3 +908,5 @@ def _should_remove_checkpoint(self, trainer: "pl.Trainer", previous: str, curren def _remove_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None: """Calls the strategy to remove the checkpoint file.""" trainer.strategy.remove_checkpoint(filepath) + + \ No newline at end of file From d8cb135f64e49a6f1414c79b3924d0acada4e141 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Mon, 22 Sep 2025 14:25:16 +0200 Subject: [PATCH 2/6] add testing --- .../checkpointing/test_model_checkpoint.py | 54 +++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py index d2bbea7ecdafe..a3c743be20e03 100644 --- a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py @@ -2124,3 +2124,57 @@ def test_save_last_without_save_on_train_epoch_and_without_val(tmp_path): # save_last=True should always save last.ckpt assert (tmp_path / "last.ckpt").exists() + + +def test_save_last_only_when_checkpoint_saved(tmp_path): + """Test that save_last only creates last.ckpt when another checkpoint is actually saved.""" + + class SelectiveModel(BoringModel): + def __init__(self): + super().__init__() + self.validation_step_outputs = [] + + def validation_step(self, batch, batch_idx): + outputs = super().validation_step(batch, batch_idx) + epoch = self.trainer.current_epoch + loss = torch.tensor(1.0 - epoch * 0.1) if epoch % 2 == 0 else torch.tensor(1.0 + epoch * 0.1) + outputs["val_loss"] = loss + self.validation_step_outputs.append(outputs) + return outputs + + def on_validation_epoch_end(self): + if self.validation_step_outputs: + avg_loss = torch.stack([x["val_loss"] for x in self.validation_step_outputs]).mean() + self.log("val_loss", avg_loss) + self.validation_step_outputs.clear() + + model = SelectiveModel() + + checkpoint_callback = ModelCheckpoint( + dirpath=tmp_path, + filename="best-{epoch}-{val_loss:.2f}", + monitor="val_loss", + save_last=True, + save_top_k=1, + mode="min", + every_n_epochs=1, + save_on_train_epoch_end=False, + ) + + trainer = Trainer( + max_epochs=4, + callbacks=[checkpoint_callback], + logger=False, + enable_progress_bar=False, + limit_train_batches=2, + limit_val_batches=2, + enable_checkpointing=True, + ) + + trainer.fit(model) + + checkpoint_files = list(tmp_path.glob("*.ckpt")) + checkpoint_names = [f.name for f in checkpoint_files] + assert "last.ckpt" in checkpoint_names, "last.ckpt should exist since checkpoints were saved" + expected_files = 2 # best checkpoint + last.ckpt + assert len(checkpoint_files) == expected_files, f"Expected {expected_files} files, got {len(checkpoint_files)}: {checkpoint_names}" From 0ba7d1b7eb693cabeb334f9c81dd3aa04b9da705 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 23 Sep 2025 07:04:06 +0200 Subject: [PATCH 3/6] smaller fix --- src/lightning/pytorch/callbacks/model_checkpoint.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/lightning/pytorch/callbacks/model_checkpoint.py b/src/lightning/pytorch/callbacks/model_checkpoint.py index 6eae15cc2acf9..c96a3453701b7 100644 --- a/src/lightning/pytorch/callbacks/model_checkpoint.py +++ b/src/lightning/pytorch/callbacks/model_checkpoint.py @@ -908,5 +908,3 @@ def _should_remove_checkpoint(self, trainer: "pl.Trainer", previous: str, curren def _remove_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None: """Calls the strategy to remove the checkpoint file.""" trainer.strategy.remove_checkpoint(filepath) - - \ No newline at end of file From de3b0f7347fc99cf431f0684316f911c7698bc63 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 23 Sep 2025 05:04:32 +0000 Subject: [PATCH 4/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../pytorch/callbacks/model_checkpoint.py | 10 ++++++---- .../checkpointing/test_model_checkpoint.py | 14 ++++++++------ 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/src/lightning/pytorch/callbacks/model_checkpoint.py b/src/lightning/pytorch/callbacks/model_checkpoint.py index c96a3453701b7..e7073b73bfbf2 100644 --- a/src/lightning/pytorch/callbacks/model_checkpoint.py +++ b/src/lightning/pytorch/callbacks/model_checkpoint.py @@ -381,8 +381,9 @@ def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModu if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0: self._save_topk_checkpoint(trainer, monitor_candidates) # Only save last checkpoint if a checkpoint was actually saved in this step or if save_last="link" - if (self._last_global_step_saved == trainer.global_step or - (self.save_last == "link" and self._last_checkpoint_saved)): + if self._last_global_step_saved == trainer.global_step or ( + self.save_last == "link" and self._last_checkpoint_saved + ): self._save_last_checkpoint(trainer, monitor_candidates) @override @@ -401,8 +402,9 @@ def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModul if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0: self._save_topk_checkpoint(trainer, monitor_candidates) # Only save last checkpoint if a checkpoint was actually saved in this step or if save_last="link" - if (self._last_global_step_saved == trainer.global_step or - (self.save_last == "link" and self._last_checkpoint_saved)): + if self._last_global_step_saved == trainer.global_step or ( + self.save_last == "link" and self._last_checkpoint_saved + ): self._save_last_checkpoint(trainer, monitor_candidates) @override diff --git a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py index a3c743be20e03..feeb9a55b91d9 100644 --- a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py @@ -2128,12 +2128,12 @@ def test_save_last_without_save_on_train_epoch_and_without_val(tmp_path): def test_save_last_only_when_checkpoint_saved(tmp_path): """Test that save_last only creates last.ckpt when another checkpoint is actually saved.""" - + class SelectiveModel(BoringModel): def __init__(self): super().__init__() self.validation_step_outputs = [] - + def validation_step(self, batch, batch_idx): outputs = super().validation_step(batch, batch_idx) epoch = self.trainer.current_epoch @@ -2141,7 +2141,7 @@ def validation_step(self, batch, batch_idx): outputs["val_loss"] = loss self.validation_step_outputs.append(outputs) return outputs - + def on_validation_epoch_end(self): if self.validation_step_outputs: avg_loss = torch.stack([x["val_loss"] for x in self.validation_step_outputs]).mean() @@ -2149,11 +2149,11 @@ def on_validation_epoch_end(self): self.validation_step_outputs.clear() model = SelectiveModel() - + checkpoint_callback = ModelCheckpoint( dirpath=tmp_path, filename="best-{epoch}-{val_loss:.2f}", - monitor="val_loss", + monitor="val_loss", save_last=True, save_top_k=1, mode="min", @@ -2177,4 +2177,6 @@ def on_validation_epoch_end(self): checkpoint_names = [f.name for f in checkpoint_files] assert "last.ckpt" in checkpoint_names, "last.ckpt should exist since checkpoints were saved" expected_files = 2 # best checkpoint + last.ckpt - assert len(checkpoint_files) == expected_files, f"Expected {expected_files} files, got {len(checkpoint_files)}: {checkpoint_names}" + assert len(checkpoint_files) == expected_files, ( + f"Expected {expected_files} files, got {len(checkpoint_files)}: {checkpoint_names}" + ) From 9c11ca3cb7556a0fe454eabaf761045fae2db9d0 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 23 Sep 2025 07:05:49 +0200 Subject: [PATCH 5/6] changelog --- src/lightning/pytorch/CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index b3ed00611c021..06a7d27fbe94c 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -49,6 +49,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed preventing recursive symlink creation iwhen `save_last='link'` and `save_top_k=-1` ([#21186](https://github.com/Lightning-AI/pytorch-lightning/pull/21186)) +- Fixed `last.ckpt` being created and not linked to another checkpoint ([#21244](https://github.com/Lightning-AI/pytorch-lightning/pull/21244)) + --- ## [2.5.5] - 2025-09-05 From 463587e243a740188cf2793b16545e360caff8de Mon Sep 17 00:00:00 2001 From: Justus Schock Date: Thu, 25 Sep 2025 00:34:37 +0200 Subject: [PATCH 6/6] empty ci commit