From 91717de1003af42d4f3dfb5422633c25ef92b382 Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Fri, 20 Feb 2026 10:52:07 -0800 Subject: [PATCH 1/7] PSv2: Async result posting, benchmarking --- trapdata/antenna/benchmark.py | 303 +++++++++++++++++++++++++++++ trapdata/antenna/client.py | 2 +- trapdata/antenna/datasets.py | 4 +- trapdata/antenna/result_posting.py | 239 +++++++++++++++++++++++ trapdata/antenna/worker.py | 52 +++-- trapdata/settings.py | 2 +- 6 files changed, 578 insertions(+), 24 deletions(-) create mode 100644 trapdata/antenna/benchmark.py create mode 100644 trapdata/antenna/result_posting.py diff --git a/trapdata/antenna/benchmark.py b/trapdata/antenna/benchmark.py new file mode 100644 index 0000000..51bb16c --- /dev/null +++ b/trapdata/antenna/benchmark.py @@ -0,0 +1,303 @@ +"""Benchmarking utilities for Antenna API data loading and result posting. + +This module provides a command-line benchmark tool for testing the performance +of the Antenna API data loading pipeline with asynchronous result posting. +The benchmark fetches batches from the API, processes acknowledgments, and +provides detailed performance metrics. + +Usage: + python -m trapdata.antenna.benchmark --job-id 123 --base-url http://localhost:8000/api/v2 + +Key metrics tracked: +- Images per second (total and successful) +- Batch processing rate +- Acknowledgment posting rate +- Result posting success/failure rates +- Queue utilization metrics +""" + +import argparse +import os +import time + +from trapdata.antenna.datasets import get_rest_dataloader +from trapdata.antenna.result_posting import ResultPoster +from trapdata.antenna.schemas import AntennaTaskResult, AntennaTaskResultError +from trapdata.common.logs import logger +from trapdata.common.utils import log_time +from trapdata.settings import Settings + + +def create_empty_result(reply_subject: str, image_id: str) -> AntennaTaskResult: + """Create an empty/acknowledgment result for a task. + + Args: + reply_subject: Subject for the reply + image_id: ID of the image being acknowledged + + Returns: + AntennaTaskResult with error acknowledgment + """ + result = AntennaTaskResultError( + error=f"Acknowledgment for image {image_id}", + image_id=image_id, + ) + return AntennaTaskResult(reply_subject=reply_subject, result=result) + + +def run_benchmark( + job_id: int, + base_url: str, + auth_token: str, + num_workers: int, + batch_size: int, + gpu_batch_size: int, + service_name: str, +) -> None: + """Run the benchmark with the specified parameters. + + Args: + job_id: Job ID to process + base_url: Antenna API base URL + auth_token: API authentication token + num_workers: Number of DataLoader workers + batch_size: Batch size for API requests + gpu_batch_size: GPU batch size for DataLoader + service_name: Processing service name + """ + # Create settings object + settings = Settings() + settings.antenna_api_base_url = base_url + settings.antenna_api_auth_token = auth_token + settings.antenna_api_batch_size = batch_size + settings.localization_batch_size = gpu_batch_size + settings.num_workers = num_workers + + print(f"Starting performance test for job {job_id}") + print(f"Configuration:") + print(f" Base URL: {base_url}") + print(f" API batch size: {batch_size}") + print(f" GPU batch size: {gpu_batch_size}") + print(f" Num workers: {num_workers}") + print(f" Service name: {service_name}") + print() + + # Create dataloader + dataloader = get_rest_dataloader( + job_id=job_id, + settings=settings, + processing_service_name=service_name, + ) + + # Initialize ResultPoster for sending acknowledgments + result_poster = ResultPoster(max_pending=10) + + # Performance metrics + total_batches = 0 + total_images = 0 + total_successful_images = 0 + total_failed_images = 0 + total_acks_sent = 0 + start_time = time.time() + last_report_time = start_time + report_interval = 10 # Report every 10 seconds + + print("Starting data consumption with acknowledgments...") + try: + _, t = log_time() + for batch_idx, batch in enumerate(dataloader): + _, t = t( + f"Fetched batch {batch_idx} with {len(batch['reply_subjects'])} items" + ) + current_time = time.time() + total_batches += 1 + + # Count images in this batch + batch_size = len(batch["reply_subjects"]) + batch_failed = len(batch["failed_items"]) + batch_successful = batch_size - batch_failed + + total_images += batch_size + total_successful_images += batch_successful + total_failed_images += batch_failed + + # Send acknowledgments for successful items + if batch_successful > 0: + ack_results = [] + for i, (reply_subject, image_id) in enumerate( + zip(batch["reply_subjects"], batch["image_ids"]) + ): + if i < batch_successful: # Only for successful items + ack_result = create_empty_result(reply_subject, image_id) + ack_results.append(ack_result) + + logger.info(f"Sending {len(ack_results)} acknowledgment(s)") + if ack_results: + # Send acknowledgments asynchronously + result_poster.post_async( + base_url=base_url, + auth_token=auth_token, + job_id=job_id, + results=ack_results, + processing_service_name=service_name, + ) + total_acks_sent += len(ack_results) + + # Send error results for failed items + if batch_failed > 0: + error_results = [] + for failed_item in batch["failed_items"]: + error_result = AntennaTaskResult( + reply_subject=failed_item["reply_subject"], + result=AntennaTaskResultError( + error=failed_item.get("error", "Image loading failed"), + image_id=failed_item["image_id"], + ), + ) + error_results.append(error_result) + + if error_results: + result_poster.post_async( + base_url=base_url, + auth_token=auth_token, + job_id=job_id, + results=error_results, + processing_service_name=service_name, + ) + total_acks_sent += len(error_results) + + # Report progress periodically + if current_time - last_report_time >= report_interval: + elapsed = current_time - start_time + images_per_sec = total_images / elapsed if elapsed > 0 else 0 + successful_per_sec = ( + total_successful_images / elapsed if elapsed > 0 else 0 + ) + acks_per_sec = total_acks_sent / elapsed if elapsed > 0 else 0 + + # Get ResultPoster metrics + post_metrics = result_poster.get_metrics() + + print( + f"Progress: {total_batches} batches, {total_images} images " + f"({total_successful_images} success, {total_failed_images} failed) " + f"- {images_per_sec:.1f} img/s, {successful_per_sec:.1f} success/s, " + f"{acks_per_sec:.1f} acks/s" + ) + print( + f" Posts: {post_metrics.successful_posts} success, " + f"{post_metrics.failed_posts} failed, " + f"{post_metrics.success_rate:.1f}% success rate" + ) + last_report_time = current_time + _, t = log_time() + + except KeyboardInterrupt: + print("\nStopped by user") + except Exception as e: + print(f"\nError occurred: {e}") + logger.error(f"DataLoader benchmark error: {e}") + finally: + # Wait for all pending result posts to complete + print("Waiting for pending result posts to complete...") + result_poster.wait_for_all_posts(timeout=30) + result_poster.shutdown() + + # Final statistics + end_time = time.time() + total_elapsed = end_time - start_time + final_post_metrics = result_poster.get_metrics() + + print("\n" + "=" * 70) + print("PERFORMANCE SUMMARY") + print("=" * 70) + print(f"Total time: {total_elapsed:.2f} seconds") + print(f"Total batches: {total_batches}") + print(f"Total images: {total_images}") + print(f"Successful images: {total_successful_images}") + print(f"Failed images: {total_failed_images}") + print(f"Acknowledgments sent: {total_acks_sent}") + + if total_elapsed > 0: + images_per_sec = total_images / total_elapsed + successful_per_sec = total_successful_images / total_elapsed + batches_per_sec = total_batches / total_elapsed + acks_per_sec = total_acks_sent / total_elapsed + + print(f"\nThroughput:") + print(f" {images_per_sec:.2f} images/second (total)") + print(f" {successful_per_sec:.2f} images/second (successful)") + print(f" {batches_per_sec:.2f} batches/second") + print(f" {acks_per_sec:.2f} acknowledgments/second") + + if total_images > 0: + success_rate = (total_successful_images / total_images) * 100 + print(f"\nSuccess rate: {success_rate:.1f}%") + + print(f"\nResult Posting Metrics:") + print(f" Total posts: {final_post_metrics.total_posts}") + print(f" Successful posts: {final_post_metrics.successful_posts}") + print(f" Failed posts: {final_post_metrics.failed_posts}") + print(f" Post success rate: {final_post_metrics.success_rate:.1f}%") + if final_post_metrics.total_posts > 0: + avg_post_time = ( + final_post_metrics.total_post_time / final_post_metrics.total_posts + ) + print(f" Average post time: {avg_post_time:.3f} seconds") + print(f" Max queue size: {final_post_metrics.max_queue_size}") + + print("=" * 70) + print("Performance benchmark completed") + print("=" * 70) + + +def main(): + """Main entry point for the benchmark CLI.""" + # Parse command line arguments + parser = argparse.ArgumentParser( + description="Benchmark dataloader performance with acknowledgements" + ) + parser.add_argument("--job-id", type=int, required=True, help="Job ID to process") + parser.add_argument( + "--base-url", + type=str, + default="http://localhost:8000/api/v2", + help="Antenna API base URL", + ) + parser.add_argument( + "--num-workers", type=int, default=2, help="Number of DataLoader workers" + ) + parser.add_argument( + "--batch-size", type=int, default=16, help="Batch size for API requests" + ) + parser.add_argument( + "--gpu-batch-size", type=int, default=16, help="GPU batch size for DataLoader" + ) + parser.add_argument( + "--service-name", + type=str, + default="Performance Test", + help="Processing service name", + ) + + args = parser.parse_args() + + # Get auth token from environment + auth_token = os.getenv("AMI_ANTENNA_API_AUTH_TOKEN", "") + if not auth_token: + print("Warning: AMI_ANTENNA_API_AUTH_TOKEN environment variable not set") + + # Run the benchmark + run_benchmark( + job_id=args.job_id, + base_url=args.base_url, + auth_token=auth_token, + num_workers=args.num_workers, + batch_size=args.batch_size, + gpu_batch_size=args.gpu_batch_size, + service_name=args.service_name, + ) + + +if __name__ == "__main__": + main() diff --git a/trapdata/antenna/client.py b/trapdata/antenna/client.py index bc97367..8681e21 100644 --- a/trapdata/antenna/client.py +++ b/trapdata/antenna/client.py @@ -98,7 +98,7 @@ def post_batch_results( params = {"processing_service_name": processing_service_name} response = session.post(url, json=payload, params=params, timeout=60) response.raise_for_status() - logger.info(f"Successfully posted {len(results)} results to {url}") + logger.debug(f"Successfully posted {len(results)} results to {url}") return True except requests.RequestException as e: logger.error(f"Failed to post results to {url}: {e}") diff --git a/trapdata/antenna/datasets.py b/trapdata/antenna/datasets.py index bb041f5..9dbdcd2 100644 --- a/trapdata/antenna/datasets.py +++ b/trapdata/antenna/datasets.py @@ -286,7 +286,7 @@ def __iter__(self): if not tasks: # Queue is empty - job complete - logger.info( + logger.debug( f"Worker {worker_id}: No more tasks for job {self.job_id}" ) break @@ -317,7 +317,7 @@ def __iter__(self): row["error"] = "; ".join(errors) if errors else None yield row - logger.info(f"Worker {worker_id}: Iterator finished") + logger.debug(f"Worker {worker_id}: Iterator finished") except Exception as e: logger.error(f"Worker {worker_id}: Exception in iterator: {e}") raise diff --git a/trapdata/antenna/result_posting.py b/trapdata/antenna/result_posting.py new file mode 100644 index 0000000..5909033 --- /dev/null +++ b/trapdata/antenna/result_posting.py @@ -0,0 +1,239 @@ +"""Asynchronous result posting utilities for Antenna API. + +This module provides utilities for posting batch results to the Antenna API with +backpressure control and comprehensive metrics tracking. The main class, ResultPoster, +manages asynchronous posting to improve worker throughput by overlapping network I/O +with compute operations. + +Key features: +- Asynchronous posting using ThreadPoolExecutor +- Configurable backpressure control to prevent unbounded memory usage +- Comprehensive metrics tracking (success/failure rates, timing, queue size) +- Graceful shutdown with timeout handling +- Thread-safe operations + +Usage: + poster = ResultPoster(max_pending=5) + poster.post_async(base_url, auth_token, job_id, results, service_name) + metrics = poster.get_metrics() + poster.shutdown() +""" + +import time +from concurrent.futures import Future, ThreadPoolExecutor +from dataclasses import dataclass +from typing import Optional + +from trapdata.antenna.client import post_batch_results +from trapdata.common.logs import logger + + +@dataclass +class ResultPostMetrics: + """Metrics for tracking result posting performance.""" + + total_posts: int = 0 + successful_posts: int = 0 + failed_posts: int = 0 + total_post_time: float = 0.0 + max_queue_size: int = 0 + + @property + def success_rate(self) -> float: + """Calculate success rate as a percentage.""" + return ( + (self.successful_posts / self.total_posts * 100) + if self.total_posts > 0 + else 0.0 + ) + + +class ResultPoster: + """Manages asynchronous posting of batch results with backpressure control. + + This class provides asynchronous result posting to improve throughput by allowing + the worker to continue processing while previous results are posted in background + threads. It includes backpressure control to prevent unbounded memory usage. + + Args: + max_pending: Maximum number of concurrent posts before blocking (default: 5) + + Example: + poster = ResultPoster(max_pending=10) + poster.post_async(base_url, auth_token, job_id, results, service_name) + metrics = poster.get_metrics() + poster.shutdown() + """ + + def __init__(self, max_pending: int = 5): + self.max_pending = max_pending + self.executor = ThreadPoolExecutor( + max_workers=2, thread_name_prefix="result_poster" + ) + self.pending_futures: list[Future] = [] + self.metrics = ResultPostMetrics() + + def post_async( + self, + base_url: str, + auth_token: str, + job_id: int, + results: list, + processing_service_name: str, + ) -> None: + """Post results asynchronously with backpressure control. + + This method will block if there are too many pending posts to prevent + unbounded memory usage and provide backpressure. + + Args: + base_url: Antenna API base URL + auth_token: API authentication token + job_id: Job ID for the results + results: List of result objects to post + processing_service_name: Name of the processing service + """ + # Clean up completed futures and update metrics + self._cleanup_completed_futures() + + # Apply backpressure: wait for pending posts to complete if we're at the limit + while len(self.pending_futures) >= self.max_pending: + logger.debug( + f"At max pending posts ({self.max_pending}), waiting for completion..." + ) + # Wait for at least one future to complete + if self.pending_futures: + # Wait for the oldest pending post to complete + completed_future = self.pending_futures[0] + try: + completed_future.result(timeout=30) # 30 second timeout + except Exception as e: + logger.warning(f"Pending result post failed: {e}") + finally: + self._cleanup_completed_futures() + + # Update queue size metric + current_queue_size = len(self.pending_futures) + self.metrics.max_queue_size = max( + self.metrics.max_queue_size, current_queue_size + ) + + # Submit new post + start_time = time.time() + future = self.executor.submit( + self._post_with_timing, + base_url, + auth_token, + job_id, + results, + processing_service_name, + start_time, + ) + self.pending_futures.append(future) + self.metrics.total_posts += 1 + + logger.debug( + f"Submitted result post for job {job_id}, {current_queue_size + 1} pending" + ) + + def _post_with_timing( + self, + base_url: str, + auth_token: str, + job_id: int, + results: list, + processing_service_name: str, + start_time: float, + ) -> bool: + """Internal method that times the post operation and updates metrics. + + Args: + base_url: Antenna API base URL + auth_token: API authentication token + job_id: Job ID for the results + results: List of result objects to post + processing_service_name: Name of the processing service + start_time: Timestamp when the post was initiated + + Returns: + True if successful, False otherwise + """ + try: + success = post_batch_results( + base_url, auth_token, job_id, results, processing_service_name + ) + elapsed_time = time.time() - start_time + + # Update metrics (thread-safe since we're updating simple counters) + self.metrics.total_post_time += elapsed_time + if success: + self.metrics.successful_posts += 1 + else: + self.metrics.failed_posts += 1 + logger.warning( + f"Result post failed for job {job_id} after {elapsed_time:.2f}s" + ) + + return success + except Exception as e: + elapsed_time = time.time() - start_time + self.metrics.total_post_time += elapsed_time + self.metrics.failed_posts += 1 + logger.error(f"Exception during result post for job {job_id}: {e}") + return False + + def _cleanup_completed_futures(self) -> None: + """Remove completed futures from the pending list.""" + self.pending_futures = [f for f in self.pending_futures if not f.done()] + + def wait_for_all_posts(self, timeout: Optional[float] = None) -> None: + """Wait for all pending posts to complete before shutting down. + + Args: + timeout: Maximum time to wait for all posts to complete (seconds) + """ + if not self.pending_futures: + return + + logger.info( + f"Waiting for {len(self.pending_futures)} pending result posts to complete..." + ) + start_time = time.time() + + for future in self.pending_futures: + remaining_timeout = None + if timeout is not None: + elapsed = time.time() - start_time + remaining_timeout = max(0, timeout - elapsed) + if remaining_timeout == 0: + logger.warning( + "Timeout waiting for pending posts, some may be lost" + ) + break + + try: + future.result(timeout=remaining_timeout) + except Exception as e: + logger.warning(f"Pending result post failed during shutdown: {e}") + + self._cleanup_completed_futures() + + def get_metrics(self) -> ResultPostMetrics: + """Get current metrics. + + Returns: + Current ResultPostMetrics object with performance data + """ + self._cleanup_completed_futures() + return self.metrics + + def shutdown(self, wait: bool = True, timeout: Optional[float] = 30) -> None: + """Shutdown the executor and optionally wait for pending posts. + + Args: + wait: Whether to wait for pending posts to complete + timeout: Maximum time to wait for pending posts (seconds) + """ + if wait: + self.wait_for_all_posts(timeout=timeout) + self.executor.shutdown(wait=wait) diff --git a/trapdata/antenna/worker.py b/trapdata/antenna/worker.py index 5e95f2e..baf1c2a 100644 --- a/trapdata/antenna/worker.py +++ b/trapdata/antenna/worker.py @@ -10,6 +10,7 @@ from trapdata.antenna.client import get_full_service_name, get_jobs, post_batch_results from trapdata.antenna.datasets import get_rest_dataloader +from trapdata.antenna.result_posting import ResultPoster from trapdata.antenna.schemas import AntennaTaskResult, AntennaTaskResultError from trapdata.api.api import CLASSIFIER_CHOICES from trapdata.api.models.localization import APIMothDetector @@ -22,6 +23,7 @@ from trapdata.common.utils import log_time from trapdata.settings import Settings, read_settings +MAX_PENDING_POSTS = 5 # Maximum number of concurrent result posts before blocking SLEEP_TIME_SECONDS = 5 @@ -156,14 +158,17 @@ def _process_job( total_detection_time = 0.0 total_classification_time = 0.0 - total_save_time = 0.0 total_dl_time = 0.0 all_detections = [] _, t = log_time() + result_poster = ResultPoster(max_pending=MAX_PENDING_POSTS) + for i, batch in enumerate(loader): - dt, t = t("Finished loading batch") - total_dl_time += dt + cls_time = 0.0 + det_time = 0.0 + load_time, t = t() + total_dl_time += load_time if not batch: logger.warning(f"Batch {i + 1} is empty, skipping") continue @@ -185,7 +190,6 @@ def _process_job( image_urls = batch.get("image_urls", [None] * len(images)) batch_results: list[AntennaTaskResult] = [] - try: # Validate all arrays have same length before zipping if len(image_ids) != len(images): @@ -203,14 +207,12 @@ def _process_job( # Track start time for this batch batch_start_time = datetime.datetime.now() - logger.info(f"Processing worker batch {i + 1} ({len(images)} images)") # output is dict of "boxes", "labels", "scores" batch_output = [] if len(images) > 0: batch_output = detector.predict_batch(images) items += len(batch_output) - logger.info(f"Total items processed so far: {items}") batch_output = list(detector.post_process_batch(batch_output)) # Convert image_ids to list if needed @@ -223,8 +225,8 @@ def _process_job( batch_output=batch_output, seconds_per_item=0, ) - dt, t = t("Finished detection") - total_detection_time += dt + det_time, t = t() + total_detection_time += det_time # Group detections by image_id image_detections: dict[str, list[DetectionResponse]] = { @@ -272,8 +274,8 @@ def _process_job( image_detections[dresp.source_image_id].append(detection) all_detections.append(detection) - ct, t = t("Finished classification") - total_classification_time += ct + cls_time, t = t() + total_classification_time += cls_time # Calculate batch processing time batch_end_time = datetime.datetime.now() @@ -330,25 +332,35 @@ def _process_job( ) ) - success = post_batch_results( + # Post results asynchronously (non-blocking) + result_poster.post_async( settings.antenna_api_base_url, settings.antenna_api_auth_token, job_id, batch_results, processing_service_name, ) - st, t = t("Finished posting results") + logger.info( + f"Finished batch {i + 1}. Total items: {items}, Classification time: {cls_time:.2f}s, Detection time: {det_time:.2f}s, Load time: {load_time:.2f}s" + ) - if not success: - logger.error( - f"Failed to post {len(batch_results)} results for job {job_id} to " - f"{settings.antenna_api_base_url}. Batch processing data lost." - ) + # Wait for all async posts to complete before finishing the job + logger.info("Waiting for all pending result posts to complete...") + result_poster.wait_for_all_posts(timeout=60) # 60 second timeout for cleanup + + # Get final metrics + post_metrics = result_poster.get_metrics() - total_save_time += st + # Clean up the result poster + result_poster.shutdown(wait=False) # Already waited above logger.info( - f"Done, detections: {len(all_detections)}. Detecting time: {total_detection_time}, " - f"classification time: {total_classification_time}, dl time: {total_dl_time}, save time: {total_save_time}" + f"Done, detections: {len(all_detections)}. Detecting time: {total_detection_time:.2f}s, " + f"classification time: {total_classification_time:.2f}s, dl time: {total_dl_time:.2f}s, " + f"result posts: {post_metrics.total_posts} " + f"(success: {post_metrics.successful_posts}, failed: {post_metrics.failed_posts}, " + f"success rate: {post_metrics.success_rate:.1f}%, avg post time: " + f"{post_metrics.total_post_time / post_metrics.total_posts if post_metrics.total_posts > 0 else 0:.2f}s, " + f"max queue size: {post_metrics.max_queue_size})" ) return did_work diff --git a/trapdata/settings.py b/trapdata/settings.py index 54324c3..d37c34d 100644 --- a/trapdata/settings.py +++ b/trapdata/settings.py @@ -35,7 +35,7 @@ class Settings(BaseSettings): classification_threshold: float = 0.6 localization_batch_size: int = 8 classification_batch_size: int = 20 - num_workers: int = 2 + num_workers: int = 4 # Antenna API worker settings antenna_api_base_url: str = "http://localhost:8000/api/v2" From ee1a8000b3ac4fa0c5ffbc5fd3cc4ec5407ea8c4 Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Fri, 20 Feb 2026 11:41:27 -0800 Subject: [PATCH 2/7] CR feedback --- trapdata/antenna/benchmark.py | 12 +++--- trapdata/antenna/datasets.py | 2 +- trapdata/antenna/result_posting.py | 69 ++++++++++++++++++++---------- trapdata/antenna/worker.py | 3 +- 4 files changed, 55 insertions(+), 31 deletions(-) diff --git a/trapdata/antenna/benchmark.py b/trapdata/antenna/benchmark.py index 51bb16c..a8df98e 100644 --- a/trapdata/antenna/benchmark.py +++ b/trapdata/antenna/benchmark.py @@ -74,7 +74,7 @@ def run_benchmark( settings.num_workers = num_workers print(f"Starting performance test for job {job_id}") - print(f"Configuration:") + print("Configuration:") print(f" Base URL: {base_url}") print(f" API batch size: {batch_size}") print(f" GPU batch size: {gpu_batch_size}") @@ -113,9 +113,9 @@ def run_benchmark( total_batches += 1 # Count images in this batch - batch_size = len(batch["reply_subjects"]) batch_failed = len(batch["failed_items"]) - batch_successful = batch_size - batch_failed + # Successful items are those with reply_subjects that are not in failed_items + batch_successful = len(batch["reply_subjects"]) total_images += batch_size total_successful_images += batch_successful @@ -200,7 +200,7 @@ def run_benchmark( finally: # Wait for all pending result posts to complete print("Waiting for pending result posts to complete...") - result_poster.wait_for_all_posts(timeout=30) + result_poster.wait_for_all_posts() result_poster.shutdown() # Final statistics @@ -224,7 +224,7 @@ def run_benchmark( batches_per_sec = total_batches / total_elapsed acks_per_sec = total_acks_sent / total_elapsed - print(f"\nThroughput:") + print("\nThroughput:") print(f" {images_per_sec:.2f} images/second (total)") print(f" {successful_per_sec:.2f} images/second (successful)") print(f" {batches_per_sec:.2f} batches/second") @@ -234,7 +234,7 @@ def run_benchmark( success_rate = (total_successful_images / total_images) * 100 print(f"\nSuccess rate: {success_rate:.1f}%") - print(f"\nResult Posting Metrics:") + print("\nResult Posting Metrics:") print(f" Total posts: {final_post_metrics.total_posts}") print(f" Successful posts: {final_post_metrics.successful_posts}") print(f" Failed posts: {final_post_metrics.failed_posts}") diff --git a/trapdata/antenna/datasets.py b/trapdata/antenna/datasets.py index 9dbdcd2..2d21dec 100644 --- a/trapdata/antenna/datasets.py +++ b/trapdata/antenna/datasets.py @@ -37,7 +37,7 @@ How many images the GPU processes at once (detection). Larger = more GPU memory. These are full-resolution images (~4K). - num_workers (default 2) + num_workers (default 4) DataLoader subprocesses. Each independently fetches tasks and downloads images. More workers = more images prefetched for the GPU, at the cost of CPU/RAM. With 0 workers, fetching and diff --git a/trapdata/antenna/result_posting.py b/trapdata/antenna/result_posting.py index 5909033..c215053 100644 --- a/trapdata/antenna/result_posting.py +++ b/trapdata/antenna/result_posting.py @@ -19,6 +19,7 @@ poster.shutdown() """ +import threading import time from concurrent.futures import Future, ThreadPoolExecutor from dataclasses import dataclass @@ -72,6 +73,7 @@ def __init__(self, max_pending: int = 5): ) self.pending_futures: list[Future] = [] self.metrics = ResultPostMetrics() + self._metrics_lock = threading.Lock() def post_async( self, @@ -164,21 +166,22 @@ def _post_with_timing( ) elapsed_time = time.time() - start_time - # Update metrics (thread-safe since we're updating simple counters) - self.metrics.total_post_time += elapsed_time - if success: - self.metrics.successful_posts += 1 - else: - self.metrics.failed_posts += 1 - logger.warning( - f"Result post failed for job {job_id} after {elapsed_time:.2f}s" - ) + with self._metrics_lock: + self.metrics.total_post_time += elapsed_time + if success: + self.metrics.successful_posts += 1 + else: + self.metrics.failed_posts += 1 + logger.warning( + f"Result post failed for job {job_id} after {elapsed_time:.2f}s" + ) return success except Exception as e: elapsed_time = time.time() - start_time - self.metrics.total_post_time += elapsed_time - self.metrics.failed_posts += 1 + with self._metrics_lock: + self.metrics.total_post_time += elapsed_time + self.metrics.failed_posts += 1 logger.error(f"Exception during result post for job {job_id}: {e}") return False @@ -186,30 +189,36 @@ def _cleanup_completed_futures(self) -> None: """Remove completed futures from the pending list.""" self.pending_futures = [f for f in self.pending_futures if not f.done()] - def wait_for_all_posts(self, timeout: Optional[float] = None) -> None: + def wait_for_all_posts( + self, + min_timeout: float = 60, + per_post_timeout: float = 30, + ) -> None: """Wait for all pending posts to complete before shutting down. Args: - timeout: Maximum time to wait for all posts to complete (seconds) + timeout: Maximum time to wait for all posts to complete (seconds). + If None, will be computed as max(min_timeout, pending_count * per_post_timeout) + min_timeout: Minimum timeout regardless of pending count (default: 60) + per_post_timeout: Additional timeout per pending post (default: 30) """ if not self.pending_futures: return + pending_count = len(self.pending_futures) + timeout = max(min_timeout, pending_count * per_post_timeout) logger.info( - f"Waiting for {len(self.pending_futures)} pending result posts to complete..." + f"Waiting for {pending_count} pending result posts with dynamic timeout {timeout}s" ) start_time = time.time() for future in self.pending_futures: remaining_timeout = None - if timeout is not None: - elapsed = time.time() - start_time - remaining_timeout = max(0, timeout - elapsed) - if remaining_timeout == 0: - logger.warning( - "Timeout waiting for pending posts, some may be lost" - ) - break + elapsed = time.time() - start_time + remaining_timeout = max(0, timeout - elapsed) + if remaining_timeout == 0: + logger.warning("Timeout waiting for pending posts, some may be lost") + break try: future.result(timeout=remaining_timeout) @@ -218,6 +227,20 @@ def wait_for_all_posts(self, timeout: Optional[float] = None) -> None: self._cleanup_completed_futures() + # Check if any posts were abandoned and emit error + posts_after_wait = len(self.pending_futures) + if posts_after_wait > 0: + metrics = self.metrics + logger.error( + f"Failed to complete all result posts before timeout. " + f"{posts_after_wait} posts were abandoned. " + f"Post metrics - Total: {metrics.total_posts}, " + f"Successful: {metrics.successful_posts}, " + f"Failed: {metrics.failed_posts}, " + f"Success rate: {metrics.success_rate:.1f}%, " + f"Max queue size: {metrics.max_queue_size}" + ) + def get_metrics(self) -> ResultPostMetrics: """Get current metrics. @@ -235,5 +258,5 @@ def shutdown(self, wait: bool = True, timeout: Optional[float] = 30) -> None: timeout: Maximum time to wait for pending posts (seconds) """ if wait: - self.wait_for_all_posts(timeout=timeout) + self.wait_for_all_posts(min_timeout=timeout) self.executor.shutdown(wait=wait) diff --git a/trapdata/antenna/worker.py b/trapdata/antenna/worker.py index baf1c2a..d79e2eb 100644 --- a/trapdata/antenna/worker.py +++ b/trapdata/antenna/worker.py @@ -346,7 +346,8 @@ def _process_job( # Wait for all async posts to complete before finishing the job logger.info("Waiting for all pending result posts to complete...") - result_poster.wait_for_all_posts(timeout=60) # 60 second timeout for cleanup + + result_poster.wait_for_all_posts(min_timeout=60, per_post_timeout=30) # Get final metrics post_metrics = result_poster.get_metrics() From eb490aeb742b96def381e92a9f0646a015105b60 Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Fri, 20 Feb 2026 12:20:13 -0800 Subject: [PATCH 3/7] Improve backpressure handling --- trapdata/antenna/benchmark.py | 9 +++++--- trapdata/antenna/result_posting.py | 37 +++++++++++++++++++----------- 2 files changed, 29 insertions(+), 17 deletions(-) diff --git a/trapdata/antenna/benchmark.py b/trapdata/antenna/benchmark.py index a8df98e..898b9ce 100644 --- a/trapdata/antenna/benchmark.py +++ b/trapdata/antenna/benchmark.py @@ -18,6 +18,7 @@ import argparse import os +import sys import time from trapdata.antenna.datasets import get_rest_dataloader @@ -251,7 +252,7 @@ def run_benchmark( print("=" * 70) -def main(): +def main() -> int: """Main entry point for the benchmark CLI.""" # Parse command line arguments parser = argparse.ArgumentParser( @@ -285,7 +286,8 @@ def main(): # Get auth token from environment auth_token = os.getenv("AMI_ANTENNA_API_AUTH_TOKEN", "") if not auth_token: - print("Warning: AMI_ANTENNA_API_AUTH_TOKEN environment variable not set") + print("ERROR: AMI_ANTENNA_API_AUTH_TOKEN environment variable not set") + return 1 # Run the benchmark run_benchmark( @@ -297,7 +299,8 @@ def main(): gpu_batch_size=args.gpu_batch_size, service_name=args.service_name, ) + return 0 if __name__ == "__main__": - main() + sys.exit(main()) diff --git a/trapdata/antenna/result_posting.py b/trapdata/antenna/result_posting.py index c215053..1c7247e 100644 --- a/trapdata/antenna/result_posting.py +++ b/trapdata/antenna/result_posting.py @@ -21,7 +21,7 @@ import threading import time -from concurrent.futures import Future, ThreadPoolExecutor +from concurrent.futures import FIRST_COMPLETED, Future, ThreadPoolExecutor, wait from dataclasses import dataclass from typing import Optional @@ -66,8 +66,13 @@ class ResultPoster: poster.shutdown() """ - def __init__(self, max_pending: int = 5): + def __init__( + self, + max_pending: int = 5, + future_timeout: float = 30.0, + ): self.max_pending = max_pending + self.future_timeout = future_timeout # Timeout for individual future waits self.executor = ThreadPoolExecutor( max_workers=2, thread_name_prefix="result_poster" ) @@ -100,19 +105,23 @@ def post_async( # Apply backpressure: wait for pending posts to complete if we're at the limit while len(self.pending_futures) >= self.max_pending: - logger.debug( - f"At max pending posts ({self.max_pending}), waiting for completion..." - ) # Wait for at least one future to complete - if self.pending_futures: - # Wait for the oldest pending post to complete - completed_future = self.pending_futures[0] - try: - completed_future.result(timeout=30) # 30 second timeout - except Exception as e: - logger.warning(f"Pending result post failed: {e}") - finally: - self._cleanup_completed_futures() + done_futures, _ = wait( + self.pending_futures, + return_when=FIRST_COMPLETED, + timeout=self.future_timeout, + ) + for completed_future in done_futures: + self.pending_futures.remove(completed_future) + if not done_futures: + # Force cleanup by cancelling all pending futures and clearing the list + for future in self.pending_futures: + future.cancel() + self.pending_futures.clear() + logger.warning( + "Timeout waiting for pending result posts. " + "Cleared pending futures to prevent blocking indefinitely." + ) # Update queue size metric current_queue_size = len(self.pending_futures) From 0b3061b23e79c5a09d40a8adca3efc84daa284ea Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Fri, 20 Feb 2026 13:15:59 -0800 Subject: [PATCH 4/7] cleanup --- trapdata/antenna/result_posting.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/trapdata/antenna/result_posting.py b/trapdata/antenna/result_posting.py index 1c7247e..b689352 100644 --- a/trapdata/antenna/result_posting.py +++ b/trapdata/antenna/result_posting.py @@ -206,8 +206,6 @@ def wait_for_all_posts( """Wait for all pending posts to complete before shutting down. Args: - timeout: Maximum time to wait for all posts to complete (seconds). - If None, will be computed as max(min_timeout, pending_count * per_post_timeout) min_timeout: Minimum timeout regardless of pending count (default: 60) per_post_timeout: Additional timeout per pending post (default: 30) """ From 0583b14edf78a12ab5b18acd2ee039c40c78cd0d Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Wed, 25 Feb 2026 15:12:41 -0800 Subject: [PATCH 5/7] Simplify shutdown --- trapdata/antenna/result_posting.py | 16 +++++++--------- trapdata/antenna/worker.py | 4 +--- 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/trapdata/antenna/result_posting.py b/trapdata/antenna/result_posting.py index b689352..4d048ef 100644 --- a/trapdata/antenna/result_posting.py +++ b/trapdata/antenna/result_posting.py @@ -257,13 +257,11 @@ def get_metrics(self) -> ResultPostMetrics: self._cleanup_completed_futures() return self.metrics - def shutdown(self, wait: bool = True, timeout: Optional[float] = 30) -> None: - """Shutdown the executor and optionally wait for pending posts. - - Args: - wait: Whether to wait for pending posts to complete - timeout: Maximum time to wait for pending posts (seconds) - """ - if wait: - self.wait_for_all_posts(min_timeout=timeout) + def shutdown(self) -> None: + """Shutdown the executor""" + if self.pending_futures: + raise RuntimeError( + f"Cannot shutdown ResultPoster with {len(self.pending_futures)} pending posts. " + "Call wait_for_all_posts() before shutdown to ensure all posts are completed." + ) self.executor.shutdown(wait=wait) diff --git a/trapdata/antenna/worker.py b/trapdata/antenna/worker.py index a50ae61..f018e82 100644 --- a/trapdata/antenna/worker.py +++ b/trapdata/antenna/worker.py @@ -454,9 +454,7 @@ def _process_job( # Get final metrics post_metrics = result_poster.get_metrics() - - # Clean up the result poster - result_poster.shutdown(wait=False) # Already waited above + result_poster.shutdown() logger.info( f"Done, detections: {len(all_detections)}. Detecting time: {total_detection_time:.2f}s, " From 51e5f5c6e2f219a6e93c905e3b20db7f3eaa4353 Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Wed, 25 Feb 2026 17:04:13 -0800 Subject: [PATCH 6/7] Fix missed issue --- trapdata/antenna/result_posting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trapdata/antenna/result_posting.py b/trapdata/antenna/result_posting.py index 4d048ef..cd3ce3b 100644 --- a/trapdata/antenna/result_posting.py +++ b/trapdata/antenna/result_posting.py @@ -264,4 +264,4 @@ def shutdown(self) -> None: f"Cannot shutdown ResultPoster with {len(self.pending_futures)} pending posts. " "Call wait_for_all_posts() before shutdown to ensure all posts are completed." ) - self.executor.shutdown(wait=wait) + self.executor.shutdown(wait=False) From f13c77c21e27ed011fbde75ea2515eaefa080b20 Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Wed, 25 Feb 2026 17:45:35 -0800 Subject: [PATCH 7/7] CR feedback --- trapdata/antenna/result_posting.py | 7 +- trapdata/antenna/worker.py | 428 +++++++++++++++-------------- 2 files changed, 227 insertions(+), 208 deletions(-) diff --git a/trapdata/antenna/result_posting.py b/trapdata/antenna/result_posting.py index cd3ce3b..16207ff 100644 --- a/trapdata/antenna/result_posting.py +++ b/trapdata/antenna/result_posting.py @@ -114,7 +114,8 @@ def post_async( for completed_future in done_futures: self.pending_futures.remove(completed_future) if not done_futures: - # Force cleanup by cancelling all pending futures and clearing the list + # Force cleanup by cancelling all pending futures and clearing the list. + # This means that potentially losing some posts, but the server tasks pipeline has built-in retries. for future in self.pending_futures: future.cancel() self.pending_futures.clear() @@ -124,7 +125,9 @@ def post_async( ) # Update queue size metric - current_queue_size = len(self.pending_futures) + current_queue_size = ( + len(self.pending_futures) + 1 + ) # +1 for the post we're about to submit self.metrics.max_queue_size = max( self.metrics.max_queue_size, current_queue_size ) diff --git a/trapdata/antenna/worker.py b/trapdata/antenna/worker.py index f018e82..68b1c81 100644 --- a/trapdata/antenna/worker.py +++ b/trapdata/antenna/worker.py @@ -231,12 +231,6 @@ def _process_job( classifier_class = CLASSIFIER_CHOICES[pipeline] use_binary_filter = should_filter_detections(classifier_class) binary_filter = None - if use_binary_filter: - binary_filter = MothClassifierBinary( - source_images=[], - detections=[], - terminal=False, - ) if torch.cuda.is_available(): torch.cuda.empty_cache() @@ -247,222 +241,244 @@ def _process_job( total_dl_time = 0.0 all_detections = [] _, t = log_time() + result_poster: ResultPoster | None = None + try: + for i, batch in enumerate(loader): + cls_time = 0.0 + det_time = 0.0 + load_time, t = t() + total_dl_time += load_time + if not batch: + logger.warning(f"Batch {i + 1} is empty, skipping") + continue + + # Defer instantiation of poster, detector and classifiers until we have data + if not classifier: + classifier = classifier_class(source_images=[], detections=[]) + detector = APIMothDetector([]) + result_poster = ResultPoster(max_pending=MAX_PENDING_POSTS) + + if use_binary_filter: + binary_filter = MothClassifierBinary( + source_images=[], + detections=[], + terminal=False, + ) - result_poster = ResultPoster(max_pending=MAX_PENDING_POSTS) - - for i, batch in enumerate(loader): - cls_time = 0.0 - det_time = 0.0 - load_time, t = t() - total_dl_time += load_time - if not batch: - logger.warning(f"Batch {i + 1} is empty, skipping") - continue - - # Defer instantiation of detector and classifier until we have data - if not classifier: - classifier = classifier_class(source_images=[], detections=[]) - detector = APIMothDetector([]) - assert detector is not None, "Detector not initialized" - assert classifier is not None, "Classifier not initialized" - detector.reset([]) - did_work = True - - # Extract data from dictionary batch - images = batch.get("images", []) - image_ids = batch.get("image_ids", []) - reply_subjects = batch.get("reply_subjects", [None] * len(images)) - image_urls = batch.get("image_urls", [None] * len(images)) - - batch_results: list[AntennaTaskResult] = [] - try: - # Validate all arrays have same length before zipping - if len(image_ids) != len(images): - raise ValueError( - f"Length mismatch: image_ids ({len(image_ids)}) != images ({len(images)})" - ) - if len(image_ids) != len(reply_subjects) or len(image_ids) != len( - image_urls - ): - raise ValueError( - f"Length mismatch: image_ids ({len(image_ids)}), " - f"reply_subjects ({len(reply_subjects)}), image_urls ({len(image_urls)})" - ) + assert detector is not None, "Detector not initialized" + assert classifier is not None, "Classifier not initialized" + assert result_poster is not None, "ResultPoster not initialized" + assert not ( + use_binary_filter and binary_filter is None + ), "Binary filter not initialized" + detector.reset([]) + did_work = True + + # Extract data from dictionary batch + images = batch.get("images", []) + image_ids = batch.get("image_ids", []) + reply_subjects = batch.get("reply_subjects", [None] * len(images)) + image_urls = batch.get("image_urls", [None] * len(images)) + + batch_results: list[AntennaTaskResult] = [] + try: + # Validate all arrays have same length before zipping + if len(image_ids) != len(images): + raise ValueError( + f"Length mismatch: image_ids ({len(image_ids)}) != images ({len(images)})" + ) + if len(image_ids) != len(reply_subjects) or len(image_ids) != len( + image_urls + ): + raise ValueError( + f"Length mismatch: image_ids ({len(image_ids)}), " + f"reply_subjects ({len(reply_subjects)}), image_urls ({len(image_urls)})" + ) - # Track start time for this batch - batch_start_time = datetime.datetime.now() + # Track start time for this batch + batch_start_time = datetime.datetime.now() - # output is dict of "boxes", "labels", "scores" - batch_output = [] - if len(images) > 0: - batch_output = detector.predict_batch(images) + # output is dict of "boxes", "labels", "scores" + batch_output = [] + if len(images) > 0: + batch_output = detector.predict_batch(images) - items += len(batch_output) - batch_output = list(detector.post_process_batch(batch_output)) + items += len(batch_output) + batch_output = list(detector.post_process_batch(batch_output)) - # Convert image_ids to list if needed - if isinstance(image_ids, (np.ndarray, torch.Tensor)): - image_ids = image_ids.tolist() + # Convert image_ids to list if needed + if isinstance(image_ids, (np.ndarray, torch.Tensor)): + image_ids = image_ids.tolist() - # TODO CGJS: Add seconds per item calculation for both detector and classifier - detector.save_results( - item_ids=image_ids, - batch_output=batch_output, - seconds_per_item=0, - ) - det_time, t = t() - total_detection_time += det_time - - # Group detections by image_id - image_detections: dict[str, list[DetectionResponse]] = { - img_id: [] for img_id in image_ids - } - image_tensors = dict(zip(image_ids, images, strict=True)) - - # Apply binary classification filter if needed - detector_results = detector.results - - if use_binary_filter: - assert binary_filter is not None, "Binary filter not initialized" - detections_for_terminal_classifier, detections_to_return = ( - _apply_binary_classification( - binary_filter, detector_results, image_tensors, image_detections - ) + # TODO CGJS: Add seconds per item calculation for both detector and classifier + detector.save_results( + item_ids=image_ids, + batch_output=batch_output, + seconds_per_item=0, ) - else: - # No binary filtering, send all detections to terminal classifier - detections_for_terminal_classifier = detector_results - detections_to_return = [] - - # Run terminal classifier on filtered detections - classifier.reset(detections_for_terminal_classifier) - to_pil = torchvision.transforms.ToPILImage() - classify_transforms = classifier.get_transforms() - - # Collect and transform all crops for batched classification - crops = [] - valid_indices = [] - for idx, dresp in enumerate(detections_for_terminal_classifier): - image_tensor = image_tensors[dresp.source_image_id] - bbox = dresp.bbox - y1, y2 = int(bbox.y1), int(bbox.y2) - x1, x2 = int(bbox.x1), int(bbox.x2) - if y1 >= y2 or x1 >= x2: - logger.warning( - f"Skipping detection {idx} with invalid bbox: " - f"({x1},{y1})->({x2},{y2})" + det_time, t = t() + total_detection_time += det_time + + # Group detections by image_id + image_detections: dict[str, list[DetectionResponse]] = { + img_id: [] for img_id in image_ids + } + image_tensors = dict(zip(image_ids, images, strict=True)) + + # Apply binary classification filter if needed + detector_results = detector.results + + if use_binary_filter: + assert binary_filter is not None, "Binary filter not initialized" + detections_for_terminal_classifier, detections_to_return = ( + _apply_binary_classification( + binary_filter, + detector_results, + image_tensors, + image_detections, + ) ) - continue - crop = image_tensor[:, y1:y2, x1:x2] - crop_pil = to_pil(crop) - crop_transformed = classify_transforms(crop_pil) - crops.append(crop_transformed) - valid_indices.append(idx) - - if crops: - batched_crops = torch.stack(crops) - classifier_out = classifier.predict_batch(batched_crops) - classifier_out = classifier.post_process_batch(classifier_out) - - for crop_i, idx in enumerate(valid_indices): - dresp = detections_for_terminal_classifier[idx] - detection = classifier.update_detection_classification( - seconds_per_item=0, - image_id=dresp.source_image_id, - detection_idx=idx, - predictions=classifier_out[crop_i], + else: + # No binary filtering, send all detections to terminal classifier + detections_for_terminal_classifier = detector_results + detections_to_return = [] + + # Run terminal classifier on filtered detections + classifier.reset(detections_for_terminal_classifier) + to_pil = torchvision.transforms.ToPILImage() + classify_transforms = classifier.get_transforms() + + # Collect and transform all crops for batched classification + crops = [] + valid_indices = [] + for idx, dresp in enumerate(detections_for_terminal_classifier): + image_tensor = image_tensors[dresp.source_image_id] + bbox = dresp.bbox + y1, y2 = int(bbox.y1), int(bbox.y2) + x1, x2 = int(bbox.x1), int(bbox.x2) + if y1 >= y2 or x1 >= x2: + logger.warning( + f"Skipping detection {idx} with invalid bbox: " + f"({x1},{y1})->({x2},{y2})" + ) + continue + crop = image_tensor[:, y1:y2, x1:x2] + crop_pil = to_pil(crop) + crop_transformed = classify_transforms(crop_pil) + crops.append(crop_transformed) + valid_indices.append(idx) + + if crops: + batched_crops = torch.stack(crops) + classifier_out = classifier.predict_batch(batched_crops) + classifier_out = classifier.post_process_batch(classifier_out) + + for crop_i, idx in enumerate(valid_indices): + dresp = detections_for_terminal_classifier[idx] + detection = classifier.update_detection_classification( + seconds_per_item=0, + image_id=dresp.source_image_id, + detection_idx=idx, + predictions=classifier_out[crop_i], + ) + image_detections[dresp.source_image_id].append(detection) + all_detections.append(detection) + + cls_time, t = t() + total_classification_time += cls_time + # Add non-moth detections to all_detections + all_detections.extend(detections_to_return) + + # Calculate batch processing time + batch_end_time = datetime.datetime.now() + batch_elapsed = (batch_end_time - batch_start_time).total_seconds() + + # Post results back to the API with PipelineResponse for each image + batch_results.clear() + for reply_subject, image_id, image_url in zip( + reply_subjects, image_ids, image_urls, strict=True + ): + # Create SourceImageResponse for this image + source_image = SourceImageResponse(id=image_id, url=image_url) + + # Create PipelineResultsResponse + pipeline_response = PipelineResultsResponse( + pipeline=pipeline, + source_images=[source_image], + detections=image_detections[image_id], + total_time=batch_elapsed + / len(image_ids), # Approximate time per image ) - image_detections[dresp.source_image_id].append(detection) - all_detections.append(detection) - - cls_time, t = t() - total_classification_time += cls_time - # Add non-moth detections to all_detections - all_detections.extend(detections_to_return) - - # Calculate batch processing time - batch_end_time = datetime.datetime.now() - batch_elapsed = (batch_end_time - batch_start_time).total_seconds() - - # Post results back to the API with PipelineResponse for each image - batch_results.clear() - for reply_subject, image_id, image_url in zip( - reply_subjects, image_ids, image_urls, strict=True - ): - # Create SourceImageResponse for this image - source_image = SourceImageResponse(id=image_id, url=image_url) - - # Create PipelineResultsResponse - pipeline_response = PipelineResultsResponse( - pipeline=pipeline, - source_images=[source_image], - detections=image_detections[image_id], - total_time=batch_elapsed - / len(image_ids), # Approximate time per image - ) - batch_results.append( - AntennaTaskResult( - reply_subject=reply_subject, - result=pipeline_response, + batch_results.append( + AntennaTaskResult( + reply_subject=reply_subject, + result=pipeline_response, + ) ) + except Exception as e: + logger.error( + f"Batch {i + 1} failed during processing: {e}", exc_info=True ) - except Exception as e: - logger.error(f"Batch {i + 1} failed during processing: {e}", exc_info=True) - # Report errors back to Antenna so tasks aren't stuck in the queue - batch_results = [] - for reply_subject, image_id in zip(reply_subjects, image_ids, strict=True): - batch_results.append( - AntennaTaskResult( - reply_subject=reply_subject, - result=AntennaTaskResultError( - error=f"Batch processing error: {e}", - image_id=image_id, - ), + # Report errors back to Antenna so tasks aren't stuck in the queue + batch_results = [] + for reply_subject, image_id in zip( + reply_subjects, image_ids, strict=True + ): + batch_results.append( + AntennaTaskResult( + reply_subject=reply_subject, + result=AntennaTaskResultError( + error=f"Batch processing error: {e}", + image_id=image_id, + ), + ) ) - ) - failed_items = batch.get("failed_items") - if failed_items: - for failed_item in failed_items: - batch_results.append( - AntennaTaskResult( - reply_subject=failed_item.get("reply_subject"), - result=AntennaTaskResultError( - error=failed_item.get("error", "Unknown error"), - image_id=failed_item.get("image_id"), - ), + failed_items = batch.get("failed_items") + if failed_items: + for failed_item in failed_items: + batch_results.append( + AntennaTaskResult( + reply_subject=failed_item.get("reply_subject"), + result=AntennaTaskResultError( + error=failed_item.get("error", "Unknown error"), + image_id=failed_item.get("image_id"), + ), + ) ) - ) - - # Post results asynchronously (non-blocking) - result_poster.post_async( - settings.antenna_api_base_url, - settings.antenna_api_auth_token, - job_id, - batch_results, - processing_service_name, - ) - logger.info( - f"Finished batch {i + 1}. Total items: {items}, Classification time: {cls_time:.2f}s, Detection time: {det_time:.2f}s, Load time: {load_time:.2f}s" - ) - # Wait for all async posts to complete before finishing the job - logger.info("Waiting for all pending result posts to complete...") + # Post results asynchronously (non-blocking) + result_poster.post_async( + settings.antenna_api_base_url, + settings.antenna_api_auth_token, + job_id, + batch_results, + processing_service_name, + ) + _, t = log_time() # reset time to measure batch load time + logger.info( + f"Finished batch {i + 1}. Total items: {items}, Classification time: {cls_time:.2f}s, Detection time: {det_time:.2f}s, Load time: {load_time:.2f}s" + ) - result_poster.wait_for_all_posts(min_timeout=60, per_post_timeout=30) + if result_poster: + # Wait for all async posts to complete before finishing the job + logger.info("Waiting for all pending result posts to complete...") + result_poster.wait_for_all_posts(min_timeout=60, per_post_timeout=30) - # Get final metrics - post_metrics = result_poster.get_metrics() - result_poster.shutdown() + # Get final metrics + post_metrics = result_poster.get_metrics() - logger.info( - f"Done, detections: {len(all_detections)}. Detecting time: {total_detection_time:.2f}s, " - f"classification time: {total_classification_time:.2f}s, dl time: {total_dl_time:.2f}s, " - f"result posts: {post_metrics.total_posts} " - f"(success: {post_metrics.successful_posts}, failed: {post_metrics.failed_posts}, " - f"success rate: {post_metrics.success_rate:.1f}%, avg post time: " - f"{post_metrics.total_post_time / post_metrics.total_posts if post_metrics.total_posts > 0 else 0:.2f}s, " - f"max queue size: {post_metrics.max_queue_size})" - ) - return did_work + logger.info( + f"Done, detections: {len(all_detections)}. Detecting time: {total_detection_time:.2f}s, " + f"classification time: {total_classification_time:.2f}s, dl time: {total_dl_time:.2f}s, " + f"result posts: {post_metrics.total_posts} " + f"(success: {post_metrics.successful_posts}, failed: {post_metrics.failed_posts}, " + f"success rate: {post_metrics.success_rate:.1f}%, avg post time: " + f"{post_metrics.total_post_time / post_metrics.total_posts if post_metrics.total_posts > 0 else 0:.2f}s, " + f"max queue size: {post_metrics.max_queue_size})" + ) + return did_work + finally: + if result_poster: + result_poster.shutdown()