diff --git a/trapdata/antenna/tests/test_memory_leak.py b/trapdata/antenna/tests/test_memory_leak.py new file mode 100644 index 0000000..7857f59 --- /dev/null +++ b/trapdata/antenna/tests/test_memory_leak.py @@ -0,0 +1,126 @@ +"""Memory leak regression test for _process_job batch processing. + +Verifies that RSS does not grow unboundedly across batches by using the +on_batch_complete callback to sample memory after each batch. + +Uses the same test infrastructure as test_worker.py (mock Antenna API, +StaticFileTestServer, real ML inference). +""" + +import os +import pathlib +from unittest import TestCase +from unittest.mock import MagicMock + +import pytest +from fastapi.testclient import TestClient + +from trapdata.antenna.schemas import AntennaPipelineProcessingTask +from trapdata.antenna.tests import antenna_api_server +from trapdata.antenna.tests.antenna_api_server import app as antenna_app +from trapdata.antenna.worker import _process_job +from trapdata.api.tests.image_server import StaticFileTestServer +from trapdata.api.tests.utils import get_test_image_urls, patch_antenna_api_requests +from trapdata.tests import TEST_IMAGES_BASE_PATH + + +def _get_rss_mb() -> float: + """Current RSS in MB, read from /proc/self/statm (Linux-only).""" + with open("/proc/self/statm") as f: + pages = int(f.read().split()[1]) # resident pages + return pages * os.sysconf("SC_PAGE_SIZE") / (1024 * 1024) + + +class TestMemoryLeak(TestCase): + """Regression test: RSS must not grow linearly with batch count.""" + + @classmethod + def setUpClass(cls): + cls.test_images_dir = pathlib.Path(TEST_IMAGES_BASE_PATH) + cls.file_server = StaticFileTestServer(cls.test_images_dir) + cls.file_server.start() + cls.antenna_client = TestClient(antenna_app) + + @classmethod + def tearDownClass(cls): + cls.file_server.stop() + + def setUp(self): + antenna_api_server.reset() + + def _make_settings(self): + settings = MagicMock() + settings.antenna_api_base_url = "http://testserver/api/v2" + settings.antenna_api_auth_token = "test-token" + settings.antenna_api_batch_size = 2 + settings.num_workers = 0 + settings.localization_batch_size = 2 + return settings + + @pytest.mark.slow + def test_rss_stable_across_batches(self): + """RSS should not grow more than 150 MB across 25+ batches. + + With the old code, all_detections accumulated ~220K DetectionResponse + objects over a large job, growing RSS by ~4 GB/hr. After the fix, + each batch's intermediates go out of scope in _process_batch(). + + The 150 MB threshold accounts for normal PyTorch/CUDA allocator + fragmentation and memory pool behavior, which is not a true leak. + """ + # Create 50 tasks by cycling through the 3 available test images + image_urls = get_test_image_urls( + self.file_server, self.test_images_dir, subdir="vermont", num=3 + ) + num_tasks = 50 + tasks = [ + AntennaPipelineProcessingTask( + id=f"task_{i}", + image_id=f"img_{i}", + image_url=image_urls[i % len(image_urls)], + reply_subject=f"reply_{i}", + ) + for i in range(num_tasks) + ] + antenna_api_server.setup_job(job_id=999, tasks=tasks) + + # Collect RSS samples via callback + rss_samples: list[float] = [] + + def on_batch(batch_num: int, items: int): + rss_samples.append(_get_rss_mb()) + + with patch_antenna_api_requests(self.antenna_client): + result = _process_job( + "quebec_vermont_moths_2023", + 999, + self._make_settings(), + on_batch_complete=on_batch, + ) + + assert result is True + assert ( + len(rss_samples) >= 10 + ), f"Expected at least 10 batches, got {len(rss_samples)}" + + # Compare RSS at end vs after first 2 batches (allow model warmup) + warmup_rss = rss_samples[2] + final_rss = rss_samples[-1] + growth_mb = final_rss - warmup_rss + + print(f"\nMemory profile ({len(rss_samples)} batches):") + print(f" After warmup (batch 2): {warmup_rss:.1f} MB") + print(f" Final (batch {len(rss_samples) - 1}): {final_rss:.1f} MB") + print(f" Growth: {growth_mb:.1f} MB") + for i, rss in enumerate(rss_samples): + print(f" Batch {i}: {rss:.1f} MB") + + # Threshold: 150 MB accounts for PyTorch/CUDA allocator pools and + # Python memory fragmentation — not a true leak. Before the fix, + # all_detections accumulated every DetectionResponse across all batches. + # At scale (31K images, ~7 detections/image), that was ~220K objects = GB. + assert growth_mb < 150, ( + f"RSS grew {growth_mb:.1f} MB across {len(rss_samples)} batches " + f"(warmup={warmup_rss:.1f} MB, final={final_rss:.1f} MB). " + f"Likely memory leak in batch processing." + ) diff --git a/trapdata/antenna/worker.py b/trapdata/antenna/worker.py index bba7905..7773c50 100644 --- a/trapdata/antenna/worker.py +++ b/trapdata/antenna/worker.py @@ -1,7 +1,10 @@ """Worker loop for processing jobs from Antenna API.""" +from __future__ import annotations + import datetime import time +from collections.abc import Callable import numpy as np import torch @@ -112,11 +115,175 @@ def _worker_loop(gpu_id: int, pipelines: list[str]): time.sleep(SLEEP_TIME_SECONDS) +def _process_batch( + batch: dict, + batch_num: int, + detector: APIMothDetector, + classifier, + pipeline: str, +) -> tuple[int, int, list[AntennaTaskResult], float, float]: + """Process a single batch of images through detection and classification. + + All large intermediates (image_tensors, crops, batched_crops) are local to this + function and freed by reference counting when it returns, preventing memory leaks. + + Returns: + (items_processed, detections_count, batch_results, detect_time, classify_time) + """ + images = batch.get("images", []) + image_ids = batch.get("image_ids", []) + reply_subjects = batch.get("reply_subjects", [None] * len(images)) + image_urls = batch.get("image_urls", [None] * len(images)) + + batch_results: list[AntennaTaskResult] = [] + + try: + # Validate all arrays have same length before zipping + if len(image_ids) != len(images): + raise ValueError( + f"Length mismatch: image_ids ({len(image_ids)}) != images ({len(images)})" + ) + if len(image_ids) != len(reply_subjects) or len(image_ids) != len(image_urls): + raise ValueError( + f"Length mismatch: image_ids ({len(image_ids)}), " + f"reply_subjects ({len(reply_subjects)}), image_urls ({len(image_urls)})" + ) + + batch_start_time = datetime.datetime.now() + + logger.info(f"Processing worker batch {batch_num + 1} ({len(images)} images)") + batch_output = [] + if len(images) > 0: + batch_output = detector.predict_batch(images) + + n_items = len(batch_output) + batch_output = list(detector.post_process_batch(batch_output)) + + # Convert image_ids to list if needed + if isinstance(image_ids, (np.ndarray, torch.Tensor)): + image_ids = image_ids.tolist() + + # TODO CGJS: Add seconds per item calculation for both detector and classifier + detector.save_results( + item_ids=image_ids, + batch_output=batch_output, + seconds_per_item=0, + ) + detect_time = (datetime.datetime.now() - batch_start_time).total_seconds() + + # Group detections by image_id + image_detections: dict[str, list[DetectionResponse]] = { + img_id: [] for img_id in image_ids + } + image_tensors = dict(zip(image_ids, images, strict=True)) + + classifier.reset(detector.results) + to_pil = torchvision.transforms.ToPILImage() + classify_transforms = classifier.get_transforms() + + # Collect and transform all crops for batched classification + crops = [] + valid_indices = [] + n_detections = 0 + for idx, dresp in enumerate(detector.results): + image_tensor = image_tensors[dresp.source_image_id] + bbox = dresp.bbox + y1, y2 = int(bbox.y1), int(bbox.y2) + x1, x2 = int(bbox.x1), int(bbox.x2) + if y1 >= y2 or x1 >= x2: + logger.warning( + f"Skipping detection {idx} with invalid bbox: " + f"({x1},{y1})->({x2},{y2})" + ) + continue + crop = image_tensor[:, y1:y2, x1:x2] + crop_pil = to_pil(crop) + crop_transformed = classify_transforms(crop_pil) + crops.append(crop_transformed) + valid_indices.append(idx) + + classify_start = datetime.datetime.now() + if crops: + batched_crops = torch.stack(crops) + classifier_out = classifier.predict_batch(batched_crops) + classifier_out = classifier.post_process_batch(classifier_out) + + for crop_i, idx in enumerate(valid_indices): + dresp = detector.results[idx] + detection = classifier.update_detection_classification( + seconds_per_item=0, + image_id=dresp.source_image_id, + detection_idx=idx, + predictions=classifier_out[crop_i], + ) + image_detections[dresp.source_image_id].append(detection) + n_detections += 1 + + classify_time = (datetime.datetime.now() - classify_start).total_seconds() + + # Calculate batch processing time + batch_elapsed = (datetime.datetime.now() - batch_start_time).total_seconds() + + # Build results for each image + for reply_subject, image_id, image_url in zip( + reply_subjects, image_ids, image_urls, strict=True + ): + source_image = SourceImageResponse(id=image_id, url=image_url) + pipeline_response = PipelineResultsResponse( + pipeline=pipeline, + source_images=[source_image], + detections=image_detections[image_id], + total_time=batch_elapsed / len(image_ids), + ) + batch_results.append( + AntennaTaskResult( + reply_subject=reply_subject, + result=pipeline_response, + ) + ) + except Exception as e: + logger.error( + f"Batch {batch_num + 1} failed during processing: {e}", exc_info=True + ) + batch_results = [] + for reply_subject, image_id in zip(reply_subjects, image_ids, strict=True): + batch_results.append( + AntennaTaskResult( + reply_subject=reply_subject, + result=AntennaTaskResultError( + error=f"Batch processing error: {e}", + image_id=image_id, + ), + ) + ) + return 0, 0, batch_results, 0.0, 0.0 + + # Append results for failed image downloads + failed_items = batch.get("failed_items") + if failed_items: + for failed_item in failed_items: + batch_results.append( + AntennaTaskResult( + reply_subject=failed_item.get("reply_subject"), + result=AntennaTaskResultError( + error=failed_item.get("error", "Unknown error"), + image_id=failed_item.get("image_id"), + ), + ) + ) + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + return n_items, n_detections, batch_results, detect_time, classify_time + + @torch.no_grad() def _process_job( pipeline: str, job_id: int, settings: Settings, + on_batch_complete: Callable | None = None, ) -> bool: """Run the worker to process images from the REST API queue. @@ -124,6 +291,8 @@ def _process_job( pipeline: Pipeline name to use for processing (e.g., moth_binary, panama_moths_2024) job_id: Job ID to process settings: Settings object with antenna_api_* configuration + on_batch_complete: Optional callback invoked after each batch, with kwargs + batch_num (int) and items (int, cumulative items processed so far). Returns: True if any work was done, False otherwise """ @@ -140,7 +309,7 @@ def _process_job( total_classification_time = 0.0 total_save_time = 0.0 total_dl_time = 0.0 - all_detections = [] + total_detections = 0 _, t = log_time() for i, batch in enumerate(loader): @@ -160,157 +329,17 @@ def _process_job( detector.reset([]) did_work = True - # Extract data from dictionary batch - images = batch.get("images", []) - image_ids = batch.get("image_ids", []) - reply_subjects = batch.get("reply_subjects", [None] * len(images)) - image_urls = batch.get("image_urls", [None] * len(images)) - - batch_results: list[AntennaTaskResult] = [] - - try: - # Validate all arrays have same length before zipping - if len(image_ids) != len(images): - raise ValueError( - f"Length mismatch: image_ids ({len(image_ids)}) != images ({len(images)})" - ) - if len(image_ids) != len(reply_subjects) or len(image_ids) != len( - image_urls - ): - raise ValueError( - f"Length mismatch: image_ids ({len(image_ids)}), " - f"reply_subjects ({len(reply_subjects)}), image_urls ({len(image_urls)})" - ) - - # Track start time for this batch - batch_start_time = datetime.datetime.now() - - logger.info(f"Processing worker batch {i + 1} ({len(images)} images)") - # output is dict of "boxes", "labels", "scores" - batch_output = [] - if len(images) > 0: - batch_output = detector.predict_batch(images) - - items += len(batch_output) - logger.info(f"Total items processed so far: {items}") - batch_output = list(detector.post_process_batch(batch_output)) - - # Convert image_ids to list if needed - if isinstance(image_ids, (np.ndarray, torch.Tensor)): - image_ids = image_ids.tolist() - - # TODO CGJS: Add seconds per item calculation for both detector and classifier - detector.save_results( - item_ids=image_ids, - batch_output=batch_output, - seconds_per_item=0, - ) - dt, t = t("Finished detection") - total_detection_time += dt - - # Group detections by image_id - image_detections: dict[str, list[DetectionResponse]] = { - img_id: [] for img_id in image_ids - } - image_tensors = dict(zip(image_ids, images, strict=True)) - - classifier.reset(detector.results) - to_pil = torchvision.transforms.ToPILImage() - classify_transforms = classifier.get_transforms() - - # Collect and transform all crops for batched classification - crops = [] - valid_indices = [] - for idx, dresp in enumerate(detector.results): - image_tensor = image_tensors[dresp.source_image_id] - bbox = dresp.bbox - y1, y2 = int(bbox.y1), int(bbox.y2) - x1, x2 = int(bbox.x1), int(bbox.x2) - if y1 >= y2 or x1 >= x2: - logger.warning( - f"Skipping detection {idx} with invalid bbox: " - f"({x1},{y1})->({x2},{y2})" - ) - continue - crop = image_tensor[:, y1:y2, x1:x2] - crop_pil = to_pil(crop) - crop_transformed = classify_transforms(crop_pil) - crops.append(crop_transformed) - valid_indices.append(idx) - - if crops: - batched_crops = torch.stack(crops) - classifier_out = classifier.predict_batch(batched_crops) - classifier_out = classifier.post_process_batch(classifier_out) - - for crop_i, idx in enumerate(valid_indices): - dresp = detector.results[idx] - detection = classifier.update_detection_classification( - seconds_per_item=0, - image_id=dresp.source_image_id, - detection_idx=idx, - predictions=classifier_out[crop_i], - ) - image_detections[dresp.source_image_id].append(detection) - all_detections.append(detection) - - ct, t = t("Finished classification") - total_classification_time += ct - - # Calculate batch processing time - batch_end_time = datetime.datetime.now() - batch_elapsed = (batch_end_time - batch_start_time).total_seconds() - - # Post results back to the API with PipelineResponse for each image - batch_results.clear() - for reply_subject, image_id, image_url in zip( - reply_subjects, image_ids, image_urls, strict=True - ): - # Create SourceImageResponse for this image - source_image = SourceImageResponse(id=image_id, url=image_url) - - # Create PipelineResultsResponse - pipeline_response = PipelineResultsResponse( - pipeline=pipeline, - source_images=[source_image], - detections=image_detections[image_id], - total_time=batch_elapsed - / len(image_ids), # Approximate time per image - ) - - batch_results.append( - AntennaTaskResult( - reply_subject=reply_subject, - result=pipeline_response, - ) - ) - except Exception as e: - logger.error(f"Batch {i + 1} failed during processing: {e}", exc_info=True) - # Report errors back to Antenna so tasks aren't stuck in the queue - batch_results = [] - for reply_subject, image_id in zip(reply_subjects, image_ids, strict=True): - batch_results.append( - AntennaTaskResult( - reply_subject=reply_subject, - result=AntennaTaskResultError( - error=f"Batch processing error: {e}", - image_id=image_id, - ), - ) - ) - - failed_items = batch.get("failed_items") - if failed_items: - for failed_item in failed_items: - batch_results.append( - AntennaTaskResult( - reply_subject=failed_item.get("reply_subject"), - result=AntennaTaskResultError( - error=failed_item.get("error", "Unknown error"), - image_id=failed_item.get("image_id"), - ), - ) - ) + n_items, n_detections, batch_results, det_t, cls_t = _process_batch( + batch, + i, + detector, + classifier, + pipeline, + ) + items += n_items + total_detections += n_detections + total_detection_time += det_t + total_classification_time += cls_t success = post_batch_results( settings.antenna_api_base_url, @@ -328,8 +357,11 @@ def _process_job( total_save_time += st + if on_batch_complete: + on_batch_complete(batch_num=i, items=items) + logger.info( - f"Done, detections: {len(all_detections)}. Detecting time: {total_detection_time}, " + f"Done, detections: {total_detections}. Detecting time: {total_detection_time}, " f"classification time: {total_classification_time}, dl time: {total_dl_time}, save time: {total_save_time}" ) return did_work diff --git a/trapdata/api/api.py b/trapdata/api/api.py index 47f34fe..6265dff 100644 --- a/trapdata/api/api.py +++ b/trapdata/api/api.py @@ -13,7 +13,7 @@ from ..common.logs import logger # noqa: F401 from . import settings -from .models.classification import ( +from .models.classification import ( # noqa: E501 (disabled: Singapore - no category map) APIMothClassifier, InsectOrderClassifier, MothClassifierBinary, @@ -22,8 +22,12 @@ MothClassifierPanama2024, MothClassifierQuebecVermont, MothClassifierTuringAnguilla, + MothClassifierTuringAnguillaV02, MothClassifierTuringCostaRica, + MothClassifierTuringJapan, MothClassifierTuringKenyaUganda, + MothClassifierTuringMadagascar, + MothClassifierTuringThailand, MothClassifierUKDenmark, ) from .models.localization import APIMothDetector @@ -59,7 +63,12 @@ async def lifespan(app: fastapi.FastAPI): "uk_denmark_moths_2023": MothClassifierUKDenmark, "costa_rica_moths_turing_2024": MothClassifierTuringCostaRica, "anguilla_moths_turing_2024": MothClassifierTuringAnguilla, + "anguilla_moths_turing_v02_2024": MothClassifierTuringAnguillaV02, + # "singapore_moths_turing_2024": MothClassifierTuringSingapore, # disabled: category map not available + "thailand_moths_turing_2024": MothClassifierTuringThailand, + "madagascar_moths_turing_2024": MothClassifierTuringMadagascar, "kenya-uganda_moths_turing_2024": MothClassifierTuringKenyaUganda, + "japan_moths_turing_2024": MothClassifierTuringJapan, "global_moths_2024": MothClassifierGlobal, "moth_binary": MothClassifierBinary, "insect_orders_2025": InsectOrderClassifier, diff --git a/trapdata/api/models/classification.py b/trapdata/api/models/classification.py index e604f3c..6709fbb 100644 --- a/trapdata/api/models/classification.py +++ b/trapdata/api/models/classification.py @@ -15,8 +15,13 @@ PanamaMothSpeciesClassifierMixedResolution2023, QuebecVermontMothSpeciesClassifier2024, TuringAnguillaSpeciesClassifier, + TuringAnguillaV02SpeciesClassifier, TuringCostaRicaSpeciesClassifier, + TuringJapanSpeciesClassifier, TuringKenyaUgandaSpeciesClassifier, + TuringMadagascarSpeciesClassifier, + TuringSingaporeSpeciesClassifier, + TuringThailandSpeciesClassifier, UKDenmarkMothSpeciesClassifier2024, ) @@ -219,12 +224,39 @@ class MothClassifierTuringAnguilla(APIMothClassifier, TuringAnguillaSpeciesClass pass +class MothClassifierTuringAnguillaV02( + APIMothClassifier, TuringAnguillaV02SpeciesClassifier +): + pass + + +class MothClassifierTuringJapan(APIMothClassifier, TuringJapanSpeciesClassifier): + pass + + class MothClassifierTuringKenyaUganda( APIMothClassifier, TuringKenyaUgandaSpeciesClassifier ): pass +class MothClassifierTuringMadagascar( + APIMothClassifier, TuringMadagascarSpeciesClassifier +): + pass + + +class MothClassifierTuringThailand(APIMothClassifier, TuringThailandSpeciesClassifier): + pass + + +# Disabled: category map not available at this time +class MothClassifierTuringSingapore( + APIMothClassifier, TuringSingaporeSpeciesClassifier +): + pass + + class MothClassifierGlobal(APIMothClassifier, GlobalMothSpeciesClassifier): pass diff --git a/trapdata/ml/models/classification.py b/trapdata/ml/models/classification.py index 21459d1..5c400c2 100644 --- a/trapdata/ml/models/classification.py +++ b/trapdata/ml/models/classification.py @@ -422,6 +422,19 @@ class TuringAnguillaSpeciesClassifier(SpeciesClassifier, Resnet50Classifier_Turi ) +class TuringAnguillaV02SpeciesClassifier(SpeciesClassifier, Resnet50Classifier_Turing): + name = "Turing Anguilla Species Classifier v02 (160 classes)" + description = "Trained on 19th November 2024 by Turing team using Resnet50 model. 160 classes." + weights_path = ( + "https://object-arbutus.cloud.computecanada.ca/ami-models/moths/classification/" + "turing-anguilla_v02_resnet50_2024-11-19-19-17_state.pt" + ) + labels_path = ( + "https://object-arbutus.cloud.computecanada.ca/ami-models/moths/classification/" + "02_anguilla_data_category_map_160cls.json" + ) + + class TuringKenyaUgandaSpeciesClassifier(SpeciesClassifier, Resnet50Classifier_Turing): name = "Turing Kenya and Uganda Species Classifier" description = "Trained on 19th November 2024 by Turing team using Resnet50 model." @@ -435,6 +448,37 @@ class TuringKenyaUgandaSpeciesClassifier(SpeciesClassifier, Resnet50Classifier_T ) +class TuringThailandSpeciesClassifier(SpeciesClassifier, Resnet50Classifier_Turing): + name = "Turing Thailand Species Classifier" + description = "Trained on 11th November 2024 by Turing team using Resnet50 model." + weights_path = "https://object-arbutus.cloud.computecanada.ca/ami-models/moths/classification/turing-thailand_v01_resnet50_2024-11-21-16-28_state.pt" + labels_path = "https://object-arbutus.cloud.computecanada.ca/ami-models/moths/classification/01_thailand_data_category_map.json" + + +class TuringMadagascarSpeciesClassifier(SpeciesClassifier, Resnet50Classifier_Turing): + name = "Turing Madagascar Species Classifier" + description = "Trained on 11th November 2024 by Turing team using Resnet50 model." + weights_path = "https://object-arbutus.cloud.computecanada.ca/ami-models/moths/classification/turing-madagascar_v01_resnet50_2024-07-01-13-01_state.pt" + labels_path = "https://object-arbutus.cloud.computecanada.ca/ami-models/moths/classification/01_madagascar_data_category_map.json" + + +class TuringJapanSpeciesClassifier(SpeciesClassifier, Resnet50Classifier_Turing): + name = "Turing Japan Species Classifier" + description = "Trained on 19th November 2024 by Turing team using Resnet50 model." + weights_path = "https://object-arbutus.cloud.computecanada.ca/ami-models/moths/classification/turing-japan_v01_resnet50_2024-11-22-17-22_state.pt" + labels_path = "https://object-arbutus.cloud.computecanada.ca/ami-models/moths/classification/01_japan_data_category_map.json" + + +# NOTE: Singapore category map (02_singapore_data_category_map.json) is not available +# in the object store. Weights are uploaded but pipeline is disabled until the +# category map is sourced and uploaded. +class TuringSingaporeSpeciesClassifier(SpeciesClassifier, Resnet50Classifier_Turing): + name = "Turing Singapore Species Classifier" + description = "Trained on 21st November 2024 by Turing team using Resnet50 model." + weights_path = "https://object-arbutus.cloud.computecanada.ca/ami-models/moths/classification/turing-singapore_v02_resnet50_2024-11-21-19-58_state.pt" + labels_path = "https://object-arbutus.cloud.computecanada.ca/ami-models/moths/classification/02_singapore_data_category_map.json" + + class TuringUKSpeciesClassifier(SpeciesClassifier, Resnet50Classifier_Turing): name = "Turing UK Species Classifier" description = "Trained on 13th May 2024 by Turing team using Resnet50 model."