-
Notifications
You must be signed in to change notification settings - Fork 68
PTDT-3807: Add temporal audio annotation support #2013
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
e4fd630
dbcc7bf
dbb592f
ff298d4
16896fd
7a666cc
ac58ad0
67dd14a
a1600e5
b4d2f42
fadb14e
1e12596
c2a7b4c
26a35fd
b16f2ea
943cb73
a838513
0ca9cd6
7861537
6c3c50a
68773cf
58b30f7
400d5bb
c761dcf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
from typing import Optional | ||
from pydantic import Field, AliasChoices | ||
|
||
from labelbox.data.annotation_types.annotation import ( | ||
ClassificationAnnotation, | ||
) | ||
|
||
|
||
class AudioClassificationAnnotation(ClassificationAnnotation): | ||
"""Audio classification for specific time range | ||
Examples: | ||
- Speaker identification from 2500ms to 4100ms | ||
- Audio quality assessment for a segment | ||
- Language detection for audio segments | ||
Args: | ||
name (Optional[str]): Name of the classification | ||
feature_schema_id (Optional[Cuid]): Feature schema identifier | ||
value (Union[Text, Checklist, Radio]): Classification value | ||
start_frame (int): The frame index in milliseconds (e.g., 2500 = 2.5 seconds) | ||
end_frame (Optional[int]): End frame in milliseconds (for time ranges) | ||
segment_index (Optional[int]): Index of audio segment this annotation belongs to | ||
extra (Dict[str, Any]): Additional metadata | ||
""" | ||
|
||
start_frame: int = Field( | ||
validation_alias=AliasChoices("start_frame", "frame"), | ||
serialization_alias="startframe", | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
end_frame: Optional[int] = Field( | ||
default=None, | ||
validation_alias=AliasChoices("end_frame", "endFrame"), | ||
serialization_alias="end_frame", | ||
) | ||
segment_index: Optional[int] = None | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -60,22 +60,6 @@ def serialize_model(self, handler): | |
return res | ||
|
||
|
||
class FrameLocation(BaseModel): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. restore |
||
end: int | ||
start: int | ||
|
||
|
||
class VideoSupported(BaseModel): | ||
# Note that frames are only allowed as top level inferences for video | ||
frames: Optional[List[FrameLocation]] = None | ||
|
||
@model_serializer(mode="wrap") | ||
def serialize_model(self, handler): | ||
res = handler(self) | ||
# This means these are no video frames .. | ||
if self.frames is None: | ||
res.pop("frames") | ||
return res | ||
|
||
|
||
class NDTextSubclass(NDAnswer): | ||
|
@@ -242,13 +226,14 @@ def from_common( | |
name=name, | ||
schema_id=feature_schema_id, | ||
uuid=uuid, | ||
frames=extra.get("frames"), | ||
message_id=message_id, | ||
confidence=text.confidence, | ||
custom_metrics=text.custom_metrics, | ||
) | ||
|
||
|
||
class NDChecklist(NDAnnotation, NDChecklistSubclass, VideoSupported): | ||
class NDChecklist(NDAnnotation, NDChecklistSubclass): | ||
@model_serializer(mode="wrap") | ||
def serialize_model(self, handler): | ||
res = handler(self) | ||
|
@@ -295,7 +280,7 @@ def from_common( | |
) | ||
|
||
|
||
class NDRadio(NDAnnotation, NDRadioSubclass, VideoSupported): | ||
class NDRadio(NDAnnotation, NDRadioSubclass): | ||
@classmethod | ||
def from_common( | ||
cls, | ||
|
@@ -425,7 +410,8 @@ def to_common( | |
def from_common( | ||
cls, | ||
annotation: Union[ | ||
ClassificationAnnotation, VideoClassificationAnnotation | ||
ClassificationAnnotation, | ||
VideoClassificationAnnotation, | ||
], | ||
data: GenericDataRowData, | ||
) -> Union[NDTextSubclass, NDChecklistSubclass, NDRadioSubclass]: | ||
|
@@ -448,7 +434,8 @@ def from_common( | |
@staticmethod | ||
def lookup_classification( | ||
annotation: Union[ | ||
ClassificationAnnotation, VideoClassificationAnnotation | ||
ClassificationAnnotation, | ||
VideoClassificationAnnotation, | ||
], | ||
) -> Union[NDText, NDChecklist, NDRadio]: | ||
return {Text: NDText, Checklist: NDChecklist, Radio: NDRadio}.get( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
import copy | ||
from itertools import groupby | ||
from operator import itemgetter | ||
from typing import Generator, List, Tuple, Union | ||
from typing import Any, Dict, Generator, List, Tuple, Union | ||
from uuid import uuid4 | ||
|
||
from pydantic import BaseModel | ||
|
@@ -24,6 +24,10 @@ | |
VideoMaskAnnotation, | ||
VideoObjectAnnotation, | ||
) | ||
from typing import List | ||
from ...annotation_types.audio import ( | ||
AudioClassificationAnnotation, | ||
) | ||
from labelbox.types import DocumentRectangle, DocumentEntity | ||
from .classification import ( | ||
NDChecklistSubclass, | ||
|
@@ -69,6 +73,7 @@ def from_common( | |
yield from cls._create_relationship_annotations(label) | ||
yield from cls._create_non_video_annotations(label) | ||
yield from cls._create_video_annotations(label) | ||
yield from cls._create_audio_annotations(label) | ||
|
||
@staticmethod | ||
def _get_consecutive_frames( | ||
|
@@ -80,6 +85,7 @@ def _get_consecutive_frames( | |
consecutive.append((group[0], group[-1])) | ||
return consecutive | ||
|
||
|
||
@classmethod | ||
def _get_segment_frame_ranges( | ||
cls, | ||
|
@@ -153,12 +159,91 @@ def _create_video_annotations( | |
for annotation in annotation_group: | ||
if ( | ||
annotation.keyframe | ||
and start_frame <= annotation.frame <= end_frame | ||
and start_frame <= annotation.start_frame <= end_frame | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Bug: Incorrect Attribute Usage in Video AnnotationsThe Additional Locations (1) |
||
): | ||
segment.append(annotation) | ||
segments.append(segment) | ||
yield NDObject.from_common(segments, label.data) | ||
|
||
@classmethod | ||
def _create_audio_annotations( | ||
cls, label: Label | ||
) -> Generator[BaseModel, None, None]: | ||
"""Create audio annotations grouped by classification name in v2.py format.""" | ||
audio_annotations = defaultdict(list) | ||
|
||
# Collect audio annotations by name/schema_id | ||
for annot in label.annotations: | ||
if isinstance(annot, AudioClassificationAnnotation): | ||
audio_annotations[annot.feature_schema_id or annot.name].append(annot) | ||
|
||
# Create v2.py format for each classification group | ||
for classification_name, annotation_group in audio_annotations.items(): | ||
# Group annotations by value (like v2.py does) | ||
value_groups = defaultdict(list) | ||
|
||
for ann in annotation_group: | ||
# Extract value based on classification type for grouping | ||
if hasattr(ann.value, 'answer'): | ||
if isinstance(ann.value.answer, list): | ||
# Checklist classification - convert list to string for grouping | ||
value = str(sorted([item.name for item in ann.value.answer])) | ||
elif hasattr(ann.value.answer, 'name'): | ||
# Radio classification - ann.value.answer is ClassificationAnswer with name | ||
value = ann.value.answer.name | ||
else: | ||
# Text classification | ||
value = ann.value.answer | ||
else: | ||
value = str(ann.value) | ||
|
||
# Group by value | ||
value_groups[value].append(ann) | ||
|
||
# Create answer items with grouped frames (like v2.py) | ||
answer_items = [] | ||
for value, annotations_with_same_value in value_groups.items(): | ||
frames = [] | ||
for ann in annotations_with_same_value: | ||
frames.append({"start": ann.start_frame, "end": ann.end_frame}) | ||
|
||
# Extract the actual value for the output (not the grouping key) | ||
first_ann = annotations_with_same_value[0] | ||
|
||
# Use different field names based on classification type | ||
if hasattr(first_ann.value, 'answer') and isinstance(first_ann.value.answer, list): | ||
# Checklist - use "name" field (like v2.py) | ||
answer_items.append({ | ||
"name": first_ann.value.answer[0].name, # Single item for now | ||
"frames": frames | ||
}) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
elif hasattr(first_ann.value, 'answer') and hasattr(first_ann.value.answer, 'name'): | ||
# Radio - use "name" field (like v2.py) | ||
answer_items.append({ | ||
"name": first_ann.value.answer.name, | ||
"frames": frames | ||
}) | ||
else: | ||
# Text - use "value" field (like v2.py) | ||
answer_items.append({ | ||
"value": first_ann.value.answer, | ||
"frames": frames | ||
}) | ||
|
||
# Create a simple Pydantic model for the v2.py format | ||
class AudioNDJSON(BaseModel): | ||
name: str | ||
answer: List[Dict[str, Any]] | ||
dataRow: Dict[str, str] | ||
|
||
yield AudioNDJSON( | ||
name=classification_name, | ||
answer=answer_items, | ||
dataRow={"globalKey": label.data.global_key} | ||
) | ||
|
||
|
||
|
||
@classmethod | ||
def _create_non_video_annotations(cls, label: Label): | ||
non_video_annotations = [ | ||
|
@@ -170,6 +255,7 @@ def _create_non_video_annotations(cls, label: Label): | |
VideoClassificationAnnotation, | ||
VideoObjectAnnotation, | ||
VideoMaskAnnotation, | ||
AudioClassificationAnnotation, | ||
RelationshipAnnotation, | ||
), | ||
) | ||
|
@@ -187,7 +273,7 @@ def _create_non_video_annotations(cls, label: Label): | |
yield NDMessageTask.from_common(annotation, label.data) | ||
else: | ||
raise TypeError( | ||
f"Unable to convert object to MAL format. `{type(getattr(annotation, 'value',annotation))}`" | ||
f"Unable to convert object to MAL format. `{type(getattr(annotation, 'value', annotation))}`" | ||
) | ||
|
||
@classmethod | ||
|
Uh oh!
There was an error while loading. Please reload this page.