27
27
from .schemas import (
28
28
AlgorithmCategoryMapResponse ,
29
29
AlgorithmConfigResponse ,
30
+ DetectionResponse ,
30
31
PipelineConfigResponse ,
31
32
)
32
33
from .schemas import PipelineRequest as PipelineRequest_
45
46
"costa_rica_moths_turing_2024" : MothClassifierTuringCostaRica ,
46
47
"anguilla_moths_turing_2024" : MothClassifierTuringAnguilla ,
47
48
"global_moths_2024" : MothClassifierGlobal ,
48
- # "moth_binary": MothClassifierBinary,
49
+ "moth_binary" : MothClassifierBinary ,
49
50
}
50
51
_classifier_choices = dict (
51
52
zip (CLASSIFIER_CHOICES .keys (), list (CLASSIFIER_CHOICES .keys ()))
55
56
PipelineChoice = enum .Enum ("PipelineChoice" , _classifier_choices )
56
57
57
58
59
+ def should_filter_detections (Classifier : type [APIMothClassifier ]) -> bool :
60
+ if Classifier == MothClassifierBinary :
61
+ return False
62
+ else :
63
+ return True
64
+
65
+
58
66
def make_category_map_response (
59
67
model : APIMothDetector | APIMothClassifier ,
60
68
default_taxon_rank : str = "SPECIES" ,
@@ -109,14 +117,23 @@ def make_algorithm_config_response(
109
117
def make_pipeline_config_response (
110
118
Classifier : type [APIMothClassifier ],
111
119
) -> PipelineConfigResponse :
120
+ """
121
+ Create a configuration for an entire pipeline, given a species classifier class.
122
+ """
123
+ algorithms = []
124
+
112
125
detector = APIMothDetector (
113
126
source_images = [],
114
127
)
128
+ algorithms .append (make_algorithm_config_response (detector ))
115
129
116
- binary_classifier = MothClassifierBinary (
117
- source_images = [],
118
- detections = [],
119
- )
130
+ if should_filter_detections (Classifier ):
131
+ binary_classifier = MothClassifierBinary (
132
+ source_images = [],
133
+ detections = [],
134
+ terminal = False ,
135
+ )
136
+ algorithms .append (make_algorithm_config_response (binary_classifier ))
120
137
121
138
classifier = Classifier (
122
139
source_images = [],
@@ -125,17 +142,14 @@ def make_pipeline_config_response(
125
142
num_workers = settings .num_workers ,
126
143
terminal = True ,
127
144
)
145
+ algorithms .append (make_algorithm_config_response (classifier ))
128
146
129
147
return PipelineConfigResponse (
130
148
name = classifier .name ,
131
149
slug = classifier .get_key (),
132
150
description = classifier .description ,
133
151
version = 1 ,
134
- algorithms = [
135
- make_algorithm_config_response (detector ),
136
- make_algorithm_config_response (binary_classifier ),
137
- make_algorithm_config_response (classifier ),
138
- ],
152
+ algorithms = algorithms ,
139
153
)
140
154
141
155
@@ -169,6 +183,7 @@ async def root():
169
183
@app .post (
170
184
"/pipeline/process/" , deprecated = True , tags = ["services" ]
171
185
) # old endpoint, deprecated, remove after jan 2025
186
+ @app .post ("/process" , tags = ["services" ]) # new endpoint
172
187
@app .post ("/process/" , tags = ["services" ]) # new endpoint
173
188
async def process (data : PipelineRequest ) -> PipelineResponse :
174
189
algorithms_used : dict [str , AlgorithmConfigResponse ] = {}
@@ -192,6 +207,9 @@ async def process(data: PipelineRequest) -> PipelineResponse:
192
207
]
193
208
194
209
start_time = time .time ()
210
+
211
+ Classifier = CLASSIFIER_CHOICES [str (data .pipeline )]
212
+
195
213
detector = APIMothDetector (
196
214
source_images = source_images ,
197
215
batch_size = settings .localization_batch_size ,
@@ -203,77 +221,87 @@ async def process(data: PipelineRequest) -> PipelineResponse:
203
221
num_pre_filter = len (detector_results )
204
222
algorithms_used [detector .get_key ()] = make_algorithm_response (detector )
205
223
206
- filter = MothClassifierBinary (
207
- source_images = source_images ,
208
- detections = detector_results ,
209
- batch_size = settings .classification_batch_size ,
210
- num_workers = settings .num_workers ,
211
- # single=True if len(detector_results) == 1 else False,
212
- single = True , # @TODO solve issues with reading images in multiprocessing
213
- terminal = False ,
214
- )
215
- filter .run ()
216
- algorithms_used [filter .get_key ()] = make_algorithm_response (filter )
224
+ detections_for_terminal_classifier : list [DetectionResponse ] = []
225
+ detections_to_return : list [DetectionResponse ] = []
226
+
227
+ if should_filter_detections (Classifier ):
228
+ filter = MothClassifierBinary (
229
+ source_images = source_images ,
230
+ detections = detector_results ,
231
+ batch_size = settings .classification_batch_size ,
232
+ num_workers = settings .num_workers ,
233
+ # single=True if len(detector_results) == 1 else False,
234
+ single = True , # @TODO solve issues with reading images in multiprocessing
235
+ terminal = False ,
236
+ )
237
+ filter .run ()
238
+ algorithms_used [filter .get_key ()] = make_algorithm_response (filter )
217
239
218
- # Compare num detections with num moth detections
219
- num_post_filter = len (filter .results )
220
- logger .info (
221
- f"Binary classifier returned { num_post_filter } of { num_pre_filter } detections"
222
- )
240
+ # Compare num detections with num moth detections
241
+ num_post_filter = len (filter .results )
242
+ logger .info (
243
+ f"Binary classifier returned { num_post_filter } of { num_pre_filter } detections"
244
+ )
223
245
224
- # Filter results based on positive_binary_label
225
- moth_detections = []
226
- non_moth_detections = []
227
- for detection in filter .results :
228
- for classification in detection .classifications :
229
- if classification .classification == filter .positive_binary_label :
230
- moth_detections .append (detection )
231
- elif classification .classification == filter .negative_binary_label :
232
- non_moth_detections .append (detection )
233
- break
246
+ # Filter results based on positive_binary_label
247
+ moth_detections = []
248
+ non_moth_detections = []
249
+ for detection in filter .results :
250
+ for classification in detection .classifications :
251
+ if classification .classification == filter .positive_binary_label :
252
+ moth_detections .append (detection )
253
+ elif classification .classification == filter .negative_binary_label :
254
+ non_moth_detections .append (detection )
255
+ break
256
+ detections_for_terminal_classifier += moth_detections
257
+ detections_to_return += non_moth_detections
258
+
259
+ else :
260
+ logger .info ("Skipping binary classification filter" )
261
+ detections_for_terminal_classifier += detector_results
234
262
235
263
logger .info (
236
- f"Sending { len (moth_detections )} of { num_pre_filter } "
264
+ f"Sending { len (detections_for_terminal_classifier )} of { num_pre_filter } "
237
265
"detections to the classifier"
238
266
)
239
267
240
- Classifier = CLASSIFIER_CHOICES [str (data .pipeline )]
241
268
classifier : APIMothClassifier = Classifier (
242
269
source_images = source_images ,
243
- detections = moth_detections ,
270
+ detections = detections_for_terminal_classifier ,
244
271
batch_size = settings .classification_batch_size ,
245
272
num_workers = settings .num_workers ,
246
273
# single=True if len(filtered_detections) == 1 else False,
247
274
single = True , # @TODO solve issues with reading images in multiprocessing
248
275
example_config_param = data .config .example_config_param ,
249
276
terminal = True ,
277
+ # critera=data.config.criteria, # @TODO another approach to intermediate filter models
250
278
)
251
279
classifier .run ()
252
280
end_time = time .time ()
253
281
seconds_elapsed = float (end_time - start_time )
254
282
algorithms_used [classifier .get_key ()] = make_algorithm_response (classifier )
255
283
256
284
# Return all detections, including those that were not classified as moths
257
- all_detections = classifier .results + non_moth_detections
285
+ detections_to_return + = classifier .results
258
286
259
287
logger .info (
260
288
f"Processed { len (source_images )} images in { seconds_elapsed :.2f} seconds"
261
289
)
262
- logger .info (f"Returning { len (all_detections )} detections" )
290
+ logger .info (f"Returning { len (detections_to_return )} detections" )
263
291
# print(all_detections)
264
292
265
293
# If the number of detections is greater than 100, its suspicious. Log it.
266
- if len (all_detections ) > 100 :
294
+ if len (detections_to_return ) > 100 :
267
295
logger .warning (
268
- f"Detected { len (all_detections )} detections. "
296
+ f"Detected { len (detections_to_return )} detections. "
269
297
"This is suspicious and may contain duplicates."
270
298
)
271
299
272
300
response = PipelineResponse (
273
301
pipeline = data .pipeline ,
274
302
algorithms = algorithms_used ,
275
303
source_images = source_image_results ,
276
- detections = all_detections ,
304
+ detections = detections_to_return ,
277
305
total_time = seconds_elapsed ,
278
306
)
279
307
return response
0 commit comments