Skip to content

Commit

Permalink
feat: update parameter name and API examples
Browse files Browse the repository at this point in the history
  • Loading branch information
mihow committed Nov 23, 2024
1 parent e4e6076 commit 455650e
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 19 deletions.
20 changes: 13 additions & 7 deletions trapdata/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,15 @@ class SourceImageRequest(pydantic.BaseModel):
# @TODO bring over new SourceImage & b64 validation from the lepsAI repo
id: str = pydantic.Field(
description="Unique identifier for the source image. This is returned in the response.",
example="e124f3b4",
examples=["e124f3b4"],
)
url: str = pydantic.Field(
description="URL to the source image. This should be publicly accessible.",
example="https://static.dev.insectai.org/ami-trapdata/vermont/RawImages/LUNA/2022/movement/2022_06_23/20220623050407-00-235.jpg",
description="URL to the source image to be processed.",
examples=[
"https://static.dev.insectai.org/ami-trapdata/vermont/RawImages/LUNA/2022/movement/2022_06_23/20220623050407-00-235.jpg"
],
)
b64: str | None = None
# b64: str | None = None


class SourceImageResponse(pydantic.BaseModel):
Expand Down Expand Up @@ -70,16 +72,20 @@ class PipelineConfig(pydantic.BaseModel):
Configuration for the processing pipeline.
"""

classification_num_predictions: int | None = pydantic.Field(
max_predictions_per_classification: int | None = pydantic.Field(
default=None,
description="Number of predictions to return for each classification. If null/None, return all predictions.",
examples=[3],
)


class PipelineRequest(pydantic.BaseModel):
pipeline: PipelineChoice
source_images: list[SourceImageRequest]
config: PipelineConfig = PipelineConfig()
config: PipelineConfig = pydantic.Field(
default=PipelineConfig(),
examples=[PipelineConfig(max_predictions_per_classification=3)],
)

class Config:
use_enum_values = True
Expand Down Expand Up @@ -173,7 +179,7 @@ async def process(data: PipelineRequest) -> PipelineResponse:
num_workers=settings.num_workers,
# single=True if len(filtered_detections) == 1 else False,
single=True, # @TODO solve issues with reading images in multiprocessing
top_n=data.config.classification_num_predictions,
top_n=data.config.max_predictions_per_classification,
)
classifier.run()
end_time = time.time()
Expand Down
28 changes: 16 additions & 12 deletions trapdata/api/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ def get_test_images(self, subdir: str = "vermont", num: int = 2):
]
return source_images

def get_test_pipeline(self, slug: str = "quebec_vermont_moths_2023"):
def get_test_pipeline(
self, slug: str = "quebec_vermont_moths_2023"
) -> SpeciesClassifier:
pipeline = PIPELINE_CHOICES[slug]
return pipeline

Expand All @@ -69,20 +71,18 @@ def test_pipeline_request(self):

def test_config_num_classification_predictions(self):
"""
Test that the pipeline respects the `classification_num_predictions` configuration.
Test that the pipeline respects the `max_predictions_per_classification` configuration.
If the configuration is set to a number, the pipeline should return that number of labels/scores per prediction.
If the configuration is set to `None`, the pipeline should return all labels/scores per prediction.
"""
test_images = self.get_test_images(num=1)
test_pipeline_slug = "quebec_vermont_moths_2023"
terminal_classifier: SpeciesClassifier = self.get_test_pipeline(
test_pipeline_slug
)
terminal_classifier = self.get_test_pipeline(test_pipeline_slug)

def _send_request(classification_num_predictions: int | None):
def _send_request(max_predictions_per_classification: int | None):
config = PipelineConfig(
classification_num_predictions=classification_num_predictions
max_predictions_per_classification=max_predictions_per_classification
)
pipeline_request = PipelineRequest(
pipeline=PipelineChoice[test_pipeline_slug],
Expand All @@ -102,15 +102,19 @@ def _send_request(classification_num_predictions: int | None):
if classification.terminal
]
for classification in terminal_classifications:
if classification_num_predictions is None:
if max_predictions_per_classification is None:
# Ensure that a score is returned for every possible class
assert len(classification.labels) == terminal_classifier.num_classes
assert len(classification.scores) == terminal_classifier.num_classes
else:
# Ensure that the number of predictions is limited to the number specified
# There may be fewer predictions than the number specified if there are fewer classes.
assert len(classification.labels) <= classification_num_predictions
assert len(classification.scores) <= classification_num_predictions
assert (
len(classification.labels) <= max_predictions_per_classification
)
assert (
len(classification.scores) <= max_predictions_per_classification
)

_send_request(classification_num_predictions=1)
_send_request(classification_num_predictions=None)
_send_request(max_predictions_per_classification=1)
_send_request(max_predictions_per_classification=None)

0 comments on commit 455650e

Please sign in to comment.