Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions ami/jobs/schemas.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
import pydantic
from drf_spectacular.utils import OpenApiParameter


class QueuedTaskAcknowledgment(pydantic.BaseModel):
"""Acknowledgment for a single result that was queued for background processing."""

reply_subject: str
status: str
task_id: str


ids_only_param = OpenApiParameter(
name="ids_only",
description="Return only job IDs instead of full objects",
Expand All @@ -13,10 +23,3 @@
required=False,
type=bool,
)

batch_param = OpenApiParameter(
name="batch",
description="Number of tasks to retrieve",
required=False,
type=int,
)
52 changes: 52 additions & 0 deletions ami/jobs/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
)
from ami.main.models import Deployment, Project, SourceImage, SourceImageCollection
from ami.ml.models import Pipeline
from ami.ml.schemas import PipelineProcessingTask, PipelineTaskResult, ProcessingServiceClientInfo
from ami.ml.serializers import PipelineNestedSerializer

from .models import Job, JobLogs, JobProgress, MLJob
from .schemas import QueuedTaskAcknowledgment


class JobProjectNestedSerializer(DefaultSerializer):
Expand Down Expand Up @@ -163,3 +165,53 @@ class MinimalJobSerializer(DefaultSerializer):
class Meta:
model = Job
fields = ["id", "pipeline_slug"]


class MLJobTasksRequestSerializer(serializers.Serializer):
"""POST /jobs/{id}/tasks/ — request body sent by a processing service to fetch work.

The processing service polls this endpoint to get tasks (images) to process.
Each task is a PipelineProcessingTask with an image URL and a NATS reply subject.
"""

batch_size = serializers.IntegerField(min_value=1, required=True)
client_info = SchemaField(schema=ProcessingServiceClientInfo, required=False, default=None)


class MLJobTasksResponseSerializer(serializers.Serializer):
"""POST /jobs/{id}/tasks/ — response body returned to the processing service.

Contains a list of tasks (PipelineProcessingTask dicts) for the worker to process.
Each task includes an image URL, task ID, and reply_subject for result correlation.
Returns an empty list when no tasks are available or the job is not active.
"""

tasks = SchemaField(schema=list[PipelineProcessingTask], default=[])


class MLJobResultsRequestSerializer(serializers.Serializer):
"""POST /jobs/{id}/result/ — request body sent by a processing service to deliver results.

"Request" here refers to the HTTP request to Antenna, not a request for work.
The processing service has finished processing tasks and is posting its results
(successes or errors) back. Each PipelineTaskResult contains a reply_subject
(correlating back to the original task) and a result payload that is either a
PipelineResultsResponse (success) or PipelineResultsError (failure).
"""

results = SchemaField(schema=list[PipelineTaskResult])
client_info = SchemaField(schema=ProcessingServiceClientInfo, required=False, default=None)


class MLJobResultsResponseSerializer(serializers.Serializer):
"""POST /jobs/{id}/result/ — acknowledgment returned to the processing service.

Confirms receipt and indicates how many results were queued for background
processing via Celery. Individual task entries include their Celery task_id
for traceability.
"""

status = serializers.CharField()
job_id = serializers.IntegerField()
results_queued = serializers.IntegerField()
tasks = SchemaField(schema=list[QueuedTaskAcknowledgment], default=[])
61 changes: 29 additions & 32 deletions ami/jobs/tests/test_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,10 +489,8 @@ def _task_batch_helper(self, value: Any, expected_status: int):
queue_images_to_nats(job, images)

self.client.force_authenticate(user=self.user)
tasks_url = reverse_with_params(
"api:job-tasks", args=[job.pk], params={"project_id": self.project.pk, "batch": value}
)
resp = self.client.get(tasks_url)
tasks_url = reverse_with_params("api:job-tasks", args=[job.pk], params={"project_id": self.project.pk})
resp = self.client.post(tasks_url, {"batch_size": value}, format="json")
self.assertEqual(resp.status_code, expected_status)
return resp.json()

Expand Down Expand Up @@ -523,10 +521,8 @@ def test_tasks_endpoint_without_pipeline(self):
)

self.client.force_authenticate(user=self.user)
tasks_url = reverse_with_params(
"api:job-tasks", args=[job.pk], params={"project_id": self.project.pk, "batch": 1}
)
resp = self.client.get(tasks_url)
tasks_url = reverse_with_params("api:job-tasks", args=[job.pk], params={"project_id": self.project.pk})
resp = self.client.post(tasks_url, {"batch_size": 1}, format="json")

