From 3bc16d322d90347378a615d3ca73956e77b71cad Mon Sep 17 00:00:00 2001 From: Albert van Houten Date: Fri, 22 Aug 2025 16:16:55 +0200 Subject: [PATCH 01/19] [WIP] Classification Dataset refactor --- src/otx/data/dataset/base_new.py | 96 ++++++++ src/otx/data/dataset/classification_new.py | 264 +++++++++++++++++++++ src/otx/data/entity/sample.py | 61 +++++ src/otx/data/factory.py | 43 +++- tests/unit/data/test_factory.py | 2 +- 5 files changed, 457 insertions(+), 9 deletions(-) create mode 100644 src/otx/data/dataset/base_new.py create mode 100644 src/otx/data/dataset/classification_new.py create mode 100644 src/otx/data/entity/sample.py diff --git a/src/otx/data/dataset/base_new.py b/src/otx/data/dataset/base_new.py new file mode 100644 index 00000000000..dbba7084e28 --- /dev/null +++ b/src/otx/data/dataset/base_new.py @@ -0,0 +1,96 @@ +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Base class for OTXDataset using new Datumaro experimental Dataset.""" + +from __future__ import annotations + +from typing import Callable, List, Type, Union, Iterable + +import numpy as np +import torch +from datumaro.experimental import Dataset +from datumaro.experimental.type_registry import convert_image_type +from torch.utils.data import Dataset as TorchDataset + +from otx.data.entity.sample import ClassificationSample +from otx.data.transform_libs.torchvision import Compose +from otx.types.image import ImageColorChannel +from otx.types.label import NullLabelInfo + +Transforms = Union[Compose, Callable, List[Callable], dict[str, Compose | Callable | List[Callable]]] + + +class OTXDataset(TorchDataset): + """Base OTXDataset using new Datumaro experimental Dataset. + + Defines basic logic for OTX datasets. + + Args: + transforms: Transforms to apply on images + image_color_channel: Color channel of images + stack_images: Whether or not to stack images in collate function in OTXBatchData entity. + sample_type: Type of sample to use for this dataset + """ + + def __init__( + self, + dm_subset: Dataset, + transforms: Transforms, + max_refetch: int = 1000, + image_color_channel: ImageColorChannel = ImageColorChannel.RGB, + stack_images: bool = True, + to_tv_image: bool = True, + data_format: str = "", + sample_type: Type[ClassificationSample] = ClassificationSample, + ) -> None: + self.transforms = transforms + self.image_color_channel = image_color_channel + self.stack_images = stack_images + self.to_tv_image = to_tv_image + self.sample_type = sample_type + + self.data_format = data_format + + # TODO: Properly reinit label_info + self.label_info = dm_subset.categories() + + self.dataset = dm_subset + + def __len__(self) -> int: + return len(self.dataset) + + def _sample_another_idx(self) -> int: + return np.random.randint(0, len(self)) + + def _apply_transforms(self, entity: ClassificationSample) -> ClassificationSample | None: + if isinstance(self.transforms, Compose): + if self.to_tv_image: + entity = convert_image_type(entity, torch.Tensor) + return self.transforms(entity) + if isinstance(self.transforms, Iterable): + return self._iterable_transforms(entity) + if callable(self.transforms): + return self.transforms(entity) + + def _iterable_transforms(self, item: ClassificationSample) -> ClassificationSample | None: + if not isinstance(self.transforms, list): + raise TypeError(item) + + results = item + for transform in self.transforms: + results = transform(results) + # MMCV transform can produce None. Please see + # https://github.com/open-mmlab/mmengine/blob/26f22ed283ae4ac3a24b756809e5961efe6f9da8/mmengine/dataset/base_dataset.py#L59-L66 + if results is None: + return None + + return results + + def __getitem__(self, index: int) -> ClassificationSample: + return self.dataset[index] + + @property + def collate_fn(self) -> Callable: + """Collection function to collect samples into a batch in data loader.""" + pass \ No newline at end of file diff --git a/src/otx/data/dataset/classification_new.py b/src/otx/data/dataset/classification_new.py new file mode 100644 index 00000000000..49526110fb5 --- /dev/null +++ b/src/otx/data/dataset/classification_new.py @@ -0,0 +1,264 @@ +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Module for OTXClassificationDatasets using new Datumaro experimental Dataset.""" + +from __future__ import annotations + +from .base_new import OTXDataset +from ..entity.sample import ClassificationSample + + +class OTXMulticlassClsDataset(OTXDataset): + """OTXDataset class for multi-class classification task using new Datumaro experimental Dataset.""" + + def __init__(self, **kwargs) -> None: + """Initialize OTXMulticlassClsDataset. + + Args: + **kwargs: Keyword arguments to pass to OTXDataset + """ + kwargs["sample_type"] = ClassificationSample + super().__init__(**kwargs) + + def _get_item_impl(self, index: int) -> ClassificationSample | None: + return self.dataset[index] + + +# class OTXMultilabelClsDataset(OTXDataset): +# """OTXDataset class for multi-label classification task using new Datumaro experimental Dataset.""" +# +# def __init__(self, **kwargs) -> None: +# """Initialize OTXMultilabelClsDataset. +# +# Args: +# **kwargs: Keyword arguments to pass to OTXDataset +# """ +# kwargs["sample_type"] = MultiLabelClassificationSample +# super().__init__(**kwargs) +# self.num_classes = len(self.dm_subset.categories()[AnnotationType.label]) +# +# def _get_item_impl(self, index: int) -> MultiLabelClassificationSample | None: +# item = self.dm_subset[index] +# img = item.media_as(Image) +# ignored_labels: list[int] = [] # This should be assigned form item +# img_data, img_shape, _ = self._get_img_data_and_shape(img) +# +# label_ids = set() +# for ann in item.annotations: +# # multilabel information stored in 'multi_label_ids' attribute when the source format is arrow +# if "multi_label_ids" in ann.attributes: +# for lbl_idx in ann.attributes["multi_label_ids"]: +# label_ids.add(lbl_idx) +# +# if isinstance(ann, Label): +# label_ids.add(ann.label) +# else: +# # If the annotation is not Label, it should be converted to Label. +# # For Chained Task: Detection (Bbox) -> Classification (Label) +# label = Label(label=ann.label) +# label_ids.add(label.label) +# labels = np.array(list(label_ids), dtype=np.int64) +# +# image_info = ImageInfo( +# width=img_data.shape[1], +# height=img_data.shape[0], +# ) +# # Create multilabel classification sample +# sample = MultiLabelClassificationSample( +# image=img_data, +# labels=self._convert_to_onehot(labels, ignored_labels), +# image_info=image_info, +# ) +# +# return self._apply_transforms(sample) +# +# def _convert_to_onehot(self, labels: np.ndarray, ignored_labels: list[int]) -> np.ndarray: +# """Convert label to one-hot vector format.""" +# # Convert to torch tensor for one_hot +# labels_tensor = torch.from_numpy(labels).long() +# # Torch's one_hot() expects the input to be of type long +# onehot = functional.one_hot(labels_tensor, self.num_classes).sum(0).clamp_max_(1).numpy() +# if ignored_labels: +# for ignore_label in ignored_labels: +# onehot[ignore_label] = -1 +# return onehot + + +# class OTXHlabelClsDataset(OTXDataset): +# """OTXDataset class for H-label classification task using new Datumaro experimental Dataset.""" +# +# def __init__(self, **kwargs) -> None: +# """Initialize OTXHlabelClsDataset. +# +# Args: +# **kwargs: Keyword arguments to pass to OTXDataset +# """ +# # Set the sample type to HierarchicalClassificationSample +# kwargs["sample_type"] = HierarchicalClassificationSample +# super().__init__(**kwargs) +# self.dm_categories = self.dm_subset.categories()[AnnotationType.label] +# +# # Hlabel classification used HLabelInfo to insert the HLabelData. +# if self.data_format == "arrow": +# # arrow format stores label IDs as names, have to deal with that here +# self.label_info = HLabelInfo.from_dm_label_groups_arrow(self.dm_categories) +# else: +# self.label_info = HLabelInfo.from_dm_label_groups(self.dm_categories) +# +# self.id_to_name_mapping = dict(zip(self.label_info.label_ids, self.label_info.label_names)) +# self.id_to_name_mapping[""] = "" +# +# if self.label_info.num_multiclass_heads == 0: +# msg = "The number of multiclass heads should be larger than 0." +# raise ValueError(msg) +# +# if self.data_format != "arrow": +# for dm_item in self.dm_subset: +# self._add_ancestors(dm_item.annotations) +# +# def _add_ancestors(self, label_anns: list[Label]) -> None: +# """Add ancestors recursively if some label miss the ancestor information. +# +# If the label tree likes below, +# object - vehicle -- car +# |- bus +# |- truck +# And annotation = ['car'], it should be ['car', 'vehicle', 'object'], to include the ancestor. +# +# This function add the ancestors to the annotation if missing. +# """ +# +# def _label_idx_to_name(idx: int) -> str: +# return self.dm_categories[idx].name +# +# def _label_name_to_idx(name: str) -> int: +# indices = [idx for idx, val in enumerate(self.label_info.label_names) if val == name] +# return indices[0] +# +# def _get_label_group_idx(label_name: str) -> int: +# if isinstance(self.label_info, HLabelInfo): +# if self.data_format == "arrow": +# return self.label_info.class_to_group_idx[self.id_to_name_mapping[label_name]][0] +# return self.label_info.class_to_group_idx[label_name][0] +# msg = f"self.label_info should have HLabelInfo type, got {type(self.label_info)}" +# raise ValueError(msg) +# +# def _find_ancestor_recursively(label_name: str, ancestors: list) -> list[str]: +# _, dm_label_category = self.dm_categories.find(label_name) +# parent_name = dm_label_category.parent if dm_label_category else "" +# +# if parent_name != "": +# ancestors.append(parent_name) +# _find_ancestor_recursively(parent_name, ancestors) +# return ancestors +# +# def _get_all_label_names_in_anns(anns: list[Label]) -> list[str]: +# return [_label_idx_to_name(ann.label) for ann in anns] +# +# all_label_names = _get_all_label_names_in_anns(label_anns) +# ancestor_dm_labels = [] +# for ann in label_anns: +# label_idx = ann.label +# label_name = _label_idx_to_name(label_idx) +# ancestors = _find_ancestor_recursively(label_name, []) +# +# for i, ancestor in enumerate(ancestors): +# if ancestor not in all_label_names: +# ancestor_dm_labels.append( +# Label( +# label=_label_name_to_idx(ancestor), +# id=len(label_anns) + i, +# group=_get_label_group_idx(ancestor), +# ), +# ) +# label_anns.extend(ancestor_dm_labels) +# +# def _get_item_impl(self, index: int) -> HierarchicalClassificationSample | None: +# item = self.dm_subset[index] +# img = item.media_as(Image) +# ignored_labels: list[int] = [] # This should be assigned form item +# img_data, img_shape, _ = self._get_img_data_and_shape(img) +# +# label_ids = set() +# for ann in item.annotations: +# # in h-cls scenario multilabel information stored in 'multi_label_ids' attribute +# if "multi_label_ids" in ann.attributes: +# for lbl_idx in ann.attributes["multi_label_ids"]: +# label_ids.add(lbl_idx) +# +# if isinstance(ann, Label): +# label_ids.add(ann.label) +# else: +# # If the annotation is not Label, it should be converted to Label. +# # For Chained Task: Detection (Bbox) -> Classification (Label) +# label = Label(label=ann.label) +# label_ids.add(label.label) +# +# hlabel_labels = self._convert_label_to_hlabel_format([Label(label=idx) for idx in label_ids], ignored_labels) +# +# # Create image info sample +# image_info = ImageInfo( +# width=img_data.shape[1], +# height=img_data.shape[0], +# ) +# +# # Create hierarchical classification sample +# sample = HierarchicalClassificationSample( +# image=img_data, +# labels=np.array(hlabel_labels, dtype=np.int64), +# image_info=image_info, +# ) +# +# return self._apply_transforms(sample) +# +# def _convert_label_to_hlabel_format(self, label_anns: list[Label], ignored_labels: list[int]) -> list[int]: +# """Convert format of the label to the h-label. +# +# It converts the label format to h-label format. +# Total length of result is sum of number of hierarchy and number of multilabel classes. +# +# i.e. +# Let's assume that we used the same dataset with example of the definition of HLabelData +# and the original labels are ["Rigid", "Triangle", "Lion"]. +# +# Then, h-label format will be [0, 1, 1, 0]. +# The first N-th indices represent the label index of multiclass heads (N=num_multiclass_heads), +# others represent the multilabel labels. +# +# [Multiclass Heads] +# 0-th index = 0 -> ["Rigid"(O), "Non-Rigid"(X)] <- First multiclass head +# 1-st index = 1 -> ["Rectangle"(O), "Triangle"(X), "Circle"(X)] <- Second multiclass head +# +# [Multilabel Head] +# 2, 3 indices = [1, 0] -> ["Lion"(O), "Panda"(X)] +# """ +# if not isinstance(self.label_info, HLabelInfo): +# msg = f"The type of label_info should be HLabelInfo, got {type(self.label_info)}." +# raise TypeError(msg) +# +# num_multiclass_heads = self.label_info.num_multiclass_heads +# num_multilabel_classes = self.label_info.num_multilabel_classes +# +# class_indices = [0] * (num_multiclass_heads + num_multilabel_classes) +# for i in range(num_multiclass_heads): +# class_indices[i] = -1 +# +# for ann in label_anns: +# if self.data_format == "arrow": +# # skips unknown labels for instance, the empty one +# if self.dm_categories.items[ann.label].name not in self.id_to_name_mapping: +# continue +# ann_name = self.id_to_name_mapping[self.dm_categories.items[ann.label].name] +# else: +# ann_name = self.dm_categories.items[ann.label].name +# group_idx, in_group_idx = self.label_info.class_to_group_idx[ann_name] +# +# if group_idx < num_multiclass_heads: +# class_indices[group_idx] = in_group_idx +# elif ann.label not in ignored_labels: +# class_indices[num_multiclass_heads + in_group_idx] = 1 +# else: +# class_indices[num_multiclass_heads + in_group_idx] = -1 +# +# return class_indices \ No newline at end of file diff --git a/src/otx/data/entity/sample.py b/src/otx/data/entity/sample.py new file mode 100644 index 00000000000..67b93aa7cff --- /dev/null +++ b/src/otx/data/entity/sample.py @@ -0,0 +1,61 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Sample classes for OTX data entities.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np +import polars as pl +import torch +from datumaro import Mask +from datumaro.components.media import Image +from datumaro.experimental.dataset import Sample +from datumaro.experimental.fields import ImageInfo, label_field, image_field + +if TYPE_CHECKING: + from datumaro import Polygon, DatasetItem + from torchvision.tv_tensors import BoundingBoxes, Mask + + +class ClassificationSample(Sample): + """OTXDataItemSample is a base class for OTX data items.""" + label: int = label_field(pl.Int32()) + image: torch.Tensor | np.ndarray = image_field(dtype=pl.UInt8) + + @classmethod + def from_dm_item(cls, item: DatasetItem) -> "ClassificationSample": + """ + Create a ClassificationSample from a Datumaro DatasetItem. + + Args: + item: Datumaro DatasetItem containing image and label + + Returns: + ClassificationSample: Instance with image and label set + """ + image = item.media_as(Image).data + label = item.annotations[0].label if item.annotations else None + return cls(image=image, label=label) + + @property + def masks(self) -> Mask | None: + return None + + @property + def bboxes(self) -> BoundingBoxes | None: + return None + + @property + def keypoints(self) -> torch.Tensor | None: + return None + + @property + def polygons(self) -> list[Polygon] | None: + return None + + @property + def img_info(self) -> ImageInfo | None: + return None diff --git a/src/otx/data/factory.py b/src/otx/data/factory.py index 7f601c4e69d..1dcbe8254ac 100644 --- a/src/otx/data/factory.py +++ b/src/otx/data/factory.py @@ -7,15 +7,22 @@ from typing import TYPE_CHECKING +from datumaro.experimental.categories import LabelCategories, LabelCategory, LabelGroup + from otx.types.image import ImageColorChannel from otx.types.task import OTXTaskType from otx.types.transformer_libs import TransformLibType from .dataset.base import OTXDataset, Transforms +from .dataset.base_new import OTXDataset as OTXDatasetNew +from datumaro.experimental import Dataset as DatasetNew -if TYPE_CHECKING: - from datumaro import Dataset as DmDataset +from otx import LabelInfo, NullLabelInfo +from datumaro.components.dataset import Dataset as DmDataset +from datumaro.components.annotation import AnnotationType + +if TYPE_CHECKING: from otx.config.data import SubsetConfig @@ -41,15 +48,15 @@ class OTXDatasetFactory: @classmethod def create( - cls: type[OTXDatasetFactory], + cls, task: OTXTaskType, - dm_subset: DmDataset, + dm_subset: DmDataset | DatasetNew, cfg_subset: SubsetConfig, data_format: str, image_color_channel: ImageColorChannel = ImageColorChannel.RGB, include_polygons: bool = False, ignore_index: int = 255, - ) -> OTXDataset: + ) -> OTXDataset | OTXDatasetNew: """Create OTXDataset.""" transforms = TransformLibFactory.generate(cfg_subset) common_kwargs = { @@ -71,12 +78,17 @@ def create( return OTXAnomalyDataset(task_type=task, **common_kwargs) if task == OTXTaskType.MULTI_CLASS_CLS: - from .dataset.classification import OTXMulticlassClsDataset - + from .dataset.classification_new import OTXMulticlassClsDataset, ClassificationSample + if isinstance(dm_subset, DmDataset): + categories = cls._get_label_categories(dm_subset, data_format) + dataset = DatasetNew(ClassificationSample, categories={"label": categories}) + for item in dm_subset: + dataset.append(ClassificationSample.from_dm_item(item)) + common_kwargs["dm_subset"] = dataset return OTXMulticlassClsDataset(**common_kwargs) if task == OTXTaskType.MULTI_LABEL_CLS: - from .dataset.classification import OTXMultilabelClsDataset + from otx.data.dataset.classification import OTXMultilabelClsDataset return OTXMultilabelClsDataset(**common_kwargs) @@ -106,3 +118,18 @@ def create( return OTXKeypointDetectionDataset(**common_kwargs) raise NotImplementedError(task) + + @staticmethod + def _get_label_categories(dm_subset: DmDataset, data_format: str) -> LabelCategories: + # TODO: Support hierarchical labels + + if dm_subset.categories() and data_format == "arrow": + label_info = LabelInfo.from_dm_label_groups_arrow(dm_subset.categories()[AnnotationType.label]) + elif dm_subset.categories(): + label_info = LabelInfo.from_dm_label_groups(dm_subset.categories()[AnnotationType.label]) + else: + label_info = NullLabelInfo() + + label_categories = [LabelCategory(name=label_name) for label_name in label_info.label_names] + label_group = LabelGroup(name="default", labels=[name for name in label_info.label_names]) + return LabelCategories(items=label_categories, label_groups=[label_group]) \ No newline at end of file diff --git a/tests/unit/data/test_factory.py b/tests/unit/data/test_factory.py index 3c24b1c774b..f28c97d52c7 100644 --- a/tests/unit/data/test_factory.py +++ b/tests/unit/data/test_factory.py @@ -10,9 +10,9 @@ from otx.data.dataset.classification import ( HLabelInfo, OTXHlabelClsDataset, - OTXMulticlassClsDataset, OTXMultilabelClsDataset, ) +from otx.data.dataset.classification_new import OTXMulticlassClsDataset from otx.data.dataset.detection import OTXDetectionDataset from otx.data.dataset.instance_segmentation import OTXInstanceSegDataset from otx.data.dataset.segmentation import OTXSegmentationDataset From 5ea1908ad475cfd318d722458509d3a62abf921a Mon Sep 17 00:00:00 2001 From: Albert van Houten Date: Wed, 27 Aug 2025 15:40:39 +0200 Subject: [PATCH 02/19] Fix image conversions --- src/otx/data/dataset/base_new.py | 24 +++++++++++++++------- src/otx/data/dataset/classification_new.py | 4 ---- src/otx/data/entity/sample.py | 14 ++++++++++--- 3 files changed, 28 insertions(+), 14 deletions(-) diff --git a/src/otx/data/dataset/base_new.py b/src/otx/data/dataset/base_new.py index dbba7084e28..f02a57c62b2 100644 --- a/src/otx/data/dataset/base_new.py +++ b/src/otx/data/dataset/base_new.py @@ -8,15 +8,12 @@ from typing import Callable, List, Type, Union, Iterable import numpy as np -import torch from datumaro.experimental import Dataset -from datumaro.experimental.type_registry import convert_image_type from torch.utils.data import Dataset as TorchDataset from otx.data.entity.sample import ClassificationSample from otx.data.transform_libs.torchvision import Compose from otx.types.image import ImageColorChannel -from otx.types.label import NullLabelInfo Transforms = Union[Compose, Callable, List[Callable], dict[str, Compose | Callable | List[Callable]]] @@ -49,11 +46,11 @@ def __init__( self.stack_images = stack_images self.to_tv_image = to_tv_image self.sample_type = sample_type - + self.max_refetch = max_refetch self.data_format = data_format # TODO: Properly reinit label_info - self.label_info = dm_subset.categories() + self.label_info = dm_subset.categories self.dataset = dm_subset @@ -66,7 +63,7 @@ def _sample_another_idx(self) -> int: def _apply_transforms(self, entity: ClassificationSample) -> ClassificationSample | None: if isinstance(self.transforms, Compose): if self.to_tv_image: - entity = convert_image_type(entity, torch.Tensor) + entity.as_tv_image() return self.transforms(entity) if isinstance(self.transforms, Iterable): return self._iterable_transforms(entity) @@ -88,7 +85,20 @@ def _iterable_transforms(self, item: ClassificationSample) -> ClassificationSamp return results def __getitem__(self, index: int) -> ClassificationSample: - return self.dataset[index] + for _ in range(self.max_refetch): + results = self._get_item_impl(index) + + if results is not None: + return results + + index = self._sample_another_idx() + + msg = f"Reach the maximum refetch number ({self.max_refetch})" + raise RuntimeError(msg) + + def _get_item_impl(self, index: int) -> ClassificationSample | None: + return self._apply_transforms(self.dataset[index]) + @property def collate_fn(self) -> Callable: diff --git a/src/otx/data/dataset/classification_new.py b/src/otx/data/dataset/classification_new.py index 49526110fb5..0a80c6488b8 100644 --- a/src/otx/data/dataset/classification_new.py +++ b/src/otx/data/dataset/classification_new.py @@ -21,10 +21,6 @@ def __init__(self, **kwargs) -> None: kwargs["sample_type"] = ClassificationSample super().__init__(**kwargs) - def _get_item_impl(self, index: int) -> ClassificationSample | None: - return self.dataset[index] - - # class OTXMultilabelClsDataset(OTXDataset): # """OTXDataset class for multi-label classification task using new Datumaro experimental Dataset.""" # diff --git a/src/otx/data/entity/sample.py b/src/otx/data/entity/sample.py index 67b93aa7cff..76cf35c51d4 100644 --- a/src/otx/data/entity/sample.py +++ b/src/otx/data/entity/sample.py @@ -14,6 +14,7 @@ from datumaro.components.media import Image from datumaro.experimental.dataset import Sample from datumaro.experimental.fields import ImageInfo, label_field, image_field +from torchvision import tv_tensors if TYPE_CHECKING: from datumaro import Polygon, DatasetItem @@ -22,8 +23,8 @@ class ClassificationSample(Sample): """OTXDataItemSample is a base class for OTX data items.""" - label: int = label_field(pl.Int32()) - image: torch.Tensor | np.ndarray = image_field(dtype=pl.UInt8) + label: torch.Tensor = label_field(pl.Int32()) + image: torch.Tensor | np.ndarray | tv_tensors.Image = image_field(dtype=pl.UInt8) @classmethod def from_dm_item(cls, item: DatasetItem) -> "ClassificationSample": @@ -38,7 +39,14 @@ def from_dm_item(cls, item: DatasetItem) -> "ClassificationSample": """ image = item.media_as(Image).data label = item.annotations[0].label if item.annotations else None - return cls(image=image, label=label) + return cls(image=image, label=torch.as_tensor(label, dtype=torch.long)) + + def as_tv_image(self): + """Convert image to torchvision tv_tensors Image format.""" + if isinstance(self.image, np.ndarray): + self.image = tv_tensors.Image(self.image) + elif isinstance(self.image, torch.Tensor): + self.image = tv_tensors.Image(self.image.numpy()) @property def masks(self) -> Mask | None: From 87c5a316d7d194e2532ed0574720e1af8cc0f1d3 Mon Sep 17 00:00:00 2001 From: Albert van Houten Date: Thu, 28 Aug 2025 09:56:56 +0200 Subject: [PATCH 03/19] add batching logic --- pyproject.toml | 2 +- src/otx/data/dataset/base_new.py | 107 ++++++++++++++++-- src/otx/data/entity/sample.py | 2 +- src/otx/data/factory.py | 26 ++--- tests/unit/data/conftest.py | 2 + .../unit/data/dataset/test_classification.py | 7 +- 6 files changed, 118 insertions(+), 28 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0d78932d307..cbbc3ecfb2f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ classifiers = [ "Programming Language :: Python :: 3.12", ] dependencies = [ - "datumaro==1.10.0", + "datumaro @ git+https://github.com/open-edge-platform/datumaro.git@albert/datumaro-otx-integration", "omegaconf==2.3.0", "rich==14.0.0", "jsonargparse==4.35.0", diff --git a/src/otx/data/dataset/base_new.py b/src/otx/data/dataset/base_new.py index f02a57c62b2..1745bcf4de1 100644 --- a/src/otx/data/dataset/base_new.py +++ b/src/otx/data/dataset/base_new.py @@ -5,19 +5,84 @@ from __future__ import annotations -from typing import Callable, List, Type, Union, Iterable +from dataclasses import asdict, dataclass +from typing import TYPE_CHECKING, Any, Callable, Iterable, List, Sequence, Union import numpy as np +import torch from datumaro.experimental import Dataset from torch.utils.data import Dataset as TorchDataset +from torchvision import tv_tensors from otx.data.entity.sample import ClassificationSample +from otx.data.entity.torch.validations import ValidateBatchMixin from otx.data.transform_libs.torchvision import Compose from otx.types.image import ImageColorChannel +if TYPE_CHECKING: + from datumaro import Polygon + from torchvision.tv_tensors import BoundingBoxes, Mask + + from otx.data.entity.base import ImageInfo + Transforms = Union[Compose, Callable, List[Callable], dict[str, Compose | Callable | List[Callable]]] +@dataclass +class ClassificationBatch(ValidateBatchMixin): + """Classification batch implementation based on ClassificationSample.""" + + batch_size: int + images: torch.Tensor | list[torch.Tensor] + labels: list[torch.Tensor] | None = None + masks: list[Mask] | None = None + bboxes: list[BoundingBoxes] | None = None + keypoints: list[torch.Tensor] | None = None + polygons: list[list[Polygon]] | None = None + imgs_info: Sequence[ImageInfo | None] | None = None + + def pin_memory(self) -> "ClassificationBatch": + """Pin memory for member tensor variables.""" + kwargs = {} + + def maybe_pin(x: Any) -> Any: # noqa: ANN401 + if isinstance(x, torch.Tensor): + return x.pin_memory() + return x + + def maybe_wrap_tv(x: Any) -> Any: # noqa: ANN401 + if isinstance(x, tv_tensors.TVTensor): + return tv_tensors.wrap(x.pin_memory(), like=x) + return maybe_pin(x) + + # Handle images separately because of tv_tensors wrapping + if self.images is not None: + if isinstance(self.images, list): + kwargs["images"] = [maybe_wrap_tv(img) for img in self.images] + else: + kwargs["images"] = maybe_wrap_tv(self.images) + + # Generic handler for all other fields + for field in ["labels", "bboxes", "keypoints", "masks"]: + value = getattr(self, field) + if value is not None: + kwargs[field] = [maybe_wrap_tv(v) if v is not None else None for v in value] + + return self.wrap(**kwargs) + + def wrap(self, **kwargs) -> "ClassificationBatch": + """Wrap this dataclass with the given keyword arguments. + + Args: + **kwargs: Keyword arguments to be overwritten on top of this dataclass + Returns: + Updated dataclass + """ + updated_kwargs = asdict(self) + updated_kwargs.update(**kwargs) + return self.__class__(**updated_kwargs) + + class OTXDataset(TorchDataset): """Base OTXDataset using new Datumaro experimental Dataset. @@ -39,7 +104,7 @@ def __init__( stack_images: bool = True, to_tv_image: bool = True, data_format: str = "", - sample_type: Type[ClassificationSample] = ClassificationSample, + sample_type: type[ClassificationSample] = ClassificationSample, ) -> None: self.transforms = transforms self.image_color_channel = image_color_channel @@ -50,7 +115,7 @@ def __init__( self.data_format = data_format # TODO: Properly reinit label_info - self.label_info = dm_subset.categories + self.label_info = dm_subset.label_group self.dataset = dm_subset @@ -68,7 +133,7 @@ def _apply_transforms(self, entity: ClassificationSample) -> ClassificationSampl if isinstance(self.transforms, Iterable): return self._iterable_transforms(entity) if callable(self.transforms): - return self.transforms(entity) + return self.transforms(entity) def _iterable_transforms(self, item: ClassificationSample) -> ClassificationSample | None: if not isinstance(self.transforms, list): @@ -97,10 +162,38 @@ def __getitem__(self, index: int) -> ClassificationSample: raise RuntimeError(msg) def _get_item_impl(self, index: int) -> ClassificationSample | None: - return self._apply_transforms(self.dataset[index]) - + dm_item = self.dataset[index] + sample = self.sample_type.from_dm_item(dm_item) + return self._apply_transforms(sample) @property def collate_fn(self) -> Callable: """Collection function to collect samples into a batch in data loader.""" - pass \ No newline at end of file + def _collate_fn(items: list[ClassificationSample]) -> ClassificationBatch: + """Collate ClassificationSample items into a ClassificationBatch. + + Args: + items: List of ClassificationSample items to batch + Returns: + Batched ClassificationSample items with stacked tensors + """ + # Check if all images have the same size for stacking + if all(item.image.shape == items[0].image.shape for item in items): + images = torch.stack([item.image for item in items]) + else: + # Keep as list if shapes differ (e.g., for OV inference) + images = [item.image for item in items] + + return ClassificationBatch( + batch_size=len(items), + images=images, + labels=[item.label for item in items] if items[0].label is not None else None, + masks=[item.masks for item in items] if any(item.masks is not None for item in items) else None, + bboxes=[item.bboxes for item in items] if any(item.bboxes is not None for item in items) else None, + keypoints=[item.keypoints for item in items] if any(item.keypoints is not None for item in items) else None, + polygons=[item.polygons for item in items] if any(item.polygons is not None for item in items) else None, + imgs_info=[item.img_info for item in items] if any(item.img_info is not None for item in items) else None, + ) + + return _collate_fn + \ No newline at end of file diff --git a/src/otx/data/entity/sample.py b/src/otx/data/entity/sample.py index 76cf35c51d4..3b5ab3352bf 100644 --- a/src/otx/data/entity/sample.py +++ b/src/otx/data/entity/sample.py @@ -41,7 +41,7 @@ def from_dm_item(cls, item: DatasetItem) -> "ClassificationSample": label = item.annotations[0].label if item.annotations else None return cls(image=image, label=torch.as_tensor(label, dtype=torch.long)) - def as_tv_image(self): + def as_tv_image(self) -> None: """Convert image to torchvision tv_tensors Image format.""" if isinstance(self.image, np.ndarray): self.image = tv_tensors.Image(self.image) diff --git a/src/otx/data/factory.py b/src/otx/data/factory.py index 1dcbe8254ac..b4d9956ef9e 100644 --- a/src/otx/data/factory.py +++ b/src/otx/data/factory.py @@ -7,20 +7,18 @@ from typing import TYPE_CHECKING -from datumaro.experimental.categories import LabelCategories, LabelCategory, LabelGroup +from datumaro.components.annotation import AnnotationType +from datumaro.components.dataset import Dataset as DmDataset +from datumaro.experimental import Dataset as DatasetNew +from datumaro.experimental.categories import LabelGroup +from otx import LabelInfo, NullLabelInfo from otx.types.image import ImageColorChannel from otx.types.task import OTXTaskType from otx.types.transformer_libs import TransformLibType from .dataset.base import OTXDataset, Transforms from .dataset.base_new import OTXDataset as OTXDatasetNew -from datumaro.experimental import Dataset as DatasetNew - -from otx import LabelInfo, NullLabelInfo - -from datumaro.components.dataset import Dataset as DmDataset -from datumaro.components.annotation import AnnotationType if TYPE_CHECKING: from otx.config.data import SubsetConfig @@ -78,10 +76,11 @@ def create( return OTXAnomalyDataset(task_type=task, **common_kwargs) if task == OTXTaskType.MULTI_CLASS_CLS: - from .dataset.classification_new import OTXMulticlassClsDataset, ClassificationSample + from .dataset.classification_new import ClassificationSample, OTXMulticlassClsDataset + if isinstance(dm_subset, DmDataset): categories = cls._get_label_categories(dm_subset, data_format) - dataset = DatasetNew(ClassificationSample, categories={"label": categories}) + dataset = DatasetNew(ClassificationSample, label_group=categories) for item in dm_subset: dataset.append(ClassificationSample.from_dm_item(item)) common_kwargs["dm_subset"] = dataset @@ -120,16 +119,11 @@ def create( raise NotImplementedError(task) @staticmethod - def _get_label_categories(dm_subset: DmDataset, data_format: str) -> LabelCategories: - # TODO: Support hierarchical labels - + def _get_label_categories(dm_subset: DmDataset, data_format: str) -> LabelGroup: if dm_subset.categories() and data_format == "arrow": label_info = LabelInfo.from_dm_label_groups_arrow(dm_subset.categories()[AnnotationType.label]) elif dm_subset.categories(): label_info = LabelInfo.from_dm_label_groups(dm_subset.categories()[AnnotationType.label]) else: label_info = NullLabelInfo() - - label_categories = [LabelCategory(name=label_name) for label_name in label_info.label_names] - label_group = LabelGroup(name="default", labels=[name for name in label_info.label_names]) - return LabelCategories(items=label_categories, label_groups=[label_group]) \ No newline at end of file + return LabelGroup(labels=label_info.label_names) diff --git a/tests/unit/data/conftest.py b/tests/unit/data/conftest.py index 2932bc055d1..af023346462 100644 --- a/tests/unit/data/conftest.py +++ b/tests/unit/data/conftest.py @@ -115,6 +115,7 @@ def fxt_mock_dm_subset(mocker: MockerFixture, fxt_dm_item: DatasetItem) -> Magic AnnotationType.mask, AnnotationType.polygon, ] + mock_dm_subset.label_group = mocker.MagicMock() return mock_dm_subset @@ -125,6 +126,7 @@ def fxt_mock_det_dm_subset(mocker: MockerFixture, fxt_dm_item_bbox_only: Dataset mock_dm_subset.__len__.return_value = 1 mock_dm_subset.categories().__getitem__.return_value = LabelCategories.from_iterable(_LABEL_NAMES) mock_dm_subset.ann_types.return_value = [AnnotationType.bbox] + mock_dm_subset.label_group = mocker.MagicMock() return mock_dm_subset diff --git a/tests/unit/data/dataset/test_classification.py b/tests/unit/data/dataset/test_classification.py index c6a62ecea9f..bb564d06cc5 100644 --- a/tests/unit/data/dataset/test_classification.py +++ b/tests/unit/data/dataset/test_classification.py @@ -8,9 +8,10 @@ from otx.data.dataset.classification import ( HLabelInfo, OTXHlabelClsDataset, - OTXMulticlassClsDataset, OTXMultilabelClsDataset, ) +from otx.data.dataset.classification_new import OTXMulticlassClsDataset +from otx.data.entity.sample import ClassificationSample from otx.data.entity.torch import OTXDataItem @@ -24,7 +25,7 @@ def test_get_item( transforms=[lambda x: x], max_refetch=3, ) - assert isinstance(dataset[0], OTXDataItem) + assert isinstance(dataset[0], ClassificationSample) def test_get_item_from_bbox_dataset( self, @@ -35,7 +36,7 @@ def test_get_item_from_bbox_dataset( transforms=[lambda x: x], max_refetch=3, ) - assert isinstance(dataset[0], OTXDataItem) + assert isinstance(dataset[0], ClassificationSample) class TestOTXMultilabelClsDataset: From 1d8f3320a55cb1eb2c976c7bc7d05a2810321550 Mon Sep 17 00:00:00 2001 From: Albert van Houten Date: Thu, 28 Aug 2025 13:28:09 +0200 Subject: [PATCH 04/19] Switch logic back to reuse OTXDataBatch which is initialized with data from Sample --- src/otx/data/dataset/base_new.py | 91 ++++++-------------------------- 1 file changed, 17 insertions(+), 74 deletions(-) diff --git a/src/otx/data/dataset/base_new.py b/src/otx/data/dataset/base_new.py index 1745bcf4de1..3696483da17 100644 --- a/src/otx/data/dataset/base_new.py +++ b/src/otx/data/dataset/base_new.py @@ -5,84 +5,21 @@ from __future__ import annotations -from dataclasses import asdict, dataclass -from typing import TYPE_CHECKING, Any, Callable, Iterable, List, Sequence, Union +from typing import Callable, Iterable, List, Union import numpy as np import torch from datumaro.experimental import Dataset from torch.utils.data import Dataset as TorchDataset -from torchvision import tv_tensors from otx.data.entity.sample import ClassificationSample -from otx.data.entity.torch.validations import ValidateBatchMixin +from otx.data.entity.torch.torch import OTXDataBatch from otx.data.transform_libs.torchvision import Compose from otx.types.image import ImageColorChannel -if TYPE_CHECKING: - from datumaro import Polygon - from torchvision.tv_tensors import BoundingBoxes, Mask - - from otx.data.entity.base import ImageInfo - Transforms = Union[Compose, Callable, List[Callable], dict[str, Compose | Callable | List[Callable]]] -@dataclass -class ClassificationBatch(ValidateBatchMixin): - """Classification batch implementation based on ClassificationSample.""" - - batch_size: int - images: torch.Tensor | list[torch.Tensor] - labels: list[torch.Tensor] | None = None - masks: list[Mask] | None = None - bboxes: list[BoundingBoxes] | None = None - keypoints: list[torch.Tensor] | None = None - polygons: list[list[Polygon]] | None = None - imgs_info: Sequence[ImageInfo | None] | None = None - - def pin_memory(self) -> "ClassificationBatch": - """Pin memory for member tensor variables.""" - kwargs = {} - - def maybe_pin(x: Any) -> Any: # noqa: ANN401 - if isinstance(x, torch.Tensor): - return x.pin_memory() - return x - - def maybe_wrap_tv(x: Any) -> Any: # noqa: ANN401 - if isinstance(x, tv_tensors.TVTensor): - return tv_tensors.wrap(x.pin_memory(), like=x) - return maybe_pin(x) - - # Handle images separately because of tv_tensors wrapping - if self.images is not None: - if isinstance(self.images, list): - kwargs["images"] = [maybe_wrap_tv(img) for img in self.images] - else: - kwargs["images"] = maybe_wrap_tv(self.images) - - # Generic handler for all other fields - for field in ["labels", "bboxes", "keypoints", "masks"]: - value = getattr(self, field) - if value is not None: - kwargs[field] = [maybe_wrap_tv(v) if v is not None else None for v in value] - - return self.wrap(**kwargs) - - def wrap(self, **kwargs) -> "ClassificationBatch": - """Wrap this dataclass with the given keyword arguments. - - Args: - **kwargs: Keyword arguments to be overwritten on top of this dataclass - Returns: - Updated dataclass - """ - updated_kwargs = asdict(self) - updated_kwargs.update(**kwargs) - return self.__class__(**updated_kwargs) - - class OTXDataset(TorchDataset): """Base OTXDataset using new Datumaro experimental Dataset. @@ -169,9 +106,10 @@ def _get_item_impl(self, index: int) -> ClassificationSample | None: @property def collate_fn(self) -> Callable: """Collection function to collect samples into a batch in data loader.""" - def _collate_fn(items: list[ClassificationSample]) -> ClassificationBatch: - """Collate ClassificationSample items into a ClassificationBatch. - + + def _collate_fn(items: list[ClassificationSample]) -> OTXDataBatch: + """Collate ClassificationSample items into an OTXDataBatch. + Args: items: List of ClassificationSample items to batch Returns: @@ -184,16 +122,21 @@ def _collate_fn(items: list[ClassificationSample]) -> ClassificationBatch: # Keep as list if shapes differ (e.g., for OV inference) images = [item.image for item in items] - return ClassificationBatch( + return OTXDataBatch( batch_size=len(items), images=images, labels=[item.label for item in items] if items[0].label is not None else None, masks=[item.masks for item in items] if any(item.masks is not None for item in items) else None, bboxes=[item.bboxes for item in items] if any(item.bboxes is not None for item in items) else None, - keypoints=[item.keypoints for item in items] if any(item.keypoints is not None for item in items) else None, - polygons=[item.polygons for item in items] if any(item.polygons is not None for item in items) else None, - imgs_info=[item.img_info for item in items] if any(item.img_info is not None for item in items) else None, + keypoints=[item.keypoints for item in items] + if any(item.keypoints is not None for item in items) + else None, + polygons=[item.polygons for item in items] + if any(item.polygons is not None for item in items) + else None, + imgs_info=[item.img_info for item in items] + if any(item.img_info is not None for item in items) + else None, ) - + return _collate_fn - \ No newline at end of file From 966ae5604554f5889d31ea08c51d0f52ff48256d Mon Sep 17 00:00:00 2001 From: Albert van Houten Date: Thu, 28 Aug 2025 15:47:55 +0200 Subject: [PATCH 05/19] Fix datumaro LabelGroup in _dispatch_label_info method --- src/otx/backend/native/models/base.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/otx/backend/native/models/base.py b/src/otx/backend/native/models/base.py index 8cc10512fe6..54372f468c5 100644 --- a/src/otx/backend/native/models/base.py +++ b/src/otx/backend/native/models/base.py @@ -16,6 +16,7 @@ import torch from datumaro import LabelCategories +from datumaro.experimental.categories import LabelGroup from lightning import LightningModule, Trainer from torch import Tensor, nn from torch.optim.lr_scheduler import ConstantLR @@ -884,6 +885,14 @@ def _dispatch_label_info(label_info: LabelInfoTypes) -> LabelInfo: label_groups=[label_info], label_ids=[str(i) for i in range(len(label_info))], ) + if isinstance(label_info, LabelGroup): + # Handle LabelGroup objects from datumaro + labels = label_info.labels + return LabelInfo( + label_names=labels, + label_groups=[labels], + label_ids=[str(i) for i in range(len(labels))], + ) if isinstance(label_info, LabelInfo): if not hasattr(label_info, "label_ids"): # NOTE: This is for backward compatibility From 53afc105db4c263a326e8efd083c61746006d422 Mon Sep 17 00:00:00 2001 From: Albert van Houten Date: Mon, 1 Sep 2025 15:42:19 +0200 Subject: [PATCH 06/19] Several fixes to use new Datumaro Dataset class for classification training --- .../native/callbacks/gpu_mem_monitor.py | 2 +- src/otx/backend/native/models/base.py | 6 +- src/otx/data/dataset/base_new.py | 102 +++++++++++------- src/otx/data/dataset/classification_new.py | 16 ++- src/otx/data/entity/sample.py | 49 +++++++-- src/otx/data/factory.py | 11 +- src/otx/data/samplers/balanced_sampler.py | 11 +- src/otx/data/transform_libs/utils.py | 1 + src/otx/data/utils/utils.py | 2 +- 9 files changed, 134 insertions(+), 66 deletions(-) diff --git a/src/otx/backend/native/callbacks/gpu_mem_monitor.py b/src/otx/backend/native/callbacks/gpu_mem_monitor.py index 4d7d6388107..dcea0d5b36c 100644 --- a/src/otx/backend/native/callbacks/gpu_mem_monitor.py +++ b/src/otx/backend/native/callbacks/gpu_mem_monitor.py @@ -29,7 +29,7 @@ def _get_and_log_device_stats( batch_size (int): batch size. """ device = trainer.strategy.root_device - if device.type in ["cpu", "xpu"]: + if device.type in ["cpu", "xpu", "mps"]: return device_stats = trainer.accelerator.get_device_stats(device) diff --git a/src/otx/backend/native/models/base.py b/src/otx/backend/native/models/base.py index 54372f468c5..2efb1552577 100644 --- a/src/otx/backend/native/models/base.py +++ b/src/otx/backend/native/models/base.py @@ -16,7 +16,7 @@ import torch from datumaro import LabelCategories -from datumaro.experimental.categories import LabelGroup +from datumaro.experimental.categories import LabelCategories as NewLabelCategories from lightning import LightningModule, Trainer from torch import Tensor, nn from torch.optim.lr_scheduler import ConstantLR @@ -885,8 +885,8 @@ def _dispatch_label_info(label_info: LabelInfoTypes) -> LabelInfo: label_groups=[label_info], label_ids=[str(i) for i in range(len(label_info))], ) - if isinstance(label_info, LabelGroup): - # Handle LabelGroup objects from datumaro + if isinstance(label_info, NewLabelCategories): + # Handle LabelCategories objects from datumaro labels = label_info.labels return LabelInfo( label_names=labels, diff --git a/src/otx/data/dataset/base_new.py b/src/otx/data/dataset/base_new.py index 3696483da17..e03f1fa253d 100644 --- a/src/otx/data/dataset/base_new.py +++ b/src/otx/data/dataset/base_new.py @@ -5,6 +5,7 @@ from __future__ import annotations +import abc from typing import Callable, Iterable, List, Union import numpy as np @@ -20,6 +21,45 @@ Transforms = Union[Compose, Callable, List[Callable], dict[str, Compose | Callable | List[Callable]]] +def _default_collate_fn(items: list[ClassificationSample]) -> OTXDataBatch: + """Collate ClassificationSample items into an OTXDataBatch. + + Args: + items: List of ClassificationSample items to batch + Returns: + Batched ClassificationSample items with stacked tensors + """ + # Convert images to float32 tensors before stacking + image_tensors = [] + for item in items: + img = item.image + if isinstance(img, torch.Tensor): + # Convert to float32 if not already + if img.dtype != torch.float32: + img = img.float() + else: + # Convert numpy array to float32 tensor + img = torch.from_numpy(img).float() + image_tensors.append(img) + + # Try to stack images if they have the same shape + if len(image_tensors) > 0 and all(t.shape == image_tensors[0].shape for t in image_tensors): + images = torch.stack(image_tensors) + else: + images = image_tensors + + return OTXDataBatch( + batch_size=len(items), + images=images, + labels=[item.label for item in items] if items[0].label is not None else None, + masks=[item.masks for item in items] if any(item.masks is not None for item in items) else None, + bboxes=[item.bboxes for item in items] if any(item.bboxes is not None for item in items) else None, + keypoints=[item.keypoints for item in items] if any(item.keypoints is not None for item in items) else None, + polygons=[item.polygons for item in items] if any(item.polygons is not None for item in items) else None, + imgs_info=[item.img_info for item in items] if any(item.img_info is not None for item in items) else None, + ) + + class OTXDataset(TorchDataset): """Base OTXDataset using new Datumaro experimental Dataset. @@ -50,14 +90,18 @@ def __init__( self.sample_type = sample_type self.max_refetch = max_refetch self.data_format = data_format - - # TODO: Properly reinit label_info - self.label_info = dm_subset.label_group - - self.dataset = dm_subset + if ( + hasattr(dm_subset, "schema") + and hasattr(dm_subset.schema, "attributes") + and "label" in dm_subset.schema.attributes + ): + self.label_info = dm_subset.schema.attributes["label"].categories + else: + self.label_info = None + self.dm_subset = dm_subset def __len__(self) -> int: - return len(self.dataset) + return len(self.dm_subset) def _sample_another_idx(self) -> int: return np.random.randint(0, len(self)) @@ -99,44 +143,20 @@ def __getitem__(self, index: int) -> ClassificationSample: raise RuntimeError(msg) def _get_item_impl(self, index: int) -> ClassificationSample | None: - dm_item = self.dataset[index] - sample = self.sample_type.from_dm_item(dm_item) + dm_item = self.dm_subset[index] + # Check if dm_item is already a sample of the expected type + if isinstance(dm_item, self.sample_type): + sample = dm_item + else: + # dm_item is a DatasetItem, convert it using from_dm_item + sample = self.sample_type.from_dm_item(dm_item) return self._apply_transforms(sample) @property def collate_fn(self) -> Callable: """Collection function to collect samples into a batch in data loader.""" + return _default_collate_fn - def _collate_fn(items: list[ClassificationSample]) -> OTXDataBatch: - """Collate ClassificationSample items into an OTXDataBatch. - - Args: - items: List of ClassificationSample items to batch - Returns: - Batched ClassificationSample items with stacked tensors - """ - # Check if all images have the same size for stacking - if all(item.image.shape == items[0].image.shape for item in items): - images = torch.stack([item.image for item in items]) - else: - # Keep as list if shapes differ (e.g., for OV inference) - images = [item.image for item in items] - - return OTXDataBatch( - batch_size=len(items), - images=images, - labels=[item.label for item in items] if items[0].label is not None else None, - masks=[item.masks for item in items] if any(item.masks is not None for item in items) else None, - bboxes=[item.bboxes for item in items] if any(item.bboxes is not None for item in items) else None, - keypoints=[item.keypoints for item in items] - if any(item.keypoints is not None for item in items) - else None, - polygons=[item.polygons for item in items] - if any(item.polygons is not None for item in items) - else None, - imgs_info=[item.img_info for item in items] - if any(item.img_info is not None for item in items) - else None, - ) - - return _collate_fn + @abc.abstractmethod + def get_idx_list_per_classes(self) -> dict[int, list[int]]: + """Get a dictionary with class labels as keys and lists of corresponding sample indices as values.""" diff --git a/src/otx/data/dataset/classification_new.py b/src/otx/data/dataset/classification_new.py index 0a80c6488b8..7e2befda7aa 100644 --- a/src/otx/data/dataset/classification_new.py +++ b/src/otx/data/dataset/classification_new.py @@ -5,8 +5,8 @@ from __future__ import annotations -from .base_new import OTXDataset from ..entity.sample import ClassificationSample +from .base_new import OTXDataset class OTXMulticlassClsDataset(OTXDataset): @@ -21,6 +21,18 @@ def __init__(self, **kwargs) -> None: kwargs["sample_type"] = ClassificationSample super().__init__(**kwargs) + def get_idx_list_per_classes(self) -> dict[int, list[int]]: + """Get index list per class.""" + idx_list_per_classes = {} + for idx in range(len(self)): + item = self.dm_subset[idx] + label_id = item.label.item() + if label_id not in idx_list_per_classes: + idx_list_per_classes[label_id] = [] + idx_list_per_classes[label_id].append(idx) + return idx_list_per_classes + + # class OTXMultilabelClsDataset(OTXDataset): # """OTXDataset class for multi-label classification task using new Datumaro experimental Dataset.""" # @@ -257,4 +269,4 @@ def __init__(self, **kwargs) -> None: # else: # class_indices[num_multiclass_heads + in_group_idx] = -1 # -# return class_indices \ No newline at end of file +# return class_indices diff --git a/src/otx/data/entity/sample.py b/src/otx/data/entity/sample.py index 3b5ab3352bf..78a7bc05ac1 100644 --- a/src/otx/data/entity/sample.py +++ b/src/otx/data/entity/sample.py @@ -13,23 +13,25 @@ from datumaro import Mask from datumaro.components.media import Image from datumaro.experimental.dataset import Sample -from datumaro.experimental.fields import ImageInfo, label_field, image_field +from datumaro.experimental.fields import image_field, label_field from torchvision import tv_tensors +from otx.data.entity.base import ImageInfo + if TYPE_CHECKING: - from datumaro import Polygon, DatasetItem + from datumaro import DatasetItem, Polygon from torchvision.tv_tensors import BoundingBoxes, Mask class ClassificationSample(Sample): """OTXDataItemSample is a base class for OTX data items.""" + label: torch.Tensor = label_field(pl.Int32()) - image: torch.Tensor | np.ndarray | tv_tensors.Image = image_field(dtype=pl.UInt8) + image: np.ndarray | torch.Tensor | tv_tensors.Image = image_field(dtype=pl.UInt8) @classmethod - def from_dm_item(cls, item: DatasetItem) -> "ClassificationSample": - """ - Create a ClassificationSample from a Datumaro DatasetItem. + def from_dm_item(cls, item: DatasetItem) -> ClassificationSample: + """Create a ClassificationSample from a Datumaro DatasetItem. Args: item: Datumaro DatasetItem containing image and label @@ -39,14 +41,24 @@ def from_dm_item(cls, item: DatasetItem) -> "ClassificationSample": """ image = item.media_as(Image).data label = item.annotations[0].label if item.annotations else None - return cls(image=image, label=torch.as_tensor(label, dtype=torch.long)) + + img_shape = image.shape[:2] + img_info = ImageInfo( + img_idx=0, + img_shape=img_shape, + ori_shape=img_shape, + ) + + sample = cls(image=image, label=torch.as_tensor(label, dtype=torch.long)) + sample._img_info = img_info + return sample def as_tv_image(self) -> None: """Convert image to torchvision tv_tensors Image format.""" - if isinstance(self.image, np.ndarray): + if isinstance(self.image, tv_tensors.Image): + return + if isinstance(self.image, (np.ndarray, torch.Tensor)): self.image = tv_tensors.Image(self.image) - elif isinstance(self.image, torch.Tensor): - self.image = tv_tensors.Image(self.image.numpy()) @property def masks(self) -> Mask | None: @@ -66,4 +78,19 @@ def polygons(self) -> list[Polygon] | None: @property def img_info(self) -> ImageInfo | None: - return None + if getattr(self, "_img_info", None) is None: + image = getattr(self, "image", None) + if image is not None and hasattr(image, "shape") and len(image.shape) == 3: + img_shape = image.shape[:2] + else: + return None + self._img_info = ImageInfo( + img_idx=0, + img_shape=img_shape, + ori_shape=img_shape, + ) + return self._img_info + + @img_info.setter + def img_info(self, value: ImageInfo | None) -> None: + self._img_info = value diff --git a/src/otx/data/factory.py b/src/otx/data/factory.py index b4d9956ef9e..d854a6a4eb6 100644 --- a/src/otx/data/factory.py +++ b/src/otx/data/factory.py @@ -10,7 +10,7 @@ from datumaro.components.annotation import AnnotationType from datumaro.components.dataset import Dataset as DmDataset from datumaro.experimental import Dataset as DatasetNew -from datumaro.experimental.categories import LabelGroup +from datumaro.experimental.categories import LabelCategories from otx import LabelInfo, NullLabelInfo from otx.types.image import ImageColorChannel @@ -80,9 +80,10 @@ def create( if isinstance(dm_subset, DmDataset): categories = cls._get_label_categories(dm_subset, data_format) - dataset = DatasetNew(ClassificationSample, label_group=categories) + dataset = DatasetNew(ClassificationSample, categories={"label": categories}) for item in dm_subset: - dataset.append(ClassificationSample.from_dm_item(item)) + if len(item.media.data.shape) == 3: # TODO: Account for grayscale images + dataset.append(ClassificationSample.from_dm_item(item)) common_kwargs["dm_subset"] = dataset return OTXMulticlassClsDataset(**common_kwargs) @@ -119,11 +120,11 @@ def create( raise NotImplementedError(task) @staticmethod - def _get_label_categories(dm_subset: DmDataset, data_format: str) -> LabelGroup: + def _get_label_categories(dm_subset: DmDataset, data_format: str) -> LabelCategories: if dm_subset.categories() and data_format == "arrow": label_info = LabelInfo.from_dm_label_groups_arrow(dm_subset.categories()[AnnotationType.label]) elif dm_subset.categories(): label_info = LabelInfo.from_dm_label_groups(dm_subset.categories()[AnnotationType.label]) else: label_info = NullLabelInfo() - return LabelGroup(labels=label_info.label_names) + return LabelCategories(labels=label_info.label_names) diff --git a/src/otx/data/samplers/balanced_sampler.py b/src/otx/data/samplers/balanced_sampler.py index 43bc11fae0b..033716e3e21 100644 --- a/src/otx/data/samplers/balanced_sampler.py +++ b/src/otx/data/samplers/balanced_sampler.py @@ -9,12 +9,15 @@ from typing import TYPE_CHECKING import torch +from datumaro import DatasetSubset +from datumaro.experimental import Dataset as NewDataset from torch.utils.data import Sampler from otx.data.utils import get_idx_list_per_classes if TYPE_CHECKING: from otx.data.dataset.base import OTXDataset + from otx.data.dataset.base_new import OTXDataset as OTXDatasetNew class BalancedSampler(Sampler): @@ -43,7 +46,7 @@ class BalancedSampler(Sampler): def __init__( self, - dataset: OTXDataset, + dataset: OTXDataset | OTXDatasetNew, efficient_mode: bool = False, num_replicas: int = 1, rank: int = 0, @@ -61,7 +64,11 @@ def __init__( super().__init__(dataset) # img_indices: dict[label: list[idx]] - ann_stats = get_idx_list_per_classes(dataset.dm_subset) + if isinstance(dataset.dm_subset, DatasetSubset): + ann_stats = get_idx_list_per_classes(dataset.dm_subset) + elif isinstance(dataset.dm_subset, NewDataset): + ann_stats = dataset.get_idx_list_per_classes() + self.img_indices = {k: torch.tensor(v, dtype=torch.int64) for k, v in ann_stats.items() if len(v) > 0} self.num_cls = len(self.img_indices.keys()) self.data_length = len(self.dataset) diff --git a/src/otx/data/transform_libs/utils.py b/src/otx/data/transform_libs/utils.py index adae5fb7c61..acbf3827c9b 100644 --- a/src/otx/data/transform_libs/utils.py +++ b/src/otx/data/transform_libs/utils.py @@ -129,6 +129,7 @@ def to_np_image(img: np.ndarray | Tensor | list) -> np.ndarray | list[np.ndarray return img if isinstance(img, list): return [to_np_image(im) for im in img] + return np.ascontiguousarray(img.numpy().transpose(1, 2, 0)) diff --git a/src/otx/data/utils/utils.py b/src/otx/data/utils/utils.py index 769fc3ec7f8..94332be8fd6 100644 --- a/src/otx/data/utils/utils.py +++ b/src/otx/data/utils/utils.py @@ -325,10 +325,10 @@ def get_adaptive_num_workers(num_dataloader: int = 1) -> int | None: def get_idx_list_per_classes(dm_dataset: DmDataset, use_string_label: bool = False) -> dict[int | str, list[int]]: """Compute class statistics.""" stats: dict[int | str, list[int]] = defaultdict(list) - labels = dm_dataset.categories().get(AnnotationType.label, LabelCategories()) for item_idx, item in enumerate(dm_dataset): for ann in item.annotations: if use_string_label: + labels = dm_dataset.categories().get(AnnotationType.label, LabelCategories()) stats[labels.items[ann.label].name].append(item_idx) else: stats[ann.label].append(item_idx) From 3ae67453af1b44e705d00d7192acafbb12ae6cf3 Mon Sep 17 00:00:00 2001 From: Albert van Houten Date: Tue, 2 Sep 2025 09:42:26 +0200 Subject: [PATCH 07/19] add base OTX Sample class --- lib/src/otx/data/dataset/base_new.py | 28 +- .../otx/data/dataset/classification_new.py | 239 ------------------ lib/src/otx/data/entity/sample.py | 68 ++--- 3 files changed, 50 insertions(+), 285 deletions(-) diff --git a/lib/src/otx/data/dataset/base_new.py b/lib/src/otx/data/dataset/base_new.py index e03f1fa253d..a4502f2c146 100644 --- a/lib/src/otx/data/dataset/base_new.py +++ b/lib/src/otx/data/dataset/base_new.py @@ -13,7 +13,7 @@ from datumaro.experimental import Dataset from torch.utils.data import Dataset as TorchDataset -from otx.data.entity.sample import ClassificationSample +from otx.data.entity.sample import OTXSample from otx.data.entity.torch.torch import OTXDataBatch from otx.data.transform_libs.torchvision import Compose from otx.types.image import ImageColorChannel @@ -21,13 +21,13 @@ Transforms = Union[Compose, Callable, List[Callable], dict[str, Compose | Callable | List[Callable]]] -def _default_collate_fn(items: list[ClassificationSample]) -> OTXDataBatch: - """Collate ClassificationSample items into an OTXDataBatch. +def _default_collate_fn(items: list[OTXSample]) -> OTXDataBatch: + """Collate OTXSample items into an OTXDataBatch. Args: - items: List of ClassificationSample items to batch + items: List of OTXSample items to batch Returns: - Batched ClassificationSample items with stacked tensors + Batched OTXSample items with stacked tensors """ # Convert images to float32 tensors before stacking image_tensors = [] @@ -81,7 +81,7 @@ def __init__( stack_images: bool = True, to_tv_image: bool = True, data_format: str = "", - sample_type: type[ClassificationSample] = ClassificationSample, + sample_type: type[OTXSample] = OTXSample, ) -> None: self.transforms = transforms self.image_color_channel = image_color_channel @@ -106,7 +106,7 @@ def __len__(self) -> int: def _sample_another_idx(self) -> int: return np.random.randint(0, len(self)) - def _apply_transforms(self, entity: ClassificationSample) -> ClassificationSample | None: + def _apply_transforms(self, entity: OTXSample) -> OTXSample | None: if isinstance(self.transforms, Compose): if self.to_tv_image: entity.as_tv_image() @@ -116,7 +116,7 @@ def _apply_transforms(self, entity: ClassificationSample) -> ClassificationSampl if callable(self.transforms): return self.transforms(entity) - def _iterable_transforms(self, item: ClassificationSample) -> ClassificationSample | None: + def _iterable_transforms(self, item: OTXSample) -> OTXSample | None: if not isinstance(self.transforms, list): raise TypeError(item) @@ -130,7 +130,7 @@ def _iterable_transforms(self, item: ClassificationSample) -> ClassificationSamp return results - def __getitem__(self, index: int) -> ClassificationSample: + def __getitem__(self, index: int) -> OTXSample: for _ in range(self.max_refetch): results = self._get_item_impl(index) @@ -142,15 +142,9 @@ def __getitem__(self, index: int) -> ClassificationSample: msg = f"Reach the maximum refetch number ({self.max_refetch})" raise RuntimeError(msg) - def _get_item_impl(self, index: int) -> ClassificationSample | None: + def _get_item_impl(self, index: int) -> OTXSample | None: dm_item = self.dm_subset[index] - # Check if dm_item is already a sample of the expected type - if isinstance(dm_item, self.sample_type): - sample = dm_item - else: - # dm_item is a DatasetItem, convert it using from_dm_item - sample = self.sample_type.from_dm_item(dm_item) - return self._apply_transforms(sample) + return self._apply_transforms(dm_item) @property def collate_fn(self) -> Callable: diff --git a/lib/src/otx/data/dataset/classification_new.py b/lib/src/otx/data/dataset/classification_new.py index 7e2befda7aa..122b117b5bd 100644 --- a/lib/src/otx/data/dataset/classification_new.py +++ b/lib/src/otx/data/dataset/classification_new.py @@ -31,242 +31,3 @@ def get_idx_list_per_classes(self) -> dict[int, list[int]]: idx_list_per_classes[label_id] = [] idx_list_per_classes[label_id].append(idx) return idx_list_per_classes - - -# class OTXMultilabelClsDataset(OTXDataset): -# """OTXDataset class for multi-label classification task using new Datumaro experimental Dataset.""" -# -# def __init__(self, **kwargs) -> None: -# """Initialize OTXMultilabelClsDataset. -# -# Args: -# **kwargs: Keyword arguments to pass to OTXDataset -# """ -# kwargs["sample_type"] = MultiLabelClassificationSample -# super().__init__(**kwargs) -# self.num_classes = len(self.dm_subset.categories()[AnnotationType.label]) -# -# def _get_item_impl(self, index: int) -> MultiLabelClassificationSample | None: -# item = self.dm_subset[index] -# img = item.media_as(Image) -# ignored_labels: list[int] = [] # This should be assigned form item -# img_data, img_shape, _ = self._get_img_data_and_shape(img) -# -# label_ids = set() -# for ann in item.annotations: -# # multilabel information stored in 'multi_label_ids' attribute when the source format is arrow -# if "multi_label_ids" in ann.attributes: -# for lbl_idx in ann.attributes["multi_label_ids"]: -# label_ids.add(lbl_idx) -# -# if isinstance(ann, Label): -# label_ids.add(ann.label) -# else: -# # If the annotation is not Label, it should be converted to Label. -# # For Chained Task: Detection (Bbox) -> Classification (Label) -# label = Label(label=ann.label) -# label_ids.add(label.label) -# labels = np.array(list(label_ids), dtype=np.int64) -# -# image_info = ImageInfo( -# width=img_data.shape[1], -# height=img_data.shape[0], -# ) -# # Create multilabel classification sample -# sample = MultiLabelClassificationSample( -# image=img_data, -# labels=self._convert_to_onehot(labels, ignored_labels), -# image_info=image_info, -# ) -# -# return self._apply_transforms(sample) -# -# def _convert_to_onehot(self, labels: np.ndarray, ignored_labels: list[int]) -> np.ndarray: -# """Convert label to one-hot vector format.""" -# # Convert to torch tensor for one_hot -# labels_tensor = torch.from_numpy(labels).long() -# # Torch's one_hot() expects the input to be of type long -# onehot = functional.one_hot(labels_tensor, self.num_classes).sum(0).clamp_max_(1).numpy() -# if ignored_labels: -# for ignore_label in ignored_labels: -# onehot[ignore_label] = -1 -# return onehot - - -# class OTXHlabelClsDataset(OTXDataset): -# """OTXDataset class for H-label classification task using new Datumaro experimental Dataset.""" -# -# def __init__(self, **kwargs) -> None: -# """Initialize OTXHlabelClsDataset. -# -# Args: -# **kwargs: Keyword arguments to pass to OTXDataset -# """ -# # Set the sample type to HierarchicalClassificationSample -# kwargs["sample_type"] = HierarchicalClassificationSample -# super().__init__(**kwargs) -# self.dm_categories = self.dm_subset.categories()[AnnotationType.label] -# -# # Hlabel classification used HLabelInfo to insert the HLabelData. -# if self.data_format == "arrow": -# # arrow format stores label IDs as names, have to deal with that here -# self.label_info = HLabelInfo.from_dm_label_groups_arrow(self.dm_categories) -# else: -# self.label_info = HLabelInfo.from_dm_label_groups(self.dm_categories) -# -# self.id_to_name_mapping = dict(zip(self.label_info.label_ids, self.label_info.label_names)) -# self.id_to_name_mapping[""] = "" -# -# if self.label_info.num_multiclass_heads == 0: -# msg = "The number of multiclass heads should be larger than 0." -# raise ValueError(msg) -# -# if self.data_format != "arrow": -# for dm_item in self.dm_subset: -# self._add_ancestors(dm_item.annotations) -# -# def _add_ancestors(self, label_anns: list[Label]) -> None: -# """Add ancestors recursively if some label miss the ancestor information. -# -# If the label tree likes below, -# object - vehicle -- car -# |- bus -# |- truck -# And annotation = ['car'], it should be ['car', 'vehicle', 'object'], to include the ancestor. -# -# This function add the ancestors to the annotation if missing. -# """ -# -# def _label_idx_to_name(idx: int) -> str: -# return self.dm_categories[idx].name -# -# def _label_name_to_idx(name: str) -> int: -# indices = [idx for idx, val in enumerate(self.label_info.label_names) if val == name] -# return indices[0] -# -# def _get_label_group_idx(label_name: str) -> int: -# if isinstance(self.label_info, HLabelInfo): -# if self.data_format == "arrow": -# return self.label_info.class_to_group_idx[self.id_to_name_mapping[label_name]][0] -# return self.label_info.class_to_group_idx[label_name][0] -# msg = f"self.label_info should have HLabelInfo type, got {type(self.label_info)}" -# raise ValueError(msg) -# -# def _find_ancestor_recursively(label_name: str, ancestors: list) -> list[str]: -# _, dm_label_category = self.dm_categories.find(label_name) -# parent_name = dm_label_category.parent if dm_label_category else "" -# -# if parent_name != "": -# ancestors.append(parent_name) -# _find_ancestor_recursively(parent_name, ancestors) -# return ancestors -# -# def _get_all_label_names_in_anns(anns: list[Label]) -> list[str]: -# return [_label_idx_to_name(ann.label) for ann in anns] -# -# all_label_names = _get_all_label_names_in_anns(label_anns) -# ancestor_dm_labels = [] -# for ann in label_anns: -# label_idx = ann.label -# label_name = _label_idx_to_name(label_idx) -# ancestors = _find_ancestor_recursively(label_name, []) -# -# for i, ancestor in enumerate(ancestors): -# if ancestor not in all_label_names: -# ancestor_dm_labels.append( -# Label( -# label=_label_name_to_idx(ancestor), -# id=len(label_anns) + i, -# group=_get_label_group_idx(ancestor), -# ), -# ) -# label_anns.extend(ancestor_dm_labels) -# -# def _get_item_impl(self, index: int) -> HierarchicalClassificationSample | None: -# item = self.dm_subset[index] -# img = item.media_as(Image) -# ignored_labels: list[int] = [] # This should be assigned form item -# img_data, img_shape, _ = self._get_img_data_and_shape(img) -# -# label_ids = set() -# for ann in item.annotations: -# # in h-cls scenario multilabel information stored in 'multi_label_ids' attribute -# if "multi_label_ids" in ann.attributes: -# for lbl_idx in ann.attributes["multi_label_ids"]: -# label_ids.add(lbl_idx) -# -# if isinstance(ann, Label): -# label_ids.add(ann.label) -# else: -# # If the annotation is not Label, it should be converted to Label. -# # For Chained Task: Detection (Bbox) -> Classification (Label) -# label = Label(label=ann.label) -# label_ids.add(label.label) -# -# hlabel_labels = self._convert_label_to_hlabel_format([Label(label=idx) for idx in label_ids], ignored_labels) -# -# # Create image info sample -# image_info = ImageInfo( -# width=img_data.shape[1], -# height=img_data.shape[0], -# ) -# -# # Create hierarchical classification sample -# sample = HierarchicalClassificationSample( -# image=img_data, -# labels=np.array(hlabel_labels, dtype=np.int64), -# image_info=image_info, -# ) -# -# return self._apply_transforms(sample) -# -# def _convert_label_to_hlabel_format(self, label_anns: list[Label], ignored_labels: list[int]) -> list[int]: -# """Convert format of the label to the h-label. -# -# It converts the label format to h-label format. -# Total length of result is sum of number of hierarchy and number of multilabel classes. -# -# i.e. -# Let's assume that we used the same dataset with example of the definition of HLabelData -# and the original labels are ["Rigid", "Triangle", "Lion"]. -# -# Then, h-label format will be [0, 1, 1, 0]. -# The first N-th indices represent the label index of multiclass heads (N=num_multiclass_heads), -# others represent the multilabel labels. -# -# [Multiclass Heads] -# 0-th index = 0 -> ["Rigid"(O), "Non-Rigid"(X)] <- First multiclass head -# 1-st index = 1 -> ["Rectangle"(O), "Triangle"(X), "Circle"(X)] <- Second multiclass head -# -# [Multilabel Head] -# 2, 3 indices = [1, 0] -> ["Lion"(O), "Panda"(X)] -# """ -# if not isinstance(self.label_info, HLabelInfo): -# msg = f"The type of label_info should be HLabelInfo, got {type(self.label_info)}." -# raise TypeError(msg) -# -# num_multiclass_heads = self.label_info.num_multiclass_heads -# num_multilabel_classes = self.label_info.num_multilabel_classes -# -# class_indices = [0] * (num_multiclass_heads + num_multilabel_classes) -# for i in range(num_multiclass_heads): -# class_indices[i] = -1 -# -# for ann in label_anns: -# if self.data_format == "arrow": -# # skips unknown labels for instance, the empty one -# if self.dm_categories.items[ann.label].name not in self.id_to_name_mapping: -# continue -# ann_name = self.id_to_name_mapping[self.dm_categories.items[ann.label].name] -# else: -# ann_name = self.dm_categories.items[ann.label].name -# group_idx, in_group_idx = self.label_info.class_to_group_idx[ann_name] -# -# if group_idx < num_multiclass_heads: -# class_indices[group_idx] = in_group_idx -# elif ann.label not in ignored_labels: -# class_indices[num_multiclass_heads + in_group_idx] = 1 -# else: -# class_indices[num_multiclass_heads + in_group_idx] = -1 -# -# return class_indices diff --git a/lib/src/otx/data/entity/sample.py b/lib/src/otx/data/entity/sample.py index 78a7bc05ac1..2259af8a6d4 100644 --- a/lib/src/otx/data/entity/sample.py +++ b/lib/src/otx/data/entity/sample.py @@ -23,35 +23,8 @@ from torchvision.tv_tensors import BoundingBoxes, Mask -class ClassificationSample(Sample): - """OTXDataItemSample is a base class for OTX data items.""" - - label: torch.Tensor = label_field(pl.Int32()) - image: np.ndarray | torch.Tensor | tv_tensors.Image = image_field(dtype=pl.UInt8) - - @classmethod - def from_dm_item(cls, item: DatasetItem) -> ClassificationSample: - """Create a ClassificationSample from a Datumaro DatasetItem. - - Args: - item: Datumaro DatasetItem containing image and label - - Returns: - ClassificationSample: Instance with image and label set - """ - image = item.media_as(Image).data - label = item.annotations[0].label if item.annotations else None - - img_shape = image.shape[:2] - img_info = ImageInfo( - img_idx=0, - img_shape=img_shape, - ori_shape=img_shape, - ) - - sample = cls(image=image, label=torch.as_tensor(label, dtype=torch.long)) - sample._img_info = img_info - return sample +class OTXSample(Sample): + """Base class for OTX data samples.""" def as_tv_image(self) -> None: """Convert image to torchvision tv_tensors Image format.""" @@ -59,6 +32,7 @@ def as_tv_image(self) -> None: return if isinstance(self.image, (np.ndarray, torch.Tensor)): self.image = tv_tensors.Image(self.image) + raise ValueError("OTXSample must have an image") @property def masks(self) -> Mask | None: @@ -76,6 +50,11 @@ def keypoints(self) -> torch.Tensor | None: def polygons(self) -> list[Polygon] | None: return None + @property + def label(self) -> torch.Tensor | None: + """Optional label property that returns None by default.""" + return None + @property def img_info(self) -> ImageInfo | None: if getattr(self, "_img_info", None) is None: @@ -94,3 +73,34 @@ def img_info(self) -> ImageInfo | None: @img_info.setter def img_info(self, value: ImageInfo | None) -> None: self._img_info = value + + +class ClassificationSample(OTXSample): + """OTXDataItemSample is a base class for OTX data items.""" + + image: np.ndarray | tv_tensors.Image = image_field(dtype=pl.UInt8) + label: torch.Tensor = label_field(pl.Int32()) + + @classmethod + def from_dm_item(cls, item: DatasetItem) -> ClassificationSample: + """Create a ClassificationSample from a Datumaro DatasetItem. + + Args: + item: Datumaro DatasetItem containing image and label + + Returns: + ClassificationSample: Instance with image and label set + """ + image = item.media_as(Image).data + label = item.annotations[0].label if item.annotations else None + + img_shape = image.shape[:2] + img_info = ImageInfo( + img_idx=0, + img_shape=img_shape, + ori_shape=img_shape, + ) + + sample = cls(image=image, label=torch.as_tensor(label, dtype=torch.long)) + sample._img_info = img_info + return sample From 742f399f0d81b7d190eee2cf09b7d2449cc1b81d Mon Sep 17 00:00:00 2001 From: Albert van Houten Date: Tue, 2 Sep 2025 16:02:32 +0200 Subject: [PATCH 08/19] Add test cases --- tests/unit/data/dataset/test_base_new.py | 261 ++++++++++++++++++ .../data/dataset/test_classification_new.py | 66 +++++ tests/unit/data/entity/test_sample.py | 174 ++++++++++++ 3 files changed, 501 insertions(+) create mode 100644 tests/unit/data/dataset/test_base_new.py create mode 100644 tests/unit/data/dataset/test_classification_new.py create mode 100644 tests/unit/data/entity/test_sample.py diff --git a/tests/unit/data/dataset/test_base_new.py b/tests/unit/data/dataset/test_base_new.py new file mode 100644 index 00000000000..3ab2fcce494 --- /dev/null +++ b/tests/unit/data/dataset/test_base_new.py @@ -0,0 +1,261 @@ +# Copyright (C) 2023-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for base_new OTXDataset.""" + +from __future__ import annotations + +from unittest.mock import Mock, patch + +import pytest +import torch +from datumaro.experimental import Dataset + +from otx.data.dataset.base_new import OTXDataset, _default_collate_fn +from otx.data.entity.sample import OTXSample +from otx.data.entity.torch.torch import OTXDataBatch + + +class TestDefaultCollateFn: + """Test _default_collate_fn function.""" + + def test_collate_with_torch_tensors(self): + """Test collating items with torch tensor images.""" + # Create mock samples with torch tensor images + sample1 = Mock(spec=OTXSample) + sample1.image = torch.randn(3, 224, 224) + sample1.label = torch.tensor(0) + sample1.masks = None + sample1.bboxes = None + sample1.keypoints = None + sample1.polygons = None + sample1.img_info = None + + sample2 = Mock(spec=OTXSample) + sample2.image = torch.randn(3, 224, 224) + sample2.label = torch.tensor(1) + sample2.masks = None + sample2.bboxes = None + sample2.keypoints = None + sample2.polygons = None + sample2.img_info = None + + items = [sample1, sample2] + result = _default_collate_fn(items) + + assert isinstance(result, OTXDataBatch) + assert result.batch_size == 2 + assert isinstance(result.images, torch.Tensor) + assert result.images.shape == (2, 3, 224, 224) + assert result.images.dtype == torch.float32 + assert result.labels == [torch.tensor(0), torch.tensor(1)] + + def test_collate_with_different_image_shapes(self): + """Test collating items with different image shapes.""" + sample1 = Mock(spec=OTXSample) + sample1.image = torch.randn(3, 224, 224) + sample1.label = None + sample1.masks = None + sample1.bboxes = None + sample1.keypoints = None + sample1.polygons = None + sample1.img_info = None + + sample2 = Mock(spec=OTXSample) + sample2.image = torch.randn(3, 256, 256) + sample2.label = None + sample2.masks = None + sample2.bboxes = None + sample2.keypoints = None + sample2.polygons = None + sample2.img_info = None + + items = [sample1, sample2] + result = _default_collate_fn(items) + + # When shapes are different, should return list instead of stacked tensor + assert isinstance(result.images, list) + assert len(result.images) == 2 + assert result.labels is None + + +class TestOTXDataset: + """Test OTXDataset class.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_dm_subset = Mock(spec=Dataset) + self.mock_dm_subset.__len__ = Mock(return_value=100) + + # Mock schema attributes for label_info + mock_schema = Mock() + mock_attributes = {"label": Mock()} + mock_attributes["label"].categories = Mock() + mock_schema.attributes = mock_attributes + self.mock_dm_subset.schema = mock_schema + + self.mock_transforms = Mock() + + def test_sample_another_idx(self): + """Test _sample_another_idx method.""" + dataset = OTXDataset( + dm_subset=self.mock_dm_subset, + transforms=self.mock_transforms, + data_format="arrow", + ) + + with patch("numpy.random.randint", return_value=42): + idx = dataset._sample_another_idx() + assert idx == 42 + + def test_apply_transforms_with_compose(self): + """Test _apply_transforms with Compose transforms.""" + from otx.data.transform_libs.torchvision import Compose + + mock_compose = Mock(spec=Compose) + mock_entity = Mock(spec=OTXSample) + mock_result = Mock() + mock_compose.return_value = mock_result + + dataset = OTXDataset( + dm_subset=self.mock_dm_subset, + transforms=mock_compose, + data_format="arrow", + to_tv_image=True, + ) + + result = dataset._apply_transforms(mock_entity) + + mock_entity.as_tv_image.assert_called_once() + mock_compose.assert_called_once_with(mock_entity) + assert result == mock_result + + def test_apply_transforms_with_callable(self): + """Test _apply_transforms with callable transform.""" + mock_transform = Mock() + mock_entity = Mock(spec=OTXSample) + mock_result = Mock() + mock_transform.return_value = mock_result + + dataset = OTXDataset( + dm_subset=self.mock_dm_subset, + transforms=mock_transform, + data_format="arrow", + ) + + result = dataset._apply_transforms(mock_entity) + + mock_transform.assert_called_once_with(mock_entity) + assert result == mock_result + + def test_apply_transforms_with_list(self): + """Test _apply_transforms with list of transforms.""" + transform1 = Mock() + transform2 = Mock() + + mock_entity = Mock(spec=OTXSample) + intermediate_result = Mock() + final_result = Mock() + + transform1.return_value = intermediate_result + transform2.return_value = final_result + + dataset = OTXDataset( + dm_subset=self.mock_dm_subset, + transforms=[transform1, transform2], + data_format="arrow", + ) + + result = dataset._apply_transforms(mock_entity) + + transform1.assert_called_once_with(mock_entity) + transform2.assert_called_once_with(intermediate_result) + assert result == final_result + + def test_apply_transforms_with_list_returns_none(self): + """Test _apply_transforms with list that returns None.""" + transform1 = Mock() + transform2 = Mock() + + mock_entity = Mock(spec=OTXSample) + transform1.return_value = None # First transform returns None + + dataset = OTXDataset( + dm_subset=self.mock_dm_subset, + transforms=[transform1, transform2], + data_format="arrow", + ) + + result = dataset._apply_transforms(mock_entity) + + transform1.assert_called_once_with(mock_entity) + transform2.assert_not_called() # Should not be called since first returned None + assert result is None + + def test_iterable_transforms_with_non_list(self): + """Test _iterable_transforms with non-list iterable raises TypeError.""" + dataset = OTXDataset( + dm_subset=self.mock_dm_subset, + transforms=self.mock_transforms, + data_format="arrow", + ) + + mock_entity = Mock(spec=OTXSample) + dataset.transforms = "not_a_list" # String is iterable but not a list + + with pytest.raises(TypeError): + dataset._iterable_transforms(mock_entity) + + def test_getitem_success(self): + """Test __getitem__ with successful retrieval.""" + mock_item = Mock() + self.mock_dm_subset.__getitem__ = Mock(return_value=mock_item) + + mock_transformed_item = Mock(spec=OTXSample) + + dataset = OTXDataset( + dm_subset=self.mock_dm_subset, + transforms=self.mock_transforms, + data_format="arrow", + ) + + with patch.object( + dataset, "_apply_transforms", return_value=mock_transformed_item + ): + result = dataset[5] + + self.mock_dm_subset.__getitem__.assert_called_once_with(5) + assert result == mock_transformed_item + + def test_getitem_with_refetch(self): + """Test __getitem__ with failed first attempt requiring refetch.""" + mock_item = Mock() + self.mock_dm_subset.__getitem__ = Mock(return_value=mock_item) + + dataset = OTXDataset( + dm_subset=self.mock_dm_subset, + transforms=self.mock_transforms, + data_format="arrow", + max_refetch=2, + ) + + mock_transformed_item = Mock(spec=OTXSample) + + # First call returns None, second returns valid item + with patch.object( + dataset, "_apply_transforms", side_effect=[None, mock_transformed_item] + ), patch.object(dataset, "_sample_another_idx", return_value=10): + result = dataset[5] + + assert result == mock_transformed_item + assert dataset._apply_transforms.call_count == 2 + + def test_collate_fn_property(self): + """Test collate_fn property returns _default_collate_fn.""" + dataset = OTXDataset( + dm_subset=self.mock_dm_subset, + transforms=self.mock_transforms, + data_format="arrow", + ) + + assert dataset.collate_fn == _default_collate_fn diff --git a/tests/unit/data/dataset/test_classification_new.py b/tests/unit/data/dataset/test_classification_new.py new file mode 100644 index 00000000000..2394c966d17 --- /dev/null +++ b/tests/unit/data/dataset/test_classification_new.py @@ -0,0 +1,66 @@ +# Copyright (C) 2023-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for classification_new dataset.""" + +from __future__ import annotations + +from unittest.mock import Mock + +from datumaro.experimental import Dataset + +from otx.data.dataset.classification_new import OTXMulticlassClsDataset +from otx.data.entity.sample import ClassificationSample + + +class TestOTXMulticlassClsDataset: + """Test OTXMulticlassClsDataset class.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_dm_subset = Mock(spec=Dataset) + self.mock_dm_subset.__len__ = Mock(return_value=10) + + # Mock schema attributes for label_info + mock_schema = Mock() + mock_attributes = {"label": Mock()} + mock_attributes["label"].categories = Mock() + mock_schema.attributes = mock_attributes + self.mock_dm_subset.schema = mock_schema + + self.mock_transforms = Mock() + + def test_init_sets_sample_type(self): + """Test that initialization sets sample_type to ClassificationSample.""" + dataset = OTXMulticlassClsDataset( + dm_subset=self.mock_dm_subset, + transforms=self.mock_transforms, + data_format="arrow", + ) + + assert dataset.sample_type == ClassificationSample + + def test_get_idx_list_per_classes_single_class(self): + """Test get_idx_list_per_classes with single class.""" + # Mock dataset items with labels + mock_items = [] + for i in range(5): + mock_item = Mock() + mock_item.label.item.return_value = 0 # All items have label 0 + mock_items.append(mock_item) + + self.mock_dm_subset.__getitem__ = Mock(side_effect=mock_items) + + dataset = OTXMulticlassClsDataset( + dm_subset=self.mock_dm_subset, + transforms=self.mock_transforms, + data_format="arrow", + ) + + # Override length for this test + dataset.dm_subset.__len__ = Mock(return_value=5) + + result = dataset.get_idx_list_per_classes() + + expected = {0: [0, 1, 2, 3, 4]} + assert result == expected diff --git a/tests/unit/data/entity/test_sample.py b/tests/unit/data/entity/test_sample.py new file mode 100644 index 00000000000..3a9d2baf9e0 --- /dev/null +++ b/tests/unit/data/entity/test_sample.py @@ -0,0 +1,174 @@ +# Copyright (C) 2023-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for sample entity classes.""" + +from __future__ import annotations + +from unittest.mock import Mock + +import numpy as np +import pytest +import torch +from datumaro import DatasetItem +from datumaro.components.annotation import Label +from datumaro.components.media import Image +from torchvision import tv_tensors + +from otx.data.entity.base import ImageInfo +from otx.data.entity.sample import ClassificationSample, OTXSample + + +class TestOTXSample: + """Test OTXSample base class.""" + + def setup_method(self): + """Set up test fixtures.""" + # Create a mock sample for testing + self.sample = OTXSample() + + def test_as_tv_image_with_tv_image(self): + """Test as_tv_image when image is already tv_tensors.Image.""" + tv_image = tv_tensors.Image(torch.randn(3, 224, 224)) + self.sample.image = tv_image + + # Should not change anything + self.sample.as_tv_image() + assert isinstance(self.sample.image, tv_tensors.Image) + assert torch.equal(self.sample.image, tv_image) + + def test_as_tv_image_with_numpy_array(self): + """Test as_tv_image with numpy array.""" + np_image = np.random.rand(3, 224, 224).astype(np.float32) + self.sample.image = np_image + + self.sample.as_tv_image() + + assert isinstance(self.sample.image, tv_tensors.Image) + assert torch.allclose(self.sample.image, torch.from_numpy(np_image)) + + def test_as_tv_image_with_torch_tensor(self): + """Test as_tv_image with torch.Tensor.""" + tensor_image = torch.randn(3, 224, 224) + self.sample.image = tensor_image + + self.sample.as_tv_image() + + assert isinstance(self.sample.image, tv_tensors.Image) + assert torch.equal(self.sample.image, tensor_image) + + def test_as_tv_image_with_invalid_type(self): + """Test as_tv_image with invalid image type raises ValueError.""" + self.sample.image = "invalid_image" + + with pytest.raises(ValueError, match="OTXSample must have an image"): + self.sample.as_tv_image() + + def test_img_info_property_with_image(self): + """Test img_info property creates ImageInfo from image.""" + self.sample.image = torch.randn(3, 224, 224) + + img_info = self.sample.img_info + + assert isinstance(img_info, ImageInfo) + assert img_info.img_idx == 0 + assert img_info.img_shape == (3, 224) # First two dimensions + assert img_info.ori_shape == (3, 224) + + def test_img_info_setter(self): + """Test setting img_info manually.""" + custom_info = ImageInfo(img_idx=5, img_shape=(100, 200), ori_shape=(100, 200)) + + self.sample.img_info = custom_info + + assert self.sample.img_info is custom_info + assert self.sample.img_info.img_idx == 5 + + +class TestClassificationSample: + """Test ClassificationSample class.""" + + def test_inheritance(self): + """Test that ClassificationSample inherits from OTXSample.""" + sample = ClassificationSample( + image=np.random.rand(3, 224, 224).astype(np.uint8), label=torch.tensor(1) + ) + + assert isinstance(sample, OTXSample) + + def test_init_with_numpy_image_and_tensor_label(self): + """Test initialization with numpy image and tensor label.""" + image = np.random.rand(3, 224, 224).astype(np.uint8) + label = torch.tensor(1) + + sample = ClassificationSample(image=image, label=label) + + assert np.array_equal(sample.image, image) + assert torch.equal(sample.label, label) + + def test_init_with_tv_image(self): + """Test initialization with tv_tensors.Image.""" + image = tv_tensors.Image(torch.randn(3, 224, 224)) + label = torch.tensor(0) + + sample = ClassificationSample(image=image, label=label) + + assert torch.equal(sample.image, image) + assert torch.equal(sample.label, label) + + def test_from_dm_item_with_image_and_annotation(self): + """Test from_dm_item with image and annotation.""" + # Mock DatasetItem + mock_item = Mock(spec=DatasetItem) + + # Mock image + mock_media = Mock(spec=Image) + mock_media.data = np.random.rand(224, 224, 3).astype(np.uint8) + mock_item.media_as.return_value = mock_media + + # Mock annotation + mock_annotation = Mock(spec=Label) + mock_annotation.label = 2 + mock_item.annotations = [mock_annotation] + + sample = ClassificationSample.from_dm_item(mock_item) + + assert isinstance(sample, ClassificationSample) + assert np.array_equal(sample.image, mock_media.data) + assert torch.equal(sample.label, torch.tensor(2, dtype=torch.long)) + + # Check img_info + assert isinstance(sample._img_info, ImageInfo) + assert sample._img_info.img_idx == 0 + assert sample._img_info.img_shape == (224, 224) + assert sample._img_info.ori_shape == (224, 224) + + def test_from_dm_item_without_annotation(self): + """Test from_dm_item without annotations.""" + # Mock DatasetItem without annotations + mock_item = Mock(spec=DatasetItem) + + # Mock image + mock_media = Mock(spec=Image) + mock_media.data = np.random.rand(100, 100, 3).astype(np.uint8) + mock_item.media_as.return_value = mock_media + + # No annotations + mock_item.annotations = [] + + sample = ClassificationSample.from_dm_item(mock_item) + + assert isinstance(sample, ClassificationSample) + assert np.array_equal(sample.image, mock_media.data) + assert sample.label is None or torch.equal( + sample.label, torch.tensor(None, dtype=torch.long) + ) + + def test_label_property_override(self): + """Test that ClassificationSample has actual label property (not None).""" + sample = ClassificationSample( + image=np.random.rand(3, 224, 224).astype(np.uint8), label=torch.tensor(42) + ) + + assert sample.label is not None + assert torch.equal(sample.label, torch.tensor(42)) From 856d54e0e3bcc5d9cc7ba1a67d39d4e783ecaaf8 Mon Sep 17 00:00:00 2001 From: Albert van Houten Date: Tue, 2 Sep 2025 16:35:13 +0200 Subject: [PATCH 09/19] Address ruff/mypy issues Signed-off-by: Albert van Houten --- lib/src/otx/data/dataset/base_new.py | 11 +++++++--- .../otx/data/dataset/classification_new.py | 6 ++--- lib/src/otx/data/entity/sample.py | 22 ++++++++++++++----- lib/src/otx/data/samplers/balanced_sampler.py | 3 ++- lib/tests/unit/data/conftest.py | 4 +--- tests/unit/data/entity/test_sample.py | 5 ++--- 6 files changed, 33 insertions(+), 18 deletions(-) diff --git a/lib/src/otx/data/dataset/base_new.py b/lib/src/otx/data/dataset/base_new.py index a4502f2c146..51b8e8d1b5e 100644 --- a/lib/src/otx/data/dataset/base_new.py +++ b/lib/src/otx/data/dataset/base_new.py @@ -6,13 +6,15 @@ from __future__ import annotations import abc -from typing import Callable, Iterable, List, Union +from typing import TYPE_CHECKING, Callable, Iterable, List, Union import numpy as np import torch -from datumaro.experimental import Dataset from torch.utils.data import Dataset as TorchDataset +if TYPE_CHECKING: + from datumaro.experimental import Dataset + from otx.data.entity.sample import OTXSample from otx.data.entity.torch.torch import OTXDataBatch from otx.data.transform_libs.torchvision import Compose @@ -55,7 +57,9 @@ def _default_collate_fn(items: list[OTXSample]) -> OTXDataBatch: masks=[item.masks for item in items] if any(item.masks is not None for item in items) else None, bboxes=[item.bboxes for item in items] if any(item.bboxes is not None for item in items) else None, keypoints=[item.keypoints for item in items] if any(item.keypoints is not None for item in items) else None, - polygons=[item.polygons for item in items] if any(item.polygons is not None for item in items) else None, + polygons=[item.polygons for item in items if item.polygons is not None] + if any(item.polygons is not None for item in items) + else None, imgs_info=[item.img_info for item in items] if any(item.img_info is not None for item in items) else None, ) @@ -115,6 +119,7 @@ def _apply_transforms(self, entity: OTXSample) -> OTXSample | None: return self._iterable_transforms(entity) if callable(self.transforms): return self.transforms(entity) + return None def _iterable_transforms(self, item: OTXSample) -> OTXSample | None: if not isinstance(self.transforms, list): diff --git a/lib/src/otx/data/dataset/classification_new.py b/lib/src/otx/data/dataset/classification_new.py index 122b117b5bd..98771f12075 100644 --- a/lib/src/otx/data/dataset/classification_new.py +++ b/lib/src/otx/data/dataset/classification_new.py @@ -5,8 +5,8 @@ from __future__ import annotations -from ..entity.sample import ClassificationSample -from .base_new import OTXDataset +from otx.data.dataset.base_new import OTXDataset +from otx.data.entity.sample import ClassificationSample class OTXMulticlassClsDataset(OTXDataset): @@ -23,7 +23,7 @@ def __init__(self, **kwargs) -> None: def get_idx_list_per_classes(self) -> dict[int, list[int]]: """Get index list per class.""" - idx_list_per_classes = {} + idx_list_per_classes: dict[int, list[int]] = {} for idx in range(len(self)): item = self.dm_subset[idx] label_id = item.label.item() diff --git a/lib/src/otx/data/entity/sample.py b/lib/src/otx/data/entity/sample.py index 2259af8a6d4..079763ca565 100644 --- a/lib/src/otx/data/entity/sample.py +++ b/lib/src/otx/data/entity/sample.py @@ -5,7 +5,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import numpy as np import polars as pl @@ -26,28 +26,36 @@ class OTXSample(Sample): """Base class for OTX data samples.""" + image: np.ndarray | torch.Tensor | tv_tensors.Image | Any + def as_tv_image(self) -> None: """Convert image to torchvision tv_tensors Image format.""" if isinstance(self.image, tv_tensors.Image): return if isinstance(self.image, (np.ndarray, torch.Tensor)): self.image = tv_tensors.Image(self.image) - raise ValueError("OTXSample must have an image") + return + msg = "OTXSample must have an image" + raise ValueError(msg) @property def masks(self) -> Mask | None: + """Get masks for the sample.""" return None @property def bboxes(self) -> BoundingBoxes | None: + """Get bounding boxes for the sample.""" return None @property def keypoints(self) -> torch.Tensor | None: + """Get keypoints for the sample.""" return None @property def polygons(self) -> list[Polygon] | None: + """Get polygons for the sample.""" return None @property @@ -57,6 +65,7 @@ def label(self) -> torch.Tensor | None: @property def img_info(self) -> ImageInfo | None: + """Get image information for the sample.""" if getattr(self, "_img_info", None) is None: image = getattr(self, "image", None) if image is not None and hasattr(image, "shape") and len(image.shape) == 3: @@ -71,7 +80,7 @@ def img_info(self) -> ImageInfo | None: return self._img_info @img_info.setter - def img_info(self, value: ImageInfo | None) -> None: + def img_info(self, value: ImageInfo) -> None: self._img_info = value @@ -101,6 +110,9 @@ def from_dm_item(cls, item: DatasetItem) -> ClassificationSample: ori_shape=img_shape, ) - sample = cls(image=image, label=torch.as_tensor(label, dtype=torch.long)) - sample._img_info = img_info + sample = cls( + image=image, + label=torch.as_tensor(label, dtype=torch.long) if label is not None else torch.tensor(-1, dtype=torch.long), + ) + sample.img_info = img_info return sample diff --git a/lib/src/otx/data/samplers/balanced_sampler.py b/lib/src/otx/data/samplers/balanced_sampler.py index 033716e3e21..12513ca1baf 100644 --- a/lib/src/otx/data/samplers/balanced_sampler.py +++ b/lib/src/otx/data/samplers/balanced_sampler.py @@ -64,10 +64,11 @@ def __init__( super().__init__(dataset) # img_indices: dict[label: list[idx]] + ann_stats: dict[int | str, list[int]] if isinstance(dataset.dm_subset, DatasetSubset): ann_stats = get_idx_list_per_classes(dataset.dm_subset) elif isinstance(dataset.dm_subset, NewDataset): - ann_stats = dataset.get_idx_list_per_classes() + ann_stats = dataset.get_idx_list_per_classes() # type: ignore[attr-defined] self.img_indices = {k: torch.tensor(v, dtype=torch.int64) for k, v in ann_stats.items() if len(v) > 0} self.num_cls = len(self.img_indices.keys()) diff --git a/lib/tests/unit/data/conftest.py b/lib/tests/unit/data/conftest.py index af023346462..6ef37c9fede 100644 --- a/lib/tests/unit/data/conftest.py +++ b/lib/tests/unit/data/conftest.py @@ -41,7 +41,7 @@ @pytest.fixture(params=["bytes", "file"]) -def fxt_dm_item(request, tmpdir) -> DatasetItem: +def fxt_dm_item(requeset, tmpdir) -> DatasetItem: np_img = np.zeros(shape=(10, 10, 3), dtype=np.uint8) np_img[:, :, 0] = 0 # Set 0 for B channel np_img[:, :, 1] = 1 # Set 1 for G channel @@ -115,7 +115,6 @@ def fxt_mock_dm_subset(mocker: MockerFixture, fxt_dm_item: DatasetItem) -> Magic AnnotationType.mask, AnnotationType.polygon, ] - mock_dm_subset.label_group = mocker.MagicMock() return mock_dm_subset @@ -126,7 +125,6 @@ def fxt_mock_det_dm_subset(mocker: MockerFixture, fxt_dm_item_bbox_only: Dataset mock_dm_subset.__len__.return_value = 1 mock_dm_subset.categories().__getitem__.return_value = LabelCategories.from_iterable(_LABEL_NAMES) mock_dm_subset.ann_types.return_value = [AnnotationType.bbox] - mock_dm_subset.label_group = mocker.MagicMock() return mock_dm_subset diff --git a/tests/unit/data/entity/test_sample.py b/tests/unit/data/entity/test_sample.py index 3a9d2baf9e0..4c28d57bbf5 100644 --- a/tests/unit/data/entity/test_sample.py +++ b/tests/unit/data/entity/test_sample.py @@ -160,9 +160,8 @@ def test_from_dm_item_without_annotation(self): assert isinstance(sample, ClassificationSample) assert np.array_equal(sample.image, mock_media.data) - assert sample.label is None or torch.equal( - sample.label, torch.tensor(None, dtype=torch.long) - ) + # When no annotation, from_dm_item should return tensor(-1) as default + assert torch.equal(sample.label, torch.tensor(-1, dtype=torch.long)) def test_label_property_override(self): """Test that ClassificationSample has actual label property (not None).""" From 785f5e6e27a46bd35c8733a1064184b434151899 Mon Sep 17 00:00:00 2001 From: Albert van Houten Date: Tue, 2 Sep 2025 16:40:05 +0200 Subject: [PATCH 10/19] Pin datumaro develop branch Signed-off-by: Albert van Houten --- lib/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/pyproject.toml b/lib/pyproject.toml index 50c9ce19a1a..2897650ac6e 100644 --- a/lib/pyproject.toml +++ b/lib/pyproject.toml @@ -27,7 +27,7 @@ classifiers = [ "Programming Language :: Python :: 3.12", ] dependencies = [ - "datumaro @ git+https://github.com/open-edge-platform/datumaro.git@albert/datumaro-otx-integration", + "datumaro @ git+https://github.com/open-edge-platform/datumaro.git@develop", "omegaconf==2.3.0", "rich==14.0.0", "jsonargparse==4.35.0", From df3c776d3794fb7fb5f579a267ef7714f1d11fe3 Mon Sep 17 00:00:00 2001 From: Albert van Houten Date: Wed, 3 Sep 2025 09:06:25 +0200 Subject: [PATCH 11/19] Move get_idx_list_per_classes to dataset class and address other PR comments. Signed-off-by: Albert van Houten --- lib/src/otx/data/dataset/base.py | 18 +++++++++++- lib/src/otx/data/dataset/base_new.py | 2 +- .../otx/data/dataset/classification_new.py | 4 ++- lib/src/otx/data/factory.py | 4 +-- lib/src/otx/data/samplers/balanced_sampler.py | 10 +------ .../samplers/class_incremental_sampler.py | 3 +- lib/src/otx/data/utils/__init__.py | 2 -- lib/src/otx/data/utils/utils.py | 19 +----------- lib/tests/unit/data/conftest.py | 2 +- .../data/samplers/test_balanced_sampler.py | 5 ++-- .../test_class_incremental_sampler.py | 3 +- lib/tests/unit/data/utils/test_utils.py | 29 ------------------- 12 files changed, 30 insertions(+), 71 deletions(-) diff --git a/lib/src/otx/data/dataset/base.py b/lib/src/otx/data/dataset/base.py index 501114f4fc6..5b6366ec64b 100644 --- a/lib/src/otx/data/dataset/base.py +++ b/lib/src/otx/data/dataset/base.py @@ -6,13 +6,14 @@ from __future__ import annotations from abc import abstractmethod +from collections import defaultdict from collections.abc import Iterable from contextlib import contextmanager from typing import TYPE_CHECKING, Any, Callable, Iterator, List, Union import cv2 import numpy as np -from datumaro.components.annotation import AnnotationType +from datumaro.components.annotation import AnnotationType, LabelCategories from datumaro.util.image import IMAGE_BACKEND, IMAGE_COLOR_CHANNEL, ImageBackend from datumaro.util.image import ImageColorChannel as DatumaroImageColorChannel from torch.utils.data import Dataset @@ -196,3 +197,18 @@ def _get_item_impl(self, idx: int) -> OTXDataItem | None: def collate_fn(self) -> Callable: """Collection function to collect KeypointDetDataEntity into KeypointDetBatchDataEntity in data loader.""" return OTXDataItem.collate_fn + + def get_idx_list_per_classes(self, use_string_label: bool = False) -> dict[int | str, list[int]]: + """Compute class statistics.""" + stats: dict[int | str, list[int]] = defaultdict(list) + for item_idx, item in enumerate(self.dm_subset): + for ann in item.annotations: + if use_string_label: + labels = self.dm_subset.categories().get(AnnotationType.label, LabelCategories()) + stats[labels.items[ann.label].name].append(item_idx) + else: + stats[ann.label].append(item_idx) + # Remove duplicates in label stats idx: O(n) + for k in stats: + stats[k] = list(dict.fromkeys(stats[k])) + return stats diff --git a/lib/src/otx/data/dataset/base_new.py b/lib/src/otx/data/dataset/base_new.py index 51b8e8d1b5e..8026aaeeb98 100644 --- a/lib/src/otx/data/dataset/base_new.py +++ b/lib/src/otx/data/dataset/base_new.py @@ -157,5 +157,5 @@ def collate_fn(self) -> Callable: return _default_collate_fn @abc.abstractmethod - def get_idx_list_per_classes(self) -> dict[int, list[int]]: + def get_idx_list_per_classes(self, use_string_label: bool = False) -> dict[int, list[int]]: """Get a dictionary with class labels as keys and lists of corresponding sample indices as values.""" diff --git a/lib/src/otx/data/dataset/classification_new.py b/lib/src/otx/data/dataset/classification_new.py index 98771f12075..370929a1b3e 100644 --- a/lib/src/otx/data/dataset/classification_new.py +++ b/lib/src/otx/data/dataset/classification_new.py @@ -21,12 +21,14 @@ def __init__(self, **kwargs) -> None: kwargs["sample_type"] = ClassificationSample super().__init__(**kwargs) - def get_idx_list_per_classes(self) -> dict[int, list[int]]: + def get_idx_list_per_classes(self, use_string_label: bool = False) -> dict[int, list[int]]: """Get index list per class.""" idx_list_per_classes: dict[int, list[int]] = {} for idx in range(len(self)): item = self.dm_subset[idx] label_id = item.label.item() + if use_string_label: + label_id = self.label_info.labels[label_id] if label_id not in idx_list_per_classes: idx_list_per_classes[label_id] = [] idx_list_per_classes[label_id].append(idx) diff --git a/lib/src/otx/data/factory.py b/lib/src/otx/data/factory.py index d854a6a4eb6..e5d1b9f9315 100644 --- a/lib/src/otx/data/factory.py +++ b/lib/src/otx/data/factory.py @@ -82,13 +82,13 @@ def create( categories = cls._get_label_categories(dm_subset, data_format) dataset = DatasetNew(ClassificationSample, categories={"label": categories}) for item in dm_subset: - if len(item.media.data.shape) == 3: # TODO: Account for grayscale images + if len(item.media.data.shape) == 3: # TODO(albert): Account for grayscale images dataset.append(ClassificationSample.from_dm_item(item)) common_kwargs["dm_subset"] = dataset return OTXMulticlassClsDataset(**common_kwargs) if task == OTXTaskType.MULTI_LABEL_CLS: - from otx.data.dataset.classification import OTXMultilabelClsDataset + from .dataset.classification import OTXMultilabelClsDataset return OTXMultilabelClsDataset(**common_kwargs) diff --git a/lib/src/otx/data/samplers/balanced_sampler.py b/lib/src/otx/data/samplers/balanced_sampler.py index 12513ca1baf..1cef96ac694 100644 --- a/lib/src/otx/data/samplers/balanced_sampler.py +++ b/lib/src/otx/data/samplers/balanced_sampler.py @@ -9,12 +9,8 @@ from typing import TYPE_CHECKING import torch -from datumaro import DatasetSubset -from datumaro.experimental import Dataset as NewDataset from torch.utils.data import Sampler -from otx.data.utils import get_idx_list_per_classes - if TYPE_CHECKING: from otx.data.dataset.base import OTXDataset from otx.data.dataset.base_new import OTXDataset as OTXDatasetNew @@ -64,11 +60,7 @@ def __init__( super().__init__(dataset) # img_indices: dict[label: list[idx]] - ann_stats: dict[int | str, list[int]] - if isinstance(dataset.dm_subset, DatasetSubset): - ann_stats = get_idx_list_per_classes(dataset.dm_subset) - elif isinstance(dataset.dm_subset, NewDataset): - ann_stats = dataset.get_idx_list_per_classes() # type: ignore[attr-defined] + ann_stats = dataset.get_idx_list_per_classes() self.img_indices = {k: torch.tensor(v, dtype=torch.int64) for k, v in ann_stats.items() if len(v) > 0} self.num_cls = len(self.img_indices.keys()) diff --git a/lib/src/otx/data/samplers/class_incremental_sampler.py b/lib/src/otx/data/samplers/class_incremental_sampler.py index 05e6f653754..12322943cc6 100644 --- a/lib/src/otx/data/samplers/class_incremental_sampler.py +++ b/lib/src/otx/data/samplers/class_incremental_sampler.py @@ -12,7 +12,6 @@ from torch.utils.data import Sampler from otx.data.dataset.base import OTXDataset -from otx.data.utils import get_idx_list_per_classes class ClassIncrementalSampler(Sampler): @@ -65,7 +64,7 @@ def __init__( super().__init__(dataset) # Need to split new classes dataset indices & old classses dataset indices - ann_stats = get_idx_list_per_classes(dataset.dm_subset, True) + ann_stats = dataset.get_idx_list_per_classes(True) new_indices, old_indices = [], [] for cls in new_classes: new_indices.extend(ann_stats[cls]) diff --git a/lib/src/otx/data/utils/__init__.py b/lib/src/otx/data/utils/__init__.py index bc2ed250b89..31242128b20 100644 --- a/lib/src/otx/data/utils/__init__.py +++ b/lib/src/otx/data/utils/__init__.py @@ -7,7 +7,6 @@ adapt_input_size_to_dataset, adapt_tile_config, get_adaptive_num_workers, - get_idx_list_per_classes, import_object_from_module, instantiate_sampler, ) @@ -17,6 +16,5 @@ "adapt_input_size_to_dataset", "instantiate_sampler", "get_adaptive_num_workers", - "get_idx_list_per_classes", "import_object_from_module", ] diff --git a/lib/src/otx/data/utils/utils.py b/lib/src/otx/data/utils/utils.py index 94332be8fd6..1d2c5eeb2f4 100644 --- a/lib/src/otx/data/utils/utils.py +++ b/lib/src/otx/data/utils/utils.py @@ -15,14 +15,13 @@ import cv2 import numpy as np import torch -from datumaro.components.annotation import AnnotationType, Bbox, ExtractedMask, LabelCategories, Polygon +from datumaro.components.annotation import AnnotationType, Bbox, ExtractedMask, Polygon from datumaro.components.annotation import Shape as _Shape from otx.types import OTXTaskType from otx.utils.device import is_xpu_available if TYPE_CHECKING: - from datumaro import Dataset as DmDataset from datumaro import DatasetSubset from torch.utils.data import Dataset, Sampler @@ -322,22 +321,6 @@ def get_adaptive_num_workers(num_dataloader: int = 1) -> int | None: return min(cpu_count() // (num_dataloader * num_devices), 8) # max available num_workers is 8 -def get_idx_list_per_classes(dm_dataset: DmDataset, use_string_label: bool = False) -> dict[int | str, list[int]]: - """Compute class statistics.""" - stats: dict[int | str, list[int]] = defaultdict(list) - for item_idx, item in enumerate(dm_dataset): - for ann in item.annotations: - if use_string_label: - labels = dm_dataset.categories().get(AnnotationType.label, LabelCategories()) - stats[labels.items[ann.label].name].append(item_idx) - else: - stats[ann.label].append(item_idx) - # Remove duplicates in label stats idx: O(n) - for k in stats: - stats[k] = list(dict.fromkeys(stats[k])) - return stats - - def import_object_from_module(obj_path: str) -> Any: # noqa: ANN401 """Get object from import format string.""" module_name, obj_name = obj_path.rsplit(".", 1) diff --git a/lib/tests/unit/data/conftest.py b/lib/tests/unit/data/conftest.py index 6ef37c9fede..2932bc055d1 100644 --- a/lib/tests/unit/data/conftest.py +++ b/lib/tests/unit/data/conftest.py @@ -41,7 +41,7 @@ @pytest.fixture(params=["bytes", "file"]) -def fxt_dm_item(requeset, tmpdir) -> DatasetItem: +def fxt_dm_item(request, tmpdir) -> DatasetItem: np_img = np.zeros(shape=(10, 10, 3), dtype=np.uint8) np_img[:, :, 0] = 0 # Set 0 for B channel np_img[:, :, 1] = 1 # Set 1 for G channel diff --git a/lib/tests/unit/data/samplers/test_balanced_sampler.py b/lib/tests/unit/data/samplers/test_balanced_sampler.py index 43b8810c3bf..768fad8ef48 100644 --- a/lib/tests/unit/data/samplers/test_balanced_sampler.py +++ b/lib/tests/unit/data/samplers/test_balanced_sampler.py @@ -12,7 +12,6 @@ from otx.data.dataset.base import OTXDataset from otx.data.samplers.balanced_sampler import BalancedSampler -from otx.data.utils import get_idx_list_per_classes @pytest.fixture() @@ -81,7 +80,7 @@ def test_sampler_iter_with_multiple_replicas(self, fxt_imbalanced_dataset): def test_compute_class_statistics(self, fxt_imbalanced_dataset): # Compute class statistics - stats = get_idx_list_per_classes(fxt_imbalanced_dataset.dm_subset) + stats = fxt_imbalanced_dataset.get_idx_list_per_classes() # Check the expected results assert stats == {0: list(range(100)), 1: list(range(100, 108))} @@ -90,7 +89,7 @@ def test_sampler_iter_per_class(self, fxt_imbalanced_dataset): batch_size = 4 sampler = BalancedSampler(fxt_imbalanced_dataset) - stats = get_idx_list_per_classes(fxt_imbalanced_dataset.dm_subset) + stats = fxt_imbalanced_dataset.get_idx_list_per_classes() class_0_idx = stats[0] class_1_idx = stats[1] list_iter = list(iter(sampler)) diff --git a/lib/tests/unit/data/samplers/test_class_incremental_sampler.py b/lib/tests/unit/data/samplers/test_class_incremental_sampler.py index cd2f34b8e53..f031f58265b 100644 --- a/lib/tests/unit/data/samplers/test_class_incremental_sampler.py +++ b/lib/tests/unit/data/samplers/test_class_incremental_sampler.py @@ -10,7 +10,6 @@ from otx.data.dataset.base import OTXDataset from otx.data.samplers.class_incremental_sampler import ClassIncrementalSampler -from otx.data.utils import get_idx_list_per_classes @pytest.fixture() @@ -107,7 +106,7 @@ def test_sampler_iter_per_class(self, fxt_old_new_dataset): new_classes=["2"], ) - stats = get_idx_list_per_classes(fxt_old_new_dataset.dm_subset, True) + stats = fxt_old_new_dataset.get_idx_list_per_classes(True) old_idx = stats["0"] + stats["1"] new_idx = stats["2"] list_iter = list(iter(sampler)) diff --git a/lib/tests/unit/data/utils/test_utils.py b/lib/tests/unit/data/utils/test_utils.py index 69d2b837f37..79cfaefe199 100644 --- a/lib/tests/unit/data/utils/test_utils.py +++ b/lib/tests/unit/data/utils/test_utils.py @@ -5,7 +5,6 @@ from __future__ import annotations -from collections import defaultdict from unittest.mock import MagicMock import cv2 @@ -23,8 +22,6 @@ compute_robust_scale_statistics, compute_robust_statistics, get_adaptive_num_workers, - get_idx_list_per_classes, - import_object_from_module, ) @@ -239,29 +236,3 @@ def fxt_dm_dataset() -> DmDataset: ] return DmDataset.from_iterable(dataset_items, categories=["0", "1"]) - - -def test_get_idx_list_per_classes(fxt_dm_dataset): - # Call the function under test - result = get_idx_list_per_classes(fxt_dm_dataset) - - # Assert the expected output - expected_result = defaultdict(list) - expected_result[0] = list(range(100)) - expected_result[1] = list(range(100, 108)) - assert result == expected_result - - # Call the function under test with use_string_label - result = get_idx_list_per_classes(fxt_dm_dataset, use_string_label=True) - - # Assert the expected output - expected_result = defaultdict(list) - expected_result["0"] = list(range(100)) - expected_result["1"] = list(range(100, 108)) - assert result == expected_result - - -def test_import_object_from_module(): - obj_path = "otx.data.utils.get_idx_list_per_classes" - obj = import_object_from_module(obj_path) - assert obj == get_idx_list_per_classes From 121bba7523300bc4dad894a9285c2708366e38b0 Mon Sep 17 00:00:00 2001 From: Albert van Houten Date: Wed, 3 Sep 2025 10:19:34 +0200 Subject: [PATCH 12/19] Raise error if dataset isn't of correcty type. Signed-off-by: Albert van Houten --- lib/src/otx/data/factory.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/lib/src/otx/data/factory.py b/lib/src/otx/data/factory.py index e5d1b9f9315..89daf65c2e1 100644 --- a/lib/src/otx/data/factory.py +++ b/lib/src/otx/data/factory.py @@ -85,6 +85,9 @@ def create( if len(item.media.data.shape) == 3: # TODO(albert): Account for grayscale images dataset.append(ClassificationSample.from_dm_item(item)) common_kwargs["dm_subset"] = dataset + else: + msg = "Dataset must be of type DmDataset." + raise RuntimeError(msg) return OTXMulticlassClsDataset(**common_kwargs) if task == OTXTaskType.MULTI_LABEL_CLS: From 6fb07d896a5ec17182fee46e7dfe017654b740b6 Mon Sep 17 00:00:00 2001 From: Albert van Houten Date: Wed, 3 Sep 2025 11:35:23 +0200 Subject: [PATCH 13/19] Apply suggestion from @leoll2 Co-authored-by: Leonardo Lai --- lib/src/otx/data/samplers/class_incremental_sampler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/src/otx/data/samplers/class_incremental_sampler.py b/lib/src/otx/data/samplers/class_incremental_sampler.py index 12322943cc6..68d0f2ee8d0 100644 --- a/lib/src/otx/data/samplers/class_incremental_sampler.py +++ b/lib/src/otx/data/samplers/class_incremental_sampler.py @@ -64,7 +64,7 @@ def __init__( super().__init__(dataset) # Need to split new classes dataset indices & old classses dataset indices - ann_stats = dataset.get_idx_list_per_classes(True) + ann_stats = dataset.get_idx_list_per_classes(use_string_label=True) new_indices, old_indices = [], [] for cls in new_classes: new_indices.extend(ann_stats[cls]) From 069474fb52f90f99a1aa8e8ea4c6c6589b111ec2 Mon Sep 17 00:00:00 2001 From: Albert van Houten Date: Wed, 3 Sep 2025 11:49:12 +0200 Subject: [PATCH 14/19] Update Copyright headers and add function doc Signed-off-by: Albert van Houten --- lib/src/otx/data/dataset/base.py | 2 +- lib/src/otx/data/dataset/base_new.py | 2 +- lib/src/otx/data/dataset/classification_new.py | 2 +- lib/src/otx/data/entity/sample.py | 2 +- tests/unit/data/dataset/test_base_new.py | 2 +- tests/unit/data/dataset/test_classification_new.py | 2 +- tests/unit/data/entity/test_sample.py | 2 +- 7 files changed, 7 insertions(+), 7 deletions(-) diff --git a/lib/src/otx/data/dataset/base.py b/lib/src/otx/data/dataset/base.py index 5b6366ec64b..1536e187cbd 100644 --- a/lib/src/otx/data/dataset/base.py +++ b/lib/src/otx/data/dataset/base.py @@ -199,7 +199,7 @@ def collate_fn(self) -> Callable: return OTXDataItem.collate_fn def get_idx_list_per_classes(self, use_string_label: bool = False) -> dict[int | str, list[int]]: - """Compute class statistics.""" + """Get a dictionary with class labels (string/int) as keys and lists of sample indices as values.""" stats: dict[int | str, list[int]] = defaultdict(list) for item_idx, item in enumerate(self.dm_subset): for ann in item.annotations: diff --git a/lib/src/otx/data/dataset/base_new.py b/lib/src/otx/data/dataset/base_new.py index 8026aaeeb98..d7583aab3df 100644 --- a/lib/src/otx/data/dataset/base_new.py +++ b/lib/src/otx/data/dataset/base_new.py @@ -1,4 +1,4 @@ -# Copyright (C) 2023 Intel Corporation +# Copyright (C) 2023-2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 """Base class for OTXDataset using new Datumaro experimental Dataset.""" diff --git a/lib/src/otx/data/dataset/classification_new.py b/lib/src/otx/data/dataset/classification_new.py index 370929a1b3e..e9c6c03447c 100644 --- a/lib/src/otx/data/dataset/classification_new.py +++ b/lib/src/otx/data/dataset/classification_new.py @@ -1,4 +1,4 @@ -# Copyright (C) 2023 Intel Corporation +# Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 """Module for OTXClassificationDatasets using new Datumaro experimental Dataset.""" diff --git a/lib/src/otx/data/entity/sample.py b/lib/src/otx/data/entity/sample.py index 079763ca565..d11ee417e85 100644 --- a/lib/src/otx/data/entity/sample.py +++ b/lib/src/otx/data/entity/sample.py @@ -1,4 +1,4 @@ -# Copyright (C) 2024 Intel Corporation +# Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 """Sample classes for OTX data entities.""" diff --git a/tests/unit/data/dataset/test_base_new.py b/tests/unit/data/dataset/test_base_new.py index 3ab2fcce494..f9efa7481ae 100644 --- a/tests/unit/data/dataset/test_base_new.py +++ b/tests/unit/data/dataset/test_base_new.py @@ -1,4 +1,4 @@ -# Copyright (C) 2023-2024 Intel Corporation +# Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 """Unit tests for base_new OTXDataset.""" diff --git a/tests/unit/data/dataset/test_classification_new.py b/tests/unit/data/dataset/test_classification_new.py index 2394c966d17..9224e7bd7ac 100644 --- a/tests/unit/data/dataset/test_classification_new.py +++ b/tests/unit/data/dataset/test_classification_new.py @@ -1,4 +1,4 @@ -# Copyright (C) 2023-2024 Intel Corporation +# Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 """Unit tests for classification_new dataset.""" diff --git a/tests/unit/data/entity/test_sample.py b/tests/unit/data/entity/test_sample.py index 4c28d57bbf5..51b075e9a5e 100644 --- a/tests/unit/data/entity/test_sample.py +++ b/tests/unit/data/entity/test_sample.py @@ -1,4 +1,4 @@ -# Copyright (C) 2023-2024 Intel Corporation +# Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 """Unit tests for sample entity classes.""" From 9ab6b7aa94f36454772b90c717fd01e5899d9e88 Mon Sep 17 00:00:00 2001 From: Albert van Houten Date: Wed, 3 Sep 2025 13:42:24 +0200 Subject: [PATCH 15/19] Pin numpy version Signed-off-by: Albert van Houten --- lib/pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/lib/pyproject.toml b/lib/pyproject.toml index 2897650ac6e..5cec04064c2 100644 --- a/lib/pyproject.toml +++ b/lib/pyproject.toml @@ -51,6 +51,7 @@ dependencies = [ "onnxconverter-common==1.14.0", "nncf==2.17.0", "anomalib[core]==1.1.3", + "numpy<2.0.0", ] [project.optional-dependencies] From 543963ebcd6112cb3cf02dc588ef4dd0d7bfb012 Mon Sep 17 00:00:00 2001 From: Albert van Houten Date: Wed, 3 Sep 2025 14:00:12 +0200 Subject: [PATCH 16/19] Fix failing test Signed-off-by: Albert van Houten --- lib/tests/test_helpers.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/lib/tests/test_helpers.py b/lib/tests/test_helpers.py index 313b6f06665..faed389f873 100644 --- a/lib/tests/test_helpers.py +++ b/lib/tests/test_helpers.py @@ -17,9 +17,6 @@ from datumaro.components.errors import MediaTypeError from datumaro.components.exporter import Exporter from datumaro.components.media import Image -from datumaro.plugins.data_formats.common_semantic_segmentation import ( - CommonSemanticSegmentationPath, -) from datumaro.util.definitions import DEFAULT_SUBSET_NAME from datumaro.util.image import save_image from datumaro.util.meta_file_util import save_meta_file @@ -122,8 +119,8 @@ def _apply_impl(self) -> None: subset_dir = Path(save_dir, _subset_name) subset_dir.mkdir(parents=True, exist_ok=True) - mask_dir = subset_dir / CommonSemanticSegmentationPath.MASKS_DIR - img_dir = subset_dir / CommonSemanticSegmentationPath.IMAGES_DIR + mask_dir = subset_dir / "masks" + img_dir = subset_dir / "images" for item in subset: self._export_item_annotation(item, mask_dir) if self._save_media: From 86e4acd2d6cafb45264dc2e4f99e193a4e1bd7ee Mon Sep 17 00:00:00 2001 From: Albert van Houten Date: Thu, 4 Sep 2025 08:49:28 +0200 Subject: [PATCH 17/19] Create LabelInfo class in the dataset and remove dispatching label info step for the new categories class. Signed-off-by: Albert van Houten --- lib/src/otx/backend/native/models/base.py | 9 --------- lib/src/otx/data/dataset/base_new.py | 11 +++++++++-- lib/src/otx/data/dataset/classification_new.py | 2 +- 3 files changed, 10 insertions(+), 12 deletions(-) diff --git a/lib/src/otx/backend/native/models/base.py b/lib/src/otx/backend/native/models/base.py index 2efb1552577..8cc10512fe6 100644 --- a/lib/src/otx/backend/native/models/base.py +++ b/lib/src/otx/backend/native/models/base.py @@ -16,7 +16,6 @@ import torch from datumaro import LabelCategories -from datumaro.experimental.categories import LabelCategories as NewLabelCategories from lightning import LightningModule, Trainer from torch import Tensor, nn from torch.optim.lr_scheduler import ConstantLR @@ -885,14 +884,6 @@ def _dispatch_label_info(label_info: LabelInfoTypes) -> LabelInfo: label_groups=[label_info], label_ids=[str(i) for i in range(len(label_info))], ) - if isinstance(label_info, NewLabelCategories): - # Handle LabelCategories objects from datumaro - labels = label_info.labels - return LabelInfo( - label_names=labels, - label_groups=[labels], - label_ids=[str(i) for i in range(len(labels))], - ) if isinstance(label_info, LabelInfo): if not hasattr(label_info, "label_ids"): # NOTE: This is for backward compatibility diff --git a/lib/src/otx/data/dataset/base_new.py b/lib/src/otx/data/dataset/base_new.py index d7583aab3df..c91b55929bb 100644 --- a/lib/src/otx/data/dataset/base_new.py +++ b/lib/src/otx/data/dataset/base_new.py @@ -12,6 +12,8 @@ import torch from torch.utils.data import Dataset as TorchDataset +from otx import LabelInfo, NullLabelInfo + if TYPE_CHECKING: from datumaro.experimental import Dataset @@ -99,9 +101,14 @@ def __init__( and hasattr(dm_subset.schema, "attributes") and "label" in dm_subset.schema.attributes ): - self.label_info = dm_subset.schema.attributes["label"].categories + labels = dm_subset.schema.attributes["label"].categories.labels + self.label_info = LabelInfo( + label_names=labels, + label_groups=[labels], + label_ids=[str(i) for i in range(len(labels))], + ) else: - self.label_info = None + self.label_info = NullLabelInfo() self.dm_subset = dm_subset def __len__(self) -> int: diff --git a/lib/src/otx/data/dataset/classification_new.py b/lib/src/otx/data/dataset/classification_new.py index e9c6c03447c..a39f3092f81 100644 --- a/lib/src/otx/data/dataset/classification_new.py +++ b/lib/src/otx/data/dataset/classification_new.py @@ -28,7 +28,7 @@ def get_idx_list_per_classes(self, use_string_label: bool = False) -> dict[int, item = self.dm_subset[idx] label_id = item.label.item() if use_string_label: - label_id = self.label_info.labels[label_id] + label_id = self.label_info.label_names[label_id] if label_id not in idx_list_per_classes: idx_list_per_classes[label_id] = [] idx_list_per_classes[label_id].append(idx) From d089db913bf7afeda1726c4d350c1c3a801c82b4 Mon Sep 17 00:00:00 2001 From: Albert van Houten Date: Thu, 4 Sep 2025 09:00:13 +0200 Subject: [PATCH 18/19] Update function doc Signed-off-by: Albert van Houten --- lib/src/otx/data/dataset/base.py | 7 ++++++- lib/src/otx/data/dataset/classification_new.py | 7 ++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/lib/src/otx/data/dataset/base.py b/lib/src/otx/data/dataset/base.py index 1536e187cbd..4f7146583b9 100644 --- a/lib/src/otx/data/dataset/base.py +++ b/lib/src/otx/data/dataset/base.py @@ -199,7 +199,12 @@ def collate_fn(self) -> Callable: return OTXDataItem.collate_fn def get_idx_list_per_classes(self, use_string_label: bool = False) -> dict[int | str, list[int]]: - """Get a dictionary with class labels (string/int) as keys and lists of sample indices as values.""" + """Get a dictionary mapping class labels (string or int) to lists of samples. + + Args: + use_string_label (bool): If True, use string class labels as keys. + If False, use integer indices as keys. + """ stats: dict[int | str, list[int]] = defaultdict(list) for item_idx, item in enumerate(self.dm_subset): for ann in item.annotations: diff --git a/lib/src/otx/data/dataset/classification_new.py b/lib/src/otx/data/dataset/classification_new.py index a39f3092f81..cb1c6569151 100644 --- a/lib/src/otx/data/dataset/classification_new.py +++ b/lib/src/otx/data/dataset/classification_new.py @@ -22,7 +22,12 @@ def __init__(self, **kwargs) -> None: super().__init__(**kwargs) def get_idx_list_per_classes(self, use_string_label: bool = False) -> dict[int, list[int]]: - """Get index list per class.""" + """Get a dictionary mapping class labels (string or int) to lists of samples. + + Args: + use_string_label (bool): If True, use string class labels as keys. + If False, use integer indices as keys. + """ idx_list_per_classes: dict[int, list[int]] = {} for idx in range(len(self)): item = self.dm_subset[idx] From b22fcc376b320d3fe955c5363eb145b43a060ea7 Mon Sep 17 00:00:00 2001 From: Albert van Houten Date: Thu, 4 Sep 2025 09:29:27 +0200 Subject: [PATCH 19/19] Fix failing tests after LabelInfo change Signed-off-by: Albert van Houten --- tests/unit/data/dataset/test_base_new.py | 2 ++ tests/unit/data/dataset/test_classification_new.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/tests/unit/data/dataset/test_base_new.py b/tests/unit/data/dataset/test_base_new.py index f9efa7481ae..e3e18a9e846 100644 --- a/tests/unit/data/dataset/test_base_new.py +++ b/tests/unit/data/dataset/test_base_new.py @@ -91,6 +91,8 @@ def setup_method(self): mock_schema = Mock() mock_attributes = {"label": Mock()} mock_attributes["label"].categories = Mock() + # Configure labels to be a list with proper length support + mock_attributes["label"].categories.labels = ["class_0", "class_1", "class_2"] mock_schema.attributes = mock_attributes self.mock_dm_subset.schema = mock_schema diff --git a/tests/unit/data/dataset/test_classification_new.py b/tests/unit/data/dataset/test_classification_new.py index 9224e7bd7ac..eaf7b30373d 100644 --- a/tests/unit/data/dataset/test_classification_new.py +++ b/tests/unit/data/dataset/test_classification_new.py @@ -25,6 +25,8 @@ def setup_method(self): mock_schema = Mock() mock_attributes = {"label": Mock()} mock_attributes["label"].categories = Mock() + # Configure labels to be a list with proper length support + mock_attributes["label"].categories.labels = ["class_0", "class_1", "class_2"] mock_schema.attributes = mock_attributes self.mock_dm_subset.schema = mock_schema