Skip to content

Commit

Permalink
Convert task classes to more generic so they can be re-used
Browse files Browse the repository at this point in the history
  • Loading branch information
amrit110 committed Nov 14, 2023
1 parent d59d1cc commit c510b4a
Show file tree
Hide file tree
Showing 10 changed files with 503 additions and 648 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ repos:
- id: black

- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: 'v0.1.0'
rev: 'v0.1.5'
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
Expand All @@ -34,7 +34,7 @@ repos:
entry: python3 -m mypy --config-file pyproject.toml
language: system
types: [python]
exclude: "use_cases|tests|cyclops/(process|models|tasks|monitor|report/plot)"
exclude: "use_cases|tests|cyclops/(process|models|monitor|report/plot)"

- repo: local
hooks:
Expand Down
4 changes: 2 additions & 2 deletions cyclops/models/wrappers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
import random
from collections import defaultdict
from typing import Mapping, Sequence, Union
from typing import Any, Mapping, Sequence, Union

import numpy as np
import torch
Expand Down Expand Up @@ -65,7 +65,7 @@ def to_tensor(
)


def to_numpy(X):
def to_numpy(X) -> Union[np.typing.NDArray[Any], Sequence, Mapping]:
"""Convert the input to a numpy array.
Parameters
Expand Down
2 changes: 0 additions & 2 deletions cyclops/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,3 @@
BinaryTabularClassificationTask,
MultilabelImageClassificationTask,
)
from cyclops.tasks.cxr_classification import CXRClassificationTask
from cyclops.tasks.mortality_prediction import MortalityPredictionTask
20 changes: 9 additions & 11 deletions cyclops/tasks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,16 @@ def __init__(
"""
self.models = prepare_models(models)
self._validate_models()
self.task_features = (
[task_features] if isinstance(task_features, str) else task_features
)
self.task_features = task_features
self.task_target = (
[task_target] if isinstance(task_target, str) else task_target
)
self.device = get_device()
self.trained_models = []
self.pretrained_models = []
self.trained_models: List[str] = []
self.pretrained_models: List[str] = []

@property
def models_count(self):
def models_count(self) -> int:
"""Number of models in the task.
Returns
Expand Down Expand Up @@ -92,7 +90,7 @@ def data_type(self) -> str:
"""
raise NotImplementedError

def list_models(self):
def list_models(self) -> List[str]:
"""List the names of the models in the task.
Returns
Expand All @@ -104,14 +102,14 @@ def list_models(self):
return list(self.models.keys())

@abstractmethod
def _validate_models(self):
def _validate_models(self) -> None:
"""Validate the models for the task data type."""
raise NotImplementedError

def add_model(
self,
model: Union[str, WrappedModel, Dict[str, WrappedModel]],
):
) -> None:
"""Add a model to the task.
Parameters
Expand Down Expand Up @@ -167,7 +165,7 @@ def save_model(
self,
filepath: Union[str, Dict[str, str]],
model_names: Optional[Union[str, List[str]]] = None,
**kwargs,
**kwargs: Any,
) -> None:
"""Save the model to a specified filepath.
Expand Down Expand Up @@ -219,7 +217,7 @@ def load_model(
self,
filepath: str,
model_name: Optional[str] = None,
**kwargs,
**kwargs: Any,
) -> WrappedModel:
"""Load a pretrained model.
Expand Down
Loading

0 comments on commit c510b4a

Please sign in to comment.