self.assertEqual(resp.status_code, 400)
self.assertIn("pipeline", resp.json()[0].lower())
Expand All @@ -537,23 +533,23 @@ def test_result_endpoint_stub(self):
job = self._create_ml_job("Job for results test", pipeline)

self.client.force_authenticate(user=self.user)
result_url = reverse_with_params(
"api:job-result", args=[job.pk], params={"project_id": self.project.pk, "batch": 1}
)
result_url = reverse_with_params("api:job-result", args=[job.pk], params={"project_id": self.project.pk})

result_data = [
{
"reply_subject": "test.reply.1",
"result": {
"pipeline": "test-pipeline",
"algorithms": {},
"total_time": 1.5,
"source_images": [],
"detections": [],
"errors": None,
},
}
]
result_data = {
"results": [
{
"reply_subject": "test.reply.1",
"result": {
"pipeline": "test-pipeline",
"algorithms": {},
"total_time": 1.5,
"source_images": [],
"detections": [],
"errors": None,
},
}
]
}

resp = self.client.post(result_url, result_data, format="json")

Expand All @@ -572,16 +568,19 @@ def test_result_endpoint_validation(self):
result_url = reverse_with_params("api:job-result", args=[job.pk], params={"project_id": self.project.pk})

# Test with missing reply_subject
invalid_data = [{"result": {"pipeline": "test"}}]
invalid_data = {"results": [{"result": {"pipeline": "test"}}]}
resp = self.client.post(result_url, invalid_data, format="json")
self.assertEqual(resp.status_code, 400)
self.assertIn("reply_subject", resp.json()[0].lower())

# Test with missing result
invalid_data = [{"reply_subject": "test.reply"}]
invalid_data = {"results": [{"reply_subject": "test.reply"}]}
resp = self.client.post(result_url, invalid_data, format="json")
self.assertEqual(resp.status_code, 400)
self.assertIn("result", resp.json()[0].lower())

# Test with bare list (no longer accepted)
bare_list = [{"reply_subject": "test.reply", "result": {"pipeline": "test"}}]
resp = self.client.post(result_url, bare_list, format="json")
self.assertEqual(resp.status_code, 400)


class TestJobDispatchModeFiltering(APITestCase):
Expand Down Expand Up @@ -722,9 +721,7 @@ def test_tasks_endpoint_rejects_non_async_jobs(self):
)

self.client.force_authenticate(user=self.user)
tasks_url = reverse_with_params(
"api:job-tasks", args=[sync_job.pk], params={"project_id": self.project.pk, "batch": 1}
)
resp = self.client.get(tasks_url)
tasks_url = reverse_with_params("api:job-tasks", args=[sync_job.pk], params={"project_id": self.project.pk})
resp = self.client.post(tasks_url, {"batch_size": 1}, format="json")
self.assertEqual(resp.status_code, 400)
self.assertIn("async_api", resp.json()[0].lower())
22 changes: 12 additions & 10 deletions ami/jobs/tests/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,16 +384,18 @@ def test_result_endpoint_with_error_result(self, mock_apply_async):
self.client.force_authenticate(user=self.user)
result_url = reverse_with_params("api:job-result", args=[self.job.pk], params={"project_id": self.project.pk})

# Create error result data
result_data = [
{
"reply_subject": "test.reply.error.1",
"result": {
"error": "Image processing timeout",
"image_id": str(self.image.pk),
},
}
]
# Create error result data (wrapped format)
result_data = {
"results": [
{
"reply_subject": "test.reply.error.1",
"result": {
"error": "Image processing timeout",
"image_id": str(self.image.pk),
},
}
]
}

# POST error result to API
resp = self.client.post(result_url, result_data, format="json")
Expand Down
79 changes: 41 additions & 38 deletions ami/jobs/views.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import asyncio
import logging

import kombu.exceptions
import nats.errors
import pydantic
from asgiref.sync import async_to_sync
from django.db.models import Q
from django.db.models.query import QuerySet
Expand All @@ -17,11 +17,16 @@

from ami.base.permissions import ObjectPermission
from ami.base.views import ProjectMixin
from ami.jobs.schemas import batch_param, ids_only_param, incomplete_only_param
from ami.jobs.schemas import ids_only_param, incomplete_only_param
from ami.jobs.serializers import (
MLJobResultsRequestSerializer,
MLJobResultsResponseSerializer,
MLJobTasksRequestSerializer,
MLJobTasksResponseSerializer,
)
from ami.jobs.tasks import process_nats_pipeline_result
from ami.main.api.schemas import project_id_doc_param
from ami.main.api.views import DefaultViewSet
from ami.ml.schemas import PipelineTaskResult
from ami.utils.fields import url_boolean_param

