diff --git a/src/lightning/fabric/loggers/tensorboard.py b/src/lightning/fabric/loggers/tensorboard.py index 14bc3d6b76220..208244dc38cd3 100644 --- a/src/lightning/fabric/loggers/tensorboard.py +++ b/src/lightning/fabric/loggers/tensorboard.py @@ -220,15 +220,19 @@ def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) @override @rank_zero_only def log_hyperparams( - self, params: Union[dict[str, Any], Namespace], metrics: Optional[dict[str, Any]] = None + self, + params: Union[dict[str, Any], Namespace], + metrics: Optional[dict[str, Any]] = None, + step: Optional[int] = None, ) -> None: """Record hyperparameters. TensorBoard logs with and without saved hyperparameters are incompatible, the hyperparameters are then not displayed in the TensorBoard. Please delete or move the previously saved logs to display the new ones with hyperparameters. Args: - params: a dictionary-like container with the hyperparameters + params: A dictionary-like container with the hyperparameters metrics: Dictionary with metric names as keys and measured quantities as values + step: Optional global step number for the logged metrics """ params = _convert_params(params) @@ -244,7 +248,7 @@ def log_hyperparams( metrics = {"hp_metric": metrics} if metrics: - self.log_metrics(metrics, 0) + self.log_metrics(metrics, step) if _TENSORBOARD_AVAILABLE: from torch.utils.tensorboard.summary import hparams @@ -253,9 +257,9 @@ def log_hyperparams( exp, ssi, sei = hparams(params, metrics) writer = self.experiment._get_file_writer() - writer.add_summary(exp) - writer.add_summary(ssi) - writer.add_summary(sei) + writer.add_summary(exp, step) + writer.add_summary(ssi, step) + writer.add_summary(sei, step) @override @rank_zero_only diff --git a/src/lightning/pytorch/loggers/tensorboard.py b/src/lightning/pytorch/loggers/tensorboard.py index e70c89269b166..f9cc41c67045c 100644 --- a/src/lightning/pytorch/loggers/tensorboard.py +++ b/src/lightning/pytorch/loggers/tensorboard.py @@ -153,15 +153,19 @@ def save_dir(self) -> str: @override @rank_zero_only def log_hyperparams( - self, params: Union[dict[str, Any], Namespace], metrics: Optional[dict[str, Any]] = None + self, + params: Union[dict[str, Any], Namespace], + metrics: Optional[dict[str, Any]] = None, + step: Optional[int] = None, ) -> None: """Record hyperparameters. TensorBoard logs with and without saved hyperparameters are incompatible, the hyperparameters are then not displayed in the TensorBoard. Please delete or move the previously saved logs to display the new ones with hyperparameters. Args: - params: a dictionary-like container with the hyperparameters + params: A dictionary-like container with the hyperparameters metrics: Dictionary with metric names as keys and measured quantities as values + step: Optional global step number for the logged metrics """ if _OMEGACONF_AVAILABLE: @@ -175,7 +179,7 @@ def log_hyperparams( else: self.hparams.update(params) - return super().log_hyperparams(params=params, metrics=metrics) + return super().log_hyperparams(params=params, metrics=metrics, step=step) @override @rank_zero_only diff --git a/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py b/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py index b90d767a23caf..40c82bec2fd10 100644 --- a/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py +++ b/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py @@ -544,7 +544,7 @@ def test_step(self, batch, batch_idx): "valid_loss_1", } assert mock_log_metrics.mock_calls == [ - call({"hp_metric": -1}, 0), + call({"hp_metric": -1}, None), call(metrics={"train_loss": ANY, "epoch": 0}, step=0), call(metrics={"valid_loss_0_step": ANY, "valid_loss_2": ANY}, step=0), call(metrics={"valid_loss_0_step": ANY, "valid_loss_2": ANY}, step=1),