diff --git a/blazingai/params.py b/blazingai/params.py index f0796f8..edc193e 100644 --- a/blazingai/params.py +++ b/blazingai/params.py @@ -5,8 +5,12 @@ def load_cfg(fpath: Path, cfg_name: str) -> DictConfig: - cfg = load_hparams_from_yaml(config_yaml=fpath, use_omegaconf=True) - cfg = cfg.get(cfg_name) - if hasattr(cfg, "lr"): + all_configs = load_hparams_from_yaml(config_yaml=fpath, use_omegaconf=True) + cfg = all_configs.get(cfg_name) + if cfg is None: + raise ValueError(f"Could not find config {cfg_name} in {fpath}") + + if hasattr(cfg, "lr"): # convert exponential notation to float cfg.lr = float(cfg.lr) # type: ignore + return cfg