-
Notifications
You must be signed in to change notification settings - Fork 462
Classification Dataset refactor #4606
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
3bc16d3
5ea1908
87c5a31
3df9e59
1d8f332
966ae56
53afc10
e9338d2
3ae6745
742f399
856d54e
785f5e6
df3c776
121bba7
6fb07d8
069474f
aaddb60
9ab6b7a
543963e
86e4acd
d089db9
b22fcc3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,168 @@ | ||
| # Copyright (C) 2023-2025 Intel Corporation | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| """Base class for OTXDataset using new Datumaro experimental Dataset.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import abc | ||
| from typing import TYPE_CHECKING, Callable, Iterable, List, Union | ||
|
|
||
| import numpy as np | ||
| import torch | ||
| from torch.utils.data import Dataset as TorchDataset | ||
|
|
||
| from otx import LabelInfo, NullLabelInfo | ||
|
|
||
| if TYPE_CHECKING: | ||
| from datumaro.experimental import Dataset | ||
|
|
||
| from otx.data.entity.sample import OTXSample | ||
| from otx.data.entity.torch.torch import OTXDataBatch | ||
| from otx.data.transform_libs.torchvision import Compose | ||
| from otx.types.image import ImageColorChannel | ||
|
|
||
| Transforms = Union[Compose, Callable, List[Callable], dict[str, Compose | Callable | List[Callable]]] | ||
|
|
||
|
|
||
| def _default_collate_fn(items: list[OTXSample]) -> OTXDataBatch: | ||
| """Collate OTXSample items into an OTXDataBatch. | ||
|
|
||
| Args: | ||
| items: List of OTXSample items to batch | ||
| Returns: | ||
| Batched OTXSample items with stacked tensors | ||
| """ | ||
| # Convert images to float32 tensors before stacking | ||
| image_tensors = [] | ||
| for item in items: | ||
| img = item.image | ||
| if isinstance(img, torch.Tensor): | ||
| # Convert to float32 if not already | ||
| if img.dtype != torch.float32: | ||
| img = img.float() | ||
| else: | ||
| # Convert numpy array to float32 tensor | ||
| img = torch.from_numpy(img).float() | ||
| image_tensors.append(img) | ||
|
|
||
| # Try to stack images if they have the same shape | ||
| if len(image_tensors) > 0 and all(t.shape == image_tensors[0].shape for t in image_tensors): | ||
| images = torch.stack(image_tensors) | ||
| else: | ||
| images = image_tensors | ||
|
|
||
| return OTXDataBatch( | ||
| batch_size=len(items), | ||
| images=images, | ||
| labels=[item.label for item in items] if items[0].label is not None else None, | ||
| masks=[item.masks for item in items] if any(item.masks is not None for item in items) else None, | ||
| bboxes=[item.bboxes for item in items] if any(item.bboxes is not None for item in items) else None, | ||
| keypoints=[item.keypoints for item in items] if any(item.keypoints is not None for item in items) else None, | ||
| polygons=[item.polygons for item in items if item.polygons is not None] | ||
| if any(item.polygons is not None for item in items) | ||
| else None, | ||
| imgs_info=[item.img_info for item in items] if any(item.img_info is not None for item in items) else None, | ||
| ) | ||
|
|
||
|
|
||
| class OTXDataset(TorchDataset): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If the old and new |
||
| """Base OTXDataset using new Datumaro experimental Dataset. | ||
|
|
||
| Defines basic logic for OTX datasets. | ||
|
|
||
| Args: | ||
| transforms: Transforms to apply on images | ||
| image_color_channel: Color channel of images | ||
| stack_images: Whether or not to stack images in collate function in OTXBatchData entity. | ||
| sample_type: Type of sample to use for this dataset | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| dm_subset: Dataset, | ||
| transforms: Transforms, | ||
| max_refetch: int = 1000, | ||
| image_color_channel: ImageColorChannel = ImageColorChannel.RGB, | ||
| stack_images: bool = True, | ||
| to_tv_image: bool = True, | ||
| data_format: str = "", | ||
| sample_type: type[OTXSample] = OTXSample, | ||
| ) -> None: | ||
| self.transforms = transforms | ||
| self.image_color_channel = image_color_channel | ||
| self.stack_images = stack_images | ||
| self.to_tv_image = to_tv_image | ||
| self.sample_type = sample_type | ||
| self.max_refetch = max_refetch | ||
| self.data_format = data_format | ||
| if ( | ||
| hasattr(dm_subset, "schema") | ||
| and hasattr(dm_subset.schema, "attributes") | ||
| and "label" in dm_subset.schema.attributes | ||
| ): | ||
| labels = dm_subset.schema.attributes["label"].categories.labels | ||
| self.label_info = LabelInfo( | ||
| label_names=labels, | ||
| label_groups=[labels], | ||
| label_ids=[str(i) for i in range(len(labels))], | ||
| ) | ||
| else: | ||
| self.label_info = NullLabelInfo() | ||
| self.dm_subset = dm_subset | ||
|
|
||
| def __len__(self) -> int: | ||
| return len(self.dm_subset) | ||
|
|
||
| def _sample_another_idx(self) -> int: | ||
| return np.random.randint(0, len(self)) | ||
|
|
||
| def _apply_transforms(self, entity: OTXSample) -> OTXSample | None: | ||
| if isinstance(self.transforms, Compose): | ||
| if self.to_tv_image: | ||
| entity.as_tv_image() | ||
| return self.transforms(entity) | ||
| if isinstance(self.transforms, Iterable): | ||
| return self._iterable_transforms(entity) | ||
| if callable(self.transforms): | ||
| return self.transforms(entity) | ||
| return None | ||
|
|
||
| def _iterable_transforms(self, item: OTXSample) -> OTXSample | None: | ||
| if not isinstance(self.transforms, list): | ||
| raise TypeError(item) | ||
|
|
||
| results = item | ||
| for transform in self.transforms: | ||
| results = transform(results) | ||
| # MMCV transform can produce None. Please see | ||
| # https://github.com/open-mmlab/mmengine/blob/26f22ed283ae4ac3a24b756809e5961efe6f9da8/mmengine/dataset/base_dataset.py#L59-L66 | ||
| if results is None: | ||
| return None | ||
|
|
||
| return results | ||
|
|
||
| def __getitem__(self, index: int) -> OTXSample: | ||
| for _ in range(self.max_refetch): | ||
| results = self._get_item_impl(index) | ||
|
|
||
| if results is not None: | ||
| return results | ||
|
|
||
| index = self._sample_another_idx() | ||
|
|
||
| msg = f"Reach the maximum refetch number ({self.max_refetch})" | ||
| raise RuntimeError(msg) | ||
|
|
||
| def _get_item_impl(self, index: int) -> OTXSample | None: | ||
| dm_item = self.dm_subset[index] | ||
| return self._apply_transforms(dm_item) | ||
|
|
||
| @property | ||
| def collate_fn(self) -> Callable: | ||
| """Collection function to collect samples into a batch in data loader.""" | ||
| return _default_collate_fn | ||
|
|
||
| @abc.abstractmethod | ||
| def get_idx_list_per_classes(self, use_string_label: bool = False) -> dict[int, list[int]]: | ||
| """Get a dictionary with class labels as keys and lists of corresponding sample indices as values.""" | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,40 @@ | ||
| # Copyright (C) 2025 Intel Corporation | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| """Module for OTXClassificationDatasets using new Datumaro experimental Dataset.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from otx.data.dataset.base_new import OTXDataset | ||
| from otx.data.entity.sample import ClassificationSample | ||
|
|
||
|
|
||
| class OTXMulticlassClsDataset(OTXDataset): | ||
| """OTXDataset class for multi-class classification task using new Datumaro experimental Dataset.""" | ||
|
|
||
| def __init__(self, **kwargs) -> None: | ||
| """Initialize OTXMulticlassClsDataset. | ||
|
|
||
| Args: | ||
| **kwargs: Keyword arguments to pass to OTXDataset | ||
| """ | ||
| kwargs["sample_type"] = ClassificationSample | ||
| super().__init__(**kwargs) | ||
|
|
||
| def get_idx_list_per_classes(self, use_string_label: bool = False) -> dict[int, list[int]]: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This method is already implemented in the parent class
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe the implementation of this method will depend on the sample type. If it will be uniform across samples, it can be moved back to the new OTXDataset. The new OTXDataset does not implement this method at the moment. |
||
| """Get a dictionary mapping class labels (string or int) to lists of samples. | ||
|
|
||
| Args: | ||
| use_string_label (bool): If True, use string class labels as keys. | ||
| If False, use integer indices as keys. | ||
| """ | ||
| idx_list_per_classes: dict[int, list[int]] = {} | ||
| for idx in range(len(self)): | ||
| item = self.dm_subset[idx] | ||
| label_id = item.label.item() | ||
| if use_string_label: | ||
| label_id = self.label_info.label_names[label_id] | ||
| if label_id not in idx_list_per_classes: | ||
| idx_list_per_classes[label_id] = [] | ||
| idx_list_per_classes[label_id].append(idx) | ||
| return idx_list_per_classes | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,118 @@ | ||
| # Copyright (C) 2025 Intel Corporation | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| """Sample classes for OTX data entities.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from typing import TYPE_CHECKING, Any | ||
|
|
||
| import numpy as np | ||
| import polars as pl | ||
| import torch | ||
| from datumaro import Mask | ||
| from datumaro.components.media import Image | ||
| from datumaro.experimental.dataset import Sample | ||
| from datumaro.experimental.fields import image_field, label_field | ||
| from torchvision import tv_tensors | ||
|
|
||
| from otx.data.entity.base import ImageInfo | ||
|
|
||
| if TYPE_CHECKING: | ||
| from datumaro import DatasetItem, Polygon | ||
| from torchvision.tv_tensors import BoundingBoxes, Mask | ||
|
|
||
|
|
||
| class OTXSample(Sample): | ||
| """Base class for OTX data samples.""" | ||
|
|
||
| image: np.ndarray | torch.Tensor | tv_tensors.Image | Any | ||
|
|
||
| def as_tv_image(self) -> None: | ||
| """Convert image to torchvision tv_tensors Image format.""" | ||
| if isinstance(self.image, tv_tensors.Image): | ||
| return | ||
| if isinstance(self.image, (np.ndarray, torch.Tensor)): | ||
| self.image = tv_tensors.Image(self.image) | ||
| return | ||
| msg = "OTXSample must have an image" | ||
| raise ValueError(msg) | ||
|
|
||
| @property | ||
| def masks(self) -> Mask | None: | ||
| """Get masks for the sample.""" | ||
| return None | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's the semantics of these properties
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A large portion of the OTX codebase depends on the OTXDataItem. For example, all the transforms in https://github.com/open-edge-platform/training_extensions/blob/8418221b4e065d14b6f2c223f1ee297b80fd64c8/lib/src/otx/data/transform_libs/torchvision.py. For now, we will mimic that functionality with the OTXSample to not break any logic. In the future, OTXDataItem should be replaced with OTXSample but this will require a larger refactor, including updating all the transforms to not take OTXDataItems as input. |
||
|
|
||
| @property | ||
| def bboxes(self) -> BoundingBoxes | None: | ||
| """Get bounding boxes for the sample.""" | ||
| return None | ||
|
|
||
| @property | ||
| def keypoints(self) -> torch.Tensor | None: | ||
| """Get keypoints for the sample.""" | ||
| return None | ||
|
|
||
| @property | ||
| def polygons(self) -> list[Polygon] | None: | ||
| """Get polygons for the sample.""" | ||
| return None | ||
|
|
||
| @property | ||
| def label(self) -> torch.Tensor | None: | ||
| """Optional label property that returns None by default.""" | ||
| return None | ||
|
|
||
| @property | ||
| def img_info(self) -> ImageInfo | None: | ||
| """Get image information for the sample.""" | ||
| if getattr(self, "_img_info", None) is None: | ||
| image = getattr(self, "image", None) | ||
| if image is not None and hasattr(image, "shape") and len(image.shape) == 3: | ||
| img_shape = image.shape[:2] | ||
| else: | ||
| return None | ||
|
Comment on lines
+73
to
+74
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In what situations is this property expected to return None? Wouldn't it be better to raise an exception if it's not possible to extract image information from the sample?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is very hard to understand all usages of OTXDataItem.img_info and if they would work with None or not. For now, it is better to maintain the existing logic until we start to factor out OTXDataItem. |
||
| self._img_info = ImageInfo( | ||
| img_idx=0, | ||
| img_shape=img_shape, | ||
| ori_shape=img_shape, | ||
| ) | ||
| return self._img_info | ||
|
|
||
| @img_info.setter | ||
| def img_info(self, value: ImageInfo) -> None: | ||
| self._img_info = value | ||
|
|
||
|
|
||
| class ClassificationSample(OTXSample): | ||
| """OTXDataItemSample is a base class for OTX data items.""" | ||
|
|
||
| image: np.ndarray | tv_tensors.Image = image_field(dtype=pl.UInt8) | ||
| label: torch.Tensor = label_field(pl.Int32()) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If the
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is again existing OTX logic from OTXDataItem. The OTX codebase is already taking the different types into account. |
||
|
|
||
| @classmethod | ||
| def from_dm_item(cls, item: DatasetItem) -> ClassificationSample: | ||
| """Create a ClassificationSample from a Datumaro DatasetItem. | ||
|
|
||
| Args: | ||
| item: Datumaro DatasetItem containing image and label | ||
|
|
||
| Returns: | ||
| ClassificationSample: Instance with image and label set | ||
| """ | ||
| image = item.media_as(Image).data | ||
| label = item.annotations[0].label if item.annotations else None | ||
|
|
||
| img_shape = image.shape[:2] | ||
| img_info = ImageInfo( | ||
| img_idx=0, | ||
| img_shape=img_shape, | ||
| ori_shape=img_shape, | ||
| ) | ||
|
|
||
| sample = cls( | ||
| image=image, | ||
| label=torch.as_tensor(label, dtype=torch.long) if label is not None else torch.tensor(-1, dtype=torch.long), | ||
| ) | ||
| sample.img_info = img_info | ||
| return sample | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: before the feature branch (
feature/datumaro) can be merged todevelop, this needs to be updated to point to an official datumaro release.