Skip to content

Commit

Permalink
Convert task classes to more generic so they can be re-used (#504)
Browse files Browse the repository at this point in the history
* Convert task classes to more generic so they can be re-used

* Fix imports in tutorial notebooks, and model report desc

* Small fix to metrics var
  • Loading branch information
amrit110 authored Nov 14, 2023
1 parent d59d1cc commit 3bba828
Show file tree
Hide file tree
Showing 13 changed files with 519 additions and 672 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ repos:
- id: black

- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: 'v0.1.0'
rev: 'v0.1.5'
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
Expand All @@ -34,7 +34,7 @@ repos:
entry: python3 -m mypy --config-file pyproject.toml
language: system
types: [python]
exclude: "use_cases|tests|cyclops/(process|models|tasks|monitor|report/plot)"
exclude: "use_cases|tests|cyclops/(process|models|monitor|report/plot)"

- repo: local
hooks:
Expand Down
4 changes: 2 additions & 2 deletions cyclops/models/wrappers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
import random
from collections import defaultdict
from typing import Mapping, Sequence, Union
from typing import Any, Mapping, Sequence, Union

import numpy as np
import torch
Expand Down Expand Up @@ -65,7 +65,7 @@ def to_tensor(
)


def to_numpy(X):
def to_numpy(X) -> Union[np.typing.NDArray[Any], Sequence, Mapping]:
"""Convert the input to a numpy array.
Parameters
Expand Down
2 changes: 0 additions & 2 deletions cyclops/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,3 @@
BinaryTabularClassificationTask,
MultilabelImageClassificationTask,
)
from cyclops.tasks.cxr_classification import CXRClassificationTask
from cyclops.tasks.mortality_prediction import MortalityPredictionTask
20 changes: 9 additions & 11 deletions cyclops/tasks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,16 @@ def __init__(
"""
self.models = prepare_models(models)
self._validate_models()
self.task_features = (
[task_features] if isinstance(task_features, str) else task_features
)
self.task_features = task_features
self.task_target = (
[task_target] if isinstance(task_target, str) else task_target
)
self.device = get_device()
self.trained_models = []
self.pretrained_models = []
self.trained_models: List[str] = []
self.pretrained_models: List[str] = []

@property
def models_count(self):
def models_count(self) -> int:
"""Number of models in the task.
Returns
Expand Down Expand Up @@ -92,7 +90,7 @@ def data_type(self) -> str:
"""
raise NotImplementedError

def list_models(self):
def list_models(self) -> List[str]:
"""List the names of the models in the task.
Returns
Expand All @@ -104,14 +102,14 @@ def list_models(self):
return list(self.models.keys())

@abstractmethod
def _validate_models(self):
def _validate_models(self) -> None:
"""Validate the models for the task data type."""
raise NotImplementedError

def add_model(
self,
model: Union[str, WrappedModel, Dict[str, WrappedModel]],
):
) -> None:
"""Add a model to the task.
Parameters
Expand Down Expand Up @@ -167,7 +165,7 @@ def save_model(
self,
filepath: Union[str, Dict[str, str]],
model_names: Optional[Union[str, List[str]]] = None,
**kwargs,
**kwargs: Any,
) -> None:
"""Save the model to a specified filepath.
Expand Down Expand Up @@ -219,7 +217,7 @@ def load_model(
self,
filepath: str,
model_name: Optional[str] = None,
**kwargs,
**kwargs: Any,
) -> WrappedModel:
"""Load a pretrained model.
Expand Down
Loading

0 comments on commit 3bba828

Please sign in to comment.