diff --git a/trapdata/antenna/benchmark.py b/trapdata/antenna/benchmark.py index 898b9ce..483391e 100644 --- a/trapdata/antenna/benchmark.py +++ b/trapdata/antenna/benchmark.py @@ -54,6 +54,7 @@ def run_benchmark( batch_size: int, gpu_batch_size: int, service_name: str, + send_acks: bool = True, ) -> None: """Run the benchmark with the specified parameters. @@ -132,8 +133,8 @@ def run_benchmark( ack_result = create_empty_result(reply_subject, image_id) ack_results.append(ack_result) - logger.info(f"Sending {len(ack_results)} acknowledgment(s)") - if ack_results: + if ack_results and send_acks: + logger.info(f"Sending {len(ack_results)} acknowledgment(s)") # Send acknowledgments asynchronously result_poster.post_async( base_url=base_url, @@ -157,7 +158,7 @@ def run_benchmark( ) error_results.append(error_result) - if error_results: + if error_results and send_acks: result_poster.post_async( base_url=base_url, auth_token=auth_token, @@ -280,6 +281,11 @@ def main() -> int: default="Performance Test", help="Processing service name", ) + parser.add_argument( + "--skip-acks", + action="store_false", + help="Skip sending acknowledgments for processed images", + ) args = parser.parse_args() @@ -298,6 +304,7 @@ def main() -> int: batch_size=args.batch_size, gpu_batch_size=args.gpu_batch_size, service_name=args.service_name, + send_acks=args.skip_acks, ) return 0 diff --git a/trapdata/antenna/datasets.py b/trapdata/antenna/datasets.py index 2d21dec..5450592 100644 --- a/trapdata/antenna/datasets.py +++ b/trapdata/antenna/datasets.py @@ -36,6 +36,7 @@ localization_batch_size (default 8) How many images the GPU processes at once (detection). Larger = more GPU memory. These are full-resolution images (~4K). + Async worker use antennna_api_batch_size for this. num_workers (default 4) DataLoader subprocesses. Each independently fetches tasks and @@ -254,7 +255,7 @@ def __iter__(self): Each API fetch returns a batch of tasks. Images for the entire batch are downloaded concurrently using threads (see _load_images_threaded), - then yielded one at a time for the DataLoader to collate. + then as a pre-collated batch. Yields: Dictionary containing: @@ -293,7 +294,7 @@ def __iter__(self): # Download all images concurrently image_map = self._load_images_threaded(tasks) - + pre_batch = [] for task in tasks: image_tensor = image_map.get(task.image_id) errors = [] @@ -315,7 +316,11 @@ def __iter__(self): } if errors: row["error"] = "; ".join(errors) if errors else None - yield row + pre_batch.append(row) + batch = rest_collate_fn( + pre_batch + ) # Collate before yielding to GPU process + yield batch logger.debug(f"Worker {worker_id}: Iterator finished") except Exception as e: @@ -360,7 +365,7 @@ def rest_collate_fn(batch: list[dict]) -> dict: # Collate successful items if successful: result = { - "images": [item["image"] for item in successful], + "images": torch.stack([item["image"] for item in successful]), "reply_subjects": [item["reply_subject"] for item in successful], "image_ids": [item["image_id"] for item in successful], "image_urls": [item.get("image_url") for item in successful], @@ -377,6 +382,17 @@ def rest_collate_fn(batch: list[dict]) -> dict: return result +def _no_op_collate_fn(batch: list[dict]) -> dict: + """ + A no-op collate function that unwraps a single-element batch. + + This can be used when the dataset already returns batches in the desired format, + and no further collation is needed. It simply returns the input list of dicts + without modification. + """ + return batch[0] + + def get_rest_dataloader( job_id: int, settings: "Settings", @@ -395,8 +411,7 @@ def get_rest_dataloader( job_id: Job ID to fetch tasks for settings: Settings object. Relevant fields: - antenna_api_base_url / antenna_api_auth_token - - antenna_api_batch_size (tasks per API call) - - localization_batch_size (images per GPU batch) + - antenna_api_batch_size (tasks per API call and GPU batch size) - num_workers (DataLoader subprocesses) - processing_service_name (name of this worker) """ @@ -410,7 +425,47 @@ def get_rest_dataloader( return torch.utils.data.DataLoader( dataset, - batch_size=settings.localization_batch_size, + batch_size=1, # We collate manually in rest_collate_fn, so set batch_size=1 here num_workers=settings.num_workers, - collate_fn=rest_collate_fn, + collate_fn=_no_op_collate_fn, + pin_memory=True, + persistent_workers=settings.num_workers > 0, + prefetch_factor=4 if settings.num_workers > 0 else None, ) + + +class CUDAPrefetcher: + def __init__(self, loader: torch.utils.data.DataLoader, device: torch.device): + self.loader = iter(loader) + self.stream = torch.cuda.Stream() + self.device = device + self.next_batch = None + self._preload() + + def _preload(self): + try: + batch = next(self.loader) + except StopIteration: + self.next_batch = None + return + + with torch.cuda.stream(self.stream): + self.next_batch = { + k: ( + v.to(self.device, non_blocking=True) + if isinstance(v, torch.Tensor) + else v + ) + for k, v in batch.items() + } + + def __iter__(self): + return self + + def __next__(self): + torch.cuda.current_stream().wait_stream(self.stream) + batch = self.next_batch + if batch is None: + raise StopIteration + self._preload() + return batch diff --git a/trapdata/antenna/tests/test_worker.py b/trapdata/antenna/tests/test_worker.py index 1cf920e..7fda2d2 100644 --- a/trapdata/antenna/tests/test_worker.py +++ b/trapdata/antenna/tests/test_worker.py @@ -191,9 +191,11 @@ def test_multiple_batches(self): dataset = self._make_dataset(job_id=4, batch_size=2) rows = list(dataset) - # Should get all 3 images (batch1: 2 images, batch2: 1 image) - assert len(rows) == 3 - assert all(r["image"] is not None for r in rows) + # Dataset now yields pre-collated batches: batch1 (2 images), batch2 (1 image) + assert len(rows) == 2 + total_images = sum(len(r["image_ids"]) for r in rows) + assert total_images == 3 + assert all(r["images"] is not None for r in rows) # --------------------------------------------------------------------------- @@ -272,6 +274,7 @@ def test_empty_queue(self): 100, self._make_settings(), "Test Service", + device=torch.device("cpu"), ) assert result is False @@ -300,6 +303,7 @@ def test_processes_batch_with_real_inference(self): 101, self._make_settings(), "Test Service", + device=torch.device("cpu"), ) # Validate processing succeeded @@ -339,6 +343,7 @@ def test_handles_failed_items(self): 102, self._make_settings(), "Test Service", + device=torch.device("cpu"), ) posted_results = antenna_api_server.get_posted_results(102) @@ -375,6 +380,7 @@ def test_mixed_batch_success_and_failures(self): 103, self._make_settings(), "Test Service", + device=torch.device("cpu"), ) assert result is True @@ -475,7 +481,11 @@ def test_full_workflow_with_real_inference(self): # Step 3: Process job result = _process_job( - pipeline_slug, 200, self._make_settings(), "Test Worker" + pipeline_slug, + 200, + self._make_settings(), + "Test Worker", + device=torch.device("cpu"), ) assert result is True @@ -527,6 +537,7 @@ def test_multiple_batches_processed(self): 201, self._make_settings(), "Test Service", + device=torch.device("cpu"), ) assert result is True diff --git a/trapdata/antenna/worker.py b/trapdata/antenna/worker.py index 68b1c81..343f9d8 100644 --- a/trapdata/antenna/worker.py +++ b/trapdata/antenna/worker.py @@ -6,10 +6,9 @@ import numpy as np import torch import torch.multiprocessing as mp -import torchvision -from trapdata.antenna.client import get_full_service_name, get_jobs, post_batch_results -from trapdata.antenna.datasets import get_rest_dataloader +from trapdata.antenna.client import get_full_service_name, get_jobs +from trapdata.antenna.datasets import CUDAPrefetcher, get_rest_dataloader from trapdata.antenna.result_posting import ResultPoster from trapdata.antenna.schemas import AntennaTaskResult, AntennaTaskResultError from trapdata.api.api import CLASSIFIER_CHOICES, should_filter_detections @@ -77,7 +76,7 @@ def _worker_loop(gpu_id: int, pipelines: list[str]): pipelines: List of pipeline slugs to poll for jobs. """ settings = read_settings() - + device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu") if torch.cuda.is_available() and torch.cuda.device_count() > 0: torch.cuda.set_device(gpu_id) logger.info( @@ -112,6 +111,7 @@ def _worker_loop(gpu_id: int, pipelines: list[str]): job_id=job_id, settings=settings, processing_service_name=full_service_name, + device=device, ) any_jobs = any_jobs or any_work_done except Exception as e: @@ -150,7 +150,6 @@ def _apply_binary_classification( # Process binary classification crops binary_crops = [] binary_valid_indices = [] - to_pil = torchvision.transforms.ToPILImage() binary_transforms = binary_filter.get_transforms() for idx, dresp in enumerate(detector_results): @@ -165,8 +164,7 @@ def _apply_binary_classification( ) continue crop = image_tensor[:, y1:y2, x1:x2] - crop_pil = to_pil(crop) - crop_transformed = binary_transforms(crop_pil) + crop_transformed = binary_transforms(crop) binary_crops.append(crop_transformed) binary_valid_indices.append(idx) @@ -207,6 +205,7 @@ def _process_job( job_id: int, settings: Settings, processing_service_name: str, + device: torch.device, ) -> bool: """Run the worker to process images from the REST API queue. @@ -215,6 +214,7 @@ def _process_job( job_id: Job ID to process settings: Settings object with antenna_api_* configuration processing_service_name: Name of the processing service + device: The device to use for processing Returns: True if any work was done, False otherwise """ @@ -242,8 +242,17 @@ def _process_job( all_detections = [] _, t = log_time() result_poster: ResultPoster | None = None + # Conditionally use CUDA prefetcher; fall back to plain iterator on CPU + if torch.cuda.is_available(): + batch_source = CUDAPrefetcher( + loader, device + ) # __init__ already calls preload() + else: + batch_source = iter(loader) + + _, t_total = log_time() try: - for i, batch in enumerate(loader): + for i, batch in enumerate(batch_source): cls_time = 0.0 det_time = 0.0 load_time, t = t() @@ -300,7 +309,10 @@ def _process_job( # output is dict of "boxes", "labels", "scores" batch_output = [] + to_gpu_time = 0.0 if len(images) > 0: + images = images.to(detector.device) + to_gpu_time, t = t() batch_output = detector.predict_batch(images) items += len(batch_output) @@ -330,13 +342,14 @@ def _process_job( if use_binary_filter: assert binary_filter is not None, "Binary filter not initialized" - detections_for_terminal_classifier, detections_to_return = ( - _apply_binary_classification( - binary_filter, - detector_results, - image_tensors, - image_detections, - ) + ( + detections_for_terminal_classifier, + detections_to_return, + ) = _apply_binary_classification( + binary_filter, + detector_results, + image_tensors, + image_detections, ) else: # No binary filtering, send all detections to terminal classifier @@ -345,7 +358,6 @@ def _process_job( # Run terminal classifier on filtered detections classifier.reset(detections_for_terminal_classifier) - to_pil = torchvision.transforms.ToPILImage() classify_transforms = classifier.get_transforms() # Collect and transform all crops for batched classification @@ -363,8 +375,7 @@ def _process_job( ) continue crop = image_tensor[:, y1:y2, x1:x2] - crop_pil = to_pil(crop) - crop_transformed = classify_transforms(crop_pil) + crop_transformed = classify_transforms(crop) crops.append(crop_transformed) valid_indices.append(idx) @@ -417,9 +428,7 @@ def _process_job( ) ) except Exception as e: - logger.error( - f"Batch {i + 1} failed during processing: {e}", exc_info=True - ) + logger.error(f"Batch {i + 1} failed during processing: {e}") # Report errors back to Antenna so tasks aren't stuck in the queue batch_results = [] for reply_subject, image_id in zip( @@ -456,10 +465,15 @@ def _process_job( batch_results, processing_service_name, ) - _, t = log_time() # reset time to measure batch load time + batch_total, t_total = t_total() logger.info( - f"Finished batch {i + 1}. Total items: {items}, Classification time: {cls_time:.2f}s, Detection time: {det_time:.2f}s, Load time: {load_time:.2f}s" + f"Total: {batch_total/max(len(images), 1):.2f}s/image, Classification time: {cls_time:.2f}s, " + f"Detection time: {det_time:.2f}s, Load time: {load_time:.2f}s, to GPU time: {to_gpu_time:.2f}s, " ) + ( + _, + t, + ) = log_time() # reset before next() call to measure next batch's load time if result_poster: # Wait for all async posts to complete before finishing the job diff --git a/trapdata/ml/models/classification.py b/trapdata/ml/models/classification.py index 21459d1..9ced939 100644 --- a/trapdata/ml/models/classification.py +++ b/trapdata/ml/models/classification.py @@ -10,6 +10,25 @@ from .base import InferenceBaseClass, imagenet_normalization +class MaybeTensor: + """Convert PIL Image to tensor if the input is not already a tensor. + + Allows classification transforms to be used with both PIL images + (e.g. ``ami api`` / database paths) and tensors (e.g. GPU pipeline in + antenna/worker.py) without a redundant GPU->CPU->GPU round-trip. + """ + + _to_tensor = torchvision.transforms.ToTensor() + + def __call__(self, x): + if isinstance(x, torch.Tensor): + return x + return self._to_tensor(x) + + def __repr__(self): + return f"{self.__class__.__name__}()" + + class ClassificationIterableDatabaseDataset(torch.utils.data.IterableDataset): def __init__(self, queue, image_transforms, batch_size=4): super().__init__() @@ -70,7 +89,7 @@ def get_transforms(self): return torchvision.transforms.Compose( [ torchvision.transforms.Resize((self.input_size, self.input_size)), - torchvision.transforms.ToTensor(), + MaybeTensor(), self.normalization, ] ) @@ -150,7 +169,7 @@ def get_transforms(self): [ # self._pad_to_square(), torchvision.transforms.Resize((self.input_size, self.input_size)), - torchvision.transforms.ToTensor(), + MaybeTensor(), torchvision.transforms.Normalize(mean, std), ] ) @@ -189,7 +208,7 @@ def get_transforms(self): return torchvision.transforms.Compose( [ torchvision.transforms.Resize((self.input_size, self.input_size)), - torchvision.transforms.ToTensor(), + MaybeTensor(), torchvision.transforms.Normalize(mean, std), ] ) @@ -237,7 +256,7 @@ def get_transforms(self): return torchvision.transforms.Compose( [ self._pad_to_square(), - torchvision.transforms.ToTensor(), + MaybeTensor(), torchvision.transforms.Resize( (self.input_size, self.input_size), antialias=True # type: ignore ), diff --git a/trapdata/settings.py b/trapdata/settings.py index d37c34d..b07e043 100644 --- a/trapdata/settings.py +++ b/trapdata/settings.py @@ -41,7 +41,7 @@ class Settings(BaseSettings): antenna_api_base_url: str = "http://localhost:8000/api/v2" antenna_api_auth_token: str = "" antenna_service_name: str = "AMI Data Companion" - antenna_api_batch_size: int = 16 + antenna_api_batch_size: int = 24 @pydantic.field_validator("image_base_path", "user_data_path") def validate_path(cls, v):