Skip to content

Commit

Permalink
Add separate classes for classification tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
amrit110 committed Nov 8, 2023
1 parent e5d0387 commit 0fc97cb
Show file tree
Hide file tree
Showing 12 changed files with 356 additions and 181 deletions.
19 changes: 19 additions & 0 deletions cyclops/models/wrappers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,18 @@ class ModelWrapper(ABC):
"""

@abstractmethod
def model_name(self) -> str:
"""Name of the model.
Returns
-------
str
Name of the model.
"""
raise NotImplementedError

@abstractmethod
def partial_fit(
self,
Expand Down Expand Up @@ -48,6 +60,7 @@ def partial_fit(
The fitted model.
"""
raise NotImplementedError

@abstractmethod
def fit(
Expand Down Expand Up @@ -89,6 +102,7 @@ def fit(
The fitted model.
"""
raise NotImplementedError

@abstractmethod
def find_best(
Expand Down Expand Up @@ -140,6 +154,7 @@ def find_best(
self
"""
raise NotImplementedError

@abstractmethod
def predict(
Expand Down Expand Up @@ -187,6 +202,7 @@ def predict(
The output of the model.
"""
raise NotImplementedError

@abstractmethod
def predict_proba(
Expand Down Expand Up @@ -214,6 +230,7 @@ def predict_proba(
The probabilities of the output of the model.
"""
raise NotImplementedError

@abstractmethod
def save_model(self, filepath: str, overwrite: bool = True, **kwargs):
Expand All @@ -234,6 +251,7 @@ def save_model(self, filepath: str, overwrite: bool = True, **kwargs):
None
"""
raise NotImplementedError

@abstractmethod
def load_model(self, filepath: str, **kwargs):
Expand All @@ -252,6 +270,7 @@ def load_model(self, filepath: str, **kwargs):
self
"""
raise NotImplementedError

def get_params(self) -> Dict[str, Any]:
"""Get parameters for the wrapper.
Expand Down
14 changes: 13 additions & 1 deletion cyclops/models/wrappers/pt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,18 @@ def __init__(
self.train_loss_ = LossMeter("train")
self.val_loss_ = LossMeter("val")

@property
def model_name(self) -> str:
"""The model name.
Returns
-------
str
The model name.
"""
return self.model_.__class__.__name__

def collect_params_for(self, prefix: str) -> Dict:
"""Collect parameters for a given prefix.
Expand Down Expand Up @@ -748,7 +760,7 @@ def _train_loop(
val_loader = self._get_dataloader(val_dataset, test=True)

save_dir = self.save_dir if self.save_dir else os.getcwd()
model_dir = join(save_dir, "saved_models", self.model_.__class__.__name__)
model_dir = join(save_dir, "saved_models", self.model_name)

best_loss = np.inf
for epoch in range(1, self.max_epochs + 1):
Expand Down
54 changes: 39 additions & 15 deletions cyclops/models/wrappers/sk_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,18 @@ def __init__(
self.batch_size = params.pop("batch_size", 64)
self.initialize_model(**params)

@property
def model_name(self) -> str:
"""Model name.
Returns
-------
str
Model name.
"""
return self.model_.__class__.__name__

def initialize_model(self, **kwargs):
"""Initialize model.
Expand Down Expand Up @@ -381,8 +393,7 @@ def partial_fit(
splits_mapping = {"train": "train"}
if not hasattr(self.model_, "partial_fit"):
raise AttributeError(
f"Model {self.model_.__class__.__name__}"
"does not have a `partial_fit` method.",
f"Model {self.model_name}" "does not have a `partial_fit` method.",
)
# Train data is a Hugging Face Dataset Dictionary.
if isinstance(X, DatasetDict):
Expand Down Expand Up @@ -669,8 +680,7 @@ def predict_proba(
splits_mapping = {"test": "test"}
if not hasattr(self.model_, "predict_proba"):
raise AttributeError(
f"Model {self.model_.__class__.__name__}"
"does not have a `predict_proba` method.",
f"Model {self.model_name}" "does not have a `predict_proba` method.",
)
# Data is a Hugging Face Dataset Dictionary.
if isinstance(X, DatasetDict):
Expand All @@ -697,9 +707,7 @@ def predict_proba(
if model_name:
pred_column = f"{prediction_column_prefix}.{model_name}"
else:
pred_column = (
f"{prediction_column_prefix}.{self.model_.__class__.__name__}"
)
pred_column = f"{prediction_column_prefix}.{self.model_name}"

format_kwargs = {}
is_callable_transform = callable(transforms)
Expand Down Expand Up @@ -822,9 +830,7 @@ def predict(
if model_name:
pred_column = f"{prediction_column_prefix}.{model_name}"
else:
pred_column = (
f"{prediction_column_prefix}.{self.model_.__class__.__name__}"
)
pred_column = f"{prediction_column_prefix}.{self.model_name}"

format_kwargs = {}
is_callable_transform = callable(transforms)
Expand Down Expand Up @@ -873,19 +879,35 @@ def get_predictions(examples: Dict[str, Union[List, np.ndarray]]) -> dict:
output = self.model_.transform(X)
return output

def save_model(self, filepath: str, overwrite: bool = True, **kwargs):
"""Save model to file."""
def save_model(self, filepath: str, overwrite: bool = True, **kwargs) -> str:
"""Save model to file.
Parameters
----------
filepath : str
The path to save the model.
overwrite : bool, optional
Whether to overwrite the existing model, by default True
**kwargs : dict, optional
Additional keyword arguments to be passed to the save function.
Returns
-------
str
The path to the saved model.
"""
# filepath could be a directory
if len(os.path.basename(filepath).split(".")) == 1:
process_dir_save_path(filepath)

if os.path.isdir(filepath):
filepath = join(filepath, self.model_.__class__.__name__, "model.pkl")
filepath = join(filepath, self.model_name, "model.pkl")

# filepath could be a file
dir_path = os.path.dirname(filepath)
if dir_path == "":
dir_path = f"./{self.model_.__class__.__name__}"
dir_path = f"./{self.model_name}"
filepath = join(dir_path, filepath)
process_dir_save_path(dir_path)

Expand All @@ -895,10 +917,12 @@ def save_model(self, filepath: str, overwrite: bool = True, **kwargs):
"The file %s already exists and will not be overwritten.",
filepath,
)
return
return None

save_pickle(self.model_, filepath, log=kwargs.get("log", True))

return filepath

def load_model(self, filepath: str, **kwargs):
"""Load a saved model.
Expand Down
5 changes: 5 additions & 0 deletions cyclops/tasks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
"""Tasks package."""

from cyclops.tasks.classification import (
BinaryTabularClassificationTask,
MultilabelImageClassificationTask,
)
from cyclops.tasks.cxr_classification import CXRClassificationTask
from cyclops.tasks.mortality_prediction import MortalityPredictionTask
92 changes: 91 additions & 1 deletion cyclops/tasks/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""Base task class."""

import logging
import os
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Sequence, Tuple, Union
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

from cyclops.models.utils import get_device
from cyclops.models.wrappers import WrappedModel
Expand Down Expand Up @@ -161,3 +162,92 @@ def get_model(self, model_name: Optional[str] = None) -> Tuple[str, WrappedModel
model = self.models[model_name]

return model_name, model

def save_model(
self,
filepath: Union[str, Dict[str, str]],
model_names: Optional[Union[str, List[str]]] = None,
**kwargs,
) -> None:
"""Save the model to a specified filepath.
Parameters
----------
filepath : Union[str, Dict[str, str]]
The destination path(s) where the model(s) will be saved.
Can be a dictionary of model names and their corresponding paths
or a single parent dirctory.
model_name : Optional[Union[str, List[str]]], optional
Model name, required if more than one model exists, by default None.
**kwargs : Any
Additional keyword arguments to be passed to the model's save method.
Returns
-------
None
"""
if isinstance(model_names, str):
model_names = [model_names]
elif not model_names:
model_names = self.trained_models

if isinstance(filepath, Dict):
assert len(filepath) == len(model_names), (
"Number of filepaths must match number of models"
"if a dictionary is given."
)
if isinstance(filepath, str) and len(model_names) > 1:
assert len(os.path.basename(filepath).split(".")) == 1, (
"Filepath must be a directory if a single string is given"
"for multiple models."
)

for model_name in model_names:
if model_name not in self.trained_models:
LOGGER.warning(
"It seems you have not trained the %s model.",
model_name,
)
model_name, model = self.get_model(model_name) # noqa: PLW2901
model_path = (
filepath[model_name] if isinstance(filepath, Dict) else filepath
)
model.save_model(model_path, **kwargs)

def load_model(
self,
filepath: str,
model_name: Optional[str] = None,
**kwargs,
) -> WrappedModel:
"""Load a pretrained model.
Parameters
----------
filepath : str
Path to the save model.
model_name : Optional[str], optional
Model name, required if more than one model exists, by default Nonee
Returns
-------
WrappedModel
The loaded model.
"""
model_name, model = self.get_model(model_name)
model.load_model(filepath, **kwargs)
self.pretrained_models.append(model_name)
return model

def list_models_params(self) -> Dict[str, Any]:
"""List the parameters of the models in the task.
Returns
-------
Dict[str, Any]
Dictionary of model parameters.
"""
return {n: m.get_params() for n, m in self.models.items()}
Loading

0 comments on commit 0fc97cb

Please sign in to comment.