From a5325ad59e9146d41db98051080039044d6be05a Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Mon, 2 Mar 2026 14:34:21 -0800 Subject: [PATCH 1/8] WIP: GPU utilization fixes --- trapdata/antenna/datasets.py | 61 +++++++++++++++++++++++++--- trapdata/antenna/worker.py | 37 +++++++++++------ trapdata/ml/models/classification.py | 8 ++-- trapdata/settings.py | 4 +- 4 files changed, 87 insertions(+), 23 deletions(-) diff --git a/trapdata/antenna/datasets.py b/trapdata/antenna/datasets.py index 2d21dec0..715ae39f 100644 --- a/trapdata/antenna/datasets.py +++ b/trapdata/antenna/datasets.py @@ -293,7 +293,7 @@ def __iter__(self): # Download all images concurrently image_map = self._load_images_threaded(tasks) - + pre_batch = [] for task in tasks: image_tensor = image_map.get(task.image_id) errors = [] @@ -315,7 +315,11 @@ def __iter__(self): } if errors: row["error"] = "; ".join(errors) if errors else None - yield row + pre_batch.append(row) + batch = rest_collate_fn( + pre_batch + ) # Collate before yielding to GPU process + yield batch logger.debug(f"Worker {worker_id}: Iterator finished") except Exception as e: @@ -360,7 +364,7 @@ def rest_collate_fn(batch: list[dict]) -> dict: # Collate successful items if successful: result = { - "images": [item["image"] for item in successful], + "images": torch.stack([item["image"] for item in successful]), "reply_subjects": [item["reply_subject"] for item in successful], "image_ids": [item["image_id"] for item in successful], "image_urls": [item.get("image_url") for item in successful], @@ -377,6 +381,17 @@ def rest_collate_fn(batch: list[dict]) -> dict: return result +def _no_op_collate_fn(batch: list[dict]) -> dict: + """ + A no-op collate function that returns the batch as-is. + + This can be used when the dataset already returns batches in the desired format, + and no further collation is needed. It simply returns the input list of dicts + without modification. + """ + return batch[0] + + def get_rest_dataloader( job_id: int, settings: "Settings", @@ -410,7 +425,43 @@ def get_rest_dataloader( return torch.utils.data.DataLoader( dataset, - batch_size=settings.localization_batch_size, + # batch_size=settings.localization_batch_size, + batch_size=1, # We collate manually in rest_collate_fn, so set batch_size=1 here num_workers=settings.num_workers, - collate_fn=rest_collate_fn, + collate_fn=_no_op_collate_fn, + pin_memory=True, + persistent_workers=True if settings.num_workers > 0 else False, + prefetch_factor=4, ) + + +class CUDAPrefetcher: + def __init__(self, loader): + self.loader = iter(loader) + self.stream = torch.cuda.Stream() + self.next_batch = None + self.preload() + + def preload(self): + try: + batch = next(self.loader) + except StopIteration: + self.next_batch = None + return + + with torch.cuda.stream(self.stream): + self.next_batch = { + k: v.cuda(non_blocking=True) if isinstance(v, torch.Tensor) else v + for k, v in batch.items() + } + + def __iter__(self): + return self + + def __next__(self): + torch.cuda.current_stream().wait_stream(self.stream) + batch = self.next_batch + if batch is None: + raise StopIteration + self.preload() + return batch diff --git a/trapdata/antenna/worker.py b/trapdata/antenna/worker.py index 68b1c813..6473f1c7 100644 --- a/trapdata/antenna/worker.py +++ b/trapdata/antenna/worker.py @@ -9,7 +9,7 @@ import torchvision from trapdata.antenna.client import get_full_service_name, get_jobs, post_batch_results -from trapdata.antenna.datasets import get_rest_dataloader +from trapdata.antenna.datasets import CUDAPrefetcher, get_rest_dataloader from trapdata.antenna.result_posting import ResultPoster from trapdata.antenna.schemas import AntennaTaskResult, AntennaTaskResultError from trapdata.api.api import CLASSIFIER_CHOICES, should_filter_detections @@ -150,7 +150,7 @@ def _apply_binary_classification( # Process binary classification crops binary_crops = [] binary_valid_indices = [] - to_pil = torchvision.transforms.ToPILImage() + # to_pil = torchvision.transforms.ToPILImage() binary_transforms = binary_filter.get_transforms() for idx, dresp in enumerate(detector_results): @@ -165,8 +165,9 @@ def _apply_binary_classification( ) continue crop = image_tensor[:, y1:y2, x1:x2] - crop_pil = to_pil(crop) - crop_transformed = binary_transforms(crop_pil) + # crop_pil = to_pil(crop) + # crop_transformed = binary_transforms(crop_pil) + crop_transformed = binary_transforms(crop) binary_crops.append(crop_transformed) binary_valid_indices.append(idx) @@ -242,8 +243,13 @@ def _process_job( all_detections = [] _, t = log_time() result_poster: ResultPoster | None = None + prefetcher = CUDAPrefetcher(loader) # if torch.cuda.is_available() else None try: - for i, batch in enumerate(loader): + prefetcher.preload() + i, batch = 0, next(prefetcher) + _, t_total = log_time() # reset total time for this batch + # for i, batch in enumerate(loader): + while batch is not None: cls_time = 0.0 det_time = 0.0 load_time, t = t() @@ -300,7 +306,10 @@ def _process_job( # output is dict of "boxes", "labels", "scores" batch_output = [] + to_gpu_time = 0.0 if len(images) > 0: + images = images.to(detector.device) + to_gpu_time, t = t() batch_output = detector.predict_batch(images) items += len(batch_output) @@ -345,7 +354,7 @@ def _process_job( # Run terminal classifier on filtered detections classifier.reset(detections_for_terminal_classifier) - to_pil = torchvision.transforms.ToPILImage() + # to_pil = torchvision.transforms.ToPILImage() classify_transforms = classifier.get_transforms() # Collect and transform all crops for batched classification @@ -363,8 +372,9 @@ def _process_job( ) continue crop = image_tensor[:, y1:y2, x1:x2] - crop_pil = to_pil(crop) - crop_transformed = classify_transforms(crop_pil) + # crop_pil = to_pil(crop) + # crop_transformed = classify_transforms(crop_pil) + crop_transformed = classify_transforms(crop) crops.append(crop_transformed) valid_indices.append(idx) @@ -417,9 +427,7 @@ def _process_job( ) ) except Exception as e: - logger.error( - f"Batch {i + 1} failed during processing: {e}", exc_info=True - ) + logger.error(f"Batch {i + 1} failed during processing: {e}") # Report errors back to Antenna so tasks aren't stuck in the queue batch_results = [] for reply_subject, image_id in zip( @@ -457,9 +465,12 @@ def _process_job( processing_service_name, ) _, t = log_time() # reset time to measure batch load time + batch_total, t_total = t_total() logger.info( - f"Finished batch {i + 1}. Total items: {items}, Classification time: {cls_time:.2f}s, Detection time: {det_time:.2f}s, Load time: {load_time:.2f}s" + f"Total: {batch_total/(len(images)):.2f}s/image, Classification time: {cls_time:.2f}s, Detection time: {det_time:.2f}s, Load time: {load_time:.2f}s, to GPU time: {to_gpu_time:.2f}s, " ) + batch = next(prefetcher) + i += 1 if result_poster: # Wait for all async posts to complete before finishing the job @@ -479,6 +490,8 @@ def _process_job( f"max queue size: {post_metrics.max_queue_size})" ) return did_work + except StopIteration: + pass finally: if result_poster: result_poster.shutdown() diff --git a/trapdata/ml/models/classification.py b/trapdata/ml/models/classification.py index 21459d1e..3aa96cb3 100644 --- a/trapdata/ml/models/classification.py +++ b/trapdata/ml/models/classification.py @@ -70,7 +70,7 @@ def get_transforms(self): return torchvision.transforms.Compose( [ torchvision.transforms.Resize((self.input_size, self.input_size)), - torchvision.transforms.ToTensor(), + # torchvision.transforms.ToTensor(), self.normalization, ] ) @@ -150,7 +150,7 @@ def get_transforms(self): [ # self._pad_to_square(), torchvision.transforms.Resize((self.input_size, self.input_size)), - torchvision.transforms.ToTensor(), + # torchvision.transforms.ToTensor(), torchvision.transforms.Normalize(mean, std), ] ) @@ -189,7 +189,7 @@ def get_transforms(self): return torchvision.transforms.Compose( [ torchvision.transforms.Resize((self.input_size, self.input_size)), - torchvision.transforms.ToTensor(), + # torchvision.transforms.ToTensor(), torchvision.transforms.Normalize(mean, std), ] ) @@ -237,7 +237,7 @@ def get_transforms(self): return torchvision.transforms.Compose( [ self._pad_to_square(), - torchvision.transforms.ToTensor(), + # torchvision.transforms.ToTensor(), torchvision.transforms.Resize( (self.input_size, self.input_size), antialias=True # type: ignore ), diff --git a/trapdata/settings.py b/trapdata/settings.py index d37c34d3..2c17c19b 100644 --- a/trapdata/settings.py +++ b/trapdata/settings.py @@ -33,7 +33,7 @@ class Settings(BaseSettings): default=ml.models.DEFAULT_FEATURE_EXTRACTOR ) classification_threshold: float = 0.6 - localization_batch_size: int = 8 + localization_batch_size: int = 32 classification_batch_size: int = 20 num_workers: int = 4 @@ -41,7 +41,7 @@ class Settings(BaseSettings): antenna_api_base_url: str = "http://localhost:8000/api/v2" antenna_api_auth_token: str = "" antenna_service_name: str = "AMI Data Companion" - antenna_api_batch_size: int = 16 + antenna_api_batch_size: int = 24 @pydantic.field_validator("image_base_path", "user_data_path") def validate_path(cls, v): From cf86915303baf5d3545411cdcb660605ef4df956 Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Wed, 4 Mar 2026 11:17:36 -0800 Subject: [PATCH 2/8] Add skip-args support --- trapdata/antenna/benchmark.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/trapdata/antenna/benchmark.py b/trapdata/antenna/benchmark.py index 898b9ce6..483391ee 100644 --- a/trapdata/antenna/benchmark.py +++ b/trapdata/antenna/benchmark.py @@ -54,6 +54,7 @@ def run_benchmark( batch_size: int, gpu_batch_size: int, service_name: str, + send_acks: bool = True, ) -> None: """Run the benchmark with the specified parameters. @@ -132,8 +133,8 @@ def run_benchmark( ack_result = create_empty_result(reply_subject, image_id) ack_results.append(ack_result) - logger.info(f"Sending {len(ack_results)} acknowledgment(s)") - if ack_results: + if ack_results and send_acks: + logger.info(f"Sending {len(ack_results)} acknowledgment(s)") # Send acknowledgments asynchronously result_poster.post_async( base_url=base_url, @@ -157,7 +158,7 @@ def run_benchmark( ) error_results.append(error_result) - if error_results: + if error_results and send_acks: result_poster.post_async( base_url=base_url, auth_token=auth_token, @@ -280,6 +281,11 @@ def main() -> int: default="Performance Test", help="Processing service name", ) + parser.add_argument( + "--skip-acks", + action="store_false", + help="Skip sending acknowledgments for processed images", + ) args = parser.parse_args() @@ -298,6 +304,7 @@ def main() -> int: batch_size=args.batch_size, gpu_batch_size=args.gpu_batch_size, service_name=args.service_name, + send_acks=args.skip_acks, ) return 0 From cbd5b672dfa2ce50fc4cd87242a6cd1c6241ac43 Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Fri, 6 Mar 2026 14:32:36 -0800 Subject: [PATCH 3/8] Implemenrt MaybeTensor --- trapdata/ml/models/classification.py | 27 +++++++++++++++++++++++---- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/trapdata/ml/models/classification.py b/trapdata/ml/models/classification.py index 3aa96cb3..9ced9395 100644 --- a/trapdata/ml/models/classification.py +++ b/trapdata/ml/models/classification.py @@ -10,6 +10,25 @@ from .base import InferenceBaseClass, imagenet_normalization +class MaybeTensor: + """Convert PIL Image to tensor if the input is not already a tensor. + + Allows classification transforms to be used with both PIL images + (e.g. ``ami api`` / database paths) and tensors (e.g. GPU pipeline in + antenna/worker.py) without a redundant GPU->CPU->GPU round-trip. + """ + + _to_tensor = torchvision.transforms.ToTensor() + + def __call__(self, x): + if isinstance(x, torch.Tensor): + return x + return self._to_tensor(x) + + def __repr__(self): + return f"{self.__class__.__name__}()" + + class ClassificationIterableDatabaseDataset(torch.utils.data.IterableDataset): def __init__(self, queue, image_transforms, batch_size=4): super().__init__() @@ -70,7 +89,7 @@ def get_transforms(self): return torchvision.transforms.Compose( [ torchvision.transforms.Resize((self.input_size, self.input_size)), - # torchvision.transforms.ToTensor(), + MaybeTensor(), self.normalization, ] ) @@ -150,7 +169,7 @@ def get_transforms(self): [ # self._pad_to_square(), torchvision.transforms.Resize((self.input_size, self.input_size)), - # torchvision.transforms.ToTensor(), + MaybeTensor(), torchvision.transforms.Normalize(mean, std), ] ) @@ -189,7 +208,7 @@ def get_transforms(self): return torchvision.transforms.Compose( [ torchvision.transforms.Resize((self.input_size, self.input_size)), - # torchvision.transforms.ToTensor(), + MaybeTensor(), torchvision.transforms.Normalize(mean, std), ] ) @@ -237,7 +256,7 @@ def get_transforms(self): return torchvision.transforms.Compose( [ self._pad_to_square(), - # torchvision.transforms.ToTensor(), + MaybeTensor(), torchvision.transforms.Resize( (self.input_size, self.input_size), antialias=True # type: ignore ), From 8e8b3cdbd6a71bd32cb32388ca820c0b677e736b Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Fri, 6 Mar 2026 14:37:32 -0800 Subject: [PATCH 4/8] Update comments --- trapdata/antenna/datasets.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/trapdata/antenna/datasets.py b/trapdata/antenna/datasets.py index 715ae39f..f2084b34 100644 --- a/trapdata/antenna/datasets.py +++ b/trapdata/antenna/datasets.py @@ -254,7 +254,7 @@ def __iter__(self): Each API fetch returns a batch of tasks. Images for the entire batch are downloaded concurrently using threads (see _load_images_threaded), - then yielded one at a time for the DataLoader to collate. + then as a pre-collated batch. Yields: Dictionary containing: @@ -383,7 +383,7 @@ def rest_collate_fn(batch: list[dict]) -> dict: def _no_op_collate_fn(batch: list[dict]) -> dict: """ - A no-op collate function that returns the batch as-is. + A no-op collate function that unwraps a single-element batch. This can be used when the dataset already returns batches in the desired format, and no further collation is needed. It simply returns the input list of dicts From 57dd5bcfa268873209ba954daaf9176f29260a83 Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Fri, 6 Mar 2026 14:50:25 -0800 Subject: [PATCH 5/8] Make this work on CPU too --- trapdata/antenna/datasets.py | 10 ++++---- trapdata/antenna/worker.py | 50 +++++++++++++++++------------------- 2 files changed, 28 insertions(+), 32 deletions(-) diff --git a/trapdata/antenna/datasets.py b/trapdata/antenna/datasets.py index f2084b34..94827e3f 100644 --- a/trapdata/antenna/datasets.py +++ b/trapdata/antenna/datasets.py @@ -430,8 +430,8 @@ def get_rest_dataloader( num_workers=settings.num_workers, collate_fn=_no_op_collate_fn, pin_memory=True, - persistent_workers=True if settings.num_workers > 0 else False, - prefetch_factor=4, + persistent_workers=settings.num_workers > 0, + **({"prefetch_factor": 4} if settings.num_workers > 0 else {}), ) @@ -440,9 +440,9 @@ def __init__(self, loader): self.loader = iter(loader) self.stream = torch.cuda.Stream() self.next_batch = None - self.preload() + self._preload() - def preload(self): + def _preload(self): try: batch = next(self.loader) except StopIteration: @@ -463,5 +463,5 @@ def __next__(self): batch = self.next_batch if batch is None: raise StopIteration - self.preload() + self._preload() return batch diff --git a/trapdata/antenna/worker.py b/trapdata/antenna/worker.py index 6473f1c7..d275007a 100644 --- a/trapdata/antenna/worker.py +++ b/trapdata/antenna/worker.py @@ -6,9 +6,8 @@ import numpy as np import torch import torch.multiprocessing as mp -import torchvision -from trapdata.antenna.client import get_full_service_name, get_jobs, post_batch_results +from trapdata.antenna.client import get_full_service_name, get_jobs from trapdata.antenna.datasets import CUDAPrefetcher, get_rest_dataloader from trapdata.antenna.result_posting import ResultPoster from trapdata.antenna.schemas import AntennaTaskResult, AntennaTaskResultError @@ -150,7 +149,6 @@ def _apply_binary_classification( # Process binary classification crops binary_crops = [] binary_valid_indices = [] - # to_pil = torchvision.transforms.ToPILImage() binary_transforms = binary_filter.get_transforms() for idx, dresp in enumerate(detector_results): @@ -165,8 +163,6 @@ def _apply_binary_classification( ) continue crop = image_tensor[:, y1:y2, x1:x2] - # crop_pil = to_pil(crop) - # crop_transformed = binary_transforms(crop_pil) crop_transformed = binary_transforms(crop) binary_crops.append(crop_transformed) binary_valid_indices.append(idx) @@ -243,13 +239,15 @@ def _process_job( all_detections = [] _, t = log_time() result_poster: ResultPoster | None = None - prefetcher = CUDAPrefetcher(loader) # if torch.cuda.is_available() else None + # Conditionally use CUDA prefetcher; fall back to plain iterator on CPU + if torch.cuda.is_available(): + batch_source = CUDAPrefetcher(loader) # __init__ already calls preload() + else: + batch_source = iter(loader) + + _, t_total = log_time() try: - prefetcher.preload() - i, batch = 0, next(prefetcher) - _, t_total = log_time() # reset total time for this batch - # for i, batch in enumerate(loader): - while batch is not None: + for i, batch in enumerate(batch_source): cls_time = 0.0 det_time = 0.0 load_time, t = t() @@ -339,13 +337,14 @@ def _process_job( if use_binary_filter: assert binary_filter is not None, "Binary filter not initialized" - detections_for_terminal_classifier, detections_to_return = ( - _apply_binary_classification( - binary_filter, - detector_results, - image_tensors, - image_detections, - ) + ( + detections_for_terminal_classifier, + detections_to_return, + ) = _apply_binary_classification( + binary_filter, + detector_results, + image_tensors, + image_detections, ) else: # No binary filtering, send all detections to terminal classifier @@ -354,7 +353,6 @@ def _process_job( # Run terminal classifier on filtered detections classifier.reset(detections_for_terminal_classifier) - # to_pil = torchvision.transforms.ToPILImage() classify_transforms = classifier.get_transforms() # Collect and transform all crops for batched classification @@ -372,8 +370,6 @@ def _process_job( ) continue crop = image_tensor[:, y1:y2, x1:x2] - # crop_pil = to_pil(crop) - # crop_transformed = classify_transforms(crop_pil) crop_transformed = classify_transforms(crop) crops.append(crop_transformed) valid_indices.append(idx) @@ -464,13 +460,15 @@ def _process_job( batch_results, processing_service_name, ) - _, t = log_time() # reset time to measure batch load time batch_total, t_total = t_total() logger.info( - f"Total: {batch_total/(len(images)):.2f}s/image, Classification time: {cls_time:.2f}s, Detection time: {det_time:.2f}s, Load time: {load_time:.2f}s, to GPU time: {to_gpu_time:.2f}s, " + f"Total: {batch_total/max(len(images), 1):.2f}s/image, Classification time: {cls_time:.2f}s, " + f"Detection time: {det_time:.2f}s, Load time: {load_time:.2f}s, to GPU time: {to_gpu_time:.2f}s, " ) - batch = next(prefetcher) - i += 1 + ( + _, + t, + ) = log_time() # reset before next() call to measure next batch's load time if result_poster: # Wait for all async posts to complete before finishing the job @@ -490,8 +488,6 @@ def _process_job( f"max queue size: {post_metrics.max_queue_size})" ) return did_work - except StopIteration: - pass finally: if result_poster: result_poster.shutdown() From 7f4b283989082164943d709bed1803d6d183a1a1 Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Fri, 6 Mar 2026 14:55:12 -0800 Subject: [PATCH 6/8] Use explicit device --- trapdata/antenna/datasets.py | 9 +++++++-- trapdata/antenna/worker.py | 9 +++++++-- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/trapdata/antenna/datasets.py b/trapdata/antenna/datasets.py index 94827e3f..5bce1d0a 100644 --- a/trapdata/antenna/datasets.py +++ b/trapdata/antenna/datasets.py @@ -436,9 +436,10 @@ def get_rest_dataloader( class CUDAPrefetcher: - def __init__(self, loader): + def __init__(self, loader: torch.utils.data.DataLoader, device: torch.device): self.loader = iter(loader) self.stream = torch.cuda.Stream() + self.device = device self.next_batch = None self._preload() @@ -451,7 +452,11 @@ def _preload(self): with torch.cuda.stream(self.stream): self.next_batch = { - k: v.cuda(non_blocking=True) if isinstance(v, torch.Tensor) else v + k: ( + v.to(self.device, non_blocking=True) + if isinstance(v, torch.Tensor) + else v + ) for k, v in batch.items() } diff --git a/trapdata/antenna/worker.py b/trapdata/antenna/worker.py index d275007a..343f9d88 100644 --- a/trapdata/antenna/worker.py +++ b/trapdata/antenna/worker.py @@ -76,7 +76,7 @@ def _worker_loop(gpu_id: int, pipelines: list[str]): pipelines: List of pipeline slugs to poll for jobs. """ settings = read_settings() - + device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu") if torch.cuda.is_available() and torch.cuda.device_count() > 0: torch.cuda.set_device(gpu_id) logger.info( @@ -111,6 +111,7 @@ def _worker_loop(gpu_id: int, pipelines: list[str]): job_id=job_id, settings=settings, processing_service_name=full_service_name, + device=device, ) any_jobs = any_jobs or any_work_done except Exception as e: @@ -204,6 +205,7 @@ def _process_job( job_id: int, settings: Settings, processing_service_name: str, + device: torch.device, ) -> bool: """Run the worker to process images from the REST API queue. @@ -212,6 +214,7 @@ def _process_job( job_id: Job ID to process settings: Settings object with antenna_api_* configuration processing_service_name: Name of the processing service + device: The device to use for processing Returns: True if any work was done, False otherwise """ @@ -241,7 +244,9 @@ def _process_job( result_poster: ResultPoster | None = None # Conditionally use CUDA prefetcher; fall back to plain iterator on CPU if torch.cuda.is_available(): - batch_source = CUDAPrefetcher(loader) # __init__ already calls preload() + batch_source = CUDAPrefetcher( + loader, device + ) # __init__ already calls preload() else: batch_source = iter(loader) From ad99763fd29e04dfc4742d772b3e7923f961eaaf Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Fri, 6 Mar 2026 15:39:57 -0800 Subject: [PATCH 7/8] Update worker test --- trapdata/antenna/tests/test_worker.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/trapdata/antenna/tests/test_worker.py b/trapdata/antenna/tests/test_worker.py index 1cf920ef..7fda2d2b 100644 --- a/trapdata/antenna/tests/test_worker.py +++ b/trapdata/antenna/tests/test_worker.py @@ -191,9 +191,11 @@ def test_multiple_batches(self): dataset = self._make_dataset(job_id=4, batch_size=2) rows = list(dataset) - # Should get all 3 images (batch1: 2 images, batch2: 1 image) - assert len(rows) == 3 - assert all(r["image"] is not None for r in rows) + # Dataset now yields pre-collated batches: batch1 (2 images), batch2 (1 image) + assert len(rows) == 2 + total_images = sum(len(r["image_ids"]) for r in rows) + assert total_images == 3 + assert all(r["images"] is not None for r in rows) # --------------------------------------------------------------------------- @@ -272,6 +274,7 @@ def test_empty_queue(self): 100, self._make_settings(), "Test Service", + device=torch.device("cpu"), ) assert result is False @@ -300,6 +303,7 @@ def test_processes_batch_with_real_inference(self): 101, self._make_settings(), "Test Service", + device=torch.device("cpu"), ) # Validate processing succeeded @@ -339,6 +343,7 @@ def test_handles_failed_items(self): 102, self._make_settings(), "Test Service", + device=torch.device("cpu"), ) posted_results = antenna_api_server.get_posted_results(102) @@ -375,6 +380,7 @@ def test_mixed_batch_success_and_failures(self): 103, self._make_settings(), "Test Service", + device=torch.device("cpu"), ) assert result is True @@ -475,7 +481,11 @@ def test_full_workflow_with_real_inference(self): # Step 3: Process job result = _process_job( - pipeline_slug, 200, self._make_settings(), "Test Worker" + pipeline_slug, + 200, + self._make_settings(), + "Test Worker", + device=torch.device("cpu"), ) assert result is True @@ -527,6 +537,7 @@ def test_multiple_batches_processed(self): 201, self._make_settings(), "Test Service", + device=torch.device("cpu"), ) assert result is True From e972ddbd2f1dda2fd5442cea99811c0e2eca294b Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Tue, 10 Mar 2026 10:13:10 -0700 Subject: [PATCH 8/8] Cleanup --- trapdata/antenna/datasets.py | 7 +++---- trapdata/settings.py | 2 +- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/trapdata/antenna/datasets.py b/trapdata/antenna/datasets.py index 5bce1d0a..5450592a 100644 --- a/trapdata/antenna/datasets.py +++ b/trapdata/antenna/datasets.py @@ -36,6 +36,7 @@ localization_batch_size (default 8) How many images the GPU processes at once (detection). Larger = more GPU memory. These are full-resolution images (~4K). + Async worker use antennna_api_batch_size for this. num_workers (default 4) DataLoader subprocesses. Each independently fetches tasks and @@ -410,8 +411,7 @@ def get_rest_dataloader( job_id: Job ID to fetch tasks for settings: Settings object. Relevant fields: - antenna_api_base_url / antenna_api_auth_token - - antenna_api_batch_size (tasks per API call) - - localization_batch_size (images per GPU batch) + - antenna_api_batch_size (tasks per API call and GPU batch size) - num_workers (DataLoader subprocesses) - processing_service_name (name of this worker) """ @@ -425,13 +425,12 @@ def get_rest_dataloader( return torch.utils.data.DataLoader( dataset, - # batch_size=settings.localization_batch_size, batch_size=1, # We collate manually in rest_collate_fn, so set batch_size=1 here num_workers=settings.num_workers, collate_fn=_no_op_collate_fn, pin_memory=True, persistent_workers=settings.num_workers > 0, - **({"prefetch_factor": 4} if settings.num_workers > 0 else {}), + prefetch_factor=4 if settings.num_workers > 0 else None, ) diff --git a/trapdata/settings.py b/trapdata/settings.py index 2c17c19b..b07e0439 100644 --- a/trapdata/settings.py +++ b/trapdata/settings.py @@ -33,7 +33,7 @@ class Settings(BaseSettings): default=ml.models.DEFAULT_FEATURE_EXTRACTOR ) classification_threshold: float = 0.6 - localization_batch_size: int = 32 + localization_batch_size: int = 8 classification_batch_size: int = 20 num_workers: int = 4