Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions doleus/annotations/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,26 @@ class Labels(Annotation):
"""

def __init__(
self, datapoint_number: int, labels: Tensor, scores: Optional[Tensor] = None
self, datapoint_number: int, labels: Optional[Tensor], scores: Optional[Tensor] = None
):
"""Initialize a Labels instance.

Parameters
----------
datapoint_number : int
Index for the corresponding data point.
labels : Tensor
A 1D integer tensor representing the label(s).
labels : Optional[Tensor]
A 1D integer tensor. For single-label tasks, this typically contains one class index
(e.g., `tensor([2])`). For multilabel tasks, this is typically a multi-hot encoded
tensor (e.g., `tensor([1, 0, 1])`). Can be `None` if only `scores` are provided.
scores : Optional[Tensor], optional
A float tensor containing predicted probability scores (optional).
A 1D float tensor. For single-label tasks (e.g. multiclass), this usually contains
probabilities for each class (e.g., `tensor([0.1, 0.2, 0.7])`). For multilabel
tasks, this contains independent probabilities for each label (e.g.,
`tensor([0.8, 0.1, 0.9])`). Optional.
"""
if labels is None and scores is None:
raise ValueError("Either 'labels' or 'scores' must be provided but both are None.")
super().__init__(datapoint_number)
self.labels = labels
self.scores = scores
Expand All @@ -38,7 +45,9 @@ def to_dict(self) -> dict:
dict
Dictionary with keys 'labels' and optionally 'scores'.
"""
output = {"labels": self.labels}
output = {}
if self.labels is not None:
output["labels"] = self.labels
if self.scores is not None:
output["scores"] = self.scores
return output
28 changes: 24 additions & 4 deletions doleus/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,15 @@
from tqdm import tqdm

