1
1
"""
2
- Fast API interface for processing images through the localization and classification pipelines.
2
+ Fast API interface for processing images through the localization and classification
3
+ pipelines.
3
4
"""
4
5
5
6
import enum
6
7
import time
7
8
8
9
import fastapi
9
10
import pydantic
10
- from rich import print
11
+ from fastapi . middleware . gzip import GZipMiddleware
11
12
12
13
from ..common .logs import logger # noqa: F401
13
14
from . import settings
23
24
MothClassifierUKDenmark ,
24
25
)
25
26
from .models .localization import APIMothDetector
26
- from .schemas import Detection , SourceImage
27
+ from .schemas import (
28
+ AlgorithmCategoryMapResponse ,
29
+ AlgorithmConfigResponse ,
30
+ PipelineConfigResponse ,
31
+ )
32
+ from .schemas import PipelineRequest as PipelineRequest_
33
+ from .schemas import PipelineResultsResponse as PipelineResponse_
34
+ from .schemas import ProcessingServiceInfoResponse , SourceImage , SourceImageResponse
27
35
28
36
app = fastapi .FastAPI ()
37
+ app .add_middleware (GZipMiddleware )
29
38
30
39
31
- class SourceImageRequest (pydantic .BaseModel ):
32
- model_config = pydantic .ConfigDict (extra = "ignore" )
33
-
34
- # @TODO bring over new SourceImage & b64 validation from the lepsAI repo
35
- id : str
36
- url : str
37
- # b64: str | None = None
38
-
39
-
40
- class SourceImageResponse (pydantic .BaseModel ):
41
- model_config = pydantic .ConfigDict (extra = "ignore" )
42
-
43
- id : str
44
- url : str
45
-
46
-
47
- PIPELINE_CHOICES = {
40
+ CLASSIFIER_CHOICES = {
48
41
"panama_moths_2023" : MothClassifierPanama ,
49
42
"panama_moths_2024" : MothClassifierPanama2024 ,
50
43
"quebec_vermont_moths_2023" : MothClassifierQuebecVermont ,
51
44
"uk_denmark_moths_2023" : MothClassifierUKDenmark ,
52
45
"costa_rica_moths_turing_2024" : MothClassifierTuringCostaRica ,
53
46
"anguilla_moths_turing_2024" : MothClassifierTuringAnguilla ,
54
47
"global_moths_2024" : MothClassifierGlobal ,
48
+ # "moth_binary": MothClassifierBinary,
55
49
}
56
- _pipeline_choices = dict (zip (PIPELINE_CHOICES .keys (), list (PIPELINE_CHOICES .keys ())))
50
+ _classifier_choices = dict (
51
+ zip (CLASSIFIER_CHOICES .keys (), list (CLASSIFIER_CHOICES .keys ()))
52
+ )
53
+
57
54
55
+ PipelineChoice = enum .Enum ("PipelineChoice" , _classifier_choices )
58
56
59
- PipelineChoice = enum .Enum ("PipelineChoice" , _pipeline_choices )
60
57
58
+ def make_category_map_response (
59
+ model : APIMothDetector | APIMothClassifier ,
60
+ default_taxon_rank : str = "SPECIES" ,
61
+ ) -> AlgorithmCategoryMapResponse :
62
+ categories_sorted_by_index = sorted (model .category_map .items (), key = lambda x : x [0 ])
63
+ # as list of dicts:
64
+ categories_sorted_by_index = [
65
+ {
66
+ "index" : index ,
67
+ "label" : label ,
68
+ "taxon_rank" : default_taxon_rank ,
69
+ }
70
+ for index , label in categories_sorted_by_index
71
+ ]
72
+ label_strings_sorted_by_index = [cat ["label" ] for cat in categories_sorted_by_index ]
73
+ return AlgorithmCategoryMapResponse (
74
+ data = categories_sorted_by_index ,
75
+ labels = label_strings_sorted_by_index ,
76
+ uri = model .labels_path ,
77
+ )
61
78
62
- class PipelineRequest (pydantic .BaseModel ):
63
- pipeline : PipelineChoice
64
- source_images : list [SourceImageRequest ]
65
79
80
+ def make_algorithm_response (
81
+ model : APIMothDetector | APIMothClassifier ,
82
+ ) -> AlgorithmConfigResponse :
66
83
67
- class PipelineResponse (pydantic .BaseModel ):
68
- pipeline : PipelineChoice
69
- total_time : float
70
- source_images : list [SourceImageResponse ]
71
- detections : list [Detection ]
84
+ category_map = make_category_map_response (model ) if model .category_map else None
85
+ return AlgorithmConfigResponse (
86
+ name = model .name ,
87
+ key = model .get_key (),
88
+ task_type = model .task_type ,
89
+ description = model .description ,
90
+ category_map = category_map ,
91
+ uri = model .weights_path ,
92
+ )
93
+
94
+
95
+ def make_algorithm_config_response (
96
+ model : APIMothDetector | APIMothClassifier ,
97
+ ) -> AlgorithmConfigResponse :
98
+ category_map = make_category_map_response (model )
99
+ return AlgorithmConfigResponse (
100
+ name = model .name ,
101
+ key = model .get_key (),
102
+ task_type = model .task_type ,
103
+ description = model .description ,
104
+ category_map = category_map ,
105
+ uri = model .weights_path ,
106
+ )
107
+
108
+
109
+ def make_pipeline_config_response (
110
+ Classifier : type [APIMothClassifier ],
111
+ ) -> PipelineConfigResponse :
112
+ detector = APIMothDetector (
113
+ source_images = [],
114
+ )
115
+
116
+ binary_classifier = MothClassifierBinary (
117
+ source_images = [],
118
+ detections = [],
119
+ )
120
+
121
+ classifier = Classifier (
122
+ source_images = [],
123
+ detections = [],
124
+ batch_size = settings .classification_batch_size ,
125
+ num_workers = settings .num_workers ,
126
+ terminal = True ,
127
+ )
128
+
129
+ return PipelineConfigResponse (
130
+ name = classifier .name ,
131
+ slug = classifier .get_key (),
132
+ description = classifier .description ,
133
+ version = 1 ,
134
+ algorithms = [
135
+ make_algorithm_config_response (detector ),
136
+ make_algorithm_config_response (binary_classifier ),
137
+ make_algorithm_config_response (classifier ),
138
+ ],
139
+ )
140
+
141
+
142
+ # @TODO This requires loading all models into memory! Can we avoid this?
143
+ PIPELINE_CONFIGS = [
144
+ make_pipeline_config_response (classifier_class )
145
+ for classifier_class in CLASSIFIER_CHOICES .values ()
146
+ ]
147
+
148
+
149
+ class PipelineRequest (PipelineRequest_ ):
150
+ pipeline : PipelineChoice = pydantic .Field (
151
+ description = PipelineRequest_ .model_fields ["pipeline" ].description ,
152
+ examples = list (_classifier_choices .keys ()),
153
+ )
154
+
155
+
156
+ class PipelineResponse (PipelineResponse_ ):
157
+ pipeline : PipelineChoice = pydantic .Field (
158
+ PipelineChoice ,
159
+ description = PipelineResponse_ .model_fields ["pipeline" ].description ,
160
+ examples = list (_classifier_choices .keys ()),
161
+ )
72
162
73
163
74
164
@app .get ("/" )
75
165
async def root ():
76
166
return fastapi .responses .RedirectResponse ("/docs" )
77
167
78
168
79
- @app .post ("/pipeline/process" )
80
- @app .post ("/pipeline/process/" )
169
+ @app .post (
170
+ "/pipeline/process/" , deprecated = True , tags = ["services" ]
171
+ ) # old endpoint, deprecated, remove after jan 2025
172
+ @app .post ("/process/" , tags = ["services" ]) # new endpoint
81
173
async def process (data : PipelineRequest ) -> PipelineResponse :
174
+ algorithms_used : dict [str , AlgorithmConfigResponse ] = {}
175
+
82
176
# Ensure that the source images are unique, filter out duplicates
83
177
source_images_index = {
84
178
source_image .id : source_image for source_image in data .source_images
85
179
}
86
180
incoming_source_images = list (source_images_index .values ())
87
181
if len (incoming_source_images ) != len (data .source_images ):
88
182
logger .warning (
89
- f"Removed { len (data .source_images ) - len (incoming_source_images )} duplicate source images"
183
+ f"Removed { len (data .source_images ) - len (incoming_source_images )} "
184
+ "duplicate source images"
90
185
)
91
186
92
187
source_image_results = [
@@ -106,6 +201,7 @@ async def process(data: PipelineRequest) -> PipelineResponse:
106
201
)
107
202
detector_results = detector .run ()
108
203
num_pre_filter = len (detector_results )
204
+ algorithms_used [detector .get_key ()] = make_algorithm_response (detector )
109
205
110
206
filter = MothClassifierBinary (
111
207
source_images = source_images ,
@@ -114,15 +210,15 @@ async def process(data: PipelineRequest) -> PipelineResponse:
114
210
num_workers = settings .num_workers ,
115
211
# single=True if len(detector_results) == 1 else False,
116
212
single = True , # @TODO solve issues with reading images in multiprocessing
117
- filter_results = False , # Only save results with the positive_binary_label, @TODO make this configurable from request
213
+ terminal = False ,
118
214
)
119
215
filter .run ()
120
- # all_binary_classifications = filter.results
216
+ algorithms_used [ filter . get_key ()] = make_algorithm_response ( filter )
121
217
122
218
# Compare num detections with num moth detections
123
219
num_post_filter = len (filter .results )
124
220
logger .info (
125
- f"Binary classifier returned { num_post_filter } out of { num_pre_filter } detections"
221
+ f"Binary classifier returned { num_post_filter } of { num_pre_filter } detections"
126
222
)
127
223
128
224
# Filter results based on positive_binary_label
@@ -137,21 +233,25 @@ async def process(data: PipelineRequest) -> PipelineResponse:
137
233
break
138
234
139
235
logger .info (
140
- f"Sending { len (moth_detections )} out of { num_pre_filter } detections to the classifier"
236
+ f"Sending { len (moth_detections )} of { num_pre_filter } "
237
+ "detections to the classifier"
141
238
)
142
239
143
- Classifier = PIPELINE_CHOICES [ data .pipeline . value ]
240
+ Classifier = CLASSIFIER_CHOICES [ str ( data .pipeline ) ]
144
241
classifier : APIMothClassifier = Classifier (
145
242
source_images = source_images ,
146
243
detections = moth_detections ,
147
244
batch_size = settings .classification_batch_size ,
148
245
num_workers = settings .num_workers ,
149
246
# single=True if len(filtered_detections) == 1 else False,
150
247
single = True , # @TODO solve issues with reading images in multiprocessing
248
+ example_config_param = data .config .example_config_param ,
249
+ terminal = True ,
151
250
)
152
251
classifier .run ()
153
252
end_time = time .time ()
154
253
seconds_elapsed = float (end_time - start_time )
254
+ algorithms_used [classifier .get_key ()] = make_algorithm_response (classifier )
155
255
156
256
# Return all detections, including those that were not classified as moths
157
257
all_detections = classifier .results + non_moth_detections
@@ -160,23 +260,64 @@ async def process(data: PipelineRequest) -> PipelineResponse:
160
260
f"Processed { len (source_images )} images in { seconds_elapsed :.2f} seconds"
161
261
)
162
262
logger .info (f"Returning { len (all_detections )} detections" )
163
- print (all_detections )
263
+ # print(all_detections)
164
264
165
265
# If the number of detections is greater than 100, its suspicious. Log it.
166
266
if len (all_detections ) > 100 :
167
267
logger .warning (
168
- f"Detected { len (all_detections )} detections. This is suspicious and may contain duplicates."
268
+ f"Detected { len (all_detections )} detections. "
269
+ "This is suspicious and may contain duplicates."
169
270
)
170
271
171
272
response = PipelineResponse (
172
273
pipeline = data .pipeline ,
274
+ algorithms = algorithms_used ,
173
275
source_images = source_image_results ,
174
276
detections = all_detections ,
175
277
total_time = seconds_elapsed ,
176
278
)
177
279
return response
178
280
179
281
282
+ @app .get ("/info" , tags = ["services" ])
283
+ async def info () -> ProcessingServiceInfoResponse :
284
+ info = ProcessingServiceInfoResponse (
285
+ name = "Antenna Inference API" ,
286
+ description = (
287
+ "The primary endpoint for processing images for the Antenna platform. "
288
+ "This API provides access to multiple detection and classification "
289
+ "algorithms by multiple labs for processing images of moths."
290
+ ),
291
+ pipelines = PIPELINE_CONFIGS ,
292
+ # algorithms=list(algorithm_choices.values()),
293
+ )
294
+ return info
295
+
296
+
297
+ # Check if the server is online
298
+ @app .get ("/livez" , tags = ["health checks" ])
299
+ async def livez ():
300
+ return fastapi .responses .JSONResponse (status_code = 200 , content = {"status" : True })
301
+
302
+
303
+ # Check if the pipelines are ready to process data
304
+ @app .get ("/readyz" , tags = ["health checks" ])
305
+ async def readyz ():
306
+ """
307
+ Check if the server is ready to process data.
308
+
309
+ Returns a list of pipeline slugs that are online and ready to process data.
310
+ @TODO may need to simplify this to just return True/False. Pipeline algorithms will
311
+ likely be loaded into memory on-demand when the pipeline is selected.
312
+ """
313
+ if _classifier_choices :
314
+ return fastapi .responses .JSONResponse (
315
+ status_code = 200 , content = {"status" : list (_classifier_choices .keys ())}
316
+ )
317
+ else :
318
+ return fastapi .responses .JSONResponse (status_code = 503 , content = {"status" : []})
319
+
320
+
180
321
# Future methods
181
322
182
323
# batch processing
0 commit comments