Skip to content

Commit

Permalink
feat: update schema to return algorithms and all scores
Browse files Browse the repository at this point in the history
  • Loading branch information
mihow committed Dec 6, 2024
1 parent 8cc7999 commit 65237cc
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 140 deletions.
87 changes: 30 additions & 57 deletions trapdata/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,38 +24,14 @@
MothClassifierUKDenmark,
)
from .models.localization import APIMothDetector
from .schemas import DetectionResponse, SourceImage
from .schemas import AlgorithmResponse
from .schemas import PipelineRequest as PipelineRequest_
from .schemas import PipelineResponse as PipelineResponse_
from .schemas import SourceImage, SourceImageResponse

app = fastapi.FastAPI()


class SourceImageRequest(pydantic.BaseModel):
model_config = pydantic.ConfigDict(extra="ignore")

# @TODO bring over new SourceImage & b64 validation from the lepsAI repo
id: str = pydantic.Field(
description=(
"Unique identifier for the source image. This is returned in the response."
),
examples=["e124f3b4"],
)
url: str = pydantic.Field(
description="URL to the source image to be processed.",
examples=[
"https://static.dev.insectai.org/ami-trapdata/"
"vermont/RawImages/LUNA/2022/movement/2022_06_23/20220623050407-00-235.jpg"
],
)
# b64: str | None = None


class SourceImageResponse(pydantic.BaseModel):
model_config = pydantic.ConfigDict(extra="ignore")

id: str
url: str


PIPELINE_CHOICES = {
"panama_moths_2023": MothClassifierPanama,
"panama_moths_2024": MothClassifierPanama2024,
Expand All @@ -64,44 +40,39 @@ class SourceImageResponse(pydantic.BaseModel):
"costa_rica_moths_turing_2024": MothClassifierTuringCostaRica,
"anguilla_moths_turing_2024": MothClassifierTuringAnguilla,
"global_moths_2024": MothClassifierGlobal,
"moth_binary": MothClassifierBinary,
}
_pipeline_choices = dict(zip(PIPELINE_CHOICES.keys(), list(PIPELINE_CHOICES.keys())))


PipelineChoice = enum.Enum("PipelineChoice", _pipeline_choices)


class PipelineConfig(pydantic.BaseModel):
"""
Configuration for the processing pipeline.
"""

