|
2 | 2 | import copy
|
3 | 3 | from itertools import groupby
|
4 | 4 | from operator import itemgetter
|
5 |
| -from typing import Generator, List, Tuple, Union |
| 5 | +from typing import Any, Dict, Generator, List, Tuple, Union |
6 | 6 | from uuid import uuid4
|
7 | 7 |
|
8 | 8 | from pydantic import BaseModel
|
@@ -168,25 +168,79 @@ def _create_video_annotations(
|
168 | 168 | @classmethod
|
169 | 169 | def _create_audio_annotations(
|
170 | 170 | cls, label: Label
|
171 |
| - ) -> Generator[Union[NDChecklistSubclass, NDRadioSubclass], None, None]: |
172 |
| - """Create audio annotations serialized in Video NDJSON classification format.""" |
| 171 | + ) -> Generator[BaseModel, None, None]: |
| 172 | + """Create audio annotations grouped by classification name in v2.py format.""" |
173 | 173 | audio_annotations = defaultdict(list)
|
174 | 174 |
|
175 | 175 | # Collect audio annotations by name/schema_id
|
176 | 176 | for annot in label.annotations:
|
177 | 177 | if isinstance(annot, AudioClassificationAnnotation):
|
178 | 178 | audio_annotations[annot.feature_schema_id or annot.name].append(annot)
|
179 | 179 |
|
180 |
| - for annotation_group in audio_annotations.values(): |
181 |
| - # Simple grouping: one NDJSON entry per annotation group (same as video) |
182 |
| - annotation = annotation_group[0] |
183 |
| - frames_data = [] |
| 180 | + # Create v2.py format for each classification group |
| 181 | + for classification_name, annotation_group in audio_annotations.items(): |
| 182 | + # Group annotations by value (like v2.py does) |
| 183 | + value_groups = defaultdict(list) |
| 184 | + |
184 | 185 | for ann in annotation_group:
|
185 |
| - start = ann.start_frame |
186 |
| - end = getattr(ann, "end_frame", None) or ann.start_frame |
187 |
| - frames_data.append({"start": start, "end": end}) |
188 |
| - annotation.extra.update({"frames": frames_data}) |
189 |
| - yield NDClassification.from_common(annotation, label.data) |
| 186 | + # Extract value based on classification type for grouping |
| 187 | + if hasattr(ann.value, 'answer'): |
| 188 | + if isinstance(ann.value.answer, list): |
| 189 | + # Checklist classification - convert list to string for grouping |
| 190 | + value = str(sorted([item.name for item in ann.value.answer])) |
| 191 | + elif hasattr(ann.value.answer, 'name'): |
| 192 | + # Radio classification - ann.value.answer is ClassificationAnswer with name |
| 193 | + value = ann.value.answer.name |
| 194 | + else: |
| 195 | + # Text classification |
| 196 | + value = ann.value.answer |
| 197 | + else: |
| 198 | + value = str(ann.value) |
| 199 | + |
| 200 | + # Group by value |
| 201 | + value_groups[value].append(ann) |
| 202 | + |
| 203 | + # Create answer items with grouped frames (like v2.py) |
| 204 | + answer_items = [] |
| 205 | + for value, annotations_with_same_value in value_groups.items(): |
| 206 | + frames = [] |
| 207 | + for ann in annotations_with_same_value: |
| 208 | + frames.append({"start": ann.start_frame, "end": ann.end_frame}) |
| 209 | + |
| 210 | + # Extract the actual value for the output (not the grouping key) |
| 211 | + first_ann = annotations_with_same_value[0] |
| 212 | + |
| 213 | + # Use different field names based on classification type |
| 214 | + if hasattr(first_ann.value, 'answer') and isinstance(first_ann.value.answer, list): |
| 215 | + # Checklist - use "name" field (like v2.py) |
| 216 | + answer_items.append({ |
| 217 | + "name": first_ann.value.answer[0].name, # Single item for now |
| 218 | + "frames": frames |
| 219 | + }) |
| 220 | + elif hasattr(first_ann.value, 'answer') and hasattr(first_ann.value.answer, 'name'): |
| 221 | + # Radio - use "name" field (like v2.py) |
| 222 | + answer_items.append({ |
| 223 | + "name": first_ann.value.answer.name, |
| 224 | + "frames": frames |
| 225 | + }) |
| 226 | + else: |
| 227 | + # Text - use "value" field (like v2.py) |
| 228 | + answer_items.append({ |
| 229 | + "value": first_ann.value.answer, |
| 230 | + "frames": frames |
| 231 | + }) |
| 232 | + |
| 233 | + # Create a simple Pydantic model for the v2.py format |
| 234 | + class AudioNDJSON(BaseModel): |
| 235 | + name: str |
| 236 | + answer: List[Dict[str, Any]] |
| 237 | + dataRow: Dict[str, str] |
| 238 | + |
| 239 | + yield AudioNDJSON( |
| 240 | + name=classification_name, |
| 241 | + answer=answer_items, |
| 242 | + dataRow={"globalKey": label.data.global_key} |
| 243 | + ) |
190 | 244 |
|
191 | 245 |
|
192 | 246 |
|
|
0 commit comments