Skip to content

Commit

Permalink
feat: add logits, always return all scores, assume calibrated
Browse files Browse the repository at this point in the history
  • Loading branch information
mihow committed Nov 26, 2024
1 parent b5d70b4 commit 97a1cc2
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 24 deletions.
8 changes: 4 additions & 4 deletions trapdata/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ class PipelineConfig(pydantic.BaseModel):
Configuration for the processing pipeline.
"""

max_predictions_per_classification: int | None = pydantic.Field(
example_config_param: int | None = pydantic.Field(
default=None,
description="Number of predictions to return for each classification. If null/None, return all predictions.",
description="Example of a configuration parameter for a pipeline.",
examples=[3],
)

Expand All @@ -84,7 +84,7 @@ class PipelineRequest(pydantic.BaseModel):
source_images: list[SourceImageRequest]
config: PipelineConfig = pydantic.Field(
default=PipelineConfig(),
examples=[PipelineConfig(max_predictions_per_classification=3)],
examples=[PipelineConfig(example_config_param=3)],
)

class Config:
Expand Down Expand Up @@ -179,7 +179,7 @@ async def process(data: PipelineRequest) -> PipelineResponse:
num_workers=settings.num_workers,
# single=True if len(filtered_detections) == 1 else False,
single=True, # @TODO solve issues with reading images in multiprocessing
top_n=data.config.max_predictions_per_classification,
example_config_param=data.config.example_config_param,
)
classifier.run()
end_time = time.time()
Expand Down
53 changes: 34 additions & 19 deletions trapdata/api/models/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,12 @@ def __init__(
self,
source_images: typing.Iterable[SourceImage],
detections: typing.Iterable[Detection],
top_n: int | None = None,
*args,
**kwargs,
):
self.source_images = source_images
self.detections = list(detections)
self.results: list[Detection] = []
self.top_n = top_n
super().__init__(*args, **kwargs)
logger.info(
f"Initialized {self.__class__.__name__} with {len(self.detections)} detections"
Expand All @@ -51,25 +49,40 @@ def get_dataset(self):
batch_size=self.batch_size,
)

def post_process_batch(self, output):
predictions = torch.nn.functional.softmax(output, dim=1)
def post_process_batch(self, logits: torch.Tensor):
"""
Return the labels, softmax/calibrated scores, and the original logits for each image in the batch.
"""
predictions = torch.nn.functional.softmax(logits, dim=1)
predictions = predictions.cpu().numpy()

# Ensure that top_n is not greater than the number of categories
# (e.g. binary classification will have only 2 categories)
if self.top_n is None:
num_results = predictions.shape[1]
else:
num_results = min(self.top_n, predictions.shape[1])
indices = np.arange(predictions.shape[1])

# @TODO Calibrate the scores here,
scores = predictions

indices = np.argpartition(predictions, -num_results, axis=1)[:, -num_results:]
scores = predictions[np.arange(predictions.shape[0])[:, None], indices]
labels = np.array([[self.category_map[i] for i in row] for row in indices])

result = [list(zip(labels, scores)) for labels, scores in zip(labels, scores)]
result = [sorted(items, key=lambda x: x[1], reverse=True) for items in result]
logger.debug(f"Post-processing result batch: {result}")
return result
return zip(labels, scores, logits)

def get_best_label(self, predictions):
"""
Convenience method to get the best label from the predictions, which are a list of tuples
in the order of the model's class index, NOT the values.
This must not modify the predictions list!
predictions look like:
[
('label1', score1, logit1),
('label2', score2, logit2),
...
]
"""

best_pred = max(predictions, key=lambda x: x[1])
best_label = best_pred[0]
return best_label

def save_results(
self, metadata, batch_output, seconds_per_item, *args, **kwargs
Expand All @@ -81,10 +94,12 @@ def save_results(
):
detection = self.detections[detection_idx]
assert detection.source_image_id == image_id
labels, scores, logits = zip(*predictions)
classification = Classification(
classification=predictions[0][0],
labels=[label for (label, _) in list(predictions)],
scores=[score for (_, score) in list(predictions)],
classification=self.get_best_label(predictions),
labels=labels, # @TODO move this to the Algorithm class instead of repeating it every prediction
scores=scores,
logits=logits,
inference_time=seconds_per_item,
algorithm=self.name,
timestamp=datetime.datetime.now(),
Expand Down
1 change: 1 addition & 0 deletions trapdata/api/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class Classification(pydantic.BaseModel):
classification: str
labels: list[str] = []
scores: list[float] = []
logits: list[float] = []
inference_time: float | None = None
algorithm: str | None = None
terminal: bool = True
Expand Down
2 changes: 1 addition & 1 deletion trapdata/api/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def test_config_num_classification_predictions(self):

def _send_request(max_predictions_per_classification: int | None):
config = PipelineConfig(
max_predictions_per_classification=max_predictions_per_classification
example_config_param=max_predictions_per_classification
)
pipeline_request = PipelineRequest(
pipeline=PipelineChoice[test_pipeline_slug],
Expand Down

0 comments on commit 97a1cc2

Please sign in to comment.