example_config_param: int | None = pydantic.Field(
default=None,
description="Example of a configuration parameter for a pipeline.",
examples=[3],
def make_algorithm_response(
Model: type[APIMothDetector] | type[APIMothClassifier],
) -> AlgorithmResponse:
return AlgorithmResponse(
name=Model.name,
key=Model.get_key(),
task_type=Model.type,
description=Model.description,
)


class PipelineRequest(pydantic.BaseModel):
model_config = pydantic.ConfigDict(use_enum_values=True)

pipeline: PipelineChoice
source_images: list[SourceImageRequest]
config: PipelineConfig = pydantic.Field(
default=PipelineConfig(),
examples=[PipelineConfig(example_config_param=3)],
class PipelineRequest(PipelineRequest_):
pipeline: PipelineChoice = pydantic.Field(
PipelineChoice,
description=PipelineRequest_.model_fields["pipeline"].description,
examples=list(_pipeline_choices.keys()),
)


class PipelineResponse(pydantic.BaseModel):
model_config = pydantic.ConfigDict(use_enum_values=True)

pipeline: PipelineChoice
total_time: float
source_images: list[SourceImageResponse]
detections: list[DetectionResponse]
config: PipelineConfig = PipelineConfig()
class PipelineResponse(PipelineResponse_):
pipeline: PipelineChoice = pydantic.Field(
PipelineChoice,
description=PipelineResponse_.model_fields["pipeline"].description,
examples=list(_pipeline_choices.keys()),
)


@app.get("/")
Expand All @@ -112,6 +83,8 @@ async def root():
@app.post("/pipeline/process")
@app.post("/pipeline/process/")
async def process(data: PipelineRequest) -> PipelineResponse:
algorithms_used: dict[str, AlgorithmResponse] = {}

# Ensure that the source images are unique, filter out duplicates
source_images_index = {
source_image.id: source_image for source_image in data.source_images
Expand Down Expand Up @@ -140,6 +113,7 @@ async def process(data: PipelineRequest) -> PipelineResponse:
)
detector_results = detector.run()
num_pre_filter = len(detector_results)
algorithms_used[detector.get_key()] = make_algorithm_response(APIMothDetector)

filter = MothClassifierBinary(
source_images=source_images,
Expand All @@ -148,12 +122,9 @@ async def process(data: PipelineRequest) -> PipelineResponse:
num_workers=settings.num_workers,
# single=True if len(detector_results) == 1 else False,
single=True, # @TODO solve issues with reading images in multiprocessing
# Only save results with the positive_binary_label,
# @TODO make this configurable from request
filter_results=False,
)
filter.run()
# all_binary_classifications = filter.results
algorithms_used[filter.get_key()] = make_algorithm_response(MothClassifierBinary)

# Compare num detections with num moth detections
num_post_filter = len(filter.results)
Expand Down Expand Up @@ -190,6 +161,7 @@ async def process(data: PipelineRequest) -> PipelineResponse:
classifier.run()
end_time = time.time()
seconds_elapsed = float(end_time - start_time)
algorithms_used[classifier.get_key()] = make_algorithm_response(Classifier)

# Return all detections, including those that were not classified as moths
all_detections = classifier.results + non_moth_detections
Expand All @@ -209,6 +181,7 @@ async def process(data: PipelineRequest) -> PipelineResponse:

response = PipelineResponse(
pipeline=data.pipeline,
algorithms=algorithms_used,
source_images=source_image_results,
detections=all_detections,
total_time=seconds_elapsed,
Expand Down
87 changes: 36 additions & 51 deletions trapdata/api/models/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,27 +18,37 @@
)

from ..datasets import ClassificationImageDataset
from ..schemas import ClassificationResponse, DetectionResponse, SourceImage
from ..schemas import (
AlgorithmReference,
ClassificationResponse,
DetectionResponse,
SourceImage,
)
from .base import APIInferenceBaseClass


class APIMothClassifier(
APIInferenceBaseClass,
InferenceBaseClass,
):
type = "classification"

def __init__(
self,
source_images: typing.Iterable[SourceImage],
detections: typing.Iterable[DetectionResponse],
terminal: bool = True,
*args,
**kwargs,
):
self.source_images = source_images
self.detections = list(detections)
self.terminal = terminal
self.results: list[DetectionResponse] = []
super().__init__(*args, **kwargs)
logger.info(
f"Initialized {self.__class__.__name__} with {len(self.detections)} detections"
f"Initialized {self.__class__.__name__} with {len(self.detections)} "
"detections"
)

def get_dataset(self):
Expand All @@ -51,19 +61,25 @@ def get_dataset(self):

def post_process_batch(self, logits: torch.Tensor):
"""
Return the labels, softmax/calibrated scores, and the original logits for each image in the batch.
Return the labels, softmax/calibrated scores, and the original logits for
each image in the batch.
Almost like the base class method, but we need to return the logits as well.
"""
predictions = torch.nn.functional.softmax(logits, dim=1)
predictions = predictions.cpu().numpy()

indices = np.arange(predictions.shape[1])

# @TODO Calibrate the scores here,
scores = predictions
batch_results = []
for pred in predictions:
# Get all class indices and their corresponding scores
class_indices = np.arange(len(pred))
scores = pred
labels = [self.category_map[i] for i in class_indices]
batch_results.append(list(zip(labels, scores, pred)))

labels = np.array([[self.category_map[i] for i in row] for row in indices])
logger.debug(f"Post-processing result batch: {batch_results}")

return zip(labels, scores, logits)
return batch_results

def get_best_label(self, predictions):
"""
Expand All @@ -79,7 +95,6 @@ def get_best_label(self, predictions):
...
]
"""

best_pred = max(predictions, key=lambda x: x[1])
best_label = best_pred[0]
return best_label
Expand All @@ -94,15 +109,15 @@ def save_results(
):
detection = self.detections[detection_idx]
assert detection.source_image_id == image_id
labels, scores, logits = zip(*predictions)
_labels, scores, logits = zip(*predictions)
classification = ClassificationResponse(
classification=self.get_best_label(predictions),
labels=labels, # @TODO move this to the Algorithm class instead of repeating it every prediction
scores=scores,
logits=logits,
inference_time=seconds_per_item,
algorithm=self.name,
algorithm=AlgorithmReference(name=self.name, key=self.get_key()),
timestamp=datetime.datetime.now(),
terminal=self.terminal,
)
self.update_classification(detection, classification)

Expand All @@ -115,60 +130,30 @@ def update_classification(
) -> None:
# Remove all existing classifications from this algorithm
detection.classifications = [
c for c in detection.classifications if c.algorithm != self.name
c for c in detection.classifications if c.algorithm.name != self.name
]
# Add the new classification for this algorithm
detection.classifications.append(new_classification)
logger.debug(
f"Updated classification for detection {detection.bbox}. Total classifications: {len(detection.classifications)}"
f"Updated classification for detection {detection.bbox}. "
f"Total classifications: {len(detection.classifications)}"
)

def run(self) -> list[DetectionResponse]:
logger.info(
f"Starting {self.__class__.__name__} run with {len(self.results)} detections"
f"Starting {self.__class__.__name__} run with {len(self.results)} "
"detections"
)
super().run()
logger.info(
f"Finished {self.__class__.__name__} run. Processed {len(self.results)} detections"
f"Finished {self.__class__.__name__} run. "
f"Processed {len(self.results)} detections"
)
return self.results


class MothClassifierBinary(APIMothClassifier, MothNonMothClassifier):
def __init__(self, *args, **kwargs):
self.filter_results = kwargs.get("filter_results", True)
super().__init__(*args, **kwargs)

def save_results(
self, metadata, batch_output, seconds_per_item, *args, **kwargs
) -> list[DetectionResponse]:
"""
Override the base class method to save only the results that have the
label we are interested in.
"""
logger.info(f"Saving {len(batch_output)} detections with classifications")
image_ids = metadata[0]
detection_idxes = metadata[1]
for image_id, detection_idx, predictions in zip(
image_ids, detection_idxes, batch_output
):
detection = self.detections[detection_idx]
assert detection.source_image_id == image_id
classification = ClassificationResponse(
classification=predictions[0][0],
labels=[label for (label, _) in list(predictions)],
scores=[score for (_, score) in list(predictions)],
inference_time=seconds_per_item,
algorithm=self.name,
timestamp=datetime.datetime.now(),
# Specific to binary classification / the filter model
terminal=False,
)
self.update_classification(detection, classification)

self.results = self.detections
logger.info(f"Saving {len(self.results)} detections with classifications")
return self.results
pass


class MothClassifierPanama(
Expand Down
4 changes: 2 additions & 2 deletions trapdata/api/models/localization.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from trapdata.ml.models.localization import MothObjectDetector_FasterRCNN_2023

from ..datasets import LocalizationImageDataset
from ..schemas import BoundingBox, DetectionResponse, SourceImage
from ..schemas import AlgorithmReference, BoundingBox, DetectionResponse, SourceImage
from .base import APIInferenceBaseClass


Expand Down Expand Up @@ -35,7 +35,7 @@ def save_detection(image_id, coords):
source_image_id=image_id,
bbox=bbox,
inference_time=seconds_per_item,
algorithm=self.name,
algorithm=AlgorithmReference(name=self.name, key=self.get_key()),
timestamp=datetime.datetime.now(),
crop_image_url=None,
)
Expand Down
Loading

0 comments on commit 65237cc

Please sign in to comment.