diff --git a/trapdata/antenna/benchmark.py b/trapdata/antenna/benchmark.py new file mode 100644 index 0000000..898b9ce --- /dev/null +++ b/trapdata/antenna/benchmark.py @@ -0,0 +1,306 @@ +"""Benchmarking utilities for Antenna API data loading and result posting. + +This module provides a command-line benchmark tool for testing the performance +of the Antenna API data loading pipeline with asynchronous result posting. +The benchmark fetches batches from the API, processes acknowledgments, and +provides detailed performance metrics. + +Usage: + python -m trapdata.antenna.benchmark --job-id 123 --base-url http://localhost:8000/api/v2 + +Key metrics tracked: +- Images per second (total and successful) +- Batch processing rate +- Acknowledgment posting rate +- Result posting success/failure rates +- Queue utilization metrics +""" + +import argparse +import os +import sys +import time + +from trapdata.antenna.datasets import get_rest_dataloader +from trapdata.antenna.result_posting import ResultPoster +from trapdata.antenna.schemas import AntennaTaskResult, AntennaTaskResultError +from trapdata.common.logs import logger +from trapdata.common.utils import log_time +from trapdata.settings import Settings + + +def create_empty_result(reply_subject: str, image_id: str) -> AntennaTaskResult: + """Create an empty/acknowledgment result for a task. + + Args: + reply_subject: Subject for the reply + image_id: ID of the image being acknowledged + + Returns: + AntennaTaskResult with error acknowledgment + """ + result = AntennaTaskResultError( + error=f"Acknowledgment for image {image_id}", + image_id=image_id, + ) + return AntennaTaskResult(reply_subject=reply_subject, result=result) + + +def run_benchmark( + job_id: int, + base_url: str, + auth_token: str, + num_workers: int, + batch_size: int, + gpu_batch_size: int, + service_name: str, +) -> None: + """Run the benchmark with the specified parameters. + + Args: + job_id: Job ID to process + base_url: Antenna API base URL + auth_token: API authentication token + num_workers: Number of DataLoader workers + batch_size: Batch size for API requests + gpu_batch_size: GPU batch size for DataLoader + service_name: Processing service name + """ + # Create settings object + settings = Settings() + settings.antenna_api_base_url = base_url + settings.antenna_api_auth_token = auth_token + settings.antenna_api_batch_size = batch_size + settings.localization_batch_size = gpu_batch_size + settings.num_workers = num_workers + + print(f"Starting performance test for job {job_id}") + print("Configuration:") + print(f" Base URL: {base_url}") + print(f" API batch size: {batch_size}") + print(f" GPU batch size: {gpu_batch_size}") + print(f" Num workers: {num_workers}") + print(f" Service name: {service_name}") + print() + + # Create dataloader + dataloader = get_rest_dataloader( + job_id=job_id, + settings=settings, + processing_service_name=service_name, + ) + + # Initialize ResultPoster for sending acknowledgments + result_poster = ResultPoster(max_pending=10) + + # Performance metrics + total_batches = 0 + total_images = 0 + total_successful_images = 0 + total_failed_images = 0 + total_acks_sent = 0 + start_time = time.time() + last_report_time = start_time + report_interval = 10 # Report every 10 seconds + + print("Starting data consumption with acknowledgments...") + try: + _, t = log_time() + for batch_idx, batch in enumerate(dataloader): + _, t = t( + f"Fetched batch {batch_idx} with {len(batch['reply_subjects'])} items" + ) + current_time = time.time() + total_batches += 1 + + # Count images in this batch + batch_failed = len(batch["failed_items"]) + # Successful items are those with reply_subjects that are not in failed_items + batch_successful = len(batch["reply_subjects"]) + + total_images += batch_size + total_successful_images += batch_successful + total_failed_images += batch_failed + + # Send acknowledgments for successful items + if batch_successful > 0: + ack_results = [] + for i, (reply_subject, image_id) in enumerate( + zip(batch["reply_subjects"], batch["image_ids"]) + ): + if i < batch_successful: # Only for successful items + ack_result = create_empty_result(reply_subject, image_id) + ack_results.append(ack_result) + + logger.info(f"Sending {len(ack_results)} acknowledgment(s)") + if ack_results: + # Send acknowledgments asynchronously + result_poster.post_async( + base_url=base_url, + auth_token=auth_token, + job_id=job_id, + results=ack_results, + processing_service_name=service_name, + ) + total_acks_sent += len(ack_results) + + # Send error results for failed items + if batch_failed > 0: + error_results = [] + for failed_item in batch["failed_items"]: + error_result = AntennaTaskResult( + reply_subject=failed_item["reply_subject"], + result=AntennaTaskResultError( + error=failed_item.get("error", "Image loading failed"), + image_id=failed_item["image_id"], + ), + ) + error_results.append(error_result) + + if error_results: + result_poster.post_async( + base_url=base_url, + auth_token=auth_token, + job_id=job_id, + results=error_results, + processing_service_name=service_name, + ) + total_acks_sent += len(error_results) + + # Report progress periodically + if current_time - last_report_time >= report_interval: + elapsed = current_time - start_time + images_per_sec = total_images / elapsed if elapsed > 0 else 0 + successful_per_sec = ( + total_successful_images / elapsed if elapsed > 0 else 0 + ) + acks_per_sec = total_acks_sent / elapsed if elapsed > 0 else 0 + + # Get ResultPoster metrics + post_metrics = result_poster.get_metrics() + + print( + f"Progress: {total_batches} batches, {total_images} images " + f"({total_successful_images} success, {total_failed_images} failed) " + f"- {images_per_sec:.1f} img/s, {successful_per_sec:.1f} success/s, " + f"{acks_per_sec:.1f} acks/s" + ) + print( + f" Posts: {post_metrics.successful_posts} success, " + f"{post_metrics.failed_posts} failed, " + f"{post_metrics.success_rate:.1f}% success rate" + ) + last_report_time = current_time + _, t = log_time() + + except KeyboardInterrupt: + print("\nStopped by user") + except Exception as e: + print(f"\nError occurred: {e}") + logger.error(f"DataLoader benchmark error: {e}") + finally: + # Wait for all pending result posts to complete + print("Waiting for pending result posts to complete...") + result_poster.wait_for_all_posts() + result_poster.shutdown() + + # Final statistics + end_time = time.time() + total_elapsed = end_time - start_time + final_post_metrics = result_poster.get_metrics() + + print("\n" + "=" * 70) + print("PERFORMANCE SUMMARY") + print("=" * 70) + print(f"Total time: {total_elapsed:.2f} seconds") + print(f"Total batches: {total_batches}") + print(f"Total images: {total_images}") + print(f"Successful images: {total_successful_images}") + print(f"Failed images: {total_failed_images}") + print(f"Acknowledgments sent: {total_acks_sent}") + + if total_elapsed > 0: + images_per_sec = total_images / total_elapsed + successful_per_sec = total_successful_images / total_elapsed + batches_per_sec = total_batches / total_elapsed + acks_per_sec = total_acks_sent / total_elapsed + + print("\nThroughput:") + print(f" {images_per_sec:.2f} images/second (total)") + print(f" {successful_per_sec:.2f} images/second (successful)") + print(f" {batches_per_sec:.2f} batches/second") + print(f" {acks_per_sec:.2f} acknowledgments/second") + + if total_images > 0: + success_rate = (total_successful_images / total_images) * 100 + print(f"\nSuccess rate: {success_rate:.1f}%") + + print("\nResult Posting Metrics:") + print(f" Total posts: {final_post_metrics.total_posts}") + print(f" Successful posts: {final_post_metrics.successful_posts}") + print(f" Failed posts: {final_post_metrics.failed_posts}") + print(f" Post success rate: {final_post_metrics.success_rate:.1f}%") + if final_post_metrics.total_posts > 0: + avg_post_time = ( + final_post_metrics.total_post_time / final_post_metrics.total_posts + ) + print(f" Average post time: {avg_post_time:.3f} seconds") + print(f" Max queue size: {final_post_metrics.max_queue_size}") + + print("=" * 70) + print("Performance benchmark completed") + print("=" * 70) + + +def main() -> int: + """Main entry point for the benchmark CLI.""" + # Parse command line arguments + parser = argparse.ArgumentParser( + description="Benchmark dataloader performance with acknowledgements" + ) + parser.add_argument("--job-id", type=int, required=True, help="Job ID to process") + parser.add_argument( + "--base-url", + type=str, + default="http://localhost:8000/api/v2", + help="Antenna API base URL", + ) + parser.add_argument( + "--num-workers", type=int, default=2, help="Number of DataLoader workers" + ) + parser.add_argument( + "--batch-size", type=int, default=16, help="Batch size for API requests" + ) + parser.add_argument( + "--gpu-batch-size", type=int, default=16, help="GPU batch size for DataLoader" + ) + parser.add_argument( + "--service-name", + type=str, + default="Performance Test", + help="Processing service name", + ) + + args = parser.parse_args() + + # Get auth token from environment + auth_token = os.getenv("AMI_ANTENNA_API_AUTH_TOKEN", "") + if not auth_token: + print("ERROR: AMI_ANTENNA_API_AUTH_TOKEN environment variable not set") + return 1 + + # Run the benchmark + run_benchmark( + job_id=args.job_id, + base_url=args.base_url, + auth_token=auth_token, + num_workers=args.num_workers, + batch_size=args.batch_size, + gpu_batch_size=args.gpu_batch_size, + service_name=args.service_name, + ) + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/trapdata/antenna/client.py b/trapdata/antenna/client.py index d86c504..0bdbfc9 100644 --- a/trapdata/antenna/client.py +++ b/trapdata/antenna/client.py @@ -100,7 +100,7 @@ def post_batch_results( params = {"processing_service_name": processing_service_name} response = session.post(url, json=payload, params=params, timeout=60) response.raise_for_status() - logger.info(f"Successfully posted {len(results)} results to {url}") + logger.debug(f"Successfully posted {len(results)} results to {url}") return True except requests.RequestException as e: logger.error(f"Failed to post results to {url}: {e}") diff --git a/trapdata/antenna/datasets.py b/trapdata/antenna/datasets.py index bb041f5..2d21dec 100644 --- a/trapdata/antenna/datasets.py +++ b/trapdata/antenna/datasets.py @@ -37,7 +37,7 @@ How many images the GPU processes at once (detection). Larger = more GPU memory. These are full-resolution images (~4K). - num_workers (default 2) + num_workers (default 4) DataLoader subprocesses. Each independently fetches tasks and downloads images. More workers = more images prefetched for the GPU, at the cost of CPU/RAM. With 0 workers, fetching and @@ -286,7 +286,7 @@ def __iter__(self): if not tasks: # Queue is empty - job complete - logger.info( + logger.debug( f"Worker {worker_id}: No more tasks for job {self.job_id}" ) break @@ -317,7 +317,7 @@ def __iter__(self): row["error"] = "; ".join(errors) if errors else None yield row - logger.info(f"Worker {worker_id}: Iterator finished") + logger.debug(f"Worker {worker_id}: Iterator finished") except Exception as e: logger.error(f"Worker {worker_id}: Exception in iterator: {e}") raise diff --git a/trapdata/antenna/result_posting.py b/trapdata/antenna/result_posting.py new file mode 100644 index 0000000..16207ff --- /dev/null +++ b/trapdata/antenna/result_posting.py @@ -0,0 +1,270 @@ +"""Asynchronous result posting utilities for Antenna API. + +This module provides utilities for posting batch results to the Antenna API with +backpressure control and comprehensive metrics tracking. The main class, ResultPoster, +manages asynchronous posting to improve worker throughput by overlapping network I/O +with compute operations. + +Key features: +- Asynchronous posting using ThreadPoolExecutor +- Configurable backpressure control to prevent unbounded memory usage +- Comprehensive metrics tracking (success/failure rates, timing, queue size) +- Graceful shutdown with timeout handling +- Thread-safe operations + +Usage: + poster = ResultPoster(max_pending=5) + poster.post_async(base_url, auth_token, job_id, results, service_name) + metrics = poster.get_metrics() + poster.shutdown() +""" + +import threading +import time +from concurrent.futures import FIRST_COMPLETED, Future, ThreadPoolExecutor, wait +from dataclasses import dataclass +from typing import Optional + +from trapdata.antenna.client import post_batch_results +from trapdata.common.logs import logger + + +@dataclass +class ResultPostMetrics: + """Metrics for tracking result posting performance.""" + + total_posts: int = 0 + successful_posts: int = 0 + failed_posts: int = 0 + total_post_time: float = 0.0 + max_queue_size: int = 0 + + @property + def success_rate(self) -> float: + """Calculate success rate as a percentage.""" + return ( + (self.successful_posts / self.total_posts * 100) + if self.total_posts > 0 + else 0.0 + ) + + +class ResultPoster: + """Manages asynchronous posting of batch results with backpressure control. + + This class provides asynchronous result posting to improve throughput by allowing + the worker to continue processing while previous results are posted in background + threads. It includes backpressure control to prevent unbounded memory usage. + + Args: + max_pending: Maximum number of concurrent posts before blocking (default: 5) + + Example: + poster = ResultPoster(max_pending=10) + poster.post_async(base_url, auth_token, job_id, results, service_name) + metrics = poster.get_metrics() + poster.shutdown() + """ + + def __init__( + self, + max_pending: int = 5, + future_timeout: float = 30.0, + ): + self.max_pending = max_pending + self.future_timeout = future_timeout # Timeout for individual future waits + self.executor = ThreadPoolExecutor( + max_workers=2, thread_name_prefix="result_poster" + ) + self.pending_futures: list[Future] = [] + self.metrics = ResultPostMetrics() + self._metrics_lock = threading.Lock() + + def post_async( + self, + base_url: str, + auth_token: str, + job_id: int, + results: list, + processing_service_name: str, + ) -> None: + """Post results asynchronously with backpressure control. + + This method will block if there are too many pending posts to prevent + unbounded memory usage and provide backpressure. + + Args: + base_url: Antenna API base URL + auth_token: API authentication token + job_id: Job ID for the results + results: List of result objects to post + processing_service_name: Name of the processing service + """ + # Clean up completed futures and update metrics + self._cleanup_completed_futures() + + # Apply backpressure: wait for pending posts to complete if we're at the limit + while len(self.pending_futures) >= self.max_pending: + # Wait for at least one future to complete + done_futures, _ = wait( + self.pending_futures, + return_when=FIRST_COMPLETED, + timeout=self.future_timeout, + ) + for completed_future in done_futures: + self.pending_futures.remove(completed_future) + if not done_futures: + # Force cleanup by cancelling all pending futures and clearing the list. + # This means that potentially losing some posts, but the server tasks pipeline has built-in retries. + for future in self.pending_futures: + future.cancel() + self.pending_futures.clear() + logger.warning( + "Timeout waiting for pending result posts. " + "Cleared pending futures to prevent blocking indefinitely." + ) + + # Update queue size metric + current_queue_size = ( + len(self.pending_futures) + 1 + ) # +1 for the post we're about to submit + self.metrics.max_queue_size = max( + self.metrics.max_queue_size, current_queue_size + ) + + # Submit new post + start_time = time.time() + future = self.executor.submit( + self._post_with_timing, + base_url, + auth_token, + job_id, + results, + processing_service_name, + start_time, + ) + self.pending_futures.append(future) + self.metrics.total_posts += 1 + + logger.debug( + f"Submitted result post for job {job_id}, {current_queue_size + 1} pending" + ) + + def _post_with_timing( + self, + base_url: str, + auth_token: str, + job_id: int, + results: list, + processing_service_name: str, + start_time: float, + ) -> bool: + """Internal method that times the post operation and updates metrics. + + Args: + base_url: Antenna API base URL + auth_token: API authentication token + job_id: Job ID for the results + results: List of result objects to post + processing_service_name: Name of the processing service + start_time: Timestamp when the post was initiated + + Returns: + True if successful, False otherwise + """ + try: + success = post_batch_results( + base_url, auth_token, job_id, results, processing_service_name + ) + elapsed_time = time.time() - start_time + + with self._metrics_lock: + self.metrics.total_post_time += elapsed_time + if success: + self.metrics.successful_posts += 1 + else: + self.metrics.failed_posts += 1 + logger.warning( + f"Result post failed for job {job_id} after {elapsed_time:.2f}s" + ) + + return success + except Exception as e: + elapsed_time = time.time() - start_time + with self._metrics_lock: + self.metrics.total_post_time += elapsed_time + self.metrics.failed_posts += 1 + logger.error(f"Exception during result post for job {job_id}: {e}") + return False + + def _cleanup_completed_futures(self) -> None: + """Remove completed futures from the pending list.""" + self.pending_futures = [f for f in self.pending_futures if not f.done()] + + def wait_for_all_posts( + self, + min_timeout: float = 60, + per_post_timeout: float = 30, + ) -> None: + """Wait for all pending posts to complete before shutting down. + + Args: + min_timeout: Minimum timeout regardless of pending count (default: 60) + per_post_timeout: Additional timeout per pending post (default: 30) + """ + if not self.pending_futures: + return + + pending_count = len(self.pending_futures) + timeout = max(min_timeout, pending_count * per_post_timeout) + logger.info( + f"Waiting for {pending_count} pending result posts with dynamic timeout {timeout}s" + ) + start_time = time.time() + + for future in self.pending_futures: + remaining_timeout = None + elapsed = time.time() - start_time + remaining_timeout = max(0, timeout - elapsed) + if remaining_timeout == 0: + logger.warning("Timeout waiting for pending posts, some may be lost") + break + + try: + future.result(timeout=remaining_timeout) + except Exception as e: + logger.warning(f"Pending result post failed during shutdown: {e}") + + self._cleanup_completed_futures() + + # Check if any posts were abandoned and emit error + posts_after_wait = len(self.pending_futures) + if posts_after_wait > 0: + metrics = self.metrics + logger.error( + f"Failed to complete all result posts before timeout. " + f"{posts_after_wait} posts were abandoned. " + f"Post metrics - Total: {metrics.total_posts}, " + f"Successful: {metrics.successful_posts}, " + f"Failed: {metrics.failed_posts}, " + f"Success rate: {metrics.success_rate:.1f}%, " + f"Max queue size: {metrics.max_queue_size}" + ) + + def get_metrics(self) -> ResultPostMetrics: + """Get current metrics. + + Returns: + Current ResultPostMetrics object with performance data + """ + self._cleanup_completed_futures() + return self.metrics + + def shutdown(self) -> None: + """Shutdown the executor""" + if self.pending_futures: + raise RuntimeError( + f"Cannot shutdown ResultPoster with {len(self.pending_futures)} pending posts. " + "Call wait_for_all_posts() before shutdown to ensure all posts are completed." + ) + self.executor.shutdown(wait=False) diff --git a/trapdata/antenna/worker.py b/trapdata/antenna/worker.py index 0d61237..68b1c81 100644 --- a/trapdata/antenna/worker.py +++ b/trapdata/antenna/worker.py @@ -10,6 +10,7 @@ from trapdata.antenna.client import get_full_service_name, get_jobs, post_batch_results from trapdata.antenna.datasets import get_rest_dataloader +from trapdata.antenna.result_posting import ResultPoster from trapdata.antenna.schemas import AntennaTaskResult, AntennaTaskResultError from trapdata.api.api import CLASSIFIER_CHOICES, should_filter_detections from trapdata.api.models.classification import MothClassifierBinary @@ -23,6 +24,7 @@ from trapdata.common.utils import log_time from trapdata.settings import Settings, read_settings +MAX_PENDING_POSTS = 5 # Maximum number of concurrent result posts before blocking SLEEP_TIME_SECONDS = 5 @@ -229,12 +231,6 @@ def _process_job( classifier_class = CLASSIFIER_CHOICES[pipeline] use_binary_filter = should_filter_detections(classifier_class) binary_filter = None - if use_binary_filter: - binary_filter = MothClassifierBinary( - source_images=[], - detections=[], - terminal=False, - ) if torch.cuda.is_available(): torch.cuda.empty_cache() @@ -242,217 +238,247 @@ def _process_job( total_detection_time = 0.0 total_classification_time = 0.0 - total_save_time = 0.0 total_dl_time = 0.0 all_detections = [] _, t = log_time() - - for i, batch in enumerate(loader): - dt, t = t("Finished loading batch") - total_dl_time += dt - if not batch: - logger.warning(f"Batch {i + 1} is empty, skipping") - continue - - # Defer instantiation of detector and classifier until we have data - if not classifier: - classifier = classifier_class(source_images=[], detections=[]) - detector = APIMothDetector([]) - assert detector is not None, "Detector not initialized" - assert classifier is not None, "Classifier not initialized" - detector.reset([]) - did_work = True - - # Extract data from dictionary batch - images = batch.get("images", []) - image_ids = batch.get("image_ids", []) - reply_subjects = batch.get("reply_subjects", [None] * len(images)) - image_urls = batch.get("image_urls", [None] * len(images)) - - batch_results: list[AntennaTaskResult] = [] - - try: - # Validate all arrays have same length before zipping - if len(image_ids) != len(images): - raise ValueError( - f"Length mismatch: image_ids ({len(image_ids)}) != images ({len(images)})" - ) - if len(image_ids) != len(reply_subjects) or len(image_ids) != len( - image_urls - ): - raise ValueError( - f"Length mismatch: image_ids ({len(image_ids)}), " - f"reply_subjects ({len(reply_subjects)}), image_urls ({len(image_urls)})" - ) - - # Track start time for this batch - batch_start_time = datetime.datetime.now() - - logger.info(f"Processing worker batch {i + 1} ({len(images)} images)") - # output is dict of "boxes", "labels", "scores" - batch_output = [] - if len(images) > 0: - batch_output = detector.predict_batch(images) - - items += len(batch_output) - logger.info(f"Total items processed so far: {items}") - batch_output = list(detector.post_process_batch(batch_output)) - - # Convert image_ids to list if needed - if isinstance(image_ids, (np.ndarray, torch.Tensor)): - image_ids = image_ids.tolist() - - # TODO CGJS: Add seconds per item calculation for both detector and classifier - detector.save_results( - item_ids=image_ids, - batch_output=batch_output, - seconds_per_item=0, - ) - dt, t = t("Finished detection") - total_detection_time += dt - - # Group detections by image_id - image_detections: dict[str, list[DetectionResponse]] = { - img_id: [] for img_id in image_ids - } - image_tensors = dict(zip(image_ids, images, strict=True)) - - # Apply binary classification filter if needed - detector_results = detector.results - - if use_binary_filter: - assert binary_filter is not None, "Binary filter not initialized" - detections_for_terminal_classifier, detections_to_return = ( - _apply_binary_classification( - binary_filter, detector_results, image_tensors, image_detections + result_poster: ResultPoster | None = None + try: + for i, batch in enumerate(loader): + cls_time = 0.0 + det_time = 0.0 + load_time, t = t() + total_dl_time += load_time + if not batch: + logger.warning(f"Batch {i + 1} is empty, skipping") + continue + + # Defer instantiation of poster, detector and classifiers until we have data + if not classifier: + classifier = classifier_class(source_images=[], detections=[]) + detector = APIMothDetector([]) + result_poster = ResultPoster(max_pending=MAX_PENDING_POSTS) + + if use_binary_filter: + binary_filter = MothClassifierBinary( + source_images=[], + detections=[], + terminal=False, ) - ) - else: - # No binary filtering, send all detections to terminal classifier - detections_for_terminal_classifier = detector_results - detections_to_return = [] - - # Run terminal classifier on filtered detections - classifier.reset(detections_for_terminal_classifier) - to_pil = torchvision.transforms.ToPILImage() - classify_transforms = classifier.get_transforms() - - # Collect and transform all crops for batched classification - crops = [] - valid_indices = [] - for idx, dresp in enumerate(detections_for_terminal_classifier): - image_tensor = image_tensors[dresp.source_image_id] - bbox = dresp.bbox - y1, y2 = int(bbox.y1), int(bbox.y2) - x1, x2 = int(bbox.x1), int(bbox.x2) - if y1 >= y2 or x1 >= x2: - logger.warning( - f"Skipping detection {idx} with invalid bbox: " - f"({x1},{y1})->({x2},{y2})" + + assert detector is not None, "Detector not initialized" + assert classifier is not None, "Classifier not initialized" + assert result_poster is not None, "ResultPoster not initialized" + assert not ( + use_binary_filter and binary_filter is None + ), "Binary filter not initialized" + detector.reset([]) + did_work = True + + # Extract data from dictionary batch + images = batch.get("images", []) + image_ids = batch.get("image_ids", []) + reply_subjects = batch.get("reply_subjects", [None] * len(images)) + image_urls = batch.get("image_urls", [None] * len(images)) + + batch_results: list[AntennaTaskResult] = [] + try: + # Validate all arrays have same length before zipping + if len(image_ids) != len(images): + raise ValueError( + f"Length mismatch: image_ids ({len(image_ids)}) != images ({len(images)})" ) - continue - crop = image_tensor[:, y1:y2, x1:x2] - crop_pil = to_pil(crop) - crop_transformed = classify_transforms(crop_pil) - crops.append(crop_transformed) - valid_indices.append(idx) - - if crops: - batched_crops = torch.stack(crops) - classifier_out = classifier.predict_batch(batched_crops) - classifier_out = classifier.post_process_batch(classifier_out) - - for crop_i, idx in enumerate(valid_indices): - dresp = detections_for_terminal_classifier[idx] - detection = classifier.update_detection_classification( - seconds_per_item=0, - image_id=dresp.source_image_id, - detection_idx=idx, - predictions=classifier_out[crop_i], + if len(image_ids) != len(reply_subjects) or len(image_ids) != len( + image_urls + ): + raise ValueError( + f"Length mismatch: image_ids ({len(image_ids)}), " + f"reply_subjects ({len(reply_subjects)}), image_urls ({len(image_urls)})" ) - image_detections[dresp.source_image_id].append(detection) - all_detections.append(detection) - # Add non-moth detections to all_detections - all_detections.extend(detections_to_return) + # Track start time for this batch + batch_start_time = datetime.datetime.now() - ct, t = t("Finished classification") - total_classification_time += ct + # output is dict of "boxes", "labels", "scores" + batch_output = [] + if len(images) > 0: + batch_output = detector.predict_batch(images) - # Calculate batch processing time - batch_end_time = datetime.datetime.now() - batch_elapsed = (batch_end_time - batch_start_time).total_seconds() + items += len(batch_output) + batch_output = list(detector.post_process_batch(batch_output)) - # Post results back to the API with PipelineResponse for each image - batch_results.clear() - for reply_subject, image_id, image_url in zip( - reply_subjects, image_ids, image_urls, strict=True - ): - # Create SourceImageResponse for this image - source_image = SourceImageResponse(id=image_id, url=image_url) + # Convert image_ids to list if needed + if isinstance(image_ids, (np.ndarray, torch.Tensor)): + image_ids = image_ids.tolist() - # Create PipelineResultsResponse - pipeline_response = PipelineResultsResponse( - pipeline=pipeline, - source_images=[source_image], - detections=image_detections[image_id], - total_time=batch_elapsed - / len(image_ids), # Approximate time per image + # TODO CGJS: Add seconds per item calculation for both detector and classifier + detector.save_results( + item_ids=image_ids, + batch_output=batch_output, + seconds_per_item=0, ) - - batch_results.append( - AntennaTaskResult( - reply_subject=reply_subject, - result=pipeline_response, + det_time, t = t() + total_detection_time += det_time + + # Group detections by image_id + image_detections: dict[str, list[DetectionResponse]] = { + img_id: [] for img_id in image_ids + } + image_tensors = dict(zip(image_ids, images, strict=True)) + + # Apply binary classification filter if needed + detector_results = detector.results + + if use_binary_filter: + assert binary_filter is not None, "Binary filter not initialized" + detections_for_terminal_classifier, detections_to_return = ( + _apply_binary_classification( + binary_filter, + detector_results, + image_tensors, + image_detections, + ) ) - ) - except Exception as e: - logger.error(f"Batch {i + 1} failed during processing: {e}", exc_info=True) - # Report errors back to Antenna so tasks aren't stuck in the queue - batch_results = [] - for reply_subject, image_id in zip(reply_subjects, image_ids, strict=True): - batch_results.append( - AntennaTaskResult( - reply_subject=reply_subject, - result=AntennaTaskResultError( - error=f"Batch processing error: {e}", - image_id=image_id, - ), + else: + # No binary filtering, send all detections to terminal classifier + detections_for_terminal_classifier = detector_results + detections_to_return = [] + + # Run terminal classifier on filtered detections + classifier.reset(detections_for_terminal_classifier) + to_pil = torchvision.transforms.ToPILImage() + classify_transforms = classifier.get_transforms() + + # Collect and transform all crops for batched classification + crops = [] + valid_indices = [] + for idx, dresp in enumerate(detections_for_terminal_classifier): + image_tensor = image_tensors[dresp.source_image_id] + bbox = dresp.bbox + y1, y2 = int(bbox.y1), int(bbox.y2) + x1, x2 = int(bbox.x1), int(bbox.x2) + if y1 >= y2 or x1 >= x2: + logger.warning( + f"Skipping detection {idx} with invalid bbox: " + f"({x1},{y1})->({x2},{y2})" + ) + continue + crop = image_tensor[:, y1:y2, x1:x2] + crop_pil = to_pil(crop) + crop_transformed = classify_transforms(crop_pil) + crops.append(crop_transformed) + valid_indices.append(idx) + + if crops: + batched_crops = torch.stack(crops) + classifier_out = classifier.predict_batch(batched_crops) + classifier_out = classifier.post_process_batch(classifier_out) + + for crop_i, idx in enumerate(valid_indices): + dresp = detections_for_terminal_classifier[idx] + detection = classifier.update_detection_classification( + seconds_per_item=0, + image_id=dresp.source_image_id, + detection_idx=idx, + predictions=classifier_out[crop_i], + ) + image_detections[dresp.source_image_id].append(detection) + all_detections.append(detection) + + cls_time, t = t() + total_classification_time += cls_time + # Add non-moth detections to all_detections + all_detections.extend(detections_to_return) + + # Calculate batch processing time + batch_end_time = datetime.datetime.now() + batch_elapsed = (batch_end_time - batch_start_time).total_seconds() + + # Post results back to the API with PipelineResponse for each image + batch_results.clear() + for reply_subject, image_id, image_url in zip( + reply_subjects, image_ids, image_urls, strict=True + ): + # Create SourceImageResponse for this image + source_image = SourceImageResponse(id=image_id, url=image_url) + + # Create PipelineResultsResponse + pipeline_response = PipelineResultsResponse( + pipeline=pipeline, + source_images=[source_image], + detections=image_detections[image_id], + total_time=batch_elapsed + / len(image_ids), # Approximate time per image ) - ) - failed_items = batch.get("failed_items") - if failed_items: - for failed_item in failed_items: - batch_results.append( - AntennaTaskResult( - reply_subject=failed_item.get("reply_subject"), - result=AntennaTaskResultError( - error=failed_item.get("error", "Unknown error"), - image_id=failed_item.get("image_id"), - ), + batch_results.append( + AntennaTaskResult( + reply_subject=reply_subject, + result=pipeline_response, + ) ) + except Exception as e: + logger.error( + f"Batch {i + 1} failed during processing: {e}", exc_info=True ) + # Report errors back to Antenna so tasks aren't stuck in the queue + batch_results = [] + for reply_subject, image_id in zip( + reply_subjects, image_ids, strict=True + ): + batch_results.append( + AntennaTaskResult( + reply_subject=reply_subject, + result=AntennaTaskResultError( + error=f"Batch processing error: {e}", + image_id=image_id, + ), + ) + ) - success = post_batch_results( - settings.antenna_api_base_url, - settings.antenna_api_auth_token, - job_id, - batch_results, - processing_service_name, - ) - st, t = t("Finished posting results") + failed_items = batch.get("failed_items") + if failed_items: + for failed_item in failed_items: + batch_results.append( + AntennaTaskResult( + reply_subject=failed_item.get("reply_subject"), + result=AntennaTaskResultError( + error=failed_item.get("error", "Unknown error"), + image_id=failed_item.get("image_id"), + ), + ) + ) - if not success: - logger.error( - f"Failed to post {len(batch_results)} results for job {job_id} to " - f"{settings.antenna_api_base_url}. Batch processing data lost." + # Post results asynchronously (non-blocking) + result_poster.post_async( + settings.antenna_api_base_url, + settings.antenna_api_auth_token, + job_id, + batch_results, + processing_service_name, + ) + _, t = log_time() # reset time to measure batch load time + logger.info( + f"Finished batch {i + 1}. Total items: {items}, Classification time: {cls_time:.2f}s, Detection time: {det_time:.2f}s, Load time: {load_time:.2f}s" ) - total_save_time += st + if result_poster: + # Wait for all async posts to complete before finishing the job + logger.info("Waiting for all pending result posts to complete...") + result_poster.wait_for_all_posts(min_timeout=60, per_post_timeout=30) - logger.info( - f"Done, detections: {len(all_detections)}. Detecting time: {total_detection_time}, " - f"classification time: {total_classification_time}, dl time: {total_dl_time}, save time: {total_save_time}" - ) - return did_work + # Get final metrics + post_metrics = result_poster.get_metrics() + + logger.info( + f"Done, detections: {len(all_detections)}. Detecting time: {total_detection_time:.2f}s, " + f"classification time: {total_classification_time:.2f}s, dl time: {total_dl_time:.2f}s, " + f"result posts: {post_metrics.total_posts} " + f"(success: {post_metrics.successful_posts}, failed: {post_metrics.failed_posts}, " + f"success rate: {post_metrics.success_rate:.1f}%, avg post time: " + f"{post_metrics.total_post_time / post_metrics.total_posts if post_metrics.total_posts > 0 else 0:.2f}s, " + f"max queue size: {post_metrics.max_queue_size})" + ) + return did_work + finally: + if result_poster: + result_poster.shutdown() diff --git a/trapdata/settings.py b/trapdata/settings.py index 54324c3..d37c34d 100644 --- a/trapdata/settings.py +++ b/trapdata/settings.py @@ -35,7 +35,7 @@ class Settings(BaseSettings): classification_threshold: float = 0.6 localization_batch_size: int = 8 classification_batch_size: int = 20 - num_workers: int = 2 + num_workers: int = 4 # Antenna API worker settings antenna_api_base_url: str = "http://localhost:8000/api/v2"