Skip to content

Commit c843686

Browse files
committed
fix: troubleshooting duplicate classifications
1 parent c587a93 commit c843686

File tree

4 files changed

+43
-18
lines changed

4 files changed

+43
-18
lines changed

trapdata/api/api.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,24 @@ def _get_source_image(source_images, source_image_id):
8181

8282

8383
@app.post("/pipeline/process")
84+
@app.post("/pipeline/process/")
8485
async def process(data: PipelineRequest) -> PipelineResponse:
86+
# Ensure that the source images are unique, filter out duplicates
87+
source_images_index = {
88+
source_image.id: source_image for source_image in data.source_images
89+
}
90+
incoming_source_images = list(source_images_index.values())
91+
if len(incoming_source_images) != len(data.source_images):
92+
logger.warning(
93+
f"Removed {len(data.source_images) - len(incoming_source_images)} duplicate source images"
94+
)
95+
8596
source_image_results = [
86-
SourceImageResponse(**image.model_dump()) for image in data.source_images
97+
SourceImageResponse(**image.model_dump()) for image in incoming_source_images
98+
]
99+
source_images = [
100+
SourceImage(**image.model_dump()) for image in incoming_source_images
87101
]
88-
source_images = [SourceImage(**image.model_dump()) for image in data.source_images]
89102

90103
start_time = time.time()
91104
detector = MothDetector(

trapdata/api/models/classification.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
import numpy as np
55
import torch
6-
from rich import print
76

87
from trapdata.common.logs import logger
98
from trapdata.ml.models.classification import (
@@ -37,6 +36,9 @@ def __init__(
3736
self.detections = list(detections)
3837
self.results: list[Detection] = []
3938
super().__init__(*args, **kwargs)
39+
logger.info(
40+
f"Initialized {self.__class__.__name__} with {len(self.detections)} detections"
41+
)
4042

4143
def get_dataset(self):
4244
return ClassificationImageDataset(
@@ -89,19 +91,32 @@ def save_results(
8991
timestamp=datetime.datetime.now(),
9092
)
9193
self.update_classification(detection, classification)
92-
print(detection)
94+
# print(detection)
9395
self.results.extend(self.detections)
9496
logger.info(f"Saving {len(self.results)} detections with classifications")
9597
return self.results
9698

97-
def update_classification(self, detection: Detection, new_classification: Classification) -> None:
99+
def update_classification(
100+
self, detection: Detection, new_classification: Classification
101+
) -> None:
98102
# Remove all existing classifications from this algorithm
99-
detection.classifications = [c for c in detection.classifications if c.algorithm != self.name]
103+
detection.classifications = [
104+
c for c in detection.classifications if c.algorithm != self.name
105+
]
100106
# Add the new classification for this algorithm
101107
detection.classifications.append(new_classification)
108+
logger.debug(
109+
f"Updated classification for detection {detection.bbox}. Total classifications: {len(detection.classifications)}"
110+
)
102111

103112
def run(self) -> list[Detection]:
113+
logger.info(
114+
f"Starting {self.__class__.__name__} run with {len(self.results)} detections"
115+
)
104116
super().run()
117+
logger.info(
118+
f"Finished {self.__class__.__name__} run. Processed {len(self.results)} detections"
119+
)
105120
return self.results
106121

107122

@@ -134,8 +149,11 @@ def save_results(
134149
# Specific to binary classification / the filter model
135150
terminal=False,
136151
)
137-
print(detection)
138-
if not self.filter_results or classification.classification == self.positive_binary_label:
152+
# print(detection)
153+
if (
154+
not self.filter_results
155+
or classification.classification == self.positive_binary_label
156+
):
139157
self.update_classification(detection, classification)
140158

141159
self.results.extend(self.detections)
@@ -149,15 +167,11 @@ class MothClassifierPanama(
149167
pass
150168

151169

152-
class MothClassifierPanama2024(
153-
MothClassifier, PanamaMothSpeciesClassifier2024
154-
):
170+
class MothClassifierPanama2024(MothClassifier, PanamaMothSpeciesClassifier2024):
155171
pass
156172

157173

158-
class MothClassifierUKDenmark(
159-
MothClassifier, UKDenmarkMothSpeciesClassifier2024
160-
):
174+
class MothClassifierUKDenmark(MothClassifier, UKDenmarkMothSpeciesClassifier2024):
161175
pass
162176

163177

trapdata/api/models/localization.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
import datetime
33
import typing
44

5-
from rich import print
6-
75
from trapdata.common.logs import logger
86
from trapdata.ml.models.localization import (
97
MothObjectDetector_FasterRCNN_2023,
@@ -98,7 +96,7 @@ def save_detection(image_id, coords):
9896
timestamp=datetime.datetime.now(),
9997
crop_image_url=crop_url,
10098
)
101-
print(detection)
99+
# print(detection)
102100
return detection
103101

104102
with concurrent.futures.ThreadPoolExecutor() as executor:

trapdata/cli/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def run_api():
9393
"""
9494
import uvicorn
9595

96-
uvicorn.run("trapdata.api.api:app", host="0.0.0.0", port=2001, reload=True)
96+
uvicorn.run("trapdata.api.api:app", host="0.0.0.0", port=2000, reload=True)
9797

9898

9999
if __name__ == "__main__":

0 commit comments

Comments
 (0)