diff --git a/lib/pyproject.toml b/lib/pyproject.toml index a2d497e13b4..5cec04064c2 100644 --- a/lib/pyproject.toml +++ b/lib/pyproject.toml @@ -27,7 +27,7 @@ classifiers = [ "Programming Language :: Python :: 3.12", ] dependencies = [ - "datumaro==1.10.0", + "datumaro @ git+https://github.com/open-edge-platform/datumaro.git@develop", "omegaconf==2.3.0", "rich==14.0.0", "jsonargparse==4.35.0", @@ -51,6 +51,7 @@ dependencies = [ "onnxconverter-common==1.14.0", "nncf==2.17.0", "anomalib[core]==1.1.3", + "numpy<2.0.0", ] [project.optional-dependencies] diff --git a/lib/src/otx/backend/native/callbacks/gpu_mem_monitor.py b/lib/src/otx/backend/native/callbacks/gpu_mem_monitor.py index 4d7d6388107..dcea0d5b36c 100644 --- a/lib/src/otx/backend/native/callbacks/gpu_mem_monitor.py +++ b/lib/src/otx/backend/native/callbacks/gpu_mem_monitor.py @@ -29,7 +29,7 @@ def _get_and_log_device_stats( batch_size (int): batch size. """ device = trainer.strategy.root_device - if device.type in ["cpu", "xpu"]: + if device.type in ["cpu", "xpu", "mps"]: return device_stats = trainer.accelerator.get_device_stats(device) diff --git a/lib/src/otx/data/dataset/base.py b/lib/src/otx/data/dataset/base.py index 501114f4fc6..4f7146583b9 100644 --- a/lib/src/otx/data/dataset/base.py +++ b/lib/src/otx/data/dataset/base.py @@ -6,13 +6,14 @@ from __future__ import annotations from abc import abstractmethod +from collections import defaultdict from collections.abc import Iterable from contextlib import contextmanager from typing import TYPE_CHECKING, Any, Callable, Iterator, List, Union import cv2 import numpy as np -from datumaro.components.annotation import AnnotationType +from datumaro.components.annotation import AnnotationType, LabelCategories from datumaro.util.image import IMAGE_BACKEND, IMAGE_COLOR_CHANNEL, ImageBackend from datumaro.util.image import ImageColorChannel as DatumaroImageColorChannel from torch.utils.data import Dataset @@ -196,3 +197,23 @@ def _get_item_impl(self, idx: int) -> OTXDataItem | None: def collate_fn(self) -> Callable: """Collection function to collect KeypointDetDataEntity into KeypointDetBatchDataEntity in data loader.""" return OTXDataItem.collate_fn + + def get_idx_list_per_classes(self, use_string_label: bool = False) -> dict[int | str, list[int]]: + """Get a dictionary mapping class labels (string or int) to lists of samples. + + Args: + use_string_label (bool): If True, use string class labels as keys. + If False, use integer indices as keys. + """ + stats: dict[int | str, list[int]] = defaultdict(list) + for item_idx, item in enumerate(self.dm_subset): + for ann in item.annotations: + if use_string_label: + labels = self.dm_subset.categories().get(AnnotationType.label, LabelCategories()) + stats[labels.items[ann.label].name].append(item_idx) + else: + stats[ann.label].append(item_idx) + # Remove duplicates in label stats idx: O(n) + for k in stats: + stats[k] = list(dict.fromkeys(stats[k])) + return stats diff --git a/lib/src/otx/data/dataset/base_new.py b/lib/src/otx/data/dataset/base_new.py new file mode 100644 index 00000000000..c91b55929bb --- /dev/null +++ b/lib/src/otx/data/dataset/base_new.py @@ -0,0 +1,168 @@ +# Copyright (C) 2023-2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Base class for OTXDataset using new Datumaro experimental Dataset.""" + +from __future__ import annotations + +import abc +from typing import TYPE_CHECKING, Callable, Iterable, List, Union + +import numpy as np +import torch +from torch.utils.data import Dataset as TorchDataset + +from otx import LabelInfo, NullLabelInfo + +if TYPE_CHECKING: + from datumaro.experimental import Dataset + +from otx.data.entity.sample import OTXSample +from otx.data.entity.torch.torch import OTXDataBatch +from otx.data.transform_libs.torchvision import Compose +from otx.types.image import ImageColorChannel + +Transforms = Union[Compose, Callable, List[Callable], dict[str, Compose | Callable | List[Callable]]] + + +def _default_collate_fn(items: list[OTXSample]) -> OTXDataBatch: + """Collate OTXSample items into an OTXDataBatch. + + Args: + items: List of OTXSample items to batch + Returns: + Batched OTXSample items with stacked tensors + """ + # Convert images to float32 tensors before stacking + image_tensors = [] + for item in items: + img = item.image + if isinstance(img, torch.Tensor): + # Convert to float32 if not already + if img.dtype != torch.float32: + img = img.float() + else: + # Convert numpy array to float32 tensor + img = torch.from_numpy(img).float() + image_tensors.append(img) + + # Try to stack images if they have the same shape + if len(image_tensors) > 0 and all(t.shape == image_tensors[0].shape for t in image_tensors): + images = torch.stack(image_tensors) + else: + images = image_tensors + + return OTXDataBatch( + batch_size=len(items), + images=images, + labels=[item.label for item in items] if items[0].label is not None else None, + masks=[item.masks for item in items] if any(item.masks is not None for item in items) else None, + bboxes=[item.bboxes for item in items] if any(item.bboxes is not None for item in items) else None, + keypoints=[item.keypoints for item in items] if any(item.keypoints is not None for item in items) else None, + polygons=[item.polygons for item in items if item.polygons is not None] + if any(item.polygons is not None for item in items) + else None, + imgs_info=[item.img_info for item in items] if any(item.img_info is not None for item in items) else None, + ) + + +class OTXDataset(TorchDataset): + """Base OTXDataset using new Datumaro experimental Dataset. + + Defines basic logic for OTX datasets. + + Args: + transforms: Transforms to apply on images + image_color_channel: Color channel of images + stack_images: Whether or not to stack images in collate function in OTXBatchData entity. + sample_type: Type of sample to use for this dataset + """ + + def __init__( + self, + dm_subset: Dataset, + transforms: Transforms, + max_refetch: int = 1000, + image_color_channel: ImageColorChannel = ImageColorChannel.RGB, + stack_images: bool = True, + to_tv_image: bool = True, + data_format: str = "", + sample_type: type[OTXSample] = OTXSample, + ) -> None: + self.transforms = transforms + self.image_color_channel = image_color_channel + self.stack_images = stack_images + self.to_tv_image = to_tv_image + self.sample_type = sample_type + self.max_refetch = max_refetch + self.data_format = data_format + if ( + hasattr(dm_subset, "schema") + and hasattr(dm_subset.schema, "attributes") + and "label" in dm_subset.schema.attributes + ): + labels = dm_subset.schema.attributes["label"].categories.labels + self.label_info = LabelInfo( + label_names=labels, + label_groups=[labels], + label_ids=[str(i) for i in range(len(labels))], + ) + else: + self.label_info = NullLabelInfo() + self.dm_subset = dm_subset + + def __len__(self) -> int: + return len(self.dm_subset) + + def _sample_another_idx(self) -> int: + return np.random.randint(0, len(self)) + + def _apply_transforms(self, entity: OTXSample) -> OTXSample | None: + if isinstance(self.transforms, Compose): + if self.to_tv_image: + entity.as_tv_image() + return self.transforms(entity) + if isinstance(self.transforms, Iterable): + return self._iterable_transforms(entity) + if callable(self.transforms): + return self.transforms(entity) + return None + + def _iterable_transforms(self, item: OTXSample) -> OTXSample | None: + if not isinstance(self.transforms, list): + raise TypeError(item) + + results = item + for transform in self.transforms: + results = transform(results) + # MMCV transform can produce None. Please see + # https://github.com/open-mmlab/mmengine/blob/26f22ed283ae4ac3a24b756809e5961efe6f9da8/mmengine/dataset/base_dataset.py#L59-L66 + if results is None: + return None + + return results + + def __getitem__(self, index: int) -> OTXSample: + for _ in range(self.max_refetch): + results = self._get_item_impl(index) + + if results is not None: + return results + + index = self._sample_another_idx() + + msg = f"Reach the maximum refetch number ({self.max_refetch})" + raise RuntimeError(msg) + + def _get_item_impl(self, index: int) -> OTXSample | None: + dm_item = self.dm_subset[index] + return self._apply_transforms(dm_item) + + @property + def collate_fn(self) -> Callable: + """Collection function to collect samples into a batch in data loader.""" + return _default_collate_fn + + @abc.abstractmethod + def get_idx_list_per_classes(self, use_string_label: bool = False) -> dict[int, list[int]]: + """Get a dictionary with class labels as keys and lists of corresponding sample indices as values.""" diff --git a/lib/src/otx/data/dataset/classification_new.py b/lib/src/otx/data/dataset/classification_new.py new file mode 100644 index 00000000000..cb1c6569151 --- /dev/null +++ b/lib/src/otx/data/dataset/classification_new.py @@ -0,0 +1,40 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Module for OTXClassificationDatasets using new Datumaro experimental Dataset.""" + +from __future__ import annotations + +from otx.data.dataset.base_new import OTXDataset +from otx.data.entity.sample import ClassificationSample + + +class OTXMulticlassClsDataset(OTXDataset): + """OTXDataset class for multi-class classification task using new Datumaro experimental Dataset.""" + + def __init__(self, **kwargs) -> None: + """Initialize OTXMulticlassClsDataset. + + Args: + **kwargs: Keyword arguments to pass to OTXDataset + """ + kwargs["sample_type"] = ClassificationSample + super().__init__(**kwargs) + + def get_idx_list_per_classes(self, use_string_label: bool = False) -> dict[int, list[int]]: + """Get a dictionary mapping class labels (string or int) to lists of samples. + + Args: + use_string_label (bool): If True, use string class labels as keys. + If False, use integer indices as keys. + """ + idx_list_per_classes: dict[int, list[int]] = {} + for idx in range(len(self)): + item = self.dm_subset[idx] + label_id = item.label.item() + if use_string_label: + label_id = self.label_info.label_names[label_id] + if label_id not in idx_list_per_classes: + idx_list_per_classes[label_id] = [] + idx_list_per_classes[label_id].append(idx) + return idx_list_per_classes diff --git a/lib/src/otx/data/entity/sample.py b/lib/src/otx/data/entity/sample.py new file mode 100644 index 00000000000..d11ee417e85 --- /dev/null +++ b/lib/src/otx/data/entity/sample.py @@ -0,0 +1,118 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Sample classes for OTX data entities.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import numpy as np +import polars as pl +import torch +from datumaro import Mask +from datumaro.components.media import Image +from datumaro.experimental.dataset import Sample +from datumaro.experimental.fields import image_field, label_field +from torchvision import tv_tensors + +from otx.data.entity.base import ImageInfo + +if TYPE_CHECKING: + from datumaro import DatasetItem, Polygon + from torchvision.tv_tensors import BoundingBoxes, Mask + + +class OTXSample(Sample): + """Base class for OTX data samples.""" + + image: np.ndarray | torch.Tensor | tv_tensors.Image | Any + + def as_tv_image(self) -> None: + """Convert image to torchvision tv_tensors Image format.""" + if isinstance(self.image, tv_tensors.Image): + return + if isinstance(self.image, (np.ndarray, torch.Tensor)): + self.image = tv_tensors.Image(self.image) + return + msg = "OTXSample must have an image" + raise ValueError(msg) + + @property + def masks(self) -> Mask | None: + """Get masks for the sample.""" + return None + + @property + def bboxes(self) -> BoundingBoxes | None: + """Get bounding boxes for the sample.""" + return None + + @property + def keypoints(self) -> torch.Tensor | None: + """Get keypoints for the sample.""" + return None + + @property + def polygons(self) -> list[Polygon] | None: + """Get polygons for the sample.""" + return None + + @property + def label(self) -> torch.Tensor | None: + """Optional label property that returns None by default.""" + return None + + @property + def img_info(self) -> ImageInfo | None: + """Get image information for the sample.""" + if getattr(self, "_img_info", None) is None: + image = getattr(self, "image", None) + if image is not None and hasattr(image, "shape") and len(image.shape) == 3: + img_shape = image.shape[:2] + else: + return None + self._img_info = ImageInfo( + img_idx=0, + img_shape=img_shape, + ori_shape=img_shape, + ) + return self._img_info + + @img_info.setter + def img_info(self, value: ImageInfo) -> None: + self._img_info = value + + +class ClassificationSample(OTXSample): + """OTXDataItemSample is a base class for OTX data items.""" + + image: np.ndarray | tv_tensors.Image = image_field(dtype=pl.UInt8) + label: torch.Tensor = label_field(pl.Int32()) + + @classmethod + def from_dm_item(cls, item: DatasetItem) -> ClassificationSample: + """Create a ClassificationSample from a Datumaro DatasetItem. + + Args: + item: Datumaro DatasetItem containing image and label + + Returns: + ClassificationSample: Instance with image and label set + """ + image = item.media_as(Image).data + label = item.annotations[0].label if item.annotations else None + + img_shape = image.shape[:2] + img_info = ImageInfo( + img_idx=0, + img_shape=img_shape, + ori_shape=img_shape, + ) + + sample = cls( + image=image, + label=torch.as_tensor(label, dtype=torch.long) if label is not None else torch.tensor(-1, dtype=torch.long), + ) + sample.img_info = img_info + return sample diff --git a/lib/src/otx/data/factory.py b/lib/src/otx/data/factory.py index 7f601c4e69d..89daf65c2e1 100644 --- a/lib/src/otx/data/factory.py +++ b/lib/src/otx/data/factory.py @@ -7,15 +7,20 @@ from typing import TYPE_CHECKING +from datumaro.components.annotation import AnnotationType +from datumaro.components.dataset import Dataset as DmDataset +from datumaro.experimental import Dataset as DatasetNew +from datumaro.experimental.categories import LabelCategories + +from otx import LabelInfo, NullLabelInfo from otx.types.image import ImageColorChannel from otx.types.task import OTXTaskType from otx.types.transformer_libs import TransformLibType from .dataset.base import OTXDataset, Transforms +from .dataset.base_new import OTXDataset as OTXDatasetNew if TYPE_CHECKING: - from datumaro import Dataset as DmDataset - from otx.config.data import SubsetConfig @@ -41,15 +46,15 @@ class OTXDatasetFactory: @classmethod def create( - cls: type[OTXDatasetFactory], + cls, task: OTXTaskType, - dm_subset: DmDataset, + dm_subset: DmDataset | DatasetNew, cfg_subset: SubsetConfig, data_format: str, image_color_channel: ImageColorChannel = ImageColorChannel.RGB, include_polygons: bool = False, ignore_index: int = 255, - ) -> OTXDataset: + ) -> OTXDataset | OTXDatasetNew: """Create OTXDataset.""" transforms = TransformLibFactory.generate(cfg_subset) common_kwargs = { @@ -71,8 +76,18 @@ def create( return OTXAnomalyDataset(task_type=task, **common_kwargs) if task == OTXTaskType.MULTI_CLASS_CLS: - from .dataset.classification import OTXMulticlassClsDataset - + from .dataset.classification_new import ClassificationSample, OTXMulticlassClsDataset + + if isinstance(dm_subset, DmDataset): + categories = cls._get_label_categories(dm_subset, data_format) + dataset = DatasetNew(ClassificationSample, categories={"label": categories}) + for item in dm_subset: + if len(item.media.data.shape) == 3: # TODO(albert): Account for grayscale images + dataset.append(ClassificationSample.from_dm_item(item)) + common_kwargs["dm_subset"] = dataset + else: + msg = "Dataset must be of type DmDataset." + raise RuntimeError(msg) return OTXMulticlassClsDataset(**common_kwargs) if task == OTXTaskType.MULTI_LABEL_CLS: @@ -106,3 +121,13 @@ def create( return OTXKeypointDetectionDataset(**common_kwargs) raise NotImplementedError(task) + + @staticmethod + def _get_label_categories(dm_subset: DmDataset, data_format: str) -> LabelCategories: + if dm_subset.categories() and data_format == "arrow": + label_info = LabelInfo.from_dm_label_groups_arrow(dm_subset.categories()[AnnotationType.label]) + elif dm_subset.categories(): + label_info = LabelInfo.from_dm_label_groups(dm_subset.categories()[AnnotationType.label]) + else: + label_info = NullLabelInfo() + return LabelCategories(labels=label_info.label_names) diff --git a/lib/src/otx/data/samplers/balanced_sampler.py b/lib/src/otx/data/samplers/balanced_sampler.py index 43bc11fae0b..1cef96ac694 100644 --- a/lib/src/otx/data/samplers/balanced_sampler.py +++ b/lib/src/otx/data/samplers/balanced_sampler.py @@ -11,10 +11,9 @@ import torch from torch.utils.data import Sampler -from otx.data.utils import get_idx_list_per_classes - if TYPE_CHECKING: from otx.data.dataset.base import OTXDataset + from otx.data.dataset.base_new import OTXDataset as OTXDatasetNew class BalancedSampler(Sampler): @@ -43,7 +42,7 @@ class BalancedSampler(Sampler): def __init__( self, - dataset: OTXDataset, + dataset: OTXDataset | OTXDatasetNew, efficient_mode: bool = False, num_replicas: int = 1, rank: int = 0, @@ -61,7 +60,8 @@ def __init__( super().__init__(dataset) # img_indices: dict[label: list[idx]] - ann_stats = get_idx_list_per_classes(dataset.dm_subset) + ann_stats = dataset.get_idx_list_per_classes() + self.img_indices = {k: torch.tensor(v, dtype=torch.int64) for k, v in ann_stats.items() if len(v) > 0} self.num_cls = len(self.img_indices.keys()) self.data_length = len(self.dataset) diff --git a/lib/src/otx/data/samplers/class_incremental_sampler.py b/lib/src/otx/data/samplers/class_incremental_sampler.py index 05e6f653754..68d0f2ee8d0 100644 --- a/lib/src/otx/data/samplers/class_incremental_sampler.py +++ b/lib/src/otx/data/samplers/class_incremental_sampler.py @@ -12,7 +12,6 @@ from torch.utils.data import Sampler from otx.data.dataset.base import OTXDataset -from otx.data.utils import get_idx_list_per_classes class ClassIncrementalSampler(Sampler): @@ -65,7 +64,7 @@ def __init__( super().__init__(dataset) # Need to split new classes dataset indices & old classses dataset indices - ann_stats = get_idx_list_per_classes(dataset.dm_subset, True) + ann_stats = dataset.get_idx_list_per_classes(use_string_label=True) new_indices, old_indices = [], [] for cls in new_classes: new_indices.extend(ann_stats[cls]) diff --git a/lib/src/otx/data/transform_libs/utils.py b/lib/src/otx/data/transform_libs/utils.py index adae5fb7c61..acbf3827c9b 100644 --- a/lib/src/otx/data/transform_libs/utils.py +++ b/lib/src/otx/data/transform_libs/utils.py @@ -129,6 +129,7 @@ def to_np_image(img: np.ndarray | Tensor | list) -> np.ndarray | list[np.ndarray return img if isinstance(img, list): return [to_np_image(im) for im in img] + return np.ascontiguousarray(img.numpy().transpose(1, 2, 0)) diff --git a/lib/src/otx/data/utils/__init__.py b/lib/src/otx/data/utils/__init__.py index bc2ed250b89..31242128b20 100644 --- a/lib/src/otx/data/utils/__init__.py +++ b/lib/src/otx/data/utils/__init__.py @@ -7,7 +7,6 @@ adapt_input_size_to_dataset, adapt_tile_config, get_adaptive_num_workers, - get_idx_list_per_classes, import_object_from_module, instantiate_sampler, ) @@ -17,6 +16,5 @@ "adapt_input_size_to_dataset", "instantiate_sampler", "get_adaptive_num_workers", - "get_idx_list_per_classes", "import_object_from_module", ] diff --git a/lib/src/otx/data/utils/utils.py b/lib/src/otx/data/utils/utils.py index 769fc3ec7f8..1d2c5eeb2f4 100644 --- a/lib/src/otx/data/utils/utils.py +++ b/lib/src/otx/data/utils/utils.py @@ -15,14 +15,13 @@ import cv2 import numpy as np import torch -from datumaro.components.annotation import AnnotationType, Bbox, ExtractedMask, LabelCategories, Polygon +from datumaro.components.annotation import AnnotationType, Bbox, ExtractedMask, Polygon from datumaro.components.annotation import Shape as _Shape from otx.types import OTXTaskType from otx.utils.device import is_xpu_available if TYPE_CHECKING: - from datumaro import Dataset as DmDataset from datumaro import DatasetSubset from torch.utils.data import Dataset, Sampler @@ -322,22 +321,6 @@ def get_adaptive_num_workers(num_dataloader: int = 1) -> int | None: return min(cpu_count() // (num_dataloader * num_devices), 8) # max available num_workers is 8 -def get_idx_list_per_classes(dm_dataset: DmDataset, use_string_label: bool = False) -> dict[int | str, list[int]]: - """Compute class statistics.""" - stats: dict[int | str, list[int]] = defaultdict(list) - labels = dm_dataset.categories().get(AnnotationType.label, LabelCategories()) - for item_idx, item in enumerate(dm_dataset): - for ann in item.annotations: - if use_string_label: - stats[labels.items[ann.label].name].append(item_idx) - else: - stats[ann.label].append(item_idx) - # Remove duplicates in label stats idx: O(n) - for k in stats: - stats[k] = list(dict.fromkeys(stats[k])) - return stats - - def import_object_from_module(obj_path: str) -> Any: # noqa: ANN401 """Get object from import format string.""" module_name, obj_name = obj_path.rsplit(".", 1) diff --git a/lib/tests/test_helpers.py b/lib/tests/test_helpers.py index 313b6f06665..faed389f873 100644 --- a/lib/tests/test_helpers.py +++ b/lib/tests/test_helpers.py @@ -17,9 +17,6 @@ from datumaro.components.errors import MediaTypeError from datumaro.components.exporter import Exporter from datumaro.components.media import Image -from datumaro.plugins.data_formats.common_semantic_segmentation import ( - CommonSemanticSegmentationPath, -) from datumaro.util.definitions import DEFAULT_SUBSET_NAME from datumaro.util.image import save_image from datumaro.util.meta_file_util import save_meta_file @@ -122,8 +119,8 @@ def _apply_impl(self) -> None: subset_dir = Path(save_dir, _subset_name) subset_dir.mkdir(parents=True, exist_ok=True) - mask_dir = subset_dir / CommonSemanticSegmentationPath.MASKS_DIR - img_dir = subset_dir / CommonSemanticSegmentationPath.IMAGES_DIR + mask_dir = subset_dir / "masks" + img_dir = subset_dir / "images" for item in subset: self._export_item_annotation(item, mask_dir) if self._save_media: diff --git a/lib/tests/unit/data/dataset/test_classification.py b/lib/tests/unit/data/dataset/test_classification.py index c6a62ecea9f..bb564d06cc5 100644 --- a/lib/tests/unit/data/dataset/test_classification.py +++ b/lib/tests/unit/data/dataset/test_classification.py @@ -8,9 +8,10 @@ from otx.data.dataset.classification import ( HLabelInfo, OTXHlabelClsDataset, - OTXMulticlassClsDataset, OTXMultilabelClsDataset, ) +from otx.data.dataset.classification_new import OTXMulticlassClsDataset +from otx.data.entity.sample import ClassificationSample from otx.data.entity.torch import OTXDataItem @@ -24,7 +25,7 @@ def test_get_item( transforms=[lambda x: x], max_refetch=3, ) - assert isinstance(dataset[0], OTXDataItem) + assert isinstance(dataset[0], ClassificationSample) def test_get_item_from_bbox_dataset( self, @@ -35,7 +36,7 @@ def test_get_item_from_bbox_dataset( transforms=[lambda x: x], max_refetch=3, ) - assert isinstance(dataset[0], OTXDataItem) + assert isinstance(dataset[0], ClassificationSample) class TestOTXMultilabelClsDataset: diff --git a/lib/tests/unit/data/samplers/test_balanced_sampler.py b/lib/tests/unit/data/samplers/test_balanced_sampler.py index 43b8810c3bf..768fad8ef48 100644 --- a/lib/tests/unit/data/samplers/test_balanced_sampler.py +++ b/lib/tests/unit/data/samplers/test_balanced_sampler.py @@ -12,7 +12,6 @@ from otx.data.dataset.base import OTXDataset from otx.data.samplers.balanced_sampler import BalancedSampler -from otx.data.utils import get_idx_list_per_classes @pytest.fixture() @@ -81,7 +80,7 @@ def test_sampler_iter_with_multiple_replicas(self, fxt_imbalanced_dataset): def test_compute_class_statistics(self, fxt_imbalanced_dataset): # Compute class statistics - stats = get_idx_list_per_classes(fxt_imbalanced_dataset.dm_subset) + stats = fxt_imbalanced_dataset.get_idx_list_per_classes() # Check the expected results assert stats == {0: list(range(100)), 1: list(range(100, 108))} @@ -90,7 +89,7 @@ def test_sampler_iter_per_class(self, fxt_imbalanced_dataset): batch_size = 4 sampler = BalancedSampler(fxt_imbalanced_dataset) - stats = get_idx_list_per_classes(fxt_imbalanced_dataset.dm_subset) + stats = fxt_imbalanced_dataset.get_idx_list_per_classes() class_0_idx = stats[0] class_1_idx = stats[1] list_iter = list(iter(sampler)) diff --git a/lib/tests/unit/data/samplers/test_class_incremental_sampler.py b/lib/tests/unit/data/samplers/test_class_incremental_sampler.py index cd2f34b8e53..f031f58265b 100644 --- a/lib/tests/unit/data/samplers/test_class_incremental_sampler.py +++ b/lib/tests/unit/data/samplers/test_class_incremental_sampler.py @@ -10,7 +10,6 @@ from otx.data.dataset.base import OTXDataset from otx.data.samplers.class_incremental_sampler import ClassIncrementalSampler -from otx.data.utils import get_idx_list_per_classes @pytest.fixture() @@ -107,7 +106,7 @@ def test_sampler_iter_per_class(self, fxt_old_new_dataset): new_classes=["2"], ) - stats = get_idx_list_per_classes(fxt_old_new_dataset.dm_subset, True) + stats = fxt_old_new_dataset.get_idx_list_per_classes(True) old_idx = stats["0"] + stats["1"] new_idx = stats["2"] list_iter = list(iter(sampler)) diff --git a/lib/tests/unit/data/test_factory.py b/lib/tests/unit/data/test_factory.py index 3c24b1c774b..f28c97d52c7 100644 --- a/lib/tests/unit/data/test_factory.py +++ b/lib/tests/unit/data/test_factory.py @@ -10,9 +10,9 @@ from otx.data.dataset.classification import ( HLabelInfo, OTXHlabelClsDataset, - OTXMulticlassClsDataset, OTXMultilabelClsDataset, ) +from otx.data.dataset.classification_new import OTXMulticlassClsDataset from otx.data.dataset.detection import OTXDetectionDataset from otx.data.dataset.instance_segmentation import OTXInstanceSegDataset from otx.data.dataset.segmentation import OTXSegmentationDataset diff --git a/lib/tests/unit/data/utils/test_utils.py b/lib/tests/unit/data/utils/test_utils.py index 69d2b837f37..79cfaefe199 100644 --- a/lib/tests/unit/data/utils/test_utils.py +++ b/lib/tests/unit/data/utils/test_utils.py @@ -5,7 +5,6 @@ from __future__ import annotations -from collections import defaultdict from unittest.mock import MagicMock import cv2 @@ -23,8 +22,6 @@ compute_robust_scale_statistics, compute_robust_statistics, get_adaptive_num_workers, - get_idx_list_per_classes, - import_object_from_module, ) @@ -239,29 +236,3 @@ def fxt_dm_dataset() -> DmDataset: ] return DmDataset.from_iterable(dataset_items, categories=["0", "1"]) - - -def test_get_idx_list_per_classes(fxt_dm_dataset): - # Call the function under test - result = get_idx_list_per_classes(fxt_dm_dataset) - - # Assert the expected output - expected_result = defaultdict(list) - expected_result[0] = list(range(100)) - expected_result[1] = list(range(100, 108)) - assert result == expected_result - - # Call the function under test with use_string_label - result = get_idx_list_per_classes(fxt_dm_dataset, use_string_label=True) - - # Assert the expected output - expected_result = defaultdict(list) - expected_result["0"] = list(range(100)) - expected_result["1"] = list(range(100, 108)) - assert result == expected_result - - -def test_import_object_from_module(): - obj_path = "otx.data.utils.get_idx_list_per_classes" - obj = import_object_from_module(obj_path) - assert obj == get_idx_list_per_classes diff --git a/tests/unit/data/dataset/test_base_new.py b/tests/unit/data/dataset/test_base_new.py new file mode 100644 index 00000000000..e3e18a9e846 --- /dev/null +++ b/tests/unit/data/dataset/test_base_new.py @@ -0,0 +1,263 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for base_new OTXDataset.""" + +from __future__ import annotations + +from unittest.mock import Mock, patch + +import pytest +import torch +from datumaro.experimental import Dataset + +from otx.data.dataset.base_new import OTXDataset, _default_collate_fn +from otx.data.entity.sample import OTXSample +from otx.data.entity.torch.torch import OTXDataBatch + + +class TestDefaultCollateFn: + """Test _default_collate_fn function.""" + + def test_collate_with_torch_tensors(self): + """Test collating items with torch tensor images.""" + # Create mock samples with torch tensor images + sample1 = Mock(spec=OTXSample) + sample1.image = torch.randn(3, 224, 224) + sample1.label = torch.tensor(0) + sample1.masks = None + sample1.bboxes = None + sample1.keypoints = None + sample1.polygons = None + sample1.img_info = None + + sample2 = Mock(spec=OTXSample) + sample2.image = torch.randn(3, 224, 224) + sample2.label = torch.tensor(1) + sample2.masks = None + sample2.bboxes = None + sample2.keypoints = None + sample2.polygons = None + sample2.img_info = None + + items = [sample1, sample2] + result = _default_collate_fn(items) + + assert isinstance(result, OTXDataBatch) + assert result.batch_size == 2 + assert isinstance(result.images, torch.Tensor) + assert result.images.shape == (2, 3, 224, 224) + assert result.images.dtype == torch.float32 + assert result.labels == [torch.tensor(0), torch.tensor(1)] + + def test_collate_with_different_image_shapes(self): + """Test collating items with different image shapes.""" + sample1 = Mock(spec=OTXSample) + sample1.image = torch.randn(3, 224, 224) + sample1.label = None + sample1.masks = None + sample1.bboxes = None + sample1.keypoints = None + sample1.polygons = None + sample1.img_info = None + + sample2 = Mock(spec=OTXSample) + sample2.image = torch.randn(3, 256, 256) + sample2.label = None + sample2.masks = None + sample2.bboxes = None + sample2.keypoints = None + sample2.polygons = None + sample2.img_info = None + + items = [sample1, sample2] + result = _default_collate_fn(items) + + # When shapes are different, should return list instead of stacked tensor + assert isinstance(result.images, list) + assert len(result.images) == 2 + assert result.labels is None + + +class TestOTXDataset: + """Test OTXDataset class.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_dm_subset = Mock(spec=Dataset) + self.mock_dm_subset.__len__ = Mock(return_value=100) + + # Mock schema attributes for label_info + mock_schema = Mock() + mock_attributes = {"label": Mock()} + mock_attributes["label"].categories = Mock() + # Configure labels to be a list with proper length support + mock_attributes["label"].categories.labels = ["class_0", "class_1", "class_2"] + mock_schema.attributes = mock_attributes + self.mock_dm_subset.schema = mock_schema + + self.mock_transforms = Mock() + + def test_sample_another_idx(self): + """Test _sample_another_idx method.""" + dataset = OTXDataset( + dm_subset=self.mock_dm_subset, + transforms=self.mock_transforms, + data_format="arrow", + ) + + with patch("numpy.random.randint", return_value=42): + idx = dataset._sample_another_idx() + assert idx == 42 + + def test_apply_transforms_with_compose(self): + """Test _apply_transforms with Compose transforms.""" + from otx.data.transform_libs.torchvision import Compose + + mock_compose = Mock(spec=Compose) + mock_entity = Mock(spec=OTXSample) + mock_result = Mock() + mock_compose.return_value = mock_result + + dataset = OTXDataset( + dm_subset=self.mock_dm_subset, + transforms=mock_compose, + data_format="arrow", + to_tv_image=True, + ) + + result = dataset._apply_transforms(mock_entity) + + mock_entity.as_tv_image.assert_called_once() + mock_compose.assert_called_once_with(mock_entity) + assert result == mock_result + + def test_apply_transforms_with_callable(self): + """Test _apply_transforms with callable transform.""" + mock_transform = Mock() + mock_entity = Mock(spec=OTXSample) + mock_result = Mock() + mock_transform.return_value = mock_result + + dataset = OTXDataset( + dm_subset=self.mock_dm_subset, + transforms=mock_transform, + data_format="arrow", + ) + + result = dataset._apply_transforms(mock_entity) + + mock_transform.assert_called_once_with(mock_entity) + assert result == mock_result + + def test_apply_transforms_with_list(self): + """Test _apply_transforms with list of transforms.""" + transform1 = Mock() + transform2 = Mock() + + mock_entity = Mock(spec=OTXSample) + intermediate_result = Mock() + final_result = Mock() + + transform1.return_value = intermediate_result + transform2.return_value = final_result + + dataset = OTXDataset( + dm_subset=self.mock_dm_subset, + transforms=[transform1, transform2], + data_format="arrow", + ) + + result = dataset._apply_transforms(mock_entity) + + transform1.assert_called_once_with(mock_entity) + transform2.assert_called_once_with(intermediate_result) + assert result == final_result + + def test_apply_transforms_with_list_returns_none(self): + """Test _apply_transforms with list that returns None.""" + transform1 = Mock() + transform2 = Mock() + + mock_entity = Mock(spec=OTXSample) + transform1.return_value = None # First transform returns None + + dataset = OTXDataset( + dm_subset=self.mock_dm_subset, + transforms=[transform1, transform2], + data_format="arrow", + ) + + result = dataset._apply_transforms(mock_entity) + + transform1.assert_called_once_with(mock_entity) + transform2.assert_not_called() # Should not be called since first returned None + assert result is None + + def test_iterable_transforms_with_non_list(self): + """Test _iterable_transforms with non-list iterable raises TypeError.""" + dataset = OTXDataset( + dm_subset=self.mock_dm_subset, + transforms=self.mock_transforms, + data_format="arrow", + ) + + mock_entity = Mock(spec=OTXSample) + dataset.transforms = "not_a_list" # String is iterable but not a list + + with pytest.raises(TypeError): + dataset._iterable_transforms(mock_entity) + + def test_getitem_success(self): + """Test __getitem__ with successful retrieval.""" + mock_item = Mock() + self.mock_dm_subset.__getitem__ = Mock(return_value=mock_item) + + mock_transformed_item = Mock(spec=OTXSample) + + dataset = OTXDataset( + dm_subset=self.mock_dm_subset, + transforms=self.mock_transforms, + data_format="arrow", + ) + + with patch.object( + dataset, "_apply_transforms", return_value=mock_transformed_item + ): + result = dataset[5] + + self.mock_dm_subset.__getitem__.assert_called_once_with(5) + assert result == mock_transformed_item + + def test_getitem_with_refetch(self): + """Test __getitem__ with failed first attempt requiring refetch.""" + mock_item = Mock() + self.mock_dm_subset.__getitem__ = Mock(return_value=mock_item) + + dataset = OTXDataset( + dm_subset=self.mock_dm_subset, + transforms=self.mock_transforms, + data_format="arrow", + max_refetch=2, + ) + + mock_transformed_item = Mock(spec=OTXSample) + + # First call returns None, second returns valid item + with patch.object( + dataset, "_apply_transforms", side_effect=[None, mock_transformed_item] + ), patch.object(dataset, "_sample_another_idx", return_value=10): + result = dataset[5] + + assert result == mock_transformed_item + assert dataset._apply_transforms.call_count == 2 + + def test_collate_fn_property(self): + """Test collate_fn property returns _default_collate_fn.""" + dataset = OTXDataset( + dm_subset=self.mock_dm_subset, + transforms=self.mock_transforms, + data_format="arrow", + ) + + assert dataset.collate_fn == _default_collate_fn diff --git a/tests/unit/data/dataset/test_classification_new.py b/tests/unit/data/dataset/test_classification_new.py new file mode 100644 index 00000000000..eaf7b30373d --- /dev/null +++ b/tests/unit/data/dataset/test_classification_new.py @@ -0,0 +1,68 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for classification_new dataset.""" + +from __future__ import annotations + +from unittest.mock import Mock + +from datumaro.experimental import Dataset + +from otx.data.dataset.classification_new import OTXMulticlassClsDataset +from otx.data.entity.sample import ClassificationSample + + +class TestOTXMulticlassClsDataset: + """Test OTXMulticlassClsDataset class.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_dm_subset = Mock(spec=Dataset) + self.mock_dm_subset.__len__ = Mock(return_value=10) + + # Mock schema attributes for label_info + mock_schema = Mock() + mock_attributes = {"label": Mock()} + mock_attributes["label"].categories = Mock() + # Configure labels to be a list with proper length support + mock_attributes["label"].categories.labels = ["class_0", "class_1", "class_2"] + mock_schema.attributes = mock_attributes + self.mock_dm_subset.schema = mock_schema + + self.mock_transforms = Mock() + + def test_init_sets_sample_type(self): + """Test that initialization sets sample_type to ClassificationSample.""" + dataset = OTXMulticlassClsDataset( + dm_subset=self.mock_dm_subset, + transforms=self.mock_transforms, + data_format="arrow", + ) + + assert dataset.sample_type == ClassificationSample + + def test_get_idx_list_per_classes_single_class(self): + """Test get_idx_list_per_classes with single class.""" + # Mock dataset items with labels + mock_items = [] + for i in range(5): + mock_item = Mock() + mock_item.label.item.return_value = 0 # All items have label 0 + mock_items.append(mock_item) + + self.mock_dm_subset.__getitem__ = Mock(side_effect=mock_items) + + dataset = OTXMulticlassClsDataset( + dm_subset=self.mock_dm_subset, + transforms=self.mock_transforms, + data_format="arrow", + ) + + # Override length for this test + dataset.dm_subset.__len__ = Mock(return_value=5) + + result = dataset.get_idx_list_per_classes() + + expected = {0: [0, 1, 2, 3, 4]} + assert result == expected diff --git a/tests/unit/data/entity/test_sample.py b/tests/unit/data/entity/test_sample.py new file mode 100644 index 00000000000..51b075e9a5e --- /dev/null +++ b/tests/unit/data/entity/test_sample.py @@ -0,0 +1,173 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for sample entity classes.""" + +from __future__ import annotations + +from unittest.mock import Mock + +import numpy as np +import pytest +import torch +from datumaro import DatasetItem +from datumaro.components.annotation import Label +from datumaro.components.media import Image +from torchvision import tv_tensors + +from otx.data.entity.base import ImageInfo +from otx.data.entity.sample import ClassificationSample, OTXSample + + +class TestOTXSample: + """Test OTXSample base class.""" + + def setup_method(self): + """Set up test fixtures.""" + # Create a mock sample for testing + self.sample = OTXSample() + + def test_as_tv_image_with_tv_image(self): + """Test as_tv_image when image is already tv_tensors.Image.""" + tv_image = tv_tensors.Image(torch.randn(3, 224, 224)) + self.sample.image = tv_image + + # Should not change anything + self.sample.as_tv_image() + assert isinstance(self.sample.image, tv_tensors.Image) + assert torch.equal(self.sample.image, tv_image) + + def test_as_tv_image_with_numpy_array(self): + """Test as_tv_image with numpy array.""" + np_image = np.random.rand(3, 224, 224).astype(np.float32) + self.sample.image = np_image + + self.sample.as_tv_image() + + assert isinstance(self.sample.image, tv_tensors.Image) + assert torch.allclose(self.sample.image, torch.from_numpy(np_image)) + + def test_as_tv_image_with_torch_tensor(self): + """Test as_tv_image with torch.Tensor.""" + tensor_image = torch.randn(3, 224, 224) + self.sample.image = tensor_image + + self.sample.as_tv_image() + + assert isinstance(self.sample.image, tv_tensors.Image) + assert torch.equal(self.sample.image, tensor_image) + + def test_as_tv_image_with_invalid_type(self): + """Test as_tv_image with invalid image type raises ValueError.""" + self.sample.image = "invalid_image" + + with pytest.raises(ValueError, match="OTXSample must have an image"): + self.sample.as_tv_image() + + def test_img_info_property_with_image(self): + """Test img_info property creates ImageInfo from image.""" + self.sample.image = torch.randn(3, 224, 224) + + img_info = self.sample.img_info + + assert isinstance(img_info, ImageInfo) + assert img_info.img_idx == 0 + assert img_info.img_shape == (3, 224) # First two dimensions + assert img_info.ori_shape == (3, 224) + + def test_img_info_setter(self): + """Test setting img_info manually.""" + custom_info = ImageInfo(img_idx=5, img_shape=(100, 200), ori_shape=(100, 200)) + + self.sample.img_info = custom_info + + assert self.sample.img_info is custom_info + assert self.sample.img_info.img_idx == 5 + + +class TestClassificationSample: + """Test ClassificationSample class.""" + + def test_inheritance(self): + """Test that ClassificationSample inherits from OTXSample.""" + sample = ClassificationSample( + image=np.random.rand(3, 224, 224).astype(np.uint8), label=torch.tensor(1) + ) + + assert isinstance(sample, OTXSample) + + def test_init_with_numpy_image_and_tensor_label(self): + """Test initialization with numpy image and tensor label.""" + image = np.random.rand(3, 224, 224).astype(np.uint8) + label = torch.tensor(1) + + sample = ClassificationSample(image=image, label=label) + + assert np.array_equal(sample.image, image) + assert torch.equal(sample.label, label) + + def test_init_with_tv_image(self): + """Test initialization with tv_tensors.Image.""" + image = tv_tensors.Image(torch.randn(3, 224, 224)) + label = torch.tensor(0) + + sample = ClassificationSample(image=image, label=label) + + assert torch.equal(sample.image, image) + assert torch.equal(sample.label, label) + + def test_from_dm_item_with_image_and_annotation(self): + """Test from_dm_item with image and annotation.""" + # Mock DatasetItem + mock_item = Mock(spec=DatasetItem) + + # Mock image + mock_media = Mock(spec=Image) + mock_media.data = np.random.rand(224, 224, 3).astype(np.uint8) + mock_item.media_as.return_value = mock_media + + # Mock annotation + mock_annotation = Mock(spec=Label) + mock_annotation.label = 2 + mock_item.annotations = [mock_annotation] + + sample = ClassificationSample.from_dm_item(mock_item) + + assert isinstance(sample, ClassificationSample) + assert np.array_equal(sample.image, mock_media.data) + assert torch.equal(sample.label, torch.tensor(2, dtype=torch.long)) + + # Check img_info + assert isinstance(sample._img_info, ImageInfo) + assert sample._img_info.img_idx == 0 + assert sample._img_info.img_shape == (224, 224) + assert sample._img_info.ori_shape == (224, 224) + + def test_from_dm_item_without_annotation(self): + """Test from_dm_item without annotations.""" + # Mock DatasetItem without annotations + mock_item = Mock(spec=DatasetItem) + + # Mock image + mock_media = Mock(spec=Image) + mock_media.data = np.random.rand(100, 100, 3).astype(np.uint8) + mock_item.media_as.return_value = mock_media + + # No annotations + mock_item.annotations = [] + + sample = ClassificationSample.from_dm_item(mock_item) + + assert isinstance(sample, ClassificationSample) + assert np.array_equal(sample.image, mock_media.data) + # When no annotation, from_dm_item should return tensor(-1) as default + assert torch.equal(sample.label, torch.tensor(-1, dtype=torch.long)) + + def test_label_property_override(self): + """Test that ClassificationSample has actual label property (not None).""" + sample = ClassificationSample( + image=np.random.rand(3, 224, 224).astype(np.uint8), label=torch.tensor(42) + ) + + assert sample.label is not None + assert torch.equal(sample.label, torch.tensor(42))