Skip to content

Commit e4e6076

Browse files
committed
test: ensure pipeline requests respect the classification_num_predictions config
1 parent d9d9a7f commit e4e6076

File tree

1 file changed

+74
-16
lines changed

1 file changed

+74
-16
lines changed

trapdata/api/tests/test_api.py

Lines changed: 74 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from fastapi.testclient import TestClient
66

77
from trapdata.api.api import (
8+
PIPELINE_CHOICES,
89
PipelineChoice,
910
PipelineConfig,
1011
PipelineRequest,
@@ -13,6 +14,7 @@
1314
app,
1415
)
1516
from trapdata.api.tests.image_server import StaticFileTestServer
17+
from trapdata.ml.models.classification import SpeciesClassifier
1618
from trapdata.tests import TEST_IMAGES_BASE_PATH
1719

1820
logging.basicConfig(level=logging.INFO)
@@ -22,7 +24,7 @@
2224
class TestInferenceAPI(TestCase):
2325
@classmethod
2426
def setUpClass(cls):
25-
cls.test_images_dir = pathlib.Path(TEST_IMAGES_BASE_PATH) / "vermont"
27+
cls.test_images_dir = pathlib.Path(TEST_IMAGES_BASE_PATH)
2628
if not cls.test_images_dir.exists():
2729
raise FileNotFoundError(
2830
f"Test images directory not found: {cls.test_images_dir}"
@@ -34,25 +36,81 @@ def setUpClass(cls):
3436
def setUp(self):
3537
self.file_server = StaticFileTestServer(self.test_images_dir)
3638

39+
def get_test_images(self, subdir: str = "vermont", num: int = 2):
40+
images_dir = self.test_images_dir / subdir
41+
source_image_urls = [
42+
self.file_server.get_url(f.relative_to(images_dir))
43+
for f in self.test_images_dir.glob("*.jpg")
44+
][:num]
45+
source_images = [
46+
SourceImageRequest(id=str(i), url=url)
47+
for i, url in enumerate(source_image_urls)
48+
]
49+
return source_images
50+
51+
def get_test_pipeline(self, slug: str = "quebec_vermont_moths_2023"):
52+
pipeline = PIPELINE_CHOICES[slug]
53+
return pipeline
54+
3755
def test_pipeline_request(self):
56+
"""
57+
Ensure that the pipeline accepts a valid request and returns a valid response.
58+
"""
59+
pipeline_request = PipelineRequest(
60+
pipeline=PipelineChoice["quebec_vermont_moths_2023"],
61+
source_images=self.get_test_images(num=2),
62+
)
3863
with self.file_server:
39-
num_images = 2
40-
source_image_urls = [
41-
self.file_server.get_url(f.relative_to(self.test_images_dir))
42-
for f in self.test_images_dir.glob("*.jpg")
43-
][:num_images]
44-
source_images = [
45-
SourceImageRequest(id=str(i), url=url)
46-
for i, url in enumerate(source_image_urls)
47-
]
48-
pipeline_request = PipelineRequest(
49-
pipeline=PipelineChoice["quebec_vermont_moths_2023"],
50-
source_images=source_images,
51-
config=PipelineConfig(classification_num_predictions=1),
52-
)
5364
response = self.client.post(
5465
"/pipeline/process", json=pipeline_request.dict()
5566
)
5667
assert response.status_code == 200
68+
PipelineResponse(**response.json())
69+
70+
def test_config_num_classification_predictions(self):
71+
"""
72+
Test that the pipeline respects the `classification_num_predictions` configuration.
73+
74+
If the configuration is set to a number, the pipeline should return that number of labels/scores per prediction.
75+
If the configuration is set to `None`, the pipeline should return all labels/scores per prediction.
76+
"""
77+
test_images = self.get_test_images(num=1)
78+
test_pipeline_slug = "quebec_vermont_moths_2023"
79+
terminal_classifier: SpeciesClassifier = self.get_test_pipeline(
80+
test_pipeline_slug
81+
)
82+
83+
def _send_request(classification_num_predictions: int | None):
84+
config = PipelineConfig(
85+
classification_num_predictions=classification_num_predictions
86+
)
87+
pipeline_request = PipelineRequest(
88+
pipeline=PipelineChoice[test_pipeline_slug],
89+
source_images=test_images,
90+
config=config,
91+
)
92+
with self.file_server:
93+
response = self.client.post(
94+
"/pipeline/process", json=pipeline_request.dict()
95+
)
96+
assert response.status_code == 200
5797
pipeline_response = PipelineResponse(**response.json())
58-
assert len(pipeline_response.detections) > 0
98+
terminal_classifications = [
99+
classification
100+
for detection in pipeline_response.detections
101+
for classification in detection.classifications
102+
if classification.terminal
103+
]
104+
for classification in terminal_classifications:
105+
if classification_num_predictions is None:
106+
# Ensure that a score is returned for every possible class
107+
assert len(classification.labels) == terminal_classifier.num_classes
108+
assert len(classification.scores) == terminal_classifier.num_classes
109+
else:
110+
# Ensure that the number of predictions is limited to the number specified
111+
# There may be fewer predictions than the number specified if there are fewer classes.
112+
assert len(classification.labels) <= classification_num_predictions
113+
assert len(classification.scores) <= classification_num_predictions
114+
115+
_send_request(classification_num_predictions=1)
116+
_send_request(classification_num_predictions=None)

0 commit comments

Comments
 (0)