Skip to content

Commit 324591d

Browse files
committed
feat: allow using moth/non-moth model as terminal classifier
1 parent 8e2c885 commit 324591d

File tree

4 files changed

+135
-50
lines changed

4 files changed

+135
-50
lines changed

trapdata/api/api.py

Lines changed: 72 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from .schemas import (
2828
AlgorithmCategoryMapResponse,
2929
AlgorithmConfigResponse,
30+
DetectionResponse,
3031
PipelineConfigResponse,
3132
)
3233
from .schemas import PipelineRequest as PipelineRequest_
@@ -45,7 +46,7 @@
4546
"costa_rica_moths_turing_2024": MothClassifierTuringCostaRica,
4647
"anguilla_moths_turing_2024": MothClassifierTuringAnguilla,
4748
"global_moths_2024": MothClassifierGlobal,
48-
# "moth_binary": MothClassifierBinary,
49+
"moth_binary": MothClassifierBinary,
4950
}
5051
_classifier_choices = dict(
5152
zip(CLASSIFIER_CHOICES.keys(), list(CLASSIFIER_CHOICES.keys()))
@@ -55,6 +56,13 @@
5556
PipelineChoice = enum.Enum("PipelineChoice", _classifier_choices)
5657

5758

59+
def should_filter_detections(Classifier: type[APIMothClassifier]) -> bool:
60+
if Classifier == MothClassifierBinary:
61+
return False
62+
else:
63+
return True
64+
65+
5866
def make_category_map_response(
5967
model: APIMothDetector | APIMothClassifier,
6068
default_taxon_rank: str = "SPECIES",
@@ -109,14 +117,23 @@ def make_algorithm_config_response(
109117
def make_pipeline_config_response(
110118
Classifier: type[APIMothClassifier],
111119
) -> PipelineConfigResponse:
120+
"""
121+
Create a configuration for an entire pipeline, given a species classifier class.
122+
"""
123+
algorithms = []
124+
112125
detector = APIMothDetector(
113126
source_images=[],
114127
)
128+
algorithms.append(make_algorithm_config_response(detector))
115129

116-
binary_classifier = MothClassifierBinary(
117-
source_images=[],
118-
detections=[],
119-
)
130+
if should_filter_detections(Classifier):
131+
binary_classifier = MothClassifierBinary(
132+
source_images=[],
133+
detections=[],
134+
terminal=False,
135+
)
136+
algorithms.append(make_algorithm_config_response(binary_classifier))
120137

121138
classifier = Classifier(
122139
source_images=[],
@@ -125,17 +142,14 @@ def make_pipeline_config_response(
125142
num_workers=settings.num_workers,
126143
terminal=True,
127144
)
145+
algorithms.append(make_algorithm_config_response(classifier))
128146

129147
return PipelineConfigResponse(
130148
name=classifier.name,
131149
slug=classifier.get_key(),
132150
description=classifier.description,
133151
version=1,
134-
algorithms=[
135-
make_algorithm_config_response(detector),
136-
make_algorithm_config_response(binary_classifier),
137-
make_algorithm_config_response(classifier),
138-
],
152+
algorithms=algorithms,
139153
)
140154

141155

@@ -169,6 +183,7 @@ async def root():
169183
@app.post(
170184
"/pipeline/process/", deprecated=True, tags=["services"]
171185
) # old endpoint, deprecated, remove after jan 2025
186+
@app.post("/process", tags=["services"]) # new endpoint
172187
@app.post("/process/", tags=["services"]) # new endpoint
173188
async def process(data: PipelineRequest) -> PipelineResponse:
174189
algorithms_used: dict[str, AlgorithmConfigResponse] = {}
@@ -192,6 +207,9 @@ async def process(data: PipelineRequest) -> PipelineResponse:
192207
]
193208

194209
start_time = time.time()
210+
211+
Classifier = CLASSIFIER_CHOICES[str(data.pipeline)]
212+
195213
detector = APIMothDetector(
196214
source_images=source_images,
197215
batch_size=settings.localization_batch_size,
@@ -203,77 +221,87 @@ async def process(data: PipelineRequest) -> PipelineResponse:
203221
num_pre_filter = len(detector_results)
204222
algorithms_used[detector.get_key()] = make_algorithm_response(detector)
205223

206-
filter = MothClassifierBinary(
207-
source_images=source_images,
208-
detections=detector_results,
209-
batch_size=settings.classification_batch_size,
210-
num_workers=settings.num_workers,
211-
# single=True if len(detector_results) == 1 else False,
212-
single=True, # @TODO solve issues with reading images in multiprocessing
213-
terminal=False,
214-
)
215-
filter.run()
216-
algorithms_used[filter.get_key()] = make_algorithm_response(filter)
224+
detections_for_terminal_classifier: list[DetectionResponse] = []
225+
detections_to_return: list[DetectionResponse] = []
226+
227+
if should_filter_detections(Classifier):
228+
filter = MothClassifierBinary(
229+
source_images=source_images,
230+
detections=detector_results,
231+
batch_size=settings.classification_batch_size,
232+
num_workers=settings.num_workers,
233+
# single=True if len(detector_results) == 1 else False,
234+
single=True, # @TODO solve issues with reading images in multiprocessing
235+
terminal=False,
236+
)
237+
filter.run()
238+
algorithms_used[filter.get_key()] = make_algorithm_response(filter)
217239

218-
# Compare num detections with num moth detections
219-
num_post_filter = len(filter.results)
220-
logger.info(
221-
f"Binary classifier returned {num_post_filter} of {num_pre_filter} detections"
222-
)
240+
# Compare num detections with num moth detections
241+
num_post_filter = len(filter.results)
242+
logger.info(
243+
f"Binary classifier returned {num_post_filter} of {num_pre_filter} detections"
244+
)
223245

224-
# Filter results based on positive_binary_label
225-
moth_detections = []
226-
non_moth_detections = []
227-
for detection in filter.results:
228-
for classification in detection.classifications:
229-
if classification.classification == filter.positive_binary_label:
230-
moth_detections.append(detection)
231-
elif classification.classification == filter.negative_binary_label:
232-
non_moth_detections.append(detection)
233-
break
246+
# Filter results based on positive_binary_label
247+
moth_detections = []
248+
non_moth_detections = []
249+
for detection in filter.results:
250+
for classification in detection.classifications:
251+
if classification.classification == filter.positive_binary_label:
252+
moth_detections.append(detection)
253+
elif classification.classification == filter.negative_binary_label:
254+
non_moth_detections.append(detection)
255+
break
256+
detections_for_terminal_classifier += moth_detections
257+
detections_to_return += non_moth_detections
258+
259+
else:
260+
logger.info("Skipping binary classification filter")
261+
detections_for_terminal_classifier += detector_results
234262

235263
logger.info(
236-
f"Sending {len(moth_detections)} of {num_pre_filter} "
264+
f"Sending {len(detections_for_terminal_classifier)} of {num_pre_filter} "
237265
"detections to the classifier"
238266
)
239267

240-
Classifier = CLASSIFIER_CHOICES[str(data.pipeline)]
241268
classifier: APIMothClassifier = Classifier(
242269
source_images=source_images,
243-
detections=moth_detections,
270+
detections=detections_for_terminal_classifier,
244271
batch_size=settings.classification_batch_size,
245272
num_workers=settings.num_workers,
246273
# single=True if len(filtered_detections) == 1 else False,
247274
single=True, # @TODO solve issues with reading images in multiprocessing
248275
example_config_param=data.config.example_config_param,
249276
terminal=True,
277+
# critera=data.config.criteria, # @TODO another approach to intermediate filter models
250278
)
251279
classifier.run()
252280
end_time = time.time()
253281
seconds_elapsed = float(end_time - start_time)
254282
algorithms_used[classifier.get_key()] = make_algorithm_response(classifier)
255283

256284
# Return all detections, including those that were not classified as moths
257-
all_detections = classifier.results + non_moth_detections
285+
detections_to_return += classifier.results
258286

259287
logger.info(
260288
f"Processed {len(source_images)} images in {seconds_elapsed:.2f} seconds"
261289
)
262-
logger.info(f"Returning {len(all_detections)} detections")
290+
logger.info(f"Returning {len(detections_to_return)} detections")
263291
# print(all_detections)
264292

265293
# If the number of detections is greater than 100, its suspicious. Log it.
266-
if len(all_detections) > 100:
294+
if len(detections_to_return) > 100:
267295
logger.warning(
268-
f"Detected {len(all_detections)} detections. "
296+
f"Detected {len(detections_to_return)} detections. "
269297
"This is suspicious and may contain duplicates."
270298
)
271299

272300
response = PipelineResponse(
273301
pipeline=data.pipeline,
274302
algorithms=algorithms_used,
275303
source_images=source_image_results,
276-
detections=all_detections,
304+
detections=detections_to_return,
277305
total_time=seconds_elapsed,
278306
)
279307
return response

trapdata/api/schemas.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ class AlgorithmCategoryMapResponse(pydantic.BaseModel):
153153
{"label": "Not a moth", "index": 1, "gbif_key": 5678},
154154
]
155155
],
156+
repr=False, # Too long to display in the repr
156157
)
157158
labels: list[str] = pydantic.Field(
158159
default_factory=list,
@@ -161,6 +162,7 @@ class AlgorithmCategoryMapResponse(pydantic.BaseModel):
161162
"the model."
162163
),
163164
examples=[["Moth", "Not a moth"]],
165+
repr=False, # Too long to display in the repr
164166
)
165167
version: str | None = pydantic.Field(
166168
default=None,

trapdata/api/tests/test_api.py

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
PipelineRequest,
1111
PipelineResponse,
1212
app,
13+
make_algorithm_response,
14+
make_pipeline_config_response,
1315
)
1416
from trapdata.api.schemas import PipelineConfigRequest, SourceImageRequest
1517
from trapdata.api.tests.image_server import StaticFileTestServer
@@ -63,11 +65,10 @@ def test_pipeline_request(self):
6365
source_images=self.get_test_images(num=2),
6466
)
6567
with self.file_server:
66-
response = self.client.post(
67-
"/pipeline/process", json=pipeline_request.model_dump()
68-
)
68+
response = self.client.post("/process", json=pipeline_request.model_dump())
6969
assert response.status_code == 200
70-
PipelineResponse(**response.json())
70+
results = PipelineResponse(**response.json())
71+
return results
7172

