Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow using moth/non-moth model as terminal classifier #70

Merged
merged 2 commits into from
Jan 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 69 additions & 44 deletions trapdata/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from .schemas import (
AlgorithmCategoryMapResponse,
AlgorithmConfigResponse,
DetectionResponse,
PipelineConfigResponse,
)
from .schemas import PipelineRequest as PipelineRequest_
Expand All @@ -45,7 +46,7 @@
"costa_rica_moths_turing_2024": MothClassifierTuringCostaRica,
"anguilla_moths_turing_2024": MothClassifierTuringAnguilla,
"global_moths_2024": MothClassifierGlobal,
# "moth_binary": MothClassifierBinary,
"moth_binary": MothClassifierBinary,
}
_classifier_choices = dict(
zip(CLASSIFIER_CHOICES.keys(), list(CLASSIFIER_CHOICES.keys()))
Expand All @@ -55,6 +56,13 @@
PipelineChoice = enum.Enum("PipelineChoice", _classifier_choices)


def should_filter_detections(Classifier: type[APIMothClassifier]) -> bool:
if Classifier == MothClassifierBinary:
return False
else:
return True


def make_category_map_response(
model: APIMothDetector | APIMothClassifier,
default_taxon_rank: str = "SPECIES",
Expand Down Expand Up @@ -113,14 +121,20 @@ def make_pipeline_config_response(
"""
Create a configuration for an entire pipeline, given a species classifier class.
"""
algorithms = []

detector = APIMothDetector(
source_images=[],
)
algorithms.append(make_algorithm_config_response(detector))

binary_classifier = MothClassifierBinary(
source_images=[],
detections=[],
)
if should_filter_detections(Classifier):
binary_classifier = MothClassifierBinary(
source_images=[],
detections=[],
terminal=False,
)
algorithms.append(make_algorithm_config_response(binary_classifier))

classifier = Classifier(
source_images=[],
Expand All @@ -129,17 +143,14 @@ def make_pipeline_config_response(
num_workers=settings.num_workers,
terminal=True,
)
algorithms.append(make_algorithm_config_response(classifier))

return PipelineConfigResponse(
name=classifier.name,
slug=slug,
description=classifier.description,
version=1,
algorithms=[
make_algorithm_config_response(detector),
make_algorithm_config_response(binary_classifier),
make_algorithm_config_response(classifier),
],
algorithms=algorithms,
)


Expand Down Expand Up @@ -173,6 +184,7 @@ async def root():
@app.post(
"/pipeline/process/", deprecated=True, tags=["services"]
) # old endpoint, deprecated, remove after jan 2025
@app.post("/process", tags=["services"]) # new endpoint
@app.post("/process/", tags=["services"]) # new endpoint
async def process(data: PipelineRequest) -> PipelineResponse:
algorithms_used: dict[str, AlgorithmConfigResponse] = {}
Expand All @@ -196,6 +208,9 @@ async def process(data: PipelineRequest) -> PipelineResponse:
]

start_time = time.time()

Classifier = CLASSIFIER_CHOICES[str(data.pipeline)]

detector = APIMothDetector(
source_images=source_images,
batch_size=settings.localization_batch_size,
Expand All @@ -207,77 +222,87 @@ async def process(data: PipelineRequest) -> PipelineResponse:
num_pre_filter = len(detector_results)
algorithms_used[detector.get_key()] = make_algorithm_response(detector)

filter = MothClassifierBinary(
source_images=source_images,
detections=detector_results,
batch_size=settings.classification_batch_size,
num_workers=settings.num_workers,
# single=True if len(detector_results) == 1 else False,
single=True, # @TODO solve issues with reading images in multiprocessing
terminal=False,
)
filter.run()
algorithms_used[filter.get_key()] = make_algorithm_response(filter)
detections_for_terminal_classifier: list[DetectionResponse] = []
detections_to_return: list[DetectionResponse] = []

if should_filter_detections(Classifier):
filter = MothClassifierBinary(
source_images=source_images,
detections=detector_results,
batch_size=settings.classification_batch_size,
num_workers=settings.num_workers,
# single=True if len(detector_results) == 1 else False,
single=True, # @TODO solve issues with reading images in multiprocessing
terminal=False,
)
filter.run()
algorithms_used[filter.get_key()] = make_algorithm_response(filter)

# Compare num detections with num moth detections
num_post_filter = len(filter.results)
logger.info(
f"Binary classifier returned {num_post_filter} of {num_pre_filter} detections"
)
# Compare num detections with num moth detections
num_post_filter = len(filter.results)
logger.info(
f"Binary classifier returned {num_post_filter} of {num_pre_filter} detections"
)

# Filter results based on positive_binary_label
moth_detections = []
non_moth_detections = []
for detection in filter.results:
for classification in detection.classifications:
if classification.classification == filter.positive_binary_label:
moth_detections.append(detection)
elif classification.classification == filter.negative_binary_label:
non_moth_detections.append(detection)
break
# Filter results based on positive_binary_label
moth_detections = []
non_moth_detections = []
for detection in filter.results:
for classification in detection.classifications:
if classification.classification == filter.positive_binary_label:
moth_detections.append(detection)
elif classification.classification == filter.negative_binary_label:
non_moth_detections.append(detection)
break
detections_for_terminal_classifier += moth_detections
detections_to_return += non_moth_detections

else:
logger.info("Skipping binary classification filter")
detections_for_terminal_classifier += detector_results

logger.info(
f"Sending {len(moth_detections)} of {num_pre_filter} "
f"Sending {len(detections_for_terminal_classifier)} of {num_pre_filter} "
"detections to the classifier"
)

Classifier = CLASSIFIER_CHOICES[str(data.pipeline)]
classifier: APIMothClassifier = Classifier(
source_images=source_images,
detections=moth_detections,
detections=detections_for_terminal_classifier,
batch_size=settings.classification_batch_size,
num_workers=settings.num_workers,
# single=True if len(filtered_detections) == 1 else False,
single=True, # @TODO solve issues with reading images in multiprocessing
example_config_param=data.config.example_config_param,
terminal=True,
# critera=data.config.criteria, # @TODO another approach to intermediate filter models
)
classifier.run()
end_time = time.time()
seconds_elapsed = float(end_time - start_time)
algorithms_used[classifier.get_key()] = make_algorithm_response(classifier)

# Return all detections, including those that were not classified as moths
all_detections = classifier.results + non_moth_detections
detections_to_return += classifier.results

logger.info(
f"Processed {len(source_images)} images in {seconds_elapsed:.2f} seconds"
)
logger.info(f"Returning {len(all_detections)} detections")
logger.info(f"Returning {len(detections_to_return)} detections")
# print(all_detections)

# If the number of detections is greater than 100, its suspicious. Log it.
if len(all_detections) > 100:
if len(detections_to_return) > 100:
logger.warning(
f"Detected {len(all_detections)} detections. "
f"Detected {len(detections_to_return)} detections. "
"This is suspicious and may contain duplicates."
)

response = PipelineResponse(
pipeline=data.pipeline,
algorithms=algorithms_used,
source_images=source_image_results,
detections=all_detections,
detections=detections_to_return,
total_time=seconds_elapsed,
)
return response
Expand Down
5 changes: 5 additions & 0 deletions trapdata/api/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,20 +82,23 @@ class ClassificationResponse(pydantic.BaseModel):
"classification in the response. Use the category map from the algorithm "
"to get the full list of labels and metadata."
),
repr=False, # Too long to display in the repr
)
scores: list[float] = pydantic.Field(
default_factory=list,
description=(
"The calibrated probabilities for each class label, most commonly "
"the softmax output."
),
repr=False, # Too long to display in the repr
)
logits: list[float] = pydantic.Field(
default_factory=list,
description=(
"The raw logits output by the model, before any calibration or "
"normalization."
),
repr=False, # Too long to display in the repr
)
inference_time: float | None = None
algorithm: AlgorithmReference
Expand Down Expand Up @@ -153,6 +156,7 @@ class AlgorithmCategoryMapResponse(pydantic.BaseModel):
{"label": "Not a moth", "index": 1, "gbif_key": 5678},
]
],
repr=False, # Too long to display in the repr
)
labels: list[str] = pydantic.Field(
default_factory=list,
Expand All @@ -161,6 +165,7 @@ class AlgorithmCategoryMapResponse(pydantic.BaseModel):
"the model."
),
examples=[["Moth", "Not a moth"]],
repr=False, # Too long to display in the repr
)
version: str | None = pydantic.Field(
default=None,
Expand Down
64 changes: 60 additions & 4 deletions trapdata/api/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
PipelineRequest,
PipelineResponse,
app,
make_algorithm_response,
make_pipeline_config_response,
)
from trapdata.api.schemas import PipelineConfigRequest, SourceImageRequest
from trapdata.api.tests.image_server import StaticFileTestServer
Expand Down Expand Up @@ -63,11 +65,10 @@ def test_pipeline_request(self):
source_images=self.get_test_images(num=2),
)
with self.file_server:
response = self.client.post(
"/pipeline/process", json=pipeline_request.model_dump()
)
response = self.client.post("/process", json=pipeline_request.model_dump())
assert response.status_code == 200
PipelineResponse(**response.json())
results = PipelineResponse(**response.json())
return results

