Skip to content

Commit

Permalink
Merge pull request #130 from geometric-intelligence/logger
Browse files Browse the repository at this point in the history
Allow all hp logging
  • Loading branch information
levtelyatnikov authored Dec 26, 2024
2 parents 58083bd + 491af81 commit 5151fe0
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 13 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ cython_debug/

# Data
/datasets/
/data/
/notebooks/*.csv
notebooks/tmp
*.pickle
Expand Down
15 changes: 2 additions & 13 deletions topobenchmark/utils/logging_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@ def log_hyperparameters(object_dict: dict[str, Any]) -> None:
log.warning("Logger not found! Skipping hyperparameter logging...")
return

hparams["model"] = cfg["model"]

# save number of model parameters
hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
hparams["model/params/trainable"] = sum(
Expand All @@ -46,17 +44,8 @@ def log_hyperparameters(object_dict: dict[str, Any]) -> None:
p.numel() for p in model.parameters() if not p.requires_grad
)

hparams["dataset"] = cfg["dataset"]
hparams["trainer"] = cfg["trainer"]

hparams["callbacks"] = cfg.get("callbacks")
hparams["extras"] = cfg.get("extras")

hparams["task_name"] = cfg.get("task_name")
hparams["tags"] = cfg.get("tags")
hparams["ckpt_path"] = cfg.get("ckpt_path")
hparams["seed"] = cfg.get("seed")
hparams["paths"] = cfg.get("paths")
for key in cfg:
hparams[key] = cfg[key]

# send hparams to all loggers
for logger in trainer.loggers:
Expand Down

0 comments on commit 5151fe0

Please sign in to comment.