From 589ce5985430e746fa10eb8effd188ddbd2c0072 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Sun, 26 Jan 2025 13:59:57 -0800 Subject: [PATCH 1/2] feat: allow using moth/non-moth model as terminal classifier --- trapdata/api/api.py | 113 ++++++++++++++++++------------ trapdata/api/schemas.py | 2 + trapdata/api/tests/test_api.py | 57 +++++++++++++-- trapdata/api/tests/test_models.py | 10 ++- 4 files changed, 132 insertions(+), 50 deletions(-) diff --git a/trapdata/api/api.py b/trapdata/api/api.py index 6ec7280..5b3120a 100644 --- a/trapdata/api/api.py +++ b/trapdata/api/api.py @@ -27,6 +27,7 @@ from .schemas import ( AlgorithmCategoryMapResponse, AlgorithmConfigResponse, + DetectionResponse, PipelineConfigResponse, ) from .schemas import PipelineRequest as PipelineRequest_ @@ -45,7 +46,7 @@ "costa_rica_moths_turing_2024": MothClassifierTuringCostaRica, "anguilla_moths_turing_2024": MothClassifierTuringAnguilla, "global_moths_2024": MothClassifierGlobal, - # "moth_binary": MothClassifierBinary, + "moth_binary": MothClassifierBinary, } _classifier_choices = dict( zip(CLASSIFIER_CHOICES.keys(), list(CLASSIFIER_CHOICES.keys())) @@ -55,6 +56,13 @@ PipelineChoice = enum.Enum("PipelineChoice", _classifier_choices) +def should_filter_detections(Classifier: type[APIMothClassifier]) -> bool: + if Classifier == MothClassifierBinary: + return False + else: + return True + + def make_category_map_response( model: APIMothDetector | APIMothClassifier, default_taxon_rank: str = "SPECIES", @@ -113,14 +121,20 @@ def make_pipeline_config_response( """ Create a configuration for an entire pipeline, given a species classifier class. """ + algorithms = [] + detector = APIMothDetector( source_images=[], ) + algorithms.append(make_algorithm_config_response(detector)) - binary_classifier = MothClassifierBinary( - source_images=[], - detections=[], - ) + if should_filter_detections(Classifier): + binary_classifier = MothClassifierBinary( + source_images=[], + detections=[], + terminal=False, + ) + algorithms.append(make_algorithm_config_response(binary_classifier)) classifier = Classifier( source_images=[], @@ -129,17 +143,14 @@ def make_pipeline_config_response( num_workers=settings.num_workers, terminal=True, ) + algorithms.append(make_algorithm_config_response(classifier)) return PipelineConfigResponse( name=classifier.name, slug=slug, description=classifier.description, version=1, - algorithms=[ - make_algorithm_config_response(detector), - make_algorithm_config_response(binary_classifier), - make_algorithm_config_response(classifier), - ], + algorithms=algorithms, ) @@ -173,6 +184,7 @@ async def root(): @app.post( "/pipeline/process/", deprecated=True, tags=["services"] ) # old endpoint, deprecated, remove after jan 2025 +@app.post("/process", tags=["services"]) # new endpoint @app.post("/process/", tags=["services"]) # new endpoint async def process(data: PipelineRequest) -> PipelineResponse: algorithms_used: dict[str, AlgorithmConfigResponse] = {} @@ -196,6 +208,9 @@ async def process(data: PipelineRequest) -> PipelineResponse: ] start_time = time.time() + + Classifier = CLASSIFIER_CHOICES[str(data.pipeline)] + detector = APIMothDetector( source_images=source_images, batch_size=settings.localization_batch_size, @@ -207,50 +222,60 @@ async def process(data: PipelineRequest) -> PipelineResponse: num_pre_filter = len(detector_results) algorithms_used[detector.get_key()] = make_algorithm_response(detector) - filter = MothClassifierBinary( - source_images=source_images, - detections=detector_results, - batch_size=settings.classification_batch_size, - num_workers=settings.num_workers, - # single=True if len(detector_results) == 1 else False, - single=True, # @TODO solve issues with reading images in multiprocessing - terminal=False, - ) - filter.run() - algorithms_used[filter.get_key()] = make_algorithm_response(filter) + detections_for_terminal_classifier: list[DetectionResponse] = [] + detections_to_return: list[DetectionResponse] = [] + + if should_filter_detections(Classifier): + filter = MothClassifierBinary( + source_images=source_images, + detections=detector_results, + batch_size=settings.classification_batch_size, + num_workers=settings.num_workers, + # single=True if len(detector_results) == 1 else False, + single=True, # @TODO solve issues with reading images in multiprocessing + terminal=False, + ) + filter.run() + algorithms_used[filter.get_key()] = make_algorithm_response(filter) - # Compare num detections with num moth detections - num_post_filter = len(filter.results) - logger.info( - f"Binary classifier returned {num_post_filter} of {num_pre_filter} detections" - ) + # Compare num detections with num moth detections + num_post_filter = len(filter.results) + logger.info( + f"Binary classifier returned {num_post_filter} of {num_pre_filter} detections" + ) - # Filter results based on positive_binary_label - moth_detections = [] - non_moth_detections = [] - for detection in filter.results: - for classification in detection.classifications: - if classification.classification == filter.positive_binary_label: - moth_detections.append(detection) - elif classification.classification == filter.negative_binary_label: - non_moth_detections.append(detection) - break + # Filter results based on positive_binary_label + moth_detections = [] + non_moth_detections = [] + for detection in filter.results: + for classification in detection.classifications: + if classification.classification == filter.positive_binary_label: + moth_detections.append(detection) + elif classification.classification == filter.negative_binary_label: + non_moth_detections.append(detection) + break + detections_for_terminal_classifier += moth_detections + detections_to_return += non_moth_detections + + else: + logger.info("Skipping binary classification filter") + detections_for_terminal_classifier += detector_results logger.info( - f"Sending {len(moth_detections)} of {num_pre_filter} " + f"Sending {len(detections_for_terminal_classifier)} of {num_pre_filter} " "detections to the classifier" ) - Classifier = CLASSIFIER_CHOICES[str(data.pipeline)] classifier: APIMothClassifier = Classifier( source_images=source_images, - detections=moth_detections, + detections=detections_for_terminal_classifier, batch_size=settings.classification_batch_size, num_workers=settings.num_workers, # single=True if len(filtered_detections) == 1 else False, single=True, # @TODO solve issues with reading images in multiprocessing example_config_param=data.config.example_config_param, terminal=True, + # critera=data.config.criteria, # @TODO another approach to intermediate filter models ) classifier.run() end_time = time.time() @@ -258,18 +283,18 @@ async def process(data: PipelineRequest) -> PipelineResponse: algorithms_used[classifier.get_key()] = make_algorithm_response(classifier) # Return all detections, including those that were not classified as moths - all_detections = classifier.results + non_moth_detections + detections_to_return += classifier.results logger.info( f"Processed {len(source_images)} images in {seconds_elapsed:.2f} seconds" ) - logger.info(f"Returning {len(all_detections)} detections") + logger.info(f"Returning {len(detections_to_return)} detections") # print(all_detections) # If the number of detections is greater than 100, its suspicious. Log it. - if len(all_detections) > 100: + if len(detections_to_return) > 100: logger.warning( - f"Detected {len(all_detections)} detections. " + f"Detected {len(detections_to_return)} detections. " "This is suspicious and may contain duplicates." ) @@ -277,7 +302,7 @@ async def process(data: PipelineRequest) -> PipelineResponse: pipeline=data.pipeline, algorithms=algorithms_used, source_images=source_image_results, - detections=all_detections, + detections=detections_to_return, total_time=seconds_elapsed, ) return response diff --git a/trapdata/api/schemas.py b/trapdata/api/schemas.py index 6099932..e0dbb33 100644 --- a/trapdata/api/schemas.py +++ b/trapdata/api/schemas.py @@ -153,6 +153,7 @@ class AlgorithmCategoryMapResponse(pydantic.BaseModel): {"label": "Not a moth", "index": 1, "gbif_key": 5678}, ] ], + repr=False, # Too long to display in the repr ) labels: list[str] = pydantic.Field( default_factory=list, @@ -161,6 +162,7 @@ class AlgorithmCategoryMapResponse(pydantic.BaseModel): "the model." ), examples=[["Moth", "Not a moth"]], + repr=False, # Too long to display in the repr ) version: str | None = pydantic.Field( default=None, diff --git a/trapdata/api/tests/test_api.py b/trapdata/api/tests/test_api.py index d535a51..c49e6b5 100644 --- a/trapdata/api/tests/test_api.py +++ b/trapdata/api/tests/test_api.py @@ -10,6 +10,8 @@ PipelineRequest, PipelineResponse, app, + make_algorithm_response, + make_pipeline_config_response, ) from trapdata.api.schemas import PipelineConfigRequest, SourceImageRequest from trapdata.api.tests.image_server import StaticFileTestServer @@ -63,11 +65,10 @@ def test_pipeline_request(self): source_images=self.get_test_images(num=2), ) with self.file_server: - response = self.client.post( - "/pipeline/process", json=pipeline_request.model_dump() - ) + response = self.client.post("/process", json=pipeline_request.model_dump()) assert response.status_code == 200 - PipelineResponse(**response.json()) + results = PipelineResponse(**response.json()) + return results def test_config_num_classification_predictions(self): """ @@ -124,3 +125,51 @@ def _send_request(max_predictions_per_classification: int | None): _send_request(max_predictions_per_classification=1) _send_request(max_predictions_per_classification=None) + + def test_pipeline_config_with_binary_classifier(self): + BinaryClassifier = CLASSIFIER_CHOICES["moth_binary"] + BinaryClassifierResponse = make_algorithm_response(BinaryClassifier) + + SpeciesClassifier = CLASSIFIER_CHOICES["quebec_vermont_moths_2023"] + SpeciesClassifierResponse = make_algorithm_response(SpeciesClassifier) + + # Test using a pipeline that finishes with a full species classifier + pipeline_config = make_pipeline_config_response(SpeciesClassifier) + + self.assertEqual(len(pipeline_config.algorithms), 3) + self.assertEqual( + pipeline_config.algorithms[-1].key, SpeciesClassifierResponse.key + ) + self.assertEqual( + pipeline_config.algorithms[1].key, BinaryClassifierResponse.key + ) + + # Test using a pipeline that finishes only with a binary classifier + pipeline_config_binary_only = make_pipeline_config_response(BinaryClassifier) + + self.assertEqual(len(pipeline_config_binary_only.algorithms), 2) + self.assertEqual( + pipeline_config_binary_only.algorithms[-1].key, BinaryClassifierResponse.key + ) + # self.assertTrue(pipeline_config_binary_only.algorithms[-1].terminal) + + def test_processing_with_only_binary_classifier(self): + binary_algorithm_key = "moth_binary" + binary_algorithm = CLASSIFIER_CHOICES[binary_algorithm_key] + pipeline_request = PipelineRequest( + pipeline=PipelineChoice[binary_algorithm_key], + source_images=self.get_test_images(num=2), + ) + with self.file_server: + response = self.client.post("/process", json=pipeline_request.model_dump()) + assert response.status_code == 200 + results = PipelineResponse(**response.json()) + + for detection in results.detections: + for classification in detection.classifications: + assert classification.algorithm.key == binary_algorithm_key + assert classification.terminal + assert classification.labels + assert len(classification.labels) == binary_algorithm.num_classes + assert classification.scores + assert len(classification.scores) == binary_algorithm.num_classes diff --git a/trapdata/api/tests/test_models.py b/trapdata/api/tests/test_models.py index b2d23bc..4b74f29 100644 --- a/trapdata/api/tests/test_models.py +++ b/trapdata/api/tests/test_models.py @@ -74,10 +74,16 @@ def make_image(): def get_test_images( - subdirs: typing.Iterable[str] = ("vermont", "panama"), limit: int = 6 + subdirs: typing.Iterable[str] = ("vermont", "panama"), + limit: int = 6, + with_urls: bool = False, ) -> list[SourceImage]: return [ - SourceImage(id=str(img["path"].name), filepath=img["path"]) + SourceImage( + id=str(img["path"].name), + filepath=img["path"], + url=img["url"] if with_urls else None, + ) for subdir in subdirs for img in find_images(pathlib.Path(TEST_IMAGES_BASE_PATH) / subdir) ][:limit] From 6f0101e0decaa9f849e476cd4f1cc867f27dabe7 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Sun, 26 Jan 2025 14:16:07 -0800 Subject: [PATCH 2/2] fix: update tests & cleanup console output --- trapdata/api/schemas.py | 3 +++ trapdata/api/tests/test_api.py | 15 +++++++++++---- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/trapdata/api/schemas.py b/trapdata/api/schemas.py index e0dbb33..7083d64 100644 --- a/trapdata/api/schemas.py +++ b/trapdata/api/schemas.py @@ -82,6 +82,7 @@ class ClassificationResponse(pydantic.BaseModel): "classification in the response. Use the category map from the algorithm " "to get the full list of labels and metadata." ), + repr=False, # Too long to display in the repr ) scores: list[float] = pydantic.Field( default_factory=list, @@ -89,6 +90,7 @@ class ClassificationResponse(pydantic.BaseModel): "The calibrated probabilities for each class label, most commonly " "the softmax output." ), + repr=False, # Too long to display in the repr ) logits: list[float] = pydantic.Field( default_factory=list, @@ -96,6 +98,7 @@ class ClassificationResponse(pydantic.BaseModel): "The raw logits output by the model, before any calibration or " "normalization." ), + repr=False, # Too long to display in the repr ) inference_time: float | None = None algorithm: AlgorithmReference diff --git a/trapdata/api/tests/test_api.py b/trapdata/api/tests/test_api.py index c49e6b5..6f4b9c1 100644 --- a/trapdata/api/tests/test_api.py +++ b/trapdata/api/tests/test_api.py @@ -127,14 +127,19 @@ def _send_request(max_predictions_per_classification: int | None): _send_request(max_predictions_per_classification=None) def test_pipeline_config_with_binary_classifier(self): - BinaryClassifier = CLASSIFIER_CHOICES["moth_binary"] + binary_classifier_pipeline_choice = "moth_binary" + BinaryClassifier = CLASSIFIER_CHOICES[binary_classifier_pipeline_choice] BinaryClassifierResponse = make_algorithm_response(BinaryClassifier) - SpeciesClassifier = CLASSIFIER_CHOICES["quebec_vermont_moths_2023"] + species_classifier_pipeline_choice = "quebec_vermont_moths_2023" + SpeciesClassifier = CLASSIFIER_CHOICES[species_classifier_pipeline_choice] SpeciesClassifierResponse = make_algorithm_response(SpeciesClassifier) # Test using a pipeline that finishes with a full species classifier - pipeline_config = make_pipeline_config_response(SpeciesClassifier) + pipeline_config = make_pipeline_config_response( + SpeciesClassifier, + slug=species_classifier_pipeline_choice, + ) self.assertEqual(len(pipeline_config.algorithms), 3) self.assertEqual( @@ -145,7 +150,9 @@ def test_pipeline_config_with_binary_classifier(self): ) # Test using a pipeline that finishes only with a binary classifier - pipeline_config_binary_only = make_pipeline_config_response(BinaryClassifier) + pipeline_config_binary_only = make_pipeline_config_response( + BinaryClassifier, slug=binary_classifier_pipeline_choice + ) self.assertEqual(len(pipeline_config_binary_only.algorithms), 2) self.assertEqual(