def test_config_num_classification_predictions(self):
"""
Expand Down Expand Up @@ -124,3 +125,58 @@ def _send_request(max_predictions_per_classification: int | None):

_send_request(max_predictions_per_classification=1)
_send_request(max_predictions_per_classification=None)

def test_pipeline_config_with_binary_classifier(self):
binary_classifier_pipeline_choice = "moth_binary"
BinaryClassifier = CLASSIFIER_CHOICES[binary_classifier_pipeline_choice]
BinaryClassifierResponse = make_algorithm_response(BinaryClassifier)

species_classifier_pipeline_choice = "quebec_vermont_moths_2023"
SpeciesClassifier = CLASSIFIER_CHOICES[species_classifier_pipeline_choice]
SpeciesClassifierResponse = make_algorithm_response(SpeciesClassifier)

# Test using a pipeline that finishes with a full species classifier
pipeline_config = make_pipeline_config_response(
SpeciesClassifier,
slug=species_classifier_pipeline_choice,
)

self.assertEqual(len(pipeline_config.algorithms), 3)
self.assertEqual(
pipeline_config.algorithms[-1].key, SpeciesClassifierResponse.key
)
self.assertEqual(
pipeline_config.algorithms[1].key, BinaryClassifierResponse.key
)

# Test using a pipeline that finishes only with a binary classifier
pipeline_config_binary_only = make_pipeline_config_response(
BinaryClassifier, slug=binary_classifier_pipeline_choice
)

self.assertEqual(len(pipeline_config_binary_only.algorithms), 2)
self.assertEqual(
pipeline_config_binary_only.algorithms[-1].key, BinaryClassifierResponse.key
)
# self.assertTrue(pipeline_config_binary_only.algorithms[-1].terminal)

def test_processing_with_only_binary_classifier(self):
binary_algorithm_key = "moth_binary"
binary_algorithm = CLASSIFIER_CHOICES[binary_algorithm_key]
pipeline_request = PipelineRequest(
pipeline=PipelineChoice[binary_algorithm_key],
source_images=self.get_test_images(num=2),
)
with self.file_server:
response = self.client.post("/process", json=pipeline_request.model_dump())
assert response.status_code == 200
results = PipelineResponse(**response.json())

for detection in results.detections:
for classification in detection.classifications:
assert classification.algorithm.key == binary_algorithm_key
assert classification.terminal
assert classification.labels
assert len(classification.labels) == binary_algorithm.num_classes
assert classification.scores
assert len(classification.scores) == binary_algorithm.num_classes
10 changes: 8 additions & 2 deletions trapdata/api/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,16 @@ def make_image():


def get_test_images(
subdirs: typing.Iterable[str] = ("vermont", "panama"), limit: int = 6
subdirs: typing.Iterable[str] = ("vermont", "panama"),
limit: int = 6,
with_urls: bool = False,
) -> list[SourceImage]:
return [
SourceImage(id=str(img["path"].name), filepath=img["path"])
SourceImage(
id=str(img["path"].name),
filepath=img["path"],
url=img["url"] if with_urls else None,
)
for subdir in subdirs
for img in find_images(pathlib.Path(TEST_IMAGES_BASE_PATH) / subdir)
][:limit]
Expand Down
Loading