27
27
from .schemas import (
28
28
AlgorithmCategoryMapResponse ,
29
29
AlgorithmConfigResponse ,
30
+ DetectionResponse ,
30
31
PipelineConfigResponse ,
31
32
)
32
33
from .schemas import PipelineRequest as PipelineRequest_
45
46
"costa_rica_moths_turing_2024" : MothClassifierTuringCostaRica ,
46
47
"anguilla_moths_turing_2024" : MothClassifierTuringAnguilla ,
47
48
"global_moths_2024" : MothClassifierGlobal ,
48
- # "moth_binary": MothClassifierBinary,
49
+ "moth_binary" : MothClassifierBinary ,
49
50
}
50
51
_classifier_choices = dict (
51
52
zip (CLASSIFIER_CHOICES .keys (), list (CLASSIFIER_CHOICES .keys ()))
55
56
PipelineChoice = enum .Enum ("PipelineChoice" , _classifier_choices )
56
57
57
58
59
+ def should_filter_detections (Classifier : type [APIMothClassifier ]) -> bool :
60
+ if Classifier == MothClassifierBinary :
61
+ return False
62
+ else :
63
+ return True
64
+
65
+
58
66
def make_category_map_response (
59
67
model : APIMothDetector | APIMothClassifier ,
60
68
default_taxon_rank : str = "SPECIES" ,
@@ -113,14 +121,20 @@ def make_pipeline_config_response(
113
121
"""
114
122
Create a configuration for an entire pipeline, given a species classifier class.
115
123
"""
124
+ algorithms = []
125
+
116
126
detector = APIMothDetector (
117
127
source_images = [],
118
128
)
129
+ algorithms .append (make_algorithm_config_response (detector ))
119
130
120
- binary_classifier = MothClassifierBinary (
121
- source_images = [],
122
- detections = [],
123
- )
131
+ if should_filter_detections (Classifier ):
132
+ binary_classifier = MothClassifierBinary (
133
+ source_images = [],
134
+ detections = [],
135
+ terminal = False ,
136
+ )
137
+ algorithms .append (make_algorithm_config_response (binary_classifier ))
124
138
125
139
classifier = Classifier (
126
140
source_images = [],
@@ -129,17 +143,14 @@ def make_pipeline_config_response(
129
143
num_workers = settings .num_workers ,
130
144
terminal = True ,
131
145
)
146
+ algorithms .append (make_algorithm_config_response (classifier ))
132
147
133
148
return PipelineConfigResponse (
134
149
name = classifier .name ,
135
150
slug = slug ,
136
151
description = classifier .description ,
137
152
version = 1 ,
138
- algorithms = [
139
- make_algorithm_config_response (detector ),
140
- make_algorithm_config_response (binary_classifier ),
141
- make_algorithm_config_response (classifier ),
142
- ],
153
+ algorithms = algorithms ,
143
154
)
144
155
145
156
@@ -173,6 +184,7 @@ async def root():
173
184
@app .post (
174
185
"/pipeline/process/" , deprecated = True , tags = ["services" ]
175
186
) # old endpoint, deprecated, remove after jan 2025
187
+ @app .post ("/process" , tags = ["services" ]) # new endpoint
176
188
@app .post ("/process/" , tags = ["services" ]) # new endpoint
177
189
async def process (data : PipelineRequest ) -> PipelineResponse :
178
190
algorithms_used : dict [str , AlgorithmConfigResponse ] = {}
@@ -196,6 +208,9 @@ async def process(data: PipelineRequest) -> PipelineResponse:
196
208
]
197
209
198
210
start_time = time .time ()
211
+
212
+ Classifier = CLASSIFIER_CHOICES [str (data .pipeline )]
213
+
199
214
detector = APIMothDetector (
200
215
source_images = source_images ,
201
216
batch_size = settings .localization_batch_size ,
@@ -207,77 +222,87 @@ async def process(data: PipelineRequest) -> PipelineResponse:
207
222
num_pre_filter = len (detector_results )
208
223
algorithms_used [detector .get_key ()] = make_algorithm_response (detector )
209
224
210
- filter = MothClassifierBinary (
211
- source_images = source_images ,
212
- detections = detector_results ,
213
- batch_size = settings .classification_batch_size ,
214
- num_workers = settings .num_workers ,
215
- # single=True if len(detector_results) == 1 else False,
216
- single = True , # @TODO solve issues with reading images in multiprocessing
217
- terminal = False ,
218
- )
219
- filter .run ()
220
- algorithms_used [filter .get_key ()] = make_algorithm_response (filter )
225
+ detections_for_terminal_classifier : list [DetectionResponse ] = []
226
+ detections_to_return : list [DetectionResponse ] = []
227
+
228
+ if should_filter_detections (Classifier ):
229
+ filter = MothClassifierBinary (
230
+ source_images = source_images ,
231
+ detections = detector_results ,
232
+ batch_size = settings .classification_batch_size ,
233
+ num_workers = settings .num_workers ,
234
+ # single=True if len(detector_results) == 1 else False,
235
+ single = True , # @TODO solve issues with reading images in multiprocessing
236
+ terminal = False ,
237
+ )
238
+ filter .run ()
239
+ algorithms_used [filter .get_key ()] = make_algorithm_response (filter )
221
240
222
- # Compare num detections with num moth detections
223
- num_post_filter = len (filter .results )
224
- logger .info (
225
- f"Binary classifier returned { num_post_filter } of { num_pre_filter } detections"
226
- )
241
+ # Compare num detections with num moth detections
242
+ num_post_filter = len (filter .results )
243
+ logger .info (
244
+ f"Binary classifier returned { num_post_filter } of { num_pre_filter } detections"
245
+ )
227
246
228
- # Filter results based on positive_binary_label
229
- moth_detections = []
230
- non_moth_detections = []
231
- for detection in filter .results :
232
- for classification in detection .classifications :
233
- if classification .classification == filter .positive_binary_label :
234
- moth_detections .append (detection )
235
- elif classification .classification == filter .negative_binary_label :
236
- non_moth_detections .append (detection )
237
- break
247
+ # Filter results based on positive_binary_label
248
+ moth_detections = []
249
+ non_moth_detections = []
250
+ for detection in filter .results :
251
+ for classification in detection .classifications :
252
+ if classification .classification == filter .positive_binary_label :
253
+ moth_detections .append (detection )
254
+ elif classification .classification == filter .negative_binary_label :
255
+ non_moth_detections .append (detection )
256
+ break
257
+ detections_for_terminal_classifier += moth_detections
258
+ detections_to_return += non_moth_detections
259
+
260
+ else :
261
+ logger .info ("Skipping binary classification filter" )
262
+ detections_for_terminal_classifier += detector_results
238
263
239
264
logger .info (
240
- f"Sending { len (moth_detections )} of { num_pre_filter } "
265
+ f"Sending { len (detections_for_terminal_classifier )} of { num_pre_filter } "
241
266
"detections to the classifier"
242
267
)
243
268
244
- Classifier = CLASSIFIER_CHOICES [str (data .pipeline )]
245
269
classifier : APIMothClassifier = Classifier (
246
270
source_images = source_images ,
247
- detections = moth_detections ,
271
+ detections = detections_for_terminal_classifier ,
248
272
batch_size = settings .classification_batch_size ,
249
273
num_workers = settings .num_workers ,
250
274
# single=True if len(filtered_detections) == 1 else False,
251
275
single = True , # @TODO solve issues with reading images in multiprocessing
252
276
example_config_param = data .config .example_config_param ,
253
277
terminal = True ,
278
+ # critera=data.config.criteria, # @TODO another approach to intermediate filter models
254
279
)
255
280
classifier .run ()
256
281
end_time = time .time ()
257
282
seconds_elapsed = float (end_time - start_time )
258
283
algorithms_used [classifier .get_key ()] = make_algorithm_response (classifier )
259
284
260
285
# Return all detections, including those that were not classified as moths
261
- all_detections = classifier .results + non_moth_detections
286
+ detections_to_return + = classifier .results
262
287
263
288
logger .info (
264
289
f"Processed { len (source_images )} images in { seconds_elapsed :.2f} seconds"
265
290
)
266
- logger .info (f"Returning { len (all_detections )} detections" )
291
+ logger .info (f"Returning { len (detections_to_return )} detections" )
267
292
# print(all_detections)
268
293
269
294
# If the number of detections is greater than 100, its suspicious. Log it.
270
- if len (all_detections ) > 100 :
295
+ if len (detections_to_return ) > 100 :
271
296
logger .warning (
272
- f"Detected { len (all_detections )} detections. "
297
+ f"Detected { len (detections_to_return )} detections. "
273
298
"This is suspicious and may contain duplicates."
274
299
)
275
300
276
301
response = PipelineResponse (
277
302
pipeline = data .pipeline ,
278
303
algorithms = algorithms_used ,
279
304
source_images = source_image_results ,
280
- detections = all_detections ,
305
+ detections = detections_to_return ,
281
306
total_time = seconds_elapsed ,
282
307
)
283
308
return response
0 commit comments