from doleus.annotations import BoundingBoxes, Labels
from doleus.storage import GroundTruthStore, MetadataStore, PredictionStore
from doleus.storage import (
MetadataStore,
BasePredictionStore,
BaseGroundTruthStore,
)
from doleus.utils import (
ATTRIBUTE_FUNCTIONS,
OPERATOR_DICT,
TaskType,
get_current_timestamp,
to_numpy_image,
create_filename,
Expand Down Expand Up @@ -70,8 +75,11 @@ def __init__(
self.metadata = metadata if metadata is not None else {}
self.metadata["_timestamp"] = get_current_timestamp()

self.groundtruth_store = GroundTruthStore(task_type=task_type, dataset=dataset)
self.prediction_store = PredictionStore(task_type=task_type)
# Ground truth and prediction stores are initialized to None in the base class.
# Specific instantiations will be handled by subclasses (DoleusClassification, DoleusDetection).
self.groundtruth_store: Optional[BaseGroundTruthStore] = None
self.prediction_store: Optional[BasePredictionStore] = None

self.metadata_store = MetadataStore(
num_datapoints=len(dataset), metadata=per_datapoint_metadata
)
Expand All @@ -86,7 +94,7 @@ def __getattr__(self, attr):
return getattr(self.dataset, attr)

@abstractmethod
def _create_new_instance(self, dataset, indices):
def _create_new_instance(self, dataset, indices, slice_name):
pass

def add_model_predictions(
Expand All @@ -106,9 +114,21 @@ def add_model_predictions(
model_id : str
Name of the model that generated these predictions
"""
kwargs = {}
if self.task_type == TaskType.CLASSIFICATION.value:
kwargs['task'] = self.task
# Ensure predictions is a Tensor for classification
if not isinstance(predictions, torch.Tensor):
raise TypeError("For classification tasks, predictions must be a torch.Tensor.")
elif self.task_type == TaskType.DETECTION.value:
# Ensure predictions is a List[Dict] for detection
if not isinstance(predictions, list) or not all(isinstance(p, dict) for p in predictions):
raise TypeError("For detection tasks, predictions must be a list of dictionaries.")

self.prediction_store.add_predictions(
predictions=predictions,
model_id=model_id,
**kwargs,
)

# -------------------------------------------------------------------------
Expand Down
11 changes: 9 additions & 2 deletions doleus/datasets/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from doleus.datasets.base import Doleus
from doleus.utils import TaskType
from doleus.storage import ClassificationGroundTruthStore, ClassificationPredictionStore


class DoleusClassification(Doleus):
Expand Down Expand Up @@ -49,9 +50,14 @@ def __init__(
metadata=metadata,
per_datapoint_metadata=per_datapoint_metadata,
)

# Instantiate the classification-specific stores
self.groundtruth_store = ClassificationGroundTruthStore(
dataset=self.dataset, task=self.task, num_classes=self.num_classes
)
self.prediction_store = ClassificationPredictionStore()

def _create_new_instance(self, dataset, indices, name):
# TODO: Do we need to create a new dataset instance?
subset = Subset(dataset, indices)
metadata_subset = self.metadata_store.get_subset(indices)
new_instance = DoleusClassification(
Expand All @@ -64,8 +70,9 @@ def _create_new_instance(self, dataset, indices, name):
per_datapoint_metadata=metadata_subset,
)

# Copy sliced predictions directly to the new instance
for model_id in self.prediction_store.predictions:
sliced_preds = self.prediction_store.get_subset(model_id, indices)
new_instance.prediction_store.add_predictions(sliced_preds, model_id)
new_instance.prediction_store.predictions[model_id] = sliced_preds

return new_instance
11 changes: 8 additions & 3 deletions doleus/datasets/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from doleus.datasets.base import Doleus
from doleus.utils import TaskType
from doleus.storage import DetectionGroundTruthStore, DetectionPredictionStore
from doleus.annotations import Annotations


class DoleusDetection(Doleus):
Expand Down Expand Up @@ -40,6 +42,8 @@ def __init__(
metadata=metadata,
per_datapoint_metadata=per_datapoint_metadata,
)
self.groundtruth_store = DetectionGroundTruthStore(dataset=self.dataset)
self.prediction_store = DetectionPredictionStore()

def _create_new_instance(self, dataset, indices, slice_name):
subset = Subset(dataset, indices)
Expand All @@ -52,8 +56,9 @@ def _create_new_instance(self, dataset, indices, slice_name):
per_datapoint_metadata=new_metadata,
)

for model_id in self.prediction_store.predictions:
sliced_preds = self.prediction_store.get_subset(model_id, indices)
new_instance.prediction_store.add_predictions(sliced_preds, model_id)
if self.prediction_store and self.prediction_store.predictions:
for model_id in self.prediction_store.predictions:
sliced_preds_annotations = self.prediction_store.get_subset(model_id, indices)
new_instance.prediction_store.predictions[model_id] = sliced_preds_annotations

return new_instance
5 changes: 1 addition & 4 deletions doleus/metrics/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,7 @@ def _calculate_classification(
try:
gt_tensor = torch.stack([ann.labels.squeeze() for ann in groundtruths])

pred_list = [
ann.scores if ann.scores is not None else ann.labels.squeeze()
for ann in predictions
]
pred_list = [ann.labels.squeeze() for ann in predictions]
if not pred_list:
raise ValueError("No predictions provided to compute the metric.")
pred_tensor = torch.stack(pred_list)
Expand Down
22 changes: 19 additions & 3 deletions doleus/storage/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,21 @@
from doleus.storage.ground_truth_store import GroundTruthStore
from doleus.storage.prediction_store import (
BasePredictionStore,
ClassificationPredictionStore,
DetectionPredictionStore,
)
from doleus.storage.ground_truth_store import (
BaseGroundTruthStore,
ClassificationGroundTruthStore,
DetectionGroundTruthStore,
)
from doleus.storage.metadata_store import MetadataStore
from doleus.storage.prediction_store import PredictionStore

__all__ = ["MetadataStore", "PredictionStore", "GroundTruthStore"]
__all__ = [
"BaseGroundTruthStore",
"BasePredictionStore",
"ClassificationGroundTruthStore",
"ClassificationPredictionStore",
"DetectionGroundTruthStore",
"DetectionPredictionStore",
"MetadataStore",
]
89 changes: 0 additions & 89 deletions doleus/storage/ground_truth_store.py

This file was deleted.

9 changes: 9 additions & 0 deletions doleus/storage/ground_truth_store/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from doleus.storage.ground_truth_store.base import BaseGroundTruthStore
from doleus.storage.ground_truth_store.classification import ClassificationGroundTruthStore
from doleus.storage.ground_truth_store.detection import DetectionGroundTruthStore

__all__ = [
"BaseGroundTruthStore",
"ClassificationGroundTruthStore",
"DetectionGroundTruthStore",
]
56 changes: 56 additions & 0 deletions doleus/storage/ground_truth_store/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from abc import ABC, abstractmethod
from typing import Any, Optional

from doleus.annotations import Annotation, Annotations


class BaseGroundTruthStore(ABC):
"""Base storage for ground truth data for a specific dataset instance."""

def __init__(self, dataset: Any):
"""
Initialize the ground truth store.

Parameters
----------
dataset : Any
The raw PyTorch dataset object.
"""
self.dataset = dataset
self.groundtruths: Optional[Annotations] = None
self.groundtruths = self._process_groundtruths()

@abstractmethod
def _process_groundtruths(self) -> Annotations:
"""
Process raw ground truth data from the dataset into the standard annotation format.
Actual implementation will depend on the task type (classification, detection).

Returns
-------
Annotations
Processed ground truths in standard annotation format.
"""
pass

def get(self, datapoint_number: int) -> Optional[Annotation]:
"""
Get a single ground truth annotation object by datapoint number.

Parameters
----------
datapoint_number : int
The ID of the sample in the dataset.

Returns
-------
Optional[Annotation]
The specific Annotation object (e.g., Labels, BoundingBoxes) for the datapoint,
or None if not found.
"""
if self.groundtruths is None:
return None
try:
return self.groundtruths[datapoint_number]
except KeyError:
return None
Loading
Loading