Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions trapdata/antenna/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def run_benchmark(
batch_size: int,
gpu_batch_size: int,
service_name: str,
send_acks: bool = True,
) -> None:
"""Run the benchmark with the specified parameters.

Expand Down Expand Up @@ -132,8 +133,8 @@ def run_benchmark(
ack_result = create_empty_result(reply_subject, image_id)
ack_results.append(ack_result)

logger.info(f"Sending {len(ack_results)} acknowledgment(s)")
if ack_results:
if ack_results and send_acks:
logger.info(f"Sending {len(ack_results)} acknowledgment(s)")
# Send acknowledgments asynchronously
result_poster.post_async(
base_url=base_url,
Expand All @@ -157,7 +158,7 @@ def run_benchmark(
)
error_results.append(error_result)

if error_results:
if error_results and send_acks:
result_poster.post_async(
base_url=base_url,
auth_token=auth_token,
Expand Down Expand Up @@ -280,6 +281,11 @@ def main() -> int:
default="Performance Test",
help="Processing service name",
)
parser.add_argument(
"--skip-acks",
action="store_false",
help="Skip sending acknowledgments for processed images",
)

args = parser.parse_args()

Expand All @@ -298,6 +304,7 @@ def main() -> int:
batch_size=args.batch_size,
gpu_batch_size=args.gpu_batch_size,
service_name=args.service_name,
send_acks=args.skip_acks,
)
return 0

Expand Down
71 changes: 63 additions & 8 deletions trapdata/antenna/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
localization_batch_size (default 8)
How many images the GPU processes at once (detection). Larger =
more GPU memory. These are full-resolution images (~4K).
Async worker use antennna_api_batch_size for this.

num_workers (default 4)
DataLoader subprocesses. Each independently fetches tasks and
Expand Down Expand Up @@ -254,7 +255,7 @@ def __iter__(self):

Each API fetch returns a batch of tasks. Images for the entire batch
are downloaded concurrently using threads (see _load_images_threaded),
then yielded one at a time for the DataLoader to collate.
then as a pre-collated batch.

Yields:
Dictionary containing:
Expand Down Expand Up @@ -293,7 +294,7 @@ def __iter__(self):

# Download all images concurrently
image_map = self._load_images_threaded(tasks)

pre_batch = []
for task in tasks:
image_tensor = image_map.get(task.image_id)
errors = []
Expand All @@ -315,7 +316,11 @@ def __iter__(self):
}
if errors:
row["error"] = "; ".join(errors) if errors else None
yield row
pre_batch.append(row)
batch = rest_collate_fn(
pre_batch
) # Collate before yielding to GPU process
yield batch

logger.debug(f"Worker {worker_id}: Iterator finished")
except Exception as e:
Expand Down Expand Up @@ -360,7 +365,7 @@ def rest_collate_fn(batch: list[dict]) -> dict:
# Collate successful items
if successful:
result = {
"images": [item["image"] for item in successful],
"images": torch.stack([item["image"] for item in successful]),
"reply_subjects": [item["reply_subject"] for item in successful],
"image_ids": [item["image_id"] for item in successful],
"image_urls": [item.get("image_url") for item in successful],
Expand All @@ -377,6 +382,17 @@ def rest_collate_fn(batch: list[dict]) -> dict:
return result


def _no_op_collate_fn(batch: list[dict]) -> dict:
"""
A no-op collate function that unwraps a single-element batch.

This can be used when the dataset already returns batches in the desired format,
and no further collation is needed. It simply returns the input list of dicts
without modification.
"""
return batch[0]


def get_rest_dataloader(
job_id: int,
settings: "Settings",
Expand All @@ -395,8 +411,7 @@ def get_rest_dataloader(
job_id: Job ID to fetch tasks for
settings: Settings object. Relevant fields:
- antenna_api_base_url / antenna_api_auth_token
- antenna_api_batch_size (tasks per API call)
- localization_batch_size (images per GPU batch)
- antenna_api_batch_size (tasks per API call and GPU batch size)
- num_workers (DataLoader subprocesses)
- processing_service_name (name of this worker)
"""
Expand All @@ -410,7 +425,47 @@ def get_rest_dataloader(

return torch.utils.data.DataLoader(
dataset,
batch_size=settings.localization_batch_size,
batch_size=1, # We collate manually in rest_collate_fn, so set batch_size=1 here
num_workers=settings.num_workers,
collate_fn=rest_collate_fn,
collate_fn=_no_op_collate_fn,
pin_memory=True,
persistent_workers=settings.num_workers > 0,
prefetch_factor=4 if settings.num_workers > 0 else None,
)


class CUDAPrefetcher:
def __init__(self, loader: torch.utils.data.DataLoader, device: torch.device):
self.loader = iter(loader)
self.stream = torch.cuda.Stream()
self.device = device
self.next_batch = None
self._preload()

def _preload(self):
try:
batch = next(self.loader)
except StopIteration:
self.next_batch = None
return

with torch.cuda.stream(self.stream):
self.next_batch = {
k: (
v.to(self.device, non_blocking=True)
if isinstance(v, torch.Tensor)
else v
)
for k, v in batch.items()
}

def __iter__(self):
return self

def __next__(self):
torch.cuda.current_stream().wait_stream(self.stream)
batch = self.next_batch
if batch is None:
raise StopIteration
self._preload()
return batch
19 changes: 15 additions & 4 deletions trapdata/antenna/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,11 @@ def test_multiple_batches(self):
dataset = self._make_dataset(job_id=4, batch_size=2)
rows = list(dataset)

# Should get all 3 images (batch1: 2 images, batch2: 1 image)
assert len(rows) == 3
assert all(r["image"] is not None for r in rows)
# Dataset now yields pre-collated batches: batch1 (2 images), batch2 (1 image)
assert len(rows) == 2
total_images = sum(len(r["image_ids"]) for r in rows)
assert total_images == 3
assert all(r["images"] is not None for r in rows)


# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -272,6 +274,7 @@ def test_empty_queue(self):
100,
self._make_settings(),
"Test Service",
device=torch.device("cpu"),
)

assert result is False
Expand Down Expand Up @@ -300,6 +303,7 @@ def test_processes_batch_with_real_inference(self):
101,
self._make_settings(),
"Test Service",
device=torch.device("cpu"),
)

# Validate processing succeeded
Expand Down Expand Up @@ -339,6 +343,7 @@ def test_handles_failed_items(self):
102,
self._make_settings(),
"Test Service",
device=torch.device("cpu"),
)

posted_results = antenna_api_server.get_posted_results(102)
Expand Down Expand Up @@ -375,6 +380,7 @@ def test_mixed_batch_success_and_failures(self):
103,
self._make_settings(),
"Test Service",
device=torch.device("cpu"),
)

assert result is True
Expand Down Expand Up @@ -475,7 +481,11 @@ def test_full_workflow_with_real_inference(self):

# Step 3: Process job
result = _process_job(
pipeline_slug, 200, self._make_settings(), "Test Worker"
pipeline_slug,
200,
self._make_settings(),
"Test Worker",
device=torch.device("cpu"),
)
assert result is True

Expand Down Expand Up @@ -527,6 +537,7 @@ def test_multiple_batches_processed(self):
201,
self._make_settings(),
"Test Service",
device=torch.device("cpu"),
)

assert result is True
Expand Down
Loading