from .models import Job, JobDispatchMode, JobState
Expand Down Expand Up @@ -238,24 +243,25 @@ def list(self, request, *args, **kwargs):
return super().list(request, *args, **kwargs)

@extend_schema(
parameters=[batch_param],
responses={200: dict},
request=MLJobTasksRequestSerializer,
responses={200: MLJobTasksResponseSerializer},
parameters=[project_id_doc_param],
)
@action(detail=True, methods=["get"], name="tasks")
@action(detail=True, methods=["post"], name="tasks")
def tasks(self, request, pk=None):
"""
Get tasks from the job queue.
Fetch tasks from the job queue (POST).

Returns task data with reply_subject for acknowledgment. External workers should:
1. Call this endpoint to get tasks
1. POST to this endpoint with {"batch_size": N}
2. Process the tasks
3. POST to /jobs/{id}/result/ with the reply_subject to acknowledge
3. POST to /jobs/{id}/result/ with the results
"""
serializer = MLJobTasksRequestSerializer(data=request.data)
serializer.is_valid(raise_exception=True)
batch_size = serializer.validated_data["batch_size"]

job: Job = self.get_object()
try:
batch = IntegerField(required=True, min_value=1).clean(request.query_params.get("batch"))
except Exception as e:
raise ValidationError({"batch": str(e)}) from e

# Only async_api jobs have tasks fetchable from NATS
if job.dispatch_mode != JobDispatchMode.ASYNC_API:
Expand All @@ -277,7 +283,7 @@ def tasks(self, request, pk=None):

async def get_tasks():
async with TaskQueueManager() as manager:
return [task.dict() for task in await manager.reserve_tasks(job.pk, count=batch, timeout=0.5)]
return [task.dict() for task in await manager.reserve_tasks(job.pk, count=batch_size, timeout=0.5)]

try:
tasks = async_to_sync(get_tasks)()
Expand All @@ -287,35 +293,31 @@ async def get_tasks():

return Response({"tasks": tasks})

@extend_schema(
request=MLJobResultsRequestSerializer,
responses={200: MLJobResultsResponseSerializer},
parameters=[project_id_doc_param],
)
@action(detail=True, methods=["post"], name="result")
def result(self, request, pk=None):
"""
The request body should be a list of results: list[PipelineTaskResult]
Submit pipeline results.

This endpoint accepts a list of pipeline results and queues them for
background processing. Each result will be validated, saved to the database,
and acknowledged via NATS in a Celery task.
Accepts: {"results": [PipelineTaskResult, ...]}

Results are validated then queued for background processing via Celery.
"""

job = self.get_object()

# Record heartbeat for async processing services on this pipeline
_mark_pipeline_pull_services_seen(job)

# Validate request data is a list
if isinstance(request.data, list):
results = request.data
else:
results = [request.data]
serializer = MLJobResultsRequestSerializer(data=request.data)
serializer.is_valid(raise_exception=True)
validated_results = serializer.validated_data["results"]

try:
# Pre-validate all results before enqueuing any tasks
# This prevents partial queueing and duplicate task processing
validated_results = []
for item in results:
task_result = PipelineTaskResult(**item)
validated_results.append(task_result)

# All validation passed, now queue all tasks
queued_tasks = []
for task_result in validated_results:
Expand All @@ -337,27 +339,28 @@ def result(self, request, pk=None):
)

logger.info(
f"Queued pipeline result processing for job {job.pk}, "
f"task_id: {task.id}, reply_subject: {reply_subject}"
"Queued pipeline result for job %s, task_id: %s, reply_subject: %s",
job.pk,
task.id,
reply_subject,
)

return Response(
{
"status": "accepted",
"job_id": job.pk,
"results_queued": len([t for t in queued_tasks if t["status"] == "queued"]),
"results_queued": len(queued_tasks),
"tasks": queued_tasks,
}
)
except pydantic.ValidationError as e:
raise ValidationError(f"Invalid result data: {e}") from e

except Exception as e:
logger.error(f"Failed to queue pipeline results for job {job.pk}: {e}")
except (OSError, kombu.exceptions.KombuError) as e:
logger.error("Failed to queue pipeline results for job %s: %s", job.pk, e)
return Response(
{
"status": "error",
"job_id": job.pk,
"detail": "Task queue temporarily unavailable",
},
status=500,
status=503,
)
Loading
Loading