Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,33 @@

# CHANGELOG

## v0.2.4 (2025-12-16)

### Features

- Added AUPRC (Area Under Precision-Recall Curve) metric support
([`a7ff4b8`](https://github.com/Doleus/doleus/commit/a7ff4b8))

### Testing

- Added AUPRC tests for binary, multiclass, and multilabel classification
([`3cdf6b0`](https://github.com/Doleus/doleus/commit/3cdf6b0))

### Bug Fixes

- Replaced numpy with PyTorch in TPR/FPR functions for better performance
([`c8013f4`](https://github.com/Doleus/doleus/commit/c8013f4))

### Build

- Removed torchvision from core dependencies
([`1b09c27`](https://github.com/Doleus/doleus/commit/1b09c27))

### Documentation

- Added CPU-only PyTorch installation instructions
([`0df9619`](https://github.com/Doleus/doleus/commit/0df9619))

## v0.2.2 (2025-10-04)

### Build
Expand Down
13 changes: 13 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,23 @@ This approach surfaces hidden failure modes that aggregate metrics miss.

## Quick Start (Classification)

### Standard Installation

```sh
pip install doleus
```

### Lightweight Installation (CPU-only PyTorch)

For a smaller install footprint (~150MB vs ~2-3GB), install CPU-only PyTorch first:

```sh
pip install torch torchmetrics --index-url https://download.pytorch.org/whl/cpu
pip install doleus
```

This prevents pip from installing the GPU-enabled PyTorch package, which includes large CUDA libraries that aren't needed for CPU-only workflows.

### Demo

Want to try a complete working example before diving into the details?
Expand Down
58 changes: 57 additions & 1 deletion doleus/metrics/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(
metric_parameters : Optional[Dict[str, Any]], optional
Optional parameters to pass directly to the corresponding torchmetrics function, by default None.
For ROC-based metrics (TPR_at_FPR, FPR_at_TPR), must include 'fpr_threshold' or 'tpr_threshold' respectively.
AUPRC requires prediction scores/logits (not labels).
target_class : Optional[Union[int, str]], optional
Optional class ID or name to compute class-specific metrics.
"""
Expand Down Expand Up @@ -73,7 +74,7 @@ def _calculate_classification(
groundtruths : List[Labels]
List of ground truth label annotations.
predictions : List[Labels]
List of predicted label annotations. For ROC-based metrics (TPR_at_FPR, FPR_at_TPR),
List of predicted label annotations. For ROC-based metrics (TPR_at_FPR, FPR_at_TPR) and AUPRC,
predictions must contain scores/logits (not labels).

Returns
Expand Down Expand Up @@ -131,6 +132,60 @@ def _calculate_classification(

return float(metric_value)

# Special handling for AUPRC (Average Precision)
if self.metric == "AUPRC":
# AUPRC requires scores/logits, not labels
pred_list = []
for ann in predictions:
if ann.scores is None:
raise ValueError(
f"{self.metric} requires prediction scores/logits, "
f"but prediction annotation has no scores. "
f"Please provide float predictions (scores/logits) instead of integer labels."
)
pred_list.append(ann.scores.squeeze())

if not pred_list:
raise ValueError("No predictions provided to compute the metric.")
pred_tensor = torch.stack(pred_list)

# Set macro averaging as the default
if "average" not in self.metric_parameters:
self.metric_parameters["average"] = "macro"

# If a specific class is requested, override averaging
if self.target_class_id is not None:
self.metric_parameters["average"] = "none"

metric_fn = METRIC_FUNCTIONS[self.metric]

# Torchmetrics expects num_labels for multilabel tasks and num_classes for other tasks
if self.dataset.task == "multilabel":
metric_value = metric_fn(
pred_tensor,
gt_tensor,
task=self.dataset.task,
num_labels=self.dataset.num_classes,
**self.metric_parameters,
)
else:
metric_value = metric_fn(
pred_tensor,
gt_tensor,
task=self.dataset.task,
num_classes=self.dataset.num_classes,
**self.metric_parameters,
)

if self.target_class_id is not None:
metric_value = metric_value[self.target_class_id]

return (
float(metric_value.item())
if hasattr(metric_value, "item")
else float(metric_value)
)

# Standard classification metrics
pred_list = []
for ann in predictions:
Expand Down Expand Up @@ -267,6 +322,7 @@ def calculate_metric(
Optional parameters to pass directly to the corresponding torchmetrics function, by default None.
For ROC-based metrics (TPR_at_FPR, FPR_at_TPR), must include 'fpr_threshold' or 'tpr_threshold' respectively.
ROC-based metrics only support binary classification and require prediction scores/logits (not labels).
AUPRC supports all classification tasks and requires prediction scores/logits (not labels).
target_class : Optional[Union[int, str]], optional
Optional class ID or name to compute class-specific metrics, by default None.

Expand Down
40 changes: 21 additions & 19 deletions doleus/metrics/metric_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

from typing import Optional, Union

import numpy as np
import torch
import torchmetrics

Expand Down Expand Up @@ -40,29 +39,30 @@ def _tpr_at_fpr(preds, target, task, num_classes, **kwargs):
preds, target, **kwargs
)

# Convert to numpy for easier indexing
fpr_np = fpr.cpu().numpy()
tpr_np = tpr.cpu().numpy()
# Ensure tensors are on CPU for indexing operations
fpr = fpr.cpu()
tpr = tpr.cpu()

# Handle edge cases
if len(fpr_np) == 0:
if len(fpr) == 0:
raise ValueError("ROC curve is empty")

# If threshold is exactly at a point
exact_matches = np.where(fpr_np == fpr_threshold)[0]
exact_matches = (fpr == fpr_threshold).nonzero(as_tuple=True)[0]
if len(exact_matches) > 0:
return float(tpr_np[exact_matches[0]])
return float(tpr[exact_matches[0]])

# Find the largest FPR <= threshold (next smaller value)
valid_indices = np.where(fpr_np <= fpr_threshold)[0]
valid_mask = fpr <= fpr_threshold
valid_indices = valid_mask.nonzero(as_tuple=True)[0]

if len(valid_indices) == 0:
# All FPR values are greater than threshold, return TPR at first point
return float(tpr_np[0])
return float(tpr[0])

# Get the point with largest FPR <= threshold
idx_max = valid_indices[-1]
return float(tpr_np[idx_max])
return float(tpr[idx_max])


def _fpr_at_tpr(preds, target, task, num_classes, **kwargs):
Expand Down Expand Up @@ -96,29 +96,30 @@ def _fpr_at_tpr(preds, target, task, num_classes, **kwargs):
preds, target, **kwargs
)

# Convert to numpy for easier indexing
fpr_np = fpr.cpu().numpy()
tpr_np = tpr.cpu().numpy()
# Ensure tensors are on CPU for indexing operations
fpr = fpr.cpu()
tpr = tpr.cpu()

# Handle edge cases
if len(tpr_np) == 0:
if len(tpr) == 0:
raise ValueError("ROC curve is empty")

# If threshold is exactly at a point
exact_matches = np.where(tpr_np == tpr_threshold)[0]
exact_matches = (tpr == tpr_threshold).nonzero(as_tuple=True)[0]
if len(exact_matches) > 0:
return float(fpr_np[exact_matches[0]])
return float(fpr[exact_matches[0]])

# Find the smallest TPR >= threshold (next bigger value)
valid_indices = np.where(tpr_np >= tpr_threshold)[0]
valid_mask = tpr >= tpr_threshold
valid_indices = valid_mask.nonzero(as_tuple=True)[0]

if len(valid_indices) == 0:
# All TPR values are less than threshold, return FPR at last point
return float(fpr_np[-1])
return float(fpr[-1])

# Get the point with smallest TPR >= threshold
idx_min = valid_indices[0]
return float(fpr_np[idx_min])
return float(fpr[idx_min])


METRIC_FUNCTIONS = {
Expand All @@ -129,6 +130,7 @@ def _fpr_at_tpr(preds, target, task, num_classes, **kwargs):
"HammingDistance": torchmetrics.functional.hamming_distance,
"TPR_at_FPR": _tpr_at_fpr,
"FPR_at_TPR": _fpr_at_tpr,
"AUPRC": torchmetrics.functional.average_precision,
"mAP": torchmetrics.detection.MeanAveragePrecision,
"mAP_small": torchmetrics.detection.MeanAveragePrecision,
"mAP_medium": torchmetrics.detection.MeanAveragePrecision,
Expand Down
45 changes: 1 addition & 44 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 2 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "doleus"
version = "0.2.2"
version = "0.2.4"
description = "Doleus: Test Your Image-based AI Models on Data Slices"
authors = [
{name = "Hendrik Schulze Bröring"},
Expand All @@ -10,8 +10,7 @@ license = { text = "Apache-2.0" }
readme = "README.md"
requires-python = ">=3.11"
dependencies = [
"torch>=2.0.0,<3.0.0",
"torchvision>=0.15.0,<1.0.0",
"torch>=2.3.0,<3.0.0",
"tqdm>=4.60.0,<5.0.0",
"torchmetrics>=1.0.0,<2.0.0",
"opencv-python-headless>=4.5.0,<5.0.0.0",
Expand Down
Loading
Loading