Skip to content

Commit

Permalink
Attempting to introduce enough typing in module_registry that mypy in…
Browse files Browse the repository at this point in the history
… pre-commit github actions will pass. (#201)
  • Loading branch information
drewoldag authored Feb 6, 2025
1 parent 78d0af6 commit d4127e1
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 10 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -154,5 +154,6 @@ module = [
'astropy.table',
'GPUtil',
'toml',
'umap',
]
ignore_missing_imports = true
8 changes: 7 additions & 1 deletion src/fibad/data_sets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,10 @@
from .hsc_data_set import HSCDataSet
from .inference_dataset import InferenceDataSet

__all__ = ["fibad_data_set", "DATA_SET_REGISTRY", "CifarDataSet", "HSCDataSet"]
__all__ = [
"fibad_data_set",
"DATA_SET_REGISTRY",
"CifarDataSet",
"HSCDataSet",
"InferenceDataSet",
]
23 changes: 14 additions & 9 deletions src/fibad/models/model_registry.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
from pathlib import Path
from typing import Any, cast

import torch.nn as nn

Expand Down Expand Up @@ -27,13 +28,15 @@ def _torch_criterion(self: nn.Module):
"""Load the criterion class using the name defined in the config and
instantiate it with the arguments defined in the config."""

config = cast(dict[str, Any], self.config)

# Load the class and get any parameters from the config dictionary
criterion_cls = get_or_load_class(self.config["criterion"])
criterion_name = self.config["criterion"]["name"]
criterion_cls = get_or_load_class(config["criterion"])
criterion_name = config["criterion"]["name"]

arguments = {}
if criterion_name in self.config:
arguments = self.config[criterion_name]
if criterion_name in config:
arguments = config[criterion_name]

# Print some information about the criterion function and parameters used
log_string = f"Using criterion: {criterion_name} "
Expand All @@ -50,13 +53,15 @@ def _torch_optimizer(self: nn.Module):
"""Load the optimizer class using the name defined in the config and
instantiate it with the arguments defined in the config."""

config = cast(dict[str, Any], self.config)

# Load the class and get any parameters from the config dictionary
optimizer_cls = get_or_load_class(self.config["optimizer"])
optimizer_name = self.config["optimizer"]["name"]
optimizer_cls = get_or_load_class(config["optimizer"])
optimizer_name = config["optimizer"]["name"]

arguments = {}
if optimizer_name in self.config:
arguments = self.config[optimizer_name]
if optimizer_name in config:
arguments = config[optimizer_name]

# Print some information about the optimizer function and parameters used
log_string = f"Using optimizer: {optimizer_name} "
Expand Down Expand Up @@ -127,7 +132,7 @@ def fetch_model_class(runtime_config: dict) -> type[nn.Module]:
model_cls = None

try:
model_cls = get_or_load_class(model_config, MODEL_REGISTRY)
model_cls = cast(type[nn.Module], get_or_load_class(model_config, MODEL_REGISTRY))
except ValueError as exc:
raise ValueError("Error fetching model class") from exc

Expand Down

0 comments on commit d4127e1

Please sign in to comment.