From 97a1cc2ead707fb818caf98a931f1c1cb44bd64e Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Tue, 26 Nov 2024 15:49:36 -0800 Subject: [PATCH] feat: add logits, always return all scores, assume calibrated --- trapdata/api/api.py | 8 ++-- trapdata/api/models/classification.py | 53 +++++++++++++++++---------- trapdata/api/schemas.py | 1 + trapdata/api/tests/test_api.py | 2 +- 4 files changed, 40 insertions(+), 24 deletions(-) diff --git a/trapdata/api/api.py b/trapdata/api/api.py index 234a138..f247128 100644 --- a/trapdata/api/api.py +++ b/trapdata/api/api.py @@ -72,9 +72,9 @@ class PipelineConfig(pydantic.BaseModel): Configuration for the processing pipeline. """ - max_predictions_per_classification: int | None = pydantic.Field( + example_config_param: int | None = pydantic.Field( default=None, - description="Number of predictions to return for each classification. If null/None, return all predictions.", + description="Example of a configuration parameter for a pipeline.", examples=[3], ) @@ -84,7 +84,7 @@ class PipelineRequest(pydantic.BaseModel): source_images: list[SourceImageRequest] config: PipelineConfig = pydantic.Field( default=PipelineConfig(), - examples=[PipelineConfig(max_predictions_per_classification=3)], + examples=[PipelineConfig(example_config_param=3)], ) class Config: @@ -179,7 +179,7 @@ async def process(data: PipelineRequest) -> PipelineResponse: num_workers=settings.num_workers, # single=True if len(filtered_detections) == 1 else False, single=True, # @TODO solve issues with reading images in multiprocessing - top_n=data.config.max_predictions_per_classification, + example_config_param=data.config.example_config_param, ) classifier.run() end_time = time.time() diff --git a/trapdata/api/models/classification.py b/trapdata/api/models/classification.py index 6a2998c..3261a1e 100644 --- a/trapdata/api/models/classification.py +++ b/trapdata/api/models/classification.py @@ -30,14 +30,12 @@ def __init__( self, source_images: typing.Iterable[SourceImage], detections: typing.Iterable[Detection], - top_n: int | None = None, *args, **kwargs, ): self.source_images = source_images self.detections = list(detections) self.results: list[Detection] = [] - self.top_n = top_n super().__init__(*args, **kwargs) logger.info( f"Initialized {self.__class__.__name__} with {len(self.detections)} detections" @@ -51,25 +49,40 @@ def get_dataset(self): batch_size=self.batch_size, ) - def post_process_batch(self, output): - predictions = torch.nn.functional.softmax(output, dim=1) + def post_process_batch(self, logits: torch.Tensor): + """ + Return the labels, softmax/calibrated scores, and the original logits for each image in the batch. + """ + predictions = torch.nn.functional.softmax(logits, dim=1) predictions = predictions.cpu().numpy() - # Ensure that top_n is not greater than the number of categories - # (e.g. binary classification will have only 2 categories) - if self.top_n is None: - num_results = predictions.shape[1] - else: - num_results = min(self.top_n, predictions.shape[1]) + indices = np.arange(predictions.shape[1]) + + # @TODO Calibrate the scores here, + scores = predictions - indices = np.argpartition(predictions, -num_results, axis=1)[:, -num_results:] - scores = predictions[np.arange(predictions.shape[0])[:, None], indices] labels = np.array([[self.category_map[i] for i in row] for row in indices]) - result = [list(zip(labels, scores)) for labels, scores in zip(labels, scores)] - result = [sorted(items, key=lambda x: x[1], reverse=True) for items in result] - logger.debug(f"Post-processing result batch: {result}") - return result + return zip(labels, scores, logits) + + def get_best_label(self, predictions): + """ + Convenience method to get the best label from the predictions, which are a list of tuples + in the order of the model's class index, NOT the values. + + This must not modify the predictions list! + + predictions look like: + [ + ('label1', score1, logit1), + ('label2', score2, logit2), + ... + ] + """ + + best_pred = max(predictions, key=lambda x: x[1]) + best_label = best_pred[0] + return best_label def save_results( self, metadata, batch_output, seconds_per_item, *args, **kwargs @@ -81,10 +94,12 @@ def save_results( ): detection = self.detections[detection_idx] assert detection.source_image_id == image_id + labels, scores, logits = zip(*predictions) classification = Classification( - classification=predictions[0][0], - labels=[label for (label, _) in list(predictions)], - scores=[score for (_, score) in list(predictions)], + classification=self.get_best_label(predictions), + labels=labels, # @TODO move this to the Algorithm class instead of repeating it every prediction + scores=scores, + logits=logits, inference_time=seconds_per_item, algorithm=self.name, timestamp=datetime.datetime.now(), diff --git a/trapdata/api/schemas.py b/trapdata/api/schemas.py index fb176d7..878a85b 100644 --- a/trapdata/api/schemas.py +++ b/trapdata/api/schemas.py @@ -70,6 +70,7 @@ class Classification(pydantic.BaseModel): classification: str labels: list[str] = [] scores: list[float] = [] + logits: list[float] = [] inference_time: float | None = None algorithm: str | None = None terminal: bool = True diff --git a/trapdata/api/tests/test_api.py b/trapdata/api/tests/test_api.py index 0c23c7c..a7f81e6 100644 --- a/trapdata/api/tests/test_api.py +++ b/trapdata/api/tests/test_api.py @@ -82,7 +82,7 @@ def test_config_num_classification_predictions(self): def _send_request(max_predictions_per_classification: int | None): config = PipelineConfig( - max_predictions_per_classification=max_predictions_per_classification + example_config_param=max_predictions_per_classification ) pipeline_request = PipelineRequest( pipeline=PipelineChoice[test_pipeline_slug],