Skip to content

Commit 15b25c9

Browse files
authored
Allow using moth/non-moth model as terminal classifier (#70)
* feat: allow using moth/non-moth model as terminal classifier * fix: update tests & cleanup console output
1 parent d6768e3 commit 15b25c9

File tree

4 files changed

+142
-50
lines changed

4 files changed

+142
-50
lines changed

trapdata/api/api.py

Lines changed: 69 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from .schemas import (
2828
AlgorithmCategoryMapResponse,
2929
AlgorithmConfigResponse,
30+
DetectionResponse,
3031
PipelineConfigResponse,
3132
)
3233
from .schemas import PipelineRequest as PipelineRequest_
@@ -45,7 +46,7 @@
4546
"costa_rica_moths_turing_2024": MothClassifierTuringCostaRica,
4647
"anguilla_moths_turing_2024": MothClassifierTuringAnguilla,
4748
"global_moths_2024": MothClassifierGlobal,
48-
# "moth_binary": MothClassifierBinary,
49+
"moth_binary": MothClassifierBinary,
4950
}
5051
_classifier_choices = dict(
5152
zip(CLASSIFIER_CHOICES.keys(), list(CLASSIFIER_CHOICES.keys()))
@@ -55,6 +56,13 @@
5556
PipelineChoice = enum.Enum("PipelineChoice", _classifier_choices)
5657

5758

59+
def should_filter_detections(Classifier: type[APIMothClassifier]) -> bool:
60+
if Classifier == MothClassifierBinary:
61+
return False
62+
else:
63+
return True
64+
65+
5866
def make_category_map_response(
5967
model: APIMothDetector | APIMothClassifier,
6068
default_taxon_rank: str = "SPECIES",
@@ -113,14 +121,20 @@ def make_pipeline_config_response(
113121
"""
114122
Create a configuration for an entire pipeline, given a species classifier class.
115123
"""
124+
algorithms = []
125+
116126
detector = APIMothDetector(
117127
source_images=[],
118128
)
129+
algorithms.append(make_algorithm_config_response(detector))
119130

120-
binary_classifier = MothClassifierBinary(
121-
source_images=[],
122-
detections=[],
123-
)
131+
if should_filter_detections(Classifier):
132+
binary_classifier = MothClassifierBinary(
133+
source_images=[],
134+
detections=[],
135+
terminal=False,
136+
)
137+
algorithms.append(make_algorithm_config_response(binary_classifier))
124138

125139
classifier = Classifier(
126140
source_images=[],
@@ -129,17 +143,14 @@ def make_pipeline_config_response(
129143
num_workers=settings.num_workers,
130144
terminal=True,
131145
)
146+
algorithms.append(make_algorithm_config_response(classifier))
132147

133148
return PipelineConfigResponse(
134149
name=classifier.name,
135150
slug=slug,
136151
description=classifier.description,
137152
version=1,
138-
algorithms=[
139-
make_algorithm_config_response(detector),
140-
make_algorithm_config_response(binary_classifier),
141-
make_algorithm_config_response(classifier),
142-
],
153+
algorithms=algorithms,
143154
)
144155

145156

@@ -173,6 +184,7 @@ async def root():
173184
@app.post(
174185
"/pipeline/process/", deprecated=True, tags=["services"]
175186
) # old endpoint, deprecated, remove after jan 2025
187+
@app.post("/process", tags=["services"]) # new endpoint
176188
@app.post("/process/", tags=["services"]) # new endpoint
177189
async def process(data: PipelineRequest) -> PipelineResponse:
178190
algorithms_used: dict[str, AlgorithmConfigResponse] = {}
@@ -196,6 +208,9 @@ async def process(data: PipelineRequest) -> PipelineResponse:
196208
]
197209

198210
start_time = time.time()
211+
212+
Classifier = CLASSIFIER_CHOICES[str(data.pipeline)]
213+
199214
detector = APIMothDetector(
200215
source_images=source_images,
201216
batch_size=settings.localization_batch_size,
@@ -207,77 +222,87 @@ async def process(data: PipelineRequest) -> PipelineResponse:
207222
num_pre_filter = len(detector_results)
208223
algorithms_used[detector.get_key()] = make_algorithm_response(detector)
209224

210-
filter = MothClassifierBinary(
211-
source_images=source_images,
212-
detections=detector_results,
213-
batch_size=settings.classification_batch_size,
214-
num_workers=settings.num_workers,
215-
# single=True if len(detector_results) == 1 else False,
216-
single=True, # @TODO solve issues with reading images in multiprocessing
217-
terminal=False,
218-
)
219-
filter.run()
220-
algorithms_used[filter.get_key()] = make_algorithm_response(filter)
225+
detections_for_terminal_classifier: list[DetectionResponse] = []
226+
detections_to_return: list[DetectionResponse] = []
227+
228+
if should_filter_detections(Classifier):
229+
filter = MothClassifierBinary(
230+
source_images=source_images,
231+
detections=detector_results,
232+
batch_size=settings.classification_batch_size,
233+
num_workers=settings.num_workers,
234+
# single=True if len(detector_results) == 1 else False,
235+
single=True, # @TODO solve issues with reading images in multiprocessing
236+
terminal=False,
237+
)
238+
filter.run()
239+
algorithms_used[filter.get_key()] = make_algorithm_response(filter)
221240

222-
# Compare num detections with num moth detections
223-
num_post_filter = len(filter.results)
224-
logger.info(
225-
f"Binary classifier returned {num_post_filter} of {num_pre_filter} detections"
226-
)
241+
# Compare num detections with num moth detections
242+
num_post_filter = len(filter.results)
243+
logger.info(
244+
f"Binary classifier returned {num_post_filter} of {num_pre_filter} detections"
245+
)
227246

228-
# Filter results based on positive_binary_label
229-
moth_detections = []
230-
non_moth_detections = []
231-
for detection in filter.results:
232-
for classification in detection.classifications:
233-
if classification.classification == filter.positive_binary_label:
234-
moth_detections.append(detection)
235-
elif classification.classification == filter.negative_binary_label:
236-
non_moth_detections.append(detection)
237-
break
247+
# Filter results based on positive_binary_label
248+
moth_detections = []
249+
non_moth_detections = []
250+
for detection in filter.results:
251+
for classification in detection.classifications:
252+
if classification.classification == filter.positive_binary_label:
253+
moth_detections.append(detection)
254+
elif classification.classification == filter.negative_binary_label:
255+
non_moth_detections.append(detection)
256+
break
257+
detections_for_terminal_classifier += moth_detections
258+
detections_to_return += non_moth_detections
259+
260+
else:
261+
logger.info("Skipping binary classification filter")
262+
detections_for_terminal_classifier += detector_results
238263

239264
logger.info(
240-
f"Sending {len(moth_detections)} of {num_pre_filter} "
265+
f"Sending {len(detections_for_terminal_classifier)} of {num_pre_filter} "
241266
"detections to the classifier"
242267
)
243268

244-
Classifier = CLASSIFIER_CHOICES[str(data.pipeline)]
245269
classifier: APIMothClassifier = Classifier(
246270
source_images=source_images,
247-
detections=moth_detections,
271+
detections=detections_for_terminal_classifier,
248272
batch_size=settings.classification_batch_size,
249273
num_workers=settings.num_workers,
250274
# single=True if len(filtered_detections) == 1 else False,
251275
single=True, # @TODO solve issues with reading images in multiprocessing
252276
example_config_param=data.config.example_config_param,
253277
terminal=True,
278+
# critera=data.config.criteria, # @TODO another approach to intermediate filter models
254279
)
255280
classifier.run()
256281
end_time = time.time()
257282
seconds_elapsed = float(end_time - start_time)
258283
algorithms_used[classifier.get_key()] = make_algorithm_response(classifier)
259284

260285
# Return all detections, including those that were not classified as moths
261-
all_detections = classifier.results + non_moth_detections
286+
detections_to_return += classifier.results
262287

263288
logger.info(
264289
f"Processed {len(source_images)} images in {seconds_elapsed:.2f} seconds"
265290
)
266-
logger.info(f"Returning {len(all_detections)} detections")
291+
logger.info(f"Returning {len(detections_to_return)} detections")
267292
# print(all_detections)
268293

269294
# If the number of detections is greater than 100, its suspicious. Log it.
270-
if len(all_detections) > 100:
295+
if len(detections_to_return) > 100:
271296
logger.warning(
272-
f"Detected {len(all_detections)} detections. "
297+
f"Detected {len(detections_to_return)} detections. "
273298
"This is suspicious and may contain duplicates."
274299
)
275300

276301
response = PipelineResponse(
277302
pipeline=data.pipeline,
278303
algorithms=algorithms_used,
279304
source_images=source_image_results,
280-
detections=all_detections,
305+
detections=detections_to_return,
281306
total_time=seconds_elapsed,
282307
)
283308
return response

trapdata/api/schemas.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,20 +82,23 @@ class ClassificationResponse(pydantic.BaseModel):
8282
"classification in the response. Use the category map from the algorithm "
8383
"to get the full list of labels and metadata."
8484
),
85+
repr=False, # Too long to display in the repr
8586
)
8687
scores: list[float] = pydantic.Field(
8788
default_factory=list,
8889
description=(
8990
"The calibrated probabilities for each class label, most commonly "
9091
"the softmax output."
9192
),
93+
repr=False, # Too long to display in the repr
9294
)
9395
logits: list[float] = pydantic.Field(
9496
default_factory=list,
9597
description=(
9698
"The raw logits output by the model, before any calibration or "
9799
"normalization."
98100
),
101+
repr=False, # Too long to display in the repr
99102
)
100103
inference_time: float | None = None
101104
algorithm: AlgorithmReference
@@ -153,6 +156,7 @@ class AlgorithmCategoryMapResponse(pydantic.BaseModel):
153156
{"label": "Not a moth", "index": 1, "gbif_key": 5678},
154157
]
155158
],
159+
repr=False, # Too long to display in the repr
156160
)
157161
labels: list[str] = pydantic.Field(
158162
default_factory=list,
@@ -161,6 +165,7 @@ class AlgorithmCategoryMapResponse(pydantic.BaseModel):
161165
"the model."
162166
),
163167
examples=[["Moth", "Not a moth"]],
168+
repr=False, # Too long to display in the repr
164169
)
165170
version: str | None = pydantic.Field(
166171
default=None,

trapdata/api/tests/test_api.py

Lines changed: 60 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
PipelineRequest,
1111
PipelineResponse,
1212
app,
13+
make_algorithm_response,
14+
make_pipeline_config_response,
1315
)
1416
from trapdata.api.schemas import PipelineConfigRequest, SourceImageRequest
1517
from trapdata.api.tests.image_server import StaticFileTestServer
@@ -63,11 +65,10 @@ def test_pipeline_request(self):
6365
source_images=self.get_test_images(num=2),
6466
)
6567
with self.file_server:
66-
response = self.client.post(
67-
"/pipeline/process", json=pipeline_request.model_dump()
68-
)
68+
response = self.client.post("/process", json=pipeline_request.model_dump())
6969
assert response.status_code == 200
70-
PipelineResponse(**response.json())
70+
results = PipelineResponse(**response.json())
71+
return results
7172

7273
def test_config_num_classification_predictions(self):
7374
"""
@@ -124,3 +125,58 @@ def _send_request(max_predictions_per_classification: int | None):
124125

125126
_send_request(max_predictions_per_classification=1)
126127
_send_request(max_predictions_per_classification=None)
128+
129+
def test_pipeline_config_with_binary_classifier(self):
130+
binary_classifier_pipeline_choice = "moth_binary"
131+
BinaryClassifier = CLASSIFIER_CHOICES[binary_classifier_pipeline_choice]
132+
BinaryClassifierResponse = make_algorithm_response(BinaryClassifier)
133+
134+
species_classifier_pipeline_choice = "quebec_vermont_moths_2023"
135+
SpeciesClassifier = CLASSIFIER_CHOICES[species_classifier_pipeline_choice]
136+
SpeciesClassifierResponse = make_algorithm_response(SpeciesClassifier)
137+
138+
# Test using a pipeline that finishes with a full species classifier
139+
pipeline_config = make_pipeline_config_response(
140+
SpeciesClassifier,
141+
slug=species_classifier_pipeline_choice,
142+
)
143+
144+
self.assertEqual(len(pipeline_config.algorithms), 3)
145+
self.assertEqual(
146+
pipeline_config.algorithms[-1].key, SpeciesClassifierResponse.key
147+
)
148+
self.assertEqual(
149+
pipeline_config.algorithms[1].key, BinaryClassifierResponse.key
150+
)
151+
152+
# Test using a pipeline that finishes only with a binary classifier
153+
pipeline_config_binary_only = make_pipeline_config_response(
154+
BinaryClassifier, slug=binary_classifier_pipeline_choice
155+
)
156+
157+
self.assertEqual(len(pipeline_config_binary_only.algorithms), 2)
158+
self.assertEqual(
159+
pipeline_config_binary_only.algorithms[-1].key, BinaryClassifierResponse.key
160+
)
161+
# self.assertTrue(pipeline_config_binary_only.algorithms[-1].terminal)
162+
163+
def test_processing_with_only_binary_classifier(self):
164+
binary_algorithm_key = "moth_binary"
165+
binary_algorithm = CLASSIFIER_CHOICES[binary_algorithm_key]
166+
pipeline_request = PipelineRequest(
167+
pipeline=PipelineChoice[binary_algorithm_key],
168+
source_images=self.get_test_images(num=2),
169+
)
170+
with self.file_server:
171+
response = self.client.post("/process", json=pipeline_request.model_dump())
172+
assert response.status_code == 200
173+
results = PipelineResponse(**response.json())
174+
175+
for detection in results.detections:
176+
for classification in detection.classifications:
177+
assert classification.algorithm.key == binary_algorithm_key
178+
assert classification.terminal
179+
assert classification.labels
180+
assert len(classification.labels) == binary_algorithm.num_classes
181+
assert classification.scores
182+
assert len(classification.scores) == binary_algorithm.num_classes

trapdata/api/tests/test_models.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,16 @@ def make_image():
7474

7575

7676
def get_test_images(
77-
subdirs: typing.Iterable[str] = ("vermont", "panama"), limit: int = 6
77+
subdirs: typing.Iterable[str] = ("vermont", "panama"),
78+
limit: int = 6,
79+
with_urls: bool = False,
7880
) -> list[SourceImage]:
7981
return [
80-
SourceImage(id=str(img["path"].name), filepath=img["path"])
82+
SourceImage(
83+
id=str(img["path"].name),
84+
filepath=img["path"],
85+
url=img["url"] if with_urls else None,
86+
)
8187
for subdir in subdirs
8288
for img in find_images(pathlib.Path(TEST_IMAGES_BASE_PATH) / subdir)
8389
][:limit]

0 commit comments

Comments
 (0)