Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 0 additions & 13 deletions trapdata/antenna/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ def run_benchmark(
num_workers: int,
batch_size: int,
gpu_batch_size: int,
service_name: str,
send_acks: bool = True,
) -> None:
"""Run the benchmark with the specified parameters.
Expand All @@ -65,7 +64,6 @@ def run_benchmark(
num_workers: Number of DataLoader workers
batch_size: Batch size for API requests
gpu_batch_size: GPU batch size for DataLoader
service_name: Processing service name
"""
# Create settings object
settings = Settings()
Expand All @@ -81,14 +79,12 @@ def run_benchmark(
print(f" API batch size: {batch_size}")
print(f" GPU batch size: {gpu_batch_size}")
print(f" Num workers: {num_workers}")
print(f" Service name: {service_name}")
print()

# Create dataloader
dataloader = get_rest_dataloader(
job_id=job_id,
settings=settings,
processing_service_name=service_name,
)

# Initialize ResultPoster for sending acknowledgments
Expand Down Expand Up @@ -141,7 +137,6 @@ def run_benchmark(
auth_token=auth_token,
job_id=job_id,
results=ack_results,
processing_service_name=service_name,
)
total_acks_sent += len(ack_results)

Expand All @@ -164,7 +159,6 @@ def run_benchmark(
auth_token=auth_token,
job_id=job_id,
results=error_results,
processing_service_name=service_name,
)
total_acks_sent += len(error_results)

Expand Down Expand Up @@ -275,12 +269,6 @@ def main() -> int:
parser.add_argument(
"--gpu-batch-size", type=int, default=16, help="GPU batch size for DataLoader"
)
parser.add_argument(
"--service-name",
type=str,
default="Performance Test",
help="Processing service name",
)
parser.add_argument(
"--skip-acks",
action="store_false",
Expand All @@ -303,7 +291,6 @@ def main() -> int:
num_workers=args.num_workers,
batch_size=args.batch_size,
gpu_batch_size=args.gpu_batch_size,
service_name=args.service_name,
send_acks=args.skip_acks,
)
return 0
Expand Down
21 changes: 11 additions & 10 deletions trapdata/antenna/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@

from trapdata.antenna.schemas import (
AntennaJobsListResponse,
AntennaResultPostResponse,
AntennaTaskResult,
AntennaTaskResults,
JobDispatchMode,
)
from trapdata.api.utils import get_http_session
Expand All @@ -30,17 +32,15 @@ def get_jobs(
base_url: str,
auth_token: str,
pipeline_slugs: list[str],
processing_service_name: str,
) -> list[tuple[int, str]]:
"""Fetch job ids from the API for the given pipelines in a single request.

Calls: GET {base_url}/jobs?pipeline__slug__in=<slugs>&ids_only=1&processing_service_name=<name>
Calls: GET {base_url}/jobs?pipeline__slug__in=<slugs>&ids_only=1

Args:
base_url: Antenna API base URL (e.g., "http://localhost:8000/api/v2")
auth_token: API authentication token
pipeline_slugs: List of pipeline slugs to filter jobs
processing_service_name: Name of the processing service

Returns:
List of (job_id, pipeline_slug) tuples (possibly empty) on success or error.
Expand All @@ -54,7 +54,6 @@ def get_jobs(
"pipeline__slug__in": ",".join(pipeline_slugs),
"ids_only": 1,
"incomplete_only": 1,
"processing_service_name": processing_service_name,
"dispatch_mode": JobDispatchMode.ASYNC_API, # Only fetch async_api jobs
}

Expand All @@ -77,7 +76,6 @@ def post_batch_results(
auth_token: str,
job_id: int,
results: list[AntennaTaskResult],
processing_service_name: str,
) -> bool:
"""
Post batch results back to the API.
Expand All @@ -87,20 +85,23 @@ def post_batch_results(
auth_token: API authentication token
job_id: Job ID
results: List of AntennaTaskResult objects
processing_service_name: Name of the processing service

Returns:
True if successful, False otherwise
"""
url = f"{base_url.rstrip('/')}/jobs/{job_id}/result/"
payload = [r.model_dump(mode="json") for r in results]
payload = AntennaTaskResults(results=results)

with get_http_session(auth_token) as session:
try:
params = {"processing_service_name": processing_service_name}
response = session.post(url, json=payload, params=params, timeout=60)
response = session.post(
url, json=payload.model_dump(mode="json"), timeout=60
)
response.raise_for_status()
logger.debug(f"Successfully posted {len(results)} results to {url}")
result = AntennaResultPostResponse.model_validate(response.json())
logger.debug(
f"Posted {len(results)} results to job {job_id}: {result.results_queued} queued"
)
return True
except requests.RequestException as e:
logger.error(f"Failed to post results to {url}: {e}")
Expand Down
24 changes: 9 additions & 15 deletions trapdata/antenna/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
├──────────────────────────────────────────────────────────────────┤
│ DataLoader workers (num_workers subprocesses) │
│ Each subprocess runs its own RESTDataset.__iter__ loop: │
│ 1. GET /tasks → fetch batch of task metadata from Antenna │
│ 1. POST /tasks → fetch batch of task metadata from Antenna │
│ 2. Download images (threaded, see below) │
│ 3. Yield individual (image_tensor, metadata) rows │
│ The DataLoader collates rows into GPU-sized batches. │
Expand Down Expand Up @@ -76,6 +76,7 @@
from trapdata.antenna.schemas import (
AntennaPipelineProcessingTask,
AntennaTasksListResponse,
AntennaTasksRequest,
)
from trapdata.api.utils import get_http_session
from trapdata.common.logs import logger
Expand All @@ -97,8 +98,8 @@ class RESTDataset(torch.utils.data.IterableDataset):
independently fetches different tasks from the shared queue.

With DataLoader num_workers > 0 (I/O subprocesses, not AMI instances):
Subprocess 1: GET /tasks → receives [1,2,3,4], removed from queue
Subprocess 2: GET /tasks → receives [5,6,7,8], removed from queue
Subprocess 1: POST /tasks → receives [1,2,3,4], removed from queue
Subprocess 2: POST /tasks → receives [5,6,7,8], removed from queue
No duplicates, safe for parallel processing
"""

Expand All @@ -109,7 +110,6 @@ def __init__(
job_id: int,
batch_size: int = 1,
image_transforms: torchvision.transforms.Compose | None = None,
processing_service_name: str = "",
):
"""
Initialize the REST dataset.
Expand All @@ -120,15 +120,13 @@ def __init__(
job_id: The job ID to fetch tasks for
batch_size: Number of tasks to request per batch
image_transforms: Optional transforms to apply to loaded images
processing_service_name: Name of the processing service
"""
super().__init__()
self.base_url = base_url
self.auth_token = auth_token
self.job_id = job_id
self.batch_size = batch_size
self.image_transforms = image_transforms or torchvision.transforms.ToTensor()
self.processing_service_name = processing_service_name

# These are created lazily in _ensure_sessions() because they contain
# unpicklable objects (ThreadPoolExecutor has a SimpleQueue) and
Expand Down Expand Up @@ -170,15 +168,14 @@ def _fetch_tasks(self) -> list[AntennaPipelineProcessingTask]:
Raises:
requests.RequestException: If the request fails (network error, etc.)
"""
url = f"{self.base_url.rstrip('/')}/jobs/{self.job_id}/tasks"
params = {
"batch": self.batch_size,
"processing_service_name": self.processing_service_name,
}
url = f"{self.base_url.rstrip('/')}/jobs/{self.job_id}/tasks/"
request_body = AntennaTasksRequest(batch_size=self.batch_size)

self._ensure_sessions()
assert self._api_session is not None
response = self._api_session.get(url, params=params, timeout=30)
response = self._api_session.post(
url, json=request_body.model_dump(), timeout=30
)
response.raise_for_status()

# Parse and validate response with Pydantic
Expand Down Expand Up @@ -410,7 +407,6 @@ def _no_op_collate_fn(batch: list[dict]) -> dict:
def get_rest_dataloader(
job_id: int,
settings: "Settings",
processing_service_name: str,
) -> torch.utils.data.DataLoader:
"""Create a DataLoader that fetches tasks from Antenna API.

Expand All @@ -427,14 +423,12 @@ def get_rest_dataloader(
- antenna_api_base_url / antenna_api_auth_token
- antenna_api_batch_size (tasks per API call and GPU batch size)
- num_workers (DataLoader subprocesses)
- processing_service_name (name of this worker)
"""
dataset = RESTDataset(
base_url=settings.antenna_api_base_url,
auth_token=settings.antenna_api_auth_token,
job_id=job_id,
batch_size=settings.antenna_api_batch_size,
processing_service_name=processing_service_name,
)

return torch.utils.data.DataLoader(
Expand Down
14 changes: 3 additions & 11 deletions trapdata/antenna/result_posting.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

Usage:
poster = ResultPoster(max_pending=5)
poster.post_async(base_url, auth_token, job_id, results, service_name)
poster.post_async(base_url, auth_token, job_id, results)
metrics = poster.get_metrics()
poster.shutdown()
"""
Expand All @@ -23,7 +23,6 @@
import time
from concurrent.futures import FIRST_COMPLETED, Future, ThreadPoolExecutor, wait
from dataclasses import dataclass
from typing import Optional

from trapdata.antenna.client import post_batch_results
from trapdata.common.logs import logger
Expand Down Expand Up @@ -61,7 +60,7 @@ class ResultPoster:

Example:
poster = ResultPoster(max_pending=10)
poster.post_async(base_url, auth_token, job_id, results, service_name)
poster.post_async(base_url, auth_token, job_id, results)
metrics = poster.get_metrics()
poster.shutdown()
"""
Expand All @@ -86,7 +85,6 @@ def post_async(
auth_token: str,
job_id: int,
results: list,
processing_service_name: str,
) -> None:
"""Post results asynchronously with backpressure control.

Expand All @@ -98,7 +96,6 @@ def post_async(
auth_token: API authentication token
job_id: Job ID for the results
results: List of result objects to post
processing_service_name: Name of the processing service
"""
# Clean up completed futures and update metrics
self._cleanup_completed_futures()
Expand Down Expand Up @@ -140,7 +137,6 @@ def post_async(
auth_token,
job_id,
results,
processing_service_name,
start_time,
)
self.pending_futures.append(future)
Expand All @@ -156,7 +152,6 @@ def _post_with_timing(
auth_token: str,
job_id: int,
results: list,
processing_service_name: str,
start_time: float,
) -> bool:
"""Internal method that times the post operation and updates metrics.
Expand All @@ -166,16 +161,13 @@ def _post_with_timing(
auth_token: API authentication token
job_id: Job ID for the results
results: List of result objects to post
processing_service_name: Name of the processing service
start_time: Timestamp when the post was initiated

Returns:
True if successful, False otherwise
"""
try:
success = post_batch_results(
base_url, auth_token, job_id, results, processing_service_name
)
success = post_batch_results(base_url, auth_token, job_id, results)
elapsed_time = time.time() - start_time

with self._metrics_lock:
Expand Down
25 changes: 24 additions & 1 deletion trapdata/antenna/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,14 @@ class AntennaJobsListResponse(pydantic.BaseModel):
results: list[AntennaJobListItem]


class AntennaTasksRequest(pydantic.BaseModel):
"""Request body for POST /api/v2/jobs/{job_id}/tasks/."""

batch_size: int = pydantic.Field(gt=0)


class AntennaTasksListResponse(pydantic.BaseModel):
"""Response from Antenna API GET /api/v2/jobs/{job_id}/tasks."""
"""Response from Antenna API POST /api/v2/jobs/{job_id}/tasks/."""

tasks: list[AntennaPipelineProcessingTask]

Expand All @@ -60,6 +66,23 @@ class AntennaTaskResults(pydantic.BaseModel):
results: list[AntennaTaskResult] = pydantic.Field(default_factory=list)


class QueuedTaskAcknowledgment(pydantic.BaseModel):
"""Acknowledgment for a single result queued for background processing."""

reply_subject: str
status: str
task_id: str


class AntennaResultPostResponse(pydantic.BaseModel):
"""Response from POST /api/v2/jobs/{job_id}/result/."""

status: str
job_id: int
results_queued: int
tasks: list[QueuedTaskAcknowledgment] = pydantic.Field(default_factory=list)


class AsyncPipelineRegistrationRequest(pydantic.BaseModel):
"""
Request to register pipelines from an async processing service
Expand Down
Loading
Loading