Skip to content

Commit 8e2c885

Browse files
authored
ML Pipeline v2 (#67)
* feat: allow all softmax scores to be returned * feat: test API requests in addition to underlying methods * fix: intial api tests working with test image server * test: ensure pipeline requests respect the classification_num_predictions config * feat: update parameter name and API examples * chore: update dependencies lock file * Try new test workflow * fix: run tests in importlib mode * feat: add logits, always return all scores, assume calibrated * fix: line-lengths * fix: most depreciation warnings that distracted from test output * fix: image server for tests * feat: update schema to return algorithms and all scores * feat: return category map data in API response * feat: compress response (getting long!) * fix: save non-moth pipeline addition for another day * fix: tests * feat: make port configurable * feat: update schemas and category map response * fix: don't print the massive output * feat: update schemas and endpoints to match new API * feat: return all algorithm & category map details in /info
1 parent 0f55d0b commit 8e2c885

21 files changed

+1886
-1089
lines changed

.github/workflows/test.yml

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
# This workflow will install Python dependencies, run tests and lint with a single version of Python
2-
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
3-
41
name: Run tests
52

63
on:
@@ -13,20 +10,43 @@ permissions:
1310
contents: read
1411

1512
jobs:
16-
build:
13+
test:
14+
name: Run Python Tests
1715
runs-on: ubuntu-latest
1816

1917
steps:
20-
- uses: actions/checkout@v3
18+
- uses: actions/checkout@v4
19+
2120
- name: Set up Python 3.10
22-
uses: actions/setup-python@main # Need latest version to use pyproject.toml instead of requirements.txt
21+
uses: actions/setup-python@v5
2322
with:
2423
python-version: "3.10"
2524
cache: "pip"
25+
cache-dependency-path: |
26+
poetry.lock
27+
pyproject.toml
28+
29+
- name: Install Poetry
30+
uses: snok/install-poetry@v1
31+
with:
32+
version: latest
33+
virtualenvs-create: true
34+
virtualenvs-in-project: true
35+
36+
- name: Load cached Poetry virtualenv
37+
id: cached-poetry-dependencies
38+
uses: actions/cache@v3
39+
with:
40+
path: .venv
41+
key: venv-${{ runner.os }}-${{ hashFiles('**/poetry.lock') }}
42+
2643
- name: Install dependencies
27-
run: |
28-
python -m pip install --upgrade pip
29-
pip install .
44+
if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true'
45+
run: poetry install --no-interaction
46+
3047
- name: Run tests
3148
run: |
32-
pytest
49+
# Clean any cached Python files before running tests
50+
find . -type d -name "__pycache__" -exec rm -r {} +
51+
find . -type f -name "*.pyc" -delete
52+
poetry run pytest --import-mode=importlib

poetry.lock

Lines changed: 1089 additions & 907 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

trapdata/api/api.py

Lines changed: 181 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
"""
2-
Fast API interface for processing images through the localization and classification pipelines.
2+
Fast API interface for processing images through the localization and classification
3+
pipelines.
34
"""
45

56
import enum
67
import time
78

89
import fastapi
910
import pydantic
10-
from rich import print
11+
from fastapi.middleware.gzip import GZipMiddleware
1112

1213
from ..common.logs import logger # noqa: F401
1314
from . import settings
@@ -23,70 +24,164 @@
2324
MothClassifierUKDenmark,
2425
)
2526
from .models.localization import APIMothDetector
26-
from .schemas import Detection, SourceImage
27+
from .schemas import (
28+
AlgorithmCategoryMapResponse,
29+
AlgorithmConfigResponse,
30+
PipelineConfigResponse,
31+
)
32+
from .schemas import PipelineRequest as PipelineRequest_
33+
from .schemas import PipelineResultsResponse as PipelineResponse_
34+
from .schemas import ProcessingServiceInfoResponse, SourceImage, SourceImageResponse
2735

2836
app = fastapi.FastAPI()
37+
app.add_middleware(GZipMiddleware)
2938

3039

31-
class SourceImageRequest(pydantic.BaseModel):
32-
model_config = pydantic.ConfigDict(extra="ignore")
33-
34-
# @TODO bring over new SourceImage & b64 validation from the lepsAI repo
35-
id: str
36-
url: str
37-
# b64: str | None = None
38-
39-
40-
class SourceImageResponse(pydantic.BaseModel):
41-
model_config = pydantic.ConfigDict(extra="ignore")
42-
43-
id: str
44-
url: str
45-
46-
47-
PIPELINE_CHOICES = {
40+
CLASSIFIER_CHOICES = {
4841
"panama_moths_2023": MothClassifierPanama,
4942
"panama_moths_2024": MothClassifierPanama2024,
5043
"quebec_vermont_moths_2023": MothClassifierQuebecVermont,
5144
"uk_denmark_moths_2023": MothClassifierUKDenmark,
5245
"costa_rica_moths_turing_2024": MothClassifierTuringCostaRica,
5346
"anguilla_moths_turing_2024": MothClassifierTuringAnguilla,
5447
"global_moths_2024": MothClassifierGlobal,
48+
# "moth_binary": MothClassifierBinary,
5549
}
56-
_pipeline_choices = dict(zip(PIPELINE_CHOICES.keys(), list(PIPELINE_CHOICES.keys())))
50+
_classifier_choices = dict(
51+
zip(CLASSIFIER_CHOICES.keys(), list(CLASSIFIER_CHOICES.keys()))
52+
)
53+
5754

55+
PipelineChoice = enum.Enum("PipelineChoice", _classifier_choices)
5856

59-
PipelineChoice = enum.Enum("PipelineChoice", _pipeline_choices)
6057

58+
def make_category_map_response(
59+
model: APIMothDetector | APIMothClassifier,
60+
default_taxon_rank: str = "SPECIES",
61+
) -> AlgorithmCategoryMapResponse:
62+
categories_sorted_by_index = sorted(model.category_map.items(), key=lambda x: x[0])
63+
# as list of dicts:
64+
categories_sorted_by_index = [
65+
{
66+
"index": index,
67+
"label": label,
68+
"taxon_rank": default_taxon_rank,
69+
}
70+
for index, label in categories_sorted_by_index
71+
]
72+
label_strings_sorted_by_index = [cat["label"] for cat in categories_sorted_by_index]
73+
return AlgorithmCategoryMapResponse(
74+
data=categories_sorted_by_index,
75+
labels=label_strings_sorted_by_index,
76+
uri=model.labels_path,
77+
)
6178

62-
class PipelineRequest(pydantic.BaseModel):
63-
pipeline: PipelineChoice
64-
source_images: list[SourceImageRequest]
6579

80+
def make_algorithm_response(
81+
model: APIMothDetector | APIMothClassifier,
82+
) -> AlgorithmConfigResponse:
6683

67-
class PipelineResponse(pydantic.BaseModel):
68-
pipeline: PipelineChoice
69-
total_time: float
70-
source_images: list[SourceImageResponse]
71-
detections: list[Detection]
84+
category_map = make_category_map_response(model) if model.category_map else None
85+
return AlgorithmConfigResponse(
86+
name=model.name,
87+
key=model.get_key(),
88+
task_type=model.task_type,
89+
description=model.description,
90+
category_map=category_map,
91+
uri=model.weights_path,
92+
)
93+
94+
95+
def make_algorithm_config_response(
96+
model: APIMothDetector | APIMothClassifier,
97+
) -> AlgorithmConfigResponse:
98+
category_map = make_category_map_response(model)
99+
return AlgorithmConfigResponse(
100+
name=model.name,
101+
key=model.get_key(),
102+
task_type=model.task_type,
103+
description=model.description,
104+
category_map=category_map,
105+
uri=model.weights_path,
106+
)
107+
108+
109+
def make_pipeline_config_response(
110+
Classifier: type[APIMothClassifier],
111+
) -> PipelineConfigResponse:
112+
detector = APIMothDetector(
113+
source_images=[],
114+
)
115+
116+
binary_classifier = MothClassifierBinary(
117+
source_images=[],
118+
detections=[],
119+
)
120+
121+
classifier = Classifier(
122+
source_images=[],
123+
detections=[],
124+
batch_size=settings.classification_batch_size,
125+
num_workers=settings.num_workers,
126+
terminal=True,
127+
)
128+
129+
return PipelineConfigResponse(
130+
name=classifier.name,
131+
slug=classifier.get_key(),
132+
description=classifier.description,
133+
version=1,
134+
algorithms=[
135+
make_algorithm_config_response(detector),
136+
make_algorithm_config_response(binary_classifier),
137+
make_algorithm_config_response(classifier),
138+
],
139+
)
140+
141+
142+
# @TODO This requires loading all models into memory! Can we avoid this?
143+
PIPELINE_CONFIGS = [
144+
make_pipeline_config_response(classifier_class)
145+
for classifier_class in CLASSIFIER_CHOICES.values()
146+
]
147+
148+
149+
class PipelineRequest(PipelineRequest_):
150+
pipeline: PipelineChoice = pydantic.Field(
151+
description=PipelineRequest_.model_fields["pipeline"].description,
152+
examples=list(_classifier_choices.keys()),
153+
)
154+
155+
156+
class PipelineResponse(PipelineResponse_):
157+
pipeline: PipelineChoice = pydantic.Field(
158+
PipelineChoice,
159+
description=PipelineResponse_.model_fields["pipeline"].description,
160+
examples=list(_classifier_choices.keys()),
161+
)
72162

73163

74164
@app.get("/")
75165
async def root():
76166
return fastapi.responses.RedirectResponse("/docs")
77167

78168

79-
@app.post("/pipeline/process")
80-
@app.post("/pipeline/process/")
169+
@app.post(
170+
"/pipeline/process/", deprecated=True, tags=["services"]
171+
) # old endpoint, deprecated, remove after jan 2025
172+
@app.post("/process/", tags=["services"]) # new endpoint
81173
async def process(data: PipelineRequest) -> PipelineResponse:
174+
algorithms_used: dict[str, AlgorithmConfigResponse] = {}
175+
82176
# Ensure that the source images are unique, filter out duplicates
83177
source_images_index = {
84178
source_image.id: source_image for source_image in data.source_images
85179
}
86180
incoming_source_images = list(source_images_index.values())
87181
if len(incoming_source_images) != len(data.source_images):
88182
logger.warning(
89-
f"Removed {len(data.source_images) - len(incoming_source_images)} duplicate source images"
183+
f"Removed {len(data.source_images) - len(incoming_source_images)} "
184+
"duplicate source images"
90185
)
91186

92187
source_image_results = [
@@ -106,6 +201,7 @@ async def process(data: PipelineRequest) -> PipelineResponse:
106201
)
107202
detector_results = detector.run()
108203
num_pre_filter = len(detector_results)
204+
algorithms_used[detector.get_key()] = make_algorithm_response(detector)
109205

110206
filter = MothClassifierBinary(
111207
source_images=source_images,
@@ -114,15 +210,15 @@ async def process(data: PipelineRequest) -> PipelineResponse:
114210
num_workers=settings.num_workers,
115211
# single=True if len(detector_results) == 1 else False,
116212
single=True, # @TODO solve issues with reading images in multiprocessing
117-
filter_results=False, # Only save results with the positive_binary_label, @TODO make this configurable from request
213+
terminal=False,
118214
)
119215
filter.run()
120-
# all_binary_classifications = filter.results
216+
algorithms_used[filter.get_key()] = make_algorithm_response(filter)
121217

122218
# Compare num detections with num moth detections
123219
num_post_filter = len(filter.results)
124220
logger.info(
125-
f"Binary classifier returned {num_post_filter} out of {num_pre_filter} detections"
221+
f"Binary classifier returned {num_post_filter} of {num_pre_filter} detections"
126222
)
127223

128224
# Filter results based on positive_binary_label
@@ -137,21 +233,25 @@ async def process(data: PipelineRequest) -> PipelineResponse:
137233
break
138234

139235
logger.info(
140-
f"Sending {len(moth_detections)} out of {num_pre_filter} detections to the classifier"
236+
f"Sending {len(moth_detections)} of {num_pre_filter} "
237+
"detections to the classifier"
141238
)
142239

143-
Classifier = PIPELINE_CHOICES[data.pipeline.value]
240+
Classifier = CLASSIFIER_CHOICES[str(data.pipeline)]
144241
classifier: APIMothClassifier = Classifier(
145242
source_images=source_images,
146243
detections=moth_detections,
147244
batch_size=settings.classification_batch_size,
148245
num_workers=settings.num_workers,
149246
# single=True if len(filtered_detections) == 1 else False,
150247
single=True, # @TODO solve issues with reading images in multiprocessing
248+
example_config_param=data.config.example_config_param,
249+
terminal=True,
151250
)
152251
classifier.run()
153252
end_time = time.time()
154253
seconds_elapsed = float(end_time - start_time)
254+
algorithms_used[classifier.get_key()] = make_algorithm_response(classifier)
155255

156256
# Return all detections, including those that were not classified as moths
157257
all_detections = classifier.results + non_moth_detections
@@ -160,23 +260,64 @@ async def process(data: PipelineRequest) -> PipelineResponse:
160260
f"Processed {len(source_images)} images in {seconds_elapsed:.2f} seconds"
161261
)
162262
logger.info(f"Returning {len(all_detections)} detections")
163-
print(all_detections)
263+
# print(all_detections)
164264

165265
# If the number of detections is greater than 100, its suspicious. Log it.
166266
if len(all_detections) > 100:
167267
logger.warning(
168-
f"Detected {len(all_detections)} detections. This is suspicious and may contain duplicates."
268+
f"Detected {len(all_detections)} detections. "
269+
"This is suspicious and may contain duplicates."
169270
)
170271

171272
response = PipelineResponse(
172273
pipeline=data.pipeline,
274+
algorithms=algorithms_used,
173275
source_images=source_image_results,
174276
detections=all_detections,
175277
total_time=seconds_elapsed,
176278
)
177279
return response
178280

179281

282+
@app.get("/info", tags=["services"])
283+
async def info() -> ProcessingServiceInfoResponse:
284+
info = ProcessingServiceInfoResponse(
285+
name="Antenna Inference API",
286+
description=(
287+
"The primary endpoint for processing images for the Antenna platform. "
288+
"This API provides access to multiple detection and classification "
289+
"algorithms by multiple labs for processing images of moths."
290+
),
291+
pipelines=PIPELINE_CONFIGS,
292+
# algorithms=list(algorithm_choices.values()),
293+
)
294+
return info
295+
296+
297+
# Check if the server is online
298+
@app.get("/livez", tags=["health checks"])
299+
async def livez():
300+
return fastapi.responses.JSONResponse(status_code=200, content={"status": True})
301+
302+
303+
# Check if the pipelines are ready to process data
304+
@app.get("/readyz", tags=["health checks"])
305+
async def readyz():
306+
"""
307+
Check if the server is ready to process data.
308+
309+
Returns a list of pipeline slugs that are online and ready to process data.
310+
@TODO may need to simplify this to just return True/False. Pipeline algorithms will
311+
likely be loaded into memory on-demand when the pipeline is selected.
312+
"""
313+
if _classifier_choices:
314+
return fastapi.responses.JSONResponse(
315+
status_code=200, content={"status": list(_classifier_choices.keys())}
316+
)
317+
else:
318+
return fastapi.responses.JSONResponse(status_code=503, content={"status": []})
319+
320+
180321
# Future methods
181322

182323
# batch processing

0 commit comments

Comments
 (0)