7273
def test_config_num_classification_predictions(self):
7374
"""
@@ -124,3 +125,51 @@ def _send_request(max_predictions_per_classification: int | None):
124125

125126
_send_request(max_predictions_per_classification=1)
126127
_send_request(max_predictions_per_classification=None)
128+
129+
def test_pipeline_config_with_binary_classifier(self):
130+
BinaryClassifier = CLASSIFIER_CHOICES["moth_binary"]
131+
BinaryClassifierResponse = make_algorithm_response(BinaryClassifier)
132+
133+
SpeciesClassifier = CLASSIFIER_CHOICES["quebec_vermont_moths_2023"]
134+
SpeciesClassifierResponse = make_algorithm_response(SpeciesClassifier)
135+
136+
# Test using a pipeline that finishes with a full species classifier
137+
pipeline_config = make_pipeline_config_response(SpeciesClassifier)
138+
139+
self.assertEqual(len(pipeline_config.algorithms), 3)
140+
self.assertEqual(
141+
pipeline_config.algorithms[-1].key, SpeciesClassifierResponse.key
142+
)
143+
self.assertEqual(
144+
pipeline_config.algorithms[1].key, BinaryClassifierResponse.key
145+
)
146+
147+
# Test using a pipeline that finishes only with a binary classifier
148+
pipeline_config_binary_only = make_pipeline_config_response(BinaryClassifier)
149+
150+
self.assertEqual(len(pipeline_config_binary_only.algorithms), 2)
151+
self.assertEqual(
152+
pipeline_config_binary_only.algorithms[-1].key, BinaryClassifierResponse.key
153+
)
154+
# self.assertTrue(pipeline_config_binary_only.algorithms[-1].terminal)
155+
156+
def test_processing_with_only_binary_classifier(self):
157+
binary_algorithm_key = "moth_binary"
158+
binary_algorithm = CLASSIFIER_CHOICES[binary_algorithm_key]
159+
pipeline_request = PipelineRequest(
160+
pipeline=PipelineChoice[binary_algorithm_key],
161+
source_images=self.get_test_images(num=2),
162+
)
163+
with self.file_server:
164+
response = self.client.post("/process", json=pipeline_request.model_dump())
165+
assert response.status_code == 200
166+
results = PipelineResponse(**response.json())
167+
168+
for detection in results.detections:
169+
for classification in detection.classifications:
170+
assert classification.algorithm.key == binary_algorithm_key
171+
assert classification.terminal
172+
assert classification.labels
173+
assert len(classification.labels) == binary_algorithm.num_classes
174+
assert classification.scores
175+
assert len(classification.scores) == binary_algorithm.num_classes

trapdata/api/tests/test_models.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,16 @@ def make_image():
7474

7575

7676
def get_test_images(
77-
subdirs: typing.Iterable[str] = ("vermont", "panama"), limit: int = 6
77+
subdirs: typing.Iterable[str] = ("vermont", "panama"),
78+
limit: int = 6,
79+
with_urls: bool = False,
7880
) -> list[SourceImage]:
7981
return [
80-
SourceImage(id=str(img["path"].name), filepath=img["path"])
82+
SourceImage(
83+
id=str(img["path"].name),
84+
filepath=img["path"],
85+
url=img["url"] if with_urls else None,
86+
)
8187
for subdir in subdirs
8288
for img in find_images(pathlib.Path(TEST_IMAGES_BASE_PATH) / subdir)
8389
][:limit]

0 commit comments

Comments
 (0)