Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
3bc16d3
[WIP] Classification Dataset refactor
AlbertvanHouten Aug 22, 2025
5ea1908
Fix image conversions
AlbertvanHouten Aug 27, 2025
87c5a31
add batching logic
AlbertvanHouten Aug 28, 2025
3df9e59
Merge branch 'develop' of https://github.com/open-edge-platform/train…
AlbertvanHouten Aug 28, 2025
1d8f332
Switch logic back to reuse OTXDataBatch which is initialized with dat…
AlbertvanHouten Aug 28, 2025
966ae56
Fix datumaro LabelGroup in _dispatch_label_info method
AlbertvanHouten Aug 28, 2025
53afc10
Several fixes to use new Datumaro Dataset class for classification tr…
AlbertvanHouten Sep 1, 2025
e9338d2
Merge branch 'develop' of https://github.com/open-edge-platform/train…
AlbertvanHouten Sep 1, 2025
3ae6745
add base OTX Sample class
AlbertvanHouten Sep 2, 2025
742f399
Add test cases
AlbertvanHouten Sep 2, 2025
856d54e
Address ruff/mypy issues
AlbertvanHouten Sep 2, 2025
785f5e6
Pin datumaro develop branch
AlbertvanHouten Sep 2, 2025
df3c776
Move get_idx_list_per_classes to dataset class and address other PR c…
AlbertvanHouten Sep 3, 2025
121bba7
Raise error if dataset isn't of correcty type.
AlbertvanHouten Sep 3, 2025
6fb07d8
Apply suggestion from @leoll2
AlbertvanHouten Sep 3, 2025
069474f
Update Copyright headers and add function doc
AlbertvanHouten Sep 3, 2025
aaddb60
Merge remote-tracking branch 'origin/albert/new-dataset-support' into…
AlbertvanHouten Sep 3, 2025
9ab6b7a
Pin numpy version
AlbertvanHouten Sep 3, 2025
543963e
Fix failing test
AlbertvanHouten Sep 3, 2025
86e4acd
Create LabelInfo class in the dataset and remove dispatching label in…
AlbertvanHouten Sep 4, 2025
d089db9
Update function doc
AlbertvanHouten Sep 4, 2025
b22fcc3
Fix failing tests after LabelInfo change
AlbertvanHouten Sep 4, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion lib/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ classifiers = [
"Programming Language :: Python :: 3.12",
]
dependencies = [
"datumaro==1.10.0",
"datumaro @ git+https://github.com/open-edge-platform/datumaro.git@develop",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: before the feature branch (feature/datumaro) can be merged to develop, this needs to be updated to point to an official datumaro release.

"omegaconf==2.3.0",
"rich==14.0.0",
"jsonargparse==4.35.0",
Expand All @@ -51,6 +51,7 @@ dependencies = [
"onnxconverter-common==1.14.0",
"nncf==2.17.0",
"anomalib[core]==1.1.3",
"numpy<2.0.0",
]

[project.optional-dependencies]
Expand Down
2 changes: 1 addition & 1 deletion lib/src/otx/backend/native/callbacks/gpu_mem_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def _get_and_log_device_stats(
batch_size (int): batch size.
"""
device = trainer.strategy.root_device
if device.type in ["cpu", "xpu"]:
if device.type in ["cpu", "xpu", "mps"]:
return

device_stats = trainer.accelerator.get_device_stats(device)
Expand Down
23 changes: 22 additions & 1 deletion lib/src/otx/data/dataset/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@
from __future__ import annotations

from abc import abstractmethod
from collections import defaultdict
from collections.abc import Iterable
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Callable, Iterator, List, Union

import cv2
import numpy as np
from datumaro.components.annotation import AnnotationType
from datumaro.components.annotation import AnnotationType, LabelCategories
from datumaro.util.image import IMAGE_BACKEND, IMAGE_COLOR_CHANNEL, ImageBackend
from datumaro.util.image import ImageColorChannel as DatumaroImageColorChannel
from torch.utils.data import Dataset
Expand Down Expand Up @@ -196,3 +197,23 @@ def _get_item_impl(self, idx: int) -> OTXDataItem | None:
def collate_fn(self) -> Callable:
"""Collection function to collect KeypointDetDataEntity into KeypointDetBatchDataEntity in data loader."""
return OTXDataItem.collate_fn

def get_idx_list_per_classes(self, use_string_label: bool = False) -> dict[int | str, list[int]]:
"""Get a dictionary mapping class labels (string or int) to lists of samples.

Args:
use_string_label (bool): If True, use string class labels as keys.
If False, use integer indices as keys.
"""
stats: dict[int | str, list[int]] = defaultdict(list)
for item_idx, item in enumerate(self.dm_subset):
for ann in item.annotations:
if use_string_label:
labels = self.dm_subset.categories().get(AnnotationType.label, LabelCategories())
stats[labels.items[ann.label].name].append(item_idx)
else:
stats[ann.label].append(item_idx)
# Remove duplicates in label stats idx: O(n)
for k in stats:
stats[k] = list(dict.fromkeys(stats[k]))
return stats
168 changes: 168 additions & 0 deletions lib/src/otx/data/dataset/base_new.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
# Copyright (C) 2023-2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""Base class for OTXDataset using new Datumaro experimental Dataset."""

from __future__ import annotations

import abc
from typing import TYPE_CHECKING, Callable, Iterable, List, Union

import numpy as np
import torch
from torch.utils.data import Dataset as TorchDataset

from otx import LabelInfo, NullLabelInfo

if TYPE_CHECKING:
from datumaro.experimental import Dataset

from otx.data.entity.sample import OTXSample
from otx.data.entity.torch.torch import OTXDataBatch
from otx.data.transform_libs.torchvision import Compose
from otx.types.image import ImageColorChannel

Transforms = Union[Compose, Callable, List[Callable], dict[str, Compose | Callable | List[Callable]]]


def _default_collate_fn(items: list[OTXSample]) -> OTXDataBatch:
"""Collate OTXSample items into an OTXDataBatch.

Args:
items: List of OTXSample items to batch
Returns:
Batched OTXSample items with stacked tensors
"""
# Convert images to float32 tensors before stacking
image_tensors = []
for item in items:
img = item.image
if isinstance(img, torch.Tensor):
# Convert to float32 if not already
if img.dtype != torch.float32:
img = img.float()
else:
# Convert numpy array to float32 tensor
img = torch.from_numpy(img).float()
image_tensors.append(img)

# Try to stack images if they have the same shape
if len(image_tensors) > 0 and all(t.shape == image_tensors[0].shape for t in image_tensors):
images = torch.stack(image_tensors)
else:
images = image_tensors

return OTXDataBatch(
batch_size=len(items),
images=images,
labels=[item.label for item in items] if items[0].label is not None else None,
masks=[item.masks for item in items] if any(item.masks is not None for item in items) else None,
bboxes=[item.bboxes for item in items] if any(item.bboxes is not None for item in items) else None,
keypoints=[item.keypoints for item in items] if any(item.keypoints is not None for item in items) else None,
polygons=[item.polygons for item in items if item.polygons is not None]
if any(item.polygons is not None for item in items)
else None,
imgs_info=[item.img_info for item in items] if any(item.img_info is not None for item in items) else None,
)


class OTXDataset(TorchDataset):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the old and new OTXDataset have the same API, I suggest to add some black-box functional tests that instantiate both classes, apply a sequence of operations to each, then finally assert that the respective results are identical.

"""Base OTXDataset using new Datumaro experimental Dataset.

Defines basic logic for OTX datasets.

Args:
transforms: Transforms to apply on images
image_color_channel: Color channel of images
stack_images: Whether or not to stack images in collate function in OTXBatchData entity.
sample_type: Type of sample to use for this dataset
"""

def __init__(
self,
dm_subset: Dataset,
transforms: Transforms,
max_refetch: int = 1000,
image_color_channel: ImageColorChannel = ImageColorChannel.RGB,
stack_images: bool = True,
to_tv_image: bool = True,
data_format: str = "",
sample_type: type[OTXSample] = OTXSample,
) -> None:
self.transforms = transforms
self.image_color_channel = image_color_channel
self.stack_images = stack_images
self.to_tv_image = to_tv_image
self.sample_type = sample_type
self.max_refetch = max_refetch
self.data_format = data_format
if (
hasattr(dm_subset, "schema")
and hasattr(dm_subset.schema, "attributes")
and "label" in dm_subset.schema.attributes
):
labels = dm_subset.schema.attributes["label"].categories.labels
self.label_info = LabelInfo(
label_names=labels,
label_groups=[labels],
label_ids=[str(i) for i in range(len(labels))],
)
else:
self.label_info = NullLabelInfo()
self.dm_subset = dm_subset

def __len__(self) -> int:
return len(self.dm_subset)

def _sample_another_idx(self) -> int:
return np.random.randint(0, len(self))

def _apply_transforms(self, entity: OTXSample) -> OTXSample | None:
if isinstance(self.transforms, Compose):
if self.to_tv_image:
entity.as_tv_image()
return self.transforms(entity)
if isinstance(self.transforms, Iterable):
return self._iterable_transforms(entity)
if callable(self.transforms):
return self.transforms(entity)
return None

def _iterable_transforms(self, item: OTXSample) -> OTXSample | None:
if not isinstance(self.transforms, list):
raise TypeError(item)

results = item
for transform in self.transforms:
results = transform(results)
# MMCV transform can produce None. Please see
# https://github.com/open-mmlab/mmengine/blob/26f22ed283ae4ac3a24b756809e5961efe6f9da8/mmengine/dataset/base_dataset.py#L59-L66
if results is None:
return None

return results

def __getitem__(self, index: int) -> OTXSample:
for _ in range(self.max_refetch):
results = self._get_item_impl(index)

if results is not None:
return results

index = self._sample_another_idx()

msg = f"Reach the maximum refetch number ({self.max_refetch})"
raise RuntimeError(msg)

def _get_item_impl(self, index: int) -> OTXSample | None:
dm_item = self.dm_subset[index]
return self._apply_transforms(dm_item)

@property
def collate_fn(self) -> Callable:
"""Collection function to collect samples into a batch in data loader."""
return _default_collate_fn

@abc.abstractmethod
def get_idx_list_per_classes(self, use_string_label: bool = False) -> dict[int, list[int]]:
"""Get a dictionary with class labels as keys and lists of corresponding sample indices as values."""
40 changes: 40 additions & 0 deletions lib/src/otx/data/dataset/classification_new.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright (C) 2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""Module for OTXClassificationDatasets using new Datumaro experimental Dataset."""

from __future__ import annotations

from otx.data.dataset.base_new import OTXDataset
from otx.data.entity.sample import ClassificationSample


class OTXMulticlassClsDataset(OTXDataset):
"""OTXDataset class for multi-class classification task using new Datumaro experimental Dataset."""

def __init__(self, **kwargs) -> None:
"""Initialize OTXMulticlassClsDataset.

Args:
**kwargs: Keyword arguments to pass to OTXDataset
"""
kwargs["sample_type"] = ClassificationSample
super().__init__(**kwargs)

def get_idx_list_per_classes(self, use_string_label: bool = False) -> dict[int, list[int]]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method is already implemented in the parent class OTXDataset. Why does this class need to override it?

Copy link
Contributor Author

@AlbertvanHouten AlbertvanHouten Sep 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe the implementation of this method will depend on the sample type. If it will be uniform across samples, it can be moved back to the new OTXDataset. The new OTXDataset does not implement this method at the moment.

"""Get a dictionary mapping class labels (string or int) to lists of samples.

Args:
use_string_label (bool): If True, use string class labels as keys.
If False, use integer indices as keys.
"""
idx_list_per_classes: dict[int, list[int]] = {}
for idx in range(len(self)):
item = self.dm_subset[idx]
label_id = item.label.item()
if use_string_label:
label_id = self.label_info.label_names[label_id]
if label_id not in idx_list_per_classes:
idx_list_per_classes[label_id] = []
idx_list_per_classes[label_id].append(idx)
return idx_list_per_classes
118 changes: 118 additions & 0 deletions lib/src/otx/data/entity/sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# Copyright (C) 2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""Sample classes for OTX data entities."""

from __future__ import annotations

from typing import TYPE_CHECKING, Any

import numpy as np
import polars as pl
import torch
from datumaro import Mask
from datumaro.components.media import Image
from datumaro.experimental.dataset import Sample
from datumaro.experimental.fields import image_field, label_field
from torchvision import tv_tensors

from otx.data.entity.base import ImageInfo

if TYPE_CHECKING:
from datumaro import DatasetItem, Polygon
from torchvision.tv_tensors import BoundingBoxes, Mask


class OTXSample(Sample):
"""Base class for OTX data samples."""

image: np.ndarray | torch.Tensor | tv_tensors.Image | Any

def as_tv_image(self) -> None:
"""Convert image to torchvision tv_tensors Image format."""
if isinstance(self.image, tv_tensors.Image):
return
if isinstance(self.image, (np.ndarray, torch.Tensor)):
self.image = tv_tensors.Image(self.image)
return
msg = "OTXSample must have an image"
raise ValueError(msg)

@property
def masks(self) -> Mask | None:
"""Get masks for the sample."""
return None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the semantics of these properties masks, bboxes, ... ? Why do they return None?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A large portion of the OTX codebase depends on the OTXDataItem. For example, all the transforms in https://github.com/open-edge-platform/training_extensions/blob/8418221b4e065d14b6f2c223f1ee297b80fd64c8/lib/src/otx/data/transform_libs/torchvision.py.

For now, we will mimic that functionality with the OTXSample to not break any logic. In the future, OTXDataItem should be replaced with OTXSample but this will require a larger refactor, including updating all the transforms to not take OTXDataItems as input.


@property
def bboxes(self) -> BoundingBoxes | None:
"""Get bounding boxes for the sample."""
return None

@property
def keypoints(self) -> torch.Tensor | None:
"""Get keypoints for the sample."""
return None

@property
def polygons(self) -> list[Polygon] | None:
"""Get polygons for the sample."""
return None

@property
def label(self) -> torch.Tensor | None:
"""Optional label property that returns None by default."""
return None

@property
def img_info(self) -> ImageInfo | None:
"""Get image information for the sample."""
if getattr(self, "_img_info", None) is None:
image = getattr(self, "image", None)
if image is not None and hasattr(image, "shape") and len(image.shape) == 3:
img_shape = image.shape[:2]
else:
return None
Comment on lines +73 to +74
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In what situations is this property expected to return None? Wouldn't it be better to raise an exception if it's not possible to extract image information from the sample?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is very hard to understand all usages of OTXDataItem.img_info and if they would work with None or not. For now, it is better to maintain the existing logic until we start to factor out OTXDataItem.

self._img_info = ImageInfo(
img_idx=0,
img_shape=img_shape,
ori_shape=img_shape,
)
return self._img_info

@img_info.setter
def img_info(self, value: ImageInfo) -> None:
self._img_info = value


class ClassificationSample(OTXSample):
"""OTXDataItemSample is a base class for OTX data items."""

image: np.ndarray | tv_tensors.Image = image_field(dtype=pl.UInt8)
label: torch.Tensor = label_field(pl.Int32())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the label is always a torch type, why can image be a numpy type? The double type np.ndarray | tv_tensors.Image may be problematic for the code that consumes this class, because it needs to handle both cases (image as np.ndarray and image as tv_tensors.Image).

Copy link
Contributor Author

@AlbertvanHouten AlbertvanHouten Sep 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is again existing OTX logic from OTXDataItem. The OTX codebase is already taking the different types into account.


@classmethod
def from_dm_item(cls, item: DatasetItem) -> ClassificationSample:
"""Create a ClassificationSample from a Datumaro DatasetItem.

Args:
item: Datumaro DatasetItem containing image and label

Returns:
ClassificationSample: Instance with image and label set
"""
image = item.media_as(Image).data
label = item.annotations[0].label if item.annotations else None

img_shape = image.shape[:2]
img_info = ImageInfo(
img_idx=0,
img_shape=img_shape,
ori_shape=img_shape,
)

sample = cls(
image=image,
label=torch.as_tensor(label, dtype=torch.long) if label is not None else torch.tensor(-1, dtype=torch.long),
)
sample.img_info = img_info
return sample
Loading
Loading