From 491af817166db458bbb681ba6be8684e811a8721 Mon Sep 17 00:00:00 2001 From: levtelyatnikov Date: Thu, 26 Dec 2024 09:03:18 -0800 Subject: [PATCH] hp updates --- .gitignore | 1 + topobenchmark/utils/logging_utils.py | 15 ++------------- 2 files changed, 3 insertions(+), 13 deletions(-) diff --git a/.gitignore b/.gitignore index 28088c16..02a6e40e 100755 --- a/.gitignore +++ b/.gitignore @@ -174,6 +174,7 @@ cython_debug/ # Data /datasets/ +/data/ /notebooks/*.csv notebooks/tmp *.pickle diff --git a/topobenchmark/utils/logging_utils.py b/topobenchmark/utils/logging_utils.py index 37735c06..0963953b 100755 --- a/topobenchmark/utils/logging_utils.py +++ b/topobenchmark/utils/logging_utils.py @@ -35,8 +35,6 @@ def log_hyperparameters(object_dict: dict[str, Any]) -> None: log.warning("Logger not found! Skipping hyperparameter logging...") return - hparams["model"] = cfg["model"] - # save number of model parameters hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) hparams["model/params/trainable"] = sum( @@ -46,17 +44,8 @@ def log_hyperparameters(object_dict: dict[str, Any]) -> None: p.numel() for p in model.parameters() if not p.requires_grad ) - hparams["dataset"] = cfg["dataset"] - hparams["trainer"] = cfg["trainer"] - - hparams["callbacks"] = cfg.get("callbacks") - hparams["extras"] = cfg.get("extras") - - hparams["task_name"] = cfg.get("task_name") - hparams["tags"] = cfg.get("tags") - hparams["ckpt_path"] = cfg.get("ckpt_path") - hparams["seed"] = cfg.get("seed") - hparams["paths"] = cfg.get("paths") + for key in cfg: + hparams[key] = cfg[key] # send hparams to all loggers for logger in trainer.loggers: