5
5
from fastapi .testclient import TestClient
6
6
7
7
from trapdata .api .api import (
8
+ PIPELINE_CHOICES ,
8
9
PipelineChoice ,
9
10
PipelineConfig ,
10
11
PipelineRequest ,
13
14
app ,
14
15
)
15
16
from trapdata .api .tests .image_server import StaticFileTestServer
17
+ from trapdata .ml .models .classification import SpeciesClassifier
16
18
from trapdata .tests import TEST_IMAGES_BASE_PATH
17
19
18
20
logging .basicConfig (level = logging .INFO )
22
24
class TestInferenceAPI (TestCase ):
23
25
@classmethod
24
26
def setUpClass (cls ):
25
- cls .test_images_dir = pathlib .Path (TEST_IMAGES_BASE_PATH ) / "vermont"
27
+ cls .test_images_dir = pathlib .Path (TEST_IMAGES_BASE_PATH )
26
28
if not cls .test_images_dir .exists ():
27
29
raise FileNotFoundError (
28
30
f"Test images directory not found: { cls .test_images_dir } "
@@ -34,25 +36,81 @@ def setUpClass(cls):
34
36
def setUp (self ):
35
37
self .file_server = StaticFileTestServer (self .test_images_dir )
36
38
39
+ def get_test_images (self , subdir : str = "vermont" , num : int = 2 ):
40
+ images_dir = self .test_images_dir / subdir
41
+ source_image_urls = [
42
+ self .file_server .get_url (f .relative_to (images_dir ))
43
+ for f in self .test_images_dir .glob ("*.jpg" )
44
+ ][:num ]
45
+ source_images = [
46
+ SourceImageRequest (id = str (i ), url = url )
47
+ for i , url in enumerate (source_image_urls )
48
+ ]
49
+ return source_images
50
+
51
+ def get_test_pipeline (self , slug : str = "quebec_vermont_moths_2023" ):
52
+ pipeline = PIPELINE_CHOICES [slug ]
53
+ return pipeline
54
+
37
55
def test_pipeline_request (self ):
56
+ """
57
+ Ensure that the pipeline accepts a valid request and returns a valid response.
58
+ """
59
+ pipeline_request = PipelineRequest (
60
+ pipeline = PipelineChoice ["quebec_vermont_moths_2023" ],
61
+ source_images = self .get_test_images (num = 2 ),
62
+ )
38
63
with self .file_server :
39
- num_images = 2
40
- source_image_urls = [
41
- self .file_server .get_url (f .relative_to (self .test_images_dir ))
42
- for f in self .test_images_dir .glob ("*.jpg" )
43
- ][:num_images ]
44
- source_images = [
45
- SourceImageRequest (id = str (i ), url = url )
46
- for i , url in enumerate (source_image_urls )
47
- ]
48
- pipeline_request = PipelineRequest (
49
- pipeline = PipelineChoice ["quebec_vermont_moths_2023" ],
50
- source_images = source_images ,
51
- config = PipelineConfig (classification_num_predictions = 1 ),
52
- )
53
64
response = self .client .post (
54
65
"/pipeline/process" , json = pipeline_request .dict ()
55
66
)
56
67
assert response .status_code == 200
68
+ PipelineResponse (** response .json ())
69
+
70
+ def test_config_num_classification_predictions (self ):
71
+ """
72
+ Test that the pipeline respects the `classification_num_predictions` configuration.
73
+
74
+ If the configuration is set to a number, the pipeline should return that number of labels/scores per prediction.
75
+ If the configuration is set to `None`, the pipeline should return all labels/scores per prediction.
76
+ """
77
+ test_images = self .get_test_images (num = 1 )
78
+ test_pipeline_slug = "quebec_vermont_moths_2023"
79
+ terminal_classifier : SpeciesClassifier = self .get_test_pipeline (
80
+ test_pipeline_slug
81
+ )
82
+
83
+ def _send_request (classification_num_predictions : int | None ):
84
+ config = PipelineConfig (
85
+ classification_num_predictions = classification_num_predictions
86
+ )
87
+ pipeline_request = PipelineRequest (
88
+ pipeline = PipelineChoice [test_pipeline_slug ],
89
+ source_images = test_images ,
90
+ config = config ,
91
+ )
92
+ with self .file_server :
93
+ response = self .client .post (
94
+ "/pipeline/process" , json = pipeline_request .dict ()
95
+ )
96
+ assert response .status_code == 200
57
97
pipeline_response = PipelineResponse (** response .json ())
58
- assert len (pipeline_response .detections ) > 0
98
+ terminal_classifications = [
99
+ classification
100
+ for detection in pipeline_response .detections
101
+ for classification in detection .classifications
102
+ if classification .terminal
103
+ ]
104
+ for classification in terminal_classifications :
105
+ if classification_num_predictions is None :
106
+ # Ensure that a score is returned for every possible class
107
+ assert len (classification .labels ) == terminal_classifier .num_classes
108
+ assert len (classification .scores ) == terminal_classifier .num_classes
109
+ else :
110
+ # Ensure that the number of predictions is limited to the number specified
111
+ # There may be fewer predictions than the number specified if there are fewer classes.
112
+ assert len (classification .labels ) <= classification_num_predictions
113
+ assert len (classification .scores ) <= classification_num_predictions
114
+
115
+ _send_request (classification_num_predictions = 1 )
116
+ _send_request (classification_num_predictions = None )
0 commit comments