From 039d714463d7336f303db9a733f03993b3ec3d8e Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Wed, 1 Apr 2026 10:38:53 -0700 Subject: [PATCH 01/12] refactor: convert /tasks endpoint from GET to POST with TasksRequestSerializer Convert the /jobs/{id}/tasks/ endpoint from GET with query params to POST with request body. This allows for validated request data including optional client_info field, better for processing service workers that need to report their identity. Changes: - Add ami/ml/serializers_client_info.py with ClientInfoSerializer and get_client_info() helper - Add TasksRequestSerializer to ami/jobs/schemas.py for validating batch and client_info - Update /tasks action from GET to POST with serializer-based validation - Support wrapped format in /result endpoint: {"client_info": {...}, "results": [...]} - Update logger calls to use %s format (more efficient) - Update all tests to use POST instead of GET - Add TestClientInfoSerializer test class with validation tests No changes to _mark_pipeline_pull_services_seen() or permission_classes. Co-Authored-By: Claude --- ami/jobs/schemas.py | 8 +++++ ami/jobs/tests/test_jobs.py | 18 ++++------- ami/jobs/views.py | 49 ++++++++++++++++-------------- ami/ml/serializers_client_info.py | 50 +++++++++++++++++++++++++++++++ ami/ml/tests.py | 40 +++++++++++++++++++++++++ 5 files changed, 131 insertions(+), 34 deletions(-) create mode 100644 ami/ml/serializers_client_info.py diff --git a/ami/jobs/schemas.py b/ami/jobs/schemas.py index 0e1ea4ac7..5bd3dc2e9 100644 --- a/ami/jobs/schemas.py +++ b/ami/jobs/schemas.py @@ -1,4 +1,7 @@ from drf_spectacular.utils import OpenApiParameter +from rest_framework import serializers + +from ami.ml.serializers_client_info import ClientInfoSerializer ids_only_param = OpenApiParameter( name="ids_only", @@ -20,3 +23,8 @@ required=False, type=int, ) + + +class TasksRequestSerializer(serializers.Serializer): + batch = serializers.IntegerField(min_value=1, required=True) + client_info = ClientInfoSerializer(required=False) diff --git a/ami/jobs/tests/test_jobs.py b/ami/jobs/tests/test_jobs.py index 7f2607bfe..5b2f05ba7 100644 --- a/ami/jobs/tests/test_jobs.py +++ b/ami/jobs/tests/test_jobs.py @@ -489,10 +489,8 @@ def _task_batch_helper(self, value: Any, expected_status: int): queue_images_to_nats(job, images) self.client.force_authenticate(user=self.user) - tasks_url = reverse_with_params( - "api:job-tasks", args=[job.pk], params={"project_id": self.project.pk, "batch": value} - ) - resp = self.client.get(tasks_url) + tasks_url = reverse_with_params("api:job-tasks", args=[job.pk], params={"project_id": self.project.pk}) + resp = self.client.post(tasks_url, {"batch": value}, format="json") self.assertEqual(resp.status_code, expected_status) return resp.json() @@ -523,10 +521,8 @@ def test_tasks_endpoint_without_pipeline(self): ) self.client.force_authenticate(user=self.user) - tasks_url = reverse_with_params( - "api:job-tasks", args=[job.pk], params={"project_id": self.project.pk, "batch": 1} - ) - resp = self.client.get(tasks_url) + tasks_url = reverse_with_params("api:job-tasks", args=[job.pk], params={"project_id": self.project.pk}) + resp = self.client.post(tasks_url, {"batch": 1}, format="json") self.assertEqual(resp.status_code, 400) self.assertIn("pipeline", resp.json()[0].lower()) @@ -722,9 +718,7 @@ def test_tasks_endpoint_rejects_non_async_jobs(self): ) self.client.force_authenticate(user=self.user) - tasks_url = reverse_with_params( - "api:job-tasks", args=[sync_job.pk], params={"project_id": self.project.pk, "batch": 1} - ) - resp = self.client.get(tasks_url) + tasks_url = reverse_with_params("api:job-tasks", args=[sync_job.pk], params={"project_id": self.project.pk}) + resp = self.client.post(tasks_url, {"batch": 1}, format="json") self.assertEqual(resp.status_code, 400) self.assertIn("async_api", resp.json()[0].lower()) diff --git a/ami/jobs/views.py b/ami/jobs/views.py index 832e15f30..ab32da33d 100644 --- a/ami/jobs/views.py +++ b/ami/jobs/views.py @@ -17,7 +17,7 @@ from ami.base.permissions import ObjectPermission from ami.base.views import ProjectMixin -from ami.jobs.schemas import batch_param, ids_only_param, incomplete_only_param +from ami.jobs.schemas import TasksRequestSerializer, ids_only_param, incomplete_only_param from ami.jobs.tasks import process_nats_pipeline_result from ami.main.api.schemas import project_id_doc_param from ami.main.api.views import DefaultViewSet @@ -238,24 +238,24 @@ def list(self, request, *args, **kwargs): return super().list(request, *args, **kwargs) @extend_schema( - parameters=[batch_param], + request=TasksRequestSerializer, responses={200: dict}, ) - @action(detail=True, methods=["get"], name="tasks") + @action(detail=True, methods=["post"], name="tasks") def tasks(self, request, pk=None): """ - Get tasks from the job queue. + Fetch tasks from the job queue (POST). Returns task data with reply_subject for acknowledgment. External workers should: - 1. Call this endpoint to get tasks + 1. POST to this endpoint with {"batch": N, "client_info": {...}} 2. Process the tasks - 3. POST to /jobs/{id}/result/ with the reply_subject to acknowledge + 3. POST to /jobs/{id}/result/ with the results """ + serializer = TasksRequestSerializer(data=request.data) + serializer.is_valid(raise_exception=True) + batch = serializer.validated_data["batch"] + job: Job = self.get_object() - try: - batch = IntegerField(required=True, min_value=1).clean(request.query_params.get("batch")) - except Exception as e: - raise ValidationError({"batch": str(e)}) from e # Only async_api jobs have tasks fetchable from NATS if job.dispatch_mode != JobDispatchMode.ASYNC_API: @@ -290,11 +290,12 @@ async def get_tasks(): @action(detail=True, methods=["post"], name="result") def result(self, request, pk=None): """ - The request body should be a list of results: list[PipelineTaskResult] + Submit pipeline results. + + Accepts: {"client_info": {...}, "results": [PipelineTaskResult, ...]} + Or legacy: [PipelineTaskResult, ...] (bare list) - This endpoint accepts a list of pipeline results and queues them for - background processing. Each result will be validated, saved to the database, - and acknowledged via NATS in a Celery task. + Results are validated then queued for background processing via Celery. """ job = self.get_object() @@ -302,17 +303,19 @@ def result(self, request, pk=None): # Record heartbeat for async processing services on this pipeline _mark_pipeline_pull_services_seen(job) - # Validate request data is a list + # Accept both wrapped format and legacy bare list if isinstance(request.data, list): - results = request.data + raw_results = request.data + elif isinstance(request.data, dict) and "results" in request.data: + raw_results = request.data["results"] else: - results = [request.data] + raw_results = [request.data] try: # Pre-validate all results before enqueuing any tasks # This prevents partial queueing and duplicate task processing validated_results = [] - for item in results: + for item in raw_results: task_result = PipelineTaskResult(**item) validated_results.append(task_result) @@ -337,15 +340,17 @@ def result(self, request, pk=None): ) logger.info( - f"Queued pipeline result processing for job {job.pk}, " - f"task_id: {task.id}, reply_subject: {reply_subject}" + "Queued pipeline result for job %s, task_id: %s, reply_subject: %s", + job.pk, + task.id, + reply_subject, ) return Response( { "status": "accepted", "job_id": job.pk, - "results_queued": len([t for t in queued_tasks if t["status"] == "queued"]), + "results_queued": len(queued_tasks), "tasks": queued_tasks, } ) @@ -353,7 +358,7 @@ def result(self, request, pk=None): raise ValidationError(f"Invalid result data: {e}") from e except Exception as e: - logger.error(f"Failed to queue pipeline results for job {job.pk}: {e}") + logger.error("Failed to queue pipeline results for job %s: %s", job.pk, e) return Response( { "status": "error", diff --git a/ami/ml/serializers_client_info.py b/ami/ml/serializers_client_info.py new file mode 100644 index 000000000..e4a804896 --- /dev/null +++ b/ami/ml/serializers_client_info.py @@ -0,0 +1,50 @@ +from rest_framework import serializers + + +class ClientInfoSerializer(serializers.Serializer): + """ + Validated client_info from processing service requests. + + Client-reported fields (all optional): + hostname, software, version, platform, pod_name, extra + + Server-observed fields (added by get_client_info(), not sent by client): + ip, user_agent + """ + + hostname = serializers.CharField(max_length=255, required=False, default="") + software = serializers.CharField(max_length=100, required=False, default="") + version = serializers.CharField(max_length=50, required=False, default="") + platform = serializers.CharField(max_length=100, required=False, default="") + pod_name = serializers.CharField(max_length=255, required=False, default="") + extra = serializers.DictField(required=False, default=dict) + + +def get_client_info(request) -> dict: + """ + Extract client_info from request body, merged with server-observed values. + + Server-observed fields (ip, user_agent) always come from the server and + cannot be spoofed by the client. + Client-reported fields come from request.data["client_info"] when provided. + Handles bare-list payloads (legacy /result format) gracefully. + """ + data = request.data if isinstance(request.data, dict) else {} + raw = data.get("client_info") or {} + serializer = ClientInfoSerializer(data=raw) + if serializer.is_valid(): + info = serializer.validated_data + else: + info = {} + + # Always overwrite server-observed fields to prevent client spoofing + info["ip"] = _get_client_ip(request) + info["user_agent"] = request.headers.get("user-agent", "") + return info + + +def _get_client_ip(request) -> str: + forwarded = request.headers.get("x-forwarded-for") + if forwarded: + return forwarded.split(",")[0].strip() + return request.META.get("REMOTE_ADDR", "unknown") diff --git a/ami/ml/tests.py b/ami/ml/tests.py index 36ba5b5f7..022a671b7 100644 --- a/ami/ml/tests.py +++ b/ami/ml/tests.py @@ -1366,3 +1366,43 @@ def test_cleanup_removes_failed_set(self): # Verify all state is gone (get_progress returns None when total_key is deleted) progress = self.manager.get_progress("process") self.assertIsNone(progress) + + +class TestClientInfoSerializer(TestCase): + def test_valid_client_info(self): + from ami.ml.serializers_client_info import ClientInfoSerializer + + data = { + "hostname": "cedar-node-01", + "software": "ami-data-companion", + "version": "2.1.0", + "platform": "Linux x86_64", + } + s = ClientInfoSerializer(data=data) + self.assertTrue(s.is_valid(), s.errors) + self.assertEqual(s.validated_data["hostname"], "cedar-node-01") + + def test_empty_client_info_is_valid(self): + from ami.ml.serializers_client_info import ClientInfoSerializer + + s = ClientInfoSerializer(data={}) + self.assertTrue(s.is_valid(), s.errors) + + def test_extra_fields_in_extra_dict(self): + from ami.ml.serializers_client_info import ClientInfoSerializer + + data = { + "hostname": "node-01", + "extra": {"gpu": "A100", "cuda": "12.0"}, + } + s = ClientInfoSerializer(data=data) + self.assertTrue(s.is_valid(), s.errors) + self.assertEqual(s.validated_data["extra"]["gpu"], "A100") + + def test_hostname_max_length_enforced(self): + from ami.ml.serializers_client_info import ClientInfoSerializer + + data = {"hostname": "x" * 256} + s = ClientInfoSerializer(data=data) + self.assertFalse(s.is_valid()) + self.assertIn("hostname", s.errors) From 0f96b332f08a05ed6cb6036cb685915f6fbb4117 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Wed, 1 Apr 2026 12:42:29 -0700 Subject: [PATCH 02/12] refactor: remove client_info from tasks PR, add review fixes - Remove ClientInfoSerializer module and tests (belongs in backend auth/identity branch) - Remove client_info field from TasksRequestSerializer - Remove unused batch_param OpenAPI parameter (dead code) - Add project_id_doc_param to /tasks/ extend_schema - Validate 'results' is a list in wrapped /result/ format - Clean up docstrings Co-Authored-By: Claude --- ami/jobs/schemas.py | 10 ------- ami/jobs/views.py | 7 +++-- ami/ml/serializers_client_info.py | 50 ------------------------------- ami/ml/tests.py | 40 ------------------------- 4 files changed, 5 insertions(+), 102 deletions(-) delete mode 100644 ami/ml/serializers_client_info.py diff --git a/ami/jobs/schemas.py b/ami/jobs/schemas.py index 5bd3dc2e9..3d2fe65fa 100644 --- a/ami/jobs/schemas.py +++ b/ami/jobs/schemas.py @@ -1,8 +1,6 @@ from drf_spectacular.utils import OpenApiParameter from rest_framework import serializers -from ami.ml.serializers_client_info import ClientInfoSerializer - ids_only_param = OpenApiParameter( name="ids_only", description="Return only job IDs instead of full objects", @@ -17,14 +15,6 @@ type=bool, ) -batch_param = OpenApiParameter( - name="batch", - description="Number of tasks to retrieve", - required=False, - type=int, -) - class TasksRequestSerializer(serializers.Serializer): batch = serializers.IntegerField(min_value=1, required=True) - client_info = ClientInfoSerializer(required=False) diff --git a/ami/jobs/views.py b/ami/jobs/views.py index ab32da33d..770d40578 100644 --- a/ami/jobs/views.py +++ b/ami/jobs/views.py @@ -240,6 +240,7 @@ def list(self, request, *args, **kwargs): @extend_schema( request=TasksRequestSerializer, responses={200: dict}, + parameters=[project_id_doc_param], ) @action(detail=True, methods=["post"], name="tasks") def tasks(self, request, pk=None): @@ -247,7 +248,7 @@ def tasks(self, request, pk=None): Fetch tasks from the job queue (POST). Returns task data with reply_subject for acknowledgment. External workers should: - 1. POST to this endpoint with {"batch": N, "client_info": {...}} + 1. POST to this endpoint with {"batch": N} 2. Process the tasks 3. POST to /jobs/{id}/result/ with the results """ @@ -292,7 +293,7 @@ def result(self, request, pk=None): """ Submit pipeline results. - Accepts: {"client_info": {...}, "results": [PipelineTaskResult, ...]} + Accepts: {"results": [PipelineTaskResult, ...]} Or legacy: [PipelineTaskResult, ...] (bare list) Results are validated then queued for background processing via Celery. @@ -308,6 +309,8 @@ def result(self, request, pk=None): raw_results = request.data elif isinstance(request.data, dict) and "results" in request.data: raw_results = request.data["results"] + if not isinstance(raw_results, list): + raise ValidationError("'results' must be a list") else: raw_results = [request.data] diff --git a/ami/ml/serializers_client_info.py b/ami/ml/serializers_client_info.py deleted file mode 100644 index e4a804896..000000000 --- a/ami/ml/serializers_client_info.py +++ /dev/null @@ -1,50 +0,0 @@ -from rest_framework import serializers - - -class ClientInfoSerializer(serializers.Serializer): - """ - Validated client_info from processing service requests. - - Client-reported fields (all optional): - hostname, software, version, platform, pod_name, extra - - Server-observed fields (added by get_client_info(), not sent by client): - ip, user_agent - """ - - hostname = serializers.CharField(max_length=255, required=False, default="") - software = serializers.CharField(max_length=100, required=False, default="") - version = serializers.CharField(max_length=50, required=False, default="") - platform = serializers.CharField(max_length=100, required=False, default="") - pod_name = serializers.CharField(max_length=255, required=False, default="") - extra = serializers.DictField(required=False, default=dict) - - -def get_client_info(request) -> dict: - """ - Extract client_info from request body, merged with server-observed values. - - Server-observed fields (ip, user_agent) always come from the server and - cannot be spoofed by the client. - Client-reported fields come from request.data["client_info"] when provided. - Handles bare-list payloads (legacy /result format) gracefully. - """ - data = request.data if isinstance(request.data, dict) else {} - raw = data.get("client_info") or {} - serializer = ClientInfoSerializer(data=raw) - if serializer.is_valid(): - info = serializer.validated_data - else: - info = {} - - # Always overwrite server-observed fields to prevent client spoofing - info["ip"] = _get_client_ip(request) - info["user_agent"] = request.headers.get("user-agent", "") - return info - - -def _get_client_ip(request) -> str: - forwarded = request.headers.get("x-forwarded-for") - if forwarded: - return forwarded.split(",")[0].strip() - return request.META.get("REMOTE_ADDR", "unknown") diff --git a/ami/ml/tests.py b/ami/ml/tests.py index 022a671b7..36ba5b5f7 100644 --- a/ami/ml/tests.py +++ b/ami/ml/tests.py @@ -1366,43 +1366,3 @@ def test_cleanup_removes_failed_set(self): # Verify all state is gone (get_progress returns None when total_key is deleted) progress = self.manager.get_progress("process") self.assertIsNone(progress) - - -class TestClientInfoSerializer(TestCase): - def test_valid_client_info(self): - from ami.ml.serializers_client_info import ClientInfoSerializer - - data = { - "hostname": "cedar-node-01", - "software": "ami-data-companion", - "version": "2.1.0", - "platform": "Linux x86_64", - } - s = ClientInfoSerializer(data=data) - self.assertTrue(s.is_valid(), s.errors) - self.assertEqual(s.validated_data["hostname"], "cedar-node-01") - - def test_empty_client_info_is_valid(self): - from ami.ml.serializers_client_info import ClientInfoSerializer - - s = ClientInfoSerializer(data={}) - self.assertTrue(s.is_valid(), s.errors) - - def test_extra_fields_in_extra_dict(self): - from ami.ml.serializers_client_info import ClientInfoSerializer - - data = { - "hostname": "node-01", - "extra": {"gpu": "A100", "cuda": "12.0"}, - } - s = ClientInfoSerializer(data=data) - self.assertTrue(s.is_valid(), s.errors) - self.assertEqual(s.validated_data["extra"]["gpu"], "A100") - - def test_hostname_max_length_enforced(self): - from ami.ml.serializers_client_info import ClientInfoSerializer - - data = {"hostname": "x" * 256} - s = ClientInfoSerializer(data=data) - self.assertFalse(s.is_valid()) - self.assertIn("hostname", s.errors) From 4cb1b7ffd8dc606cae71f6989ded3698edb87289 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Wed, 1 Apr 2026 17:52:21 -0700 Subject: [PATCH 03/12] refactor: use parse_obj_as for result validation Replace manual loop with pydantic.parse_obj_as() to validate the results list. This catches non-dict items (e.g. int, null) with a proper 400 ValidationError instead of a TypeError falling through to the generic 500 handler. Co-Authored-By: Claude --- ami/jobs/views.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/ami/jobs/views.py b/ami/jobs/views.py index 770d40578..1ba720ddb 100644 --- a/ami/jobs/views.py +++ b/ami/jobs/views.py @@ -317,10 +317,7 @@ def result(self, request, pk=None): try: # Pre-validate all results before enqueuing any tasks # This prevents partial queueing and duplicate task processing - validated_results = [] - for item in raw_results: - task_result = PipelineTaskResult(**item) - validated_results.append(task_result) + validated_results = pydantic.parse_obj_as(list[PipelineTaskResult], raw_results) # All validation passed, now queue all tasks queued_tasks = [] From fd2e705f9a94e7fcf9c2d7d4c38830fbf19b25bf Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Wed, 1 Apr 2026 22:19:47 -0700 Subject: [PATCH 04/12] feat: add ProcessingServiceClientInfo schema Identity metadata for specific processing service instances. Allows tracking which worker/pod is making requests when a single ProcessingService has multiple workers. Co-Authored-By: Claude --- ami/ml/schemas.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/ami/ml/schemas.py b/ami/ml/schemas.py index 7449c59e6..77bf4875f 100644 --- a/ami/ml/schemas.py +++ b/ami/ml/schemas.py @@ -262,6 +262,20 @@ class PipelineProcessingTask(pydantic.BaseModel): # config: PipelineRequestConfigParameters | dict | None = None +class ProcessingServiceClientInfo(pydantic.BaseModel): + """Identity metadata for a specific processing service instance. + + A single ProcessingService may have multiple workers/pods. + This identifies which one is making the request. + """ + + hostname: str = "" + software: str = "" + version: str = "" + platform: str = "" + pod_name: str = "" + + class PipelineTaskResult(pydantic.BaseModel): """ The result from processing a single PipelineProcessingTask. From f9525595993ccaf9c938363431073e92ed4a8479 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Wed, 1 Apr 2026 22:19:55 -0700 Subject: [PATCH 05/12] feat: add DRF serializers for /tasks/ and /result/ endpoints Add typed request/response serializers for both job endpoints: - TasksRequestSerializer (existing, now with client_info) - TasksResponseSerializer - PipelineResultsRequestSerializer (replaces manual isinstance checks) - PipelineResultsResponseSerializer Uses SchemaField to delegate item validation to Pydantic models, keeping DRF for the HTTP envelope and Pydantic for domain schemas. Co-Authored-By: Claude --- ami/jobs/schemas.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/ami/jobs/schemas.py b/ami/jobs/schemas.py index 3d2fe65fa..cf225fcbe 100644 --- a/ami/jobs/schemas.py +++ b/ami/jobs/schemas.py @@ -1,6 +1,9 @@ +from django_pydantic_field.rest_framework import SchemaField from drf_spectacular.utils import OpenApiParameter from rest_framework import serializers +from ami.ml.schemas import PipelineTaskResult, ProcessingServiceClientInfo + ids_only_param = OpenApiParameter( name="ids_only", description="Return only job IDs instead of full objects", @@ -17,4 +20,29 @@ class TasksRequestSerializer(serializers.Serializer): + """POST /jobs/{id}/tasks/ request body. Fetch tasks from the job queue.""" + batch = serializers.IntegerField(min_value=1, required=True) + client_info = SchemaField(schema=ProcessingServiceClientInfo, required=False, default=None) + + +class TasksResponseSerializer(serializers.Serializer): + """POST /jobs/{id}/tasks/ response body. Tasks returned to the processing service.""" + + tasks = serializers.ListField(child=serializers.DictField(), default=[]) + + +class PipelineResultsRequestSerializer(serializers.Serializer): + """POST /jobs/{id}/result/ request body. Submit pipeline results for processing.""" + + results = SchemaField(schema=list[PipelineTaskResult]) + client_info = SchemaField(schema=ProcessingServiceClientInfo, required=False, default=None) + + +class PipelineResultsResponseSerializer(serializers.Serializer): + """POST /jobs/{id}/result/ response body. Acknowledgment of queued results.""" + + status = serializers.CharField() + job_id = serializers.IntegerField() + results_queued = serializers.IntegerField() + tasks = serializers.ListField(child=serializers.DictField(), default=[]) From a0e3afae7d96238bfa74aec27262cae2960e266f Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Wed, 1 Apr 2026 22:20:02 -0700 Subject: [PATCH 06/12] refactor: use PipelineResultsRequestSerializer for /result/ endpoint Replace manual isinstance branching and parse_obj_as with PipelineResultsRequestSerializer. One canonical request shape: {"results": [...]}. Bare list and single-item formats removed. Tests updated to use wrapped format and assert bare lists are rejected. Co-Authored-By: Claude --- ami/jobs/tests/test_jobs.py | 39 +++++++++++++++++++++---------------- ami/jobs/views.py | 37 ++++++++++++++++------------------- 2 files changed, 39 insertions(+), 37 deletions(-) diff --git a/ami/jobs/tests/test_jobs.py b/ami/jobs/tests/test_jobs.py index 5b2f05ba7..e385ebb0f 100644 --- a/ami/jobs/tests/test_jobs.py +++ b/ami/jobs/tests/test_jobs.py @@ -537,19 +537,21 @@ def test_result_endpoint_stub(self): "api:job-result", args=[job.pk], params={"project_id": self.project.pk, "batch": 1} ) - result_data = [ - { - "reply_subject": "test.reply.1", - "result": { - "pipeline": "test-pipeline", - "algorithms": {}, - "total_time": 1.5, - "source_images": [], - "detections": [], - "errors": None, - }, - } - ] + result_data = { + "results": [ + { + "reply_subject": "test.reply.1", + "result": { + "pipeline": "test-pipeline", + "algorithms": {}, + "total_time": 1.5, + "source_images": [], + "detections": [], + "errors": None, + }, + } + ] + } resp = self.client.post(result_url, result_data, format="json") @@ -568,16 +570,19 @@ def test_result_endpoint_validation(self): result_url = reverse_with_params("api:job-result", args=[job.pk], params={"project_id": self.project.pk}) # Test with missing reply_subject - invalid_data = [{"result": {"pipeline": "test"}}] + invalid_data = {"results": [{"result": {"pipeline": "test"}}]} resp = self.client.post(result_url, invalid_data, format="json") self.assertEqual(resp.status_code, 400) - self.assertIn("reply_subject", resp.json()[0].lower()) # Test with missing result - invalid_data = [{"reply_subject": "test.reply"}] + invalid_data = {"results": [{"reply_subject": "test.reply"}]} resp = self.client.post(result_url, invalid_data, format="json") self.assertEqual(resp.status_code, 400) - self.assertIn("result", resp.json()[0].lower()) + + # Test with bare list (no longer accepted) + bare_list = [{"reply_subject": "test.reply", "result": {"pipeline": "test"}}] + resp = self.client.post(result_url, bare_list, format="json") + self.assertEqual(resp.status_code, 400) class TestJobDispatchModeFiltering(APITestCase): diff --git a/ami/jobs/views.py b/ami/jobs/views.py index 1ba720ddb..a5e384bd0 100644 --- a/ami/jobs/views.py +++ b/ami/jobs/views.py @@ -2,7 +2,6 @@ import logging import nats.errors -import pydantic from asgiref.sync import async_to_sync from django.db.models import Q from django.db.models.query import QuerySet @@ -17,11 +16,17 @@ from ami.base.permissions import ObjectPermission from ami.base.views import ProjectMixin -from ami.jobs.schemas import TasksRequestSerializer, ids_only_param, incomplete_only_param +from ami.jobs.schemas import ( + PipelineResultsRequestSerializer, + PipelineResultsResponseSerializer, + TasksRequestSerializer, + TasksResponseSerializer, + ids_only_param, + incomplete_only_param, +) from ami.jobs.tasks import process_nats_pipeline_result from ami.main.api.schemas import project_id_doc_param from ami.main.api.views import DefaultViewSet -from ami.ml.schemas import PipelineTaskResult from ami.utils.fields import url_boolean_param from .models import Job, JobDispatchMode, JobState @@ -239,7 +244,7 @@ def list(self, request, *args, **kwargs): @extend_schema( request=TasksRequestSerializer, - responses={200: dict}, + responses={200: TasksResponseSerializer}, parameters=[project_id_doc_param], ) @action(detail=True, methods=["post"], name="tasks") @@ -288,13 +293,17 @@ async def get_tasks(): return Response({"tasks": tasks}) + @extend_schema( + request=PipelineResultsRequestSerializer, + responses={200: PipelineResultsResponseSerializer}, + parameters=[project_id_doc_param], + ) @action(detail=True, methods=["post"], name="result") def result(self, request, pk=None): """ Submit pipeline results. Accepts: {"results": [PipelineTaskResult, ...]} - Or legacy: [PipelineTaskResult, ...] (bare list) Results are validated then queued for background processing via Celery. """ @@ -304,21 +313,11 @@ def result(self, request, pk=None): # Record heartbeat for async processing services on this pipeline _mark_pipeline_pull_services_seen(job) - # Accept both wrapped format and legacy bare list - if isinstance(request.data, list): - raw_results = request.data - elif isinstance(request.data, dict) and "results" in request.data: - raw_results = request.data["results"] - if not isinstance(raw_results, list): - raise ValidationError("'results' must be a list") - else: - raw_results = [request.data] + serializer = PipelineResultsRequestSerializer(data=request.data) + serializer.is_valid(raise_exception=True) + validated_results = serializer.validated_data["results"] try: - # Pre-validate all results before enqueuing any tasks - # This prevents partial queueing and duplicate task processing - validated_results = pydantic.parse_obj_as(list[PipelineTaskResult], raw_results) - # All validation passed, now queue all tasks queued_tasks = [] for task_result in validated_results: @@ -354,8 +353,6 @@ def result(self, request, pk=None): "tasks": queued_tasks, } ) - except pydantic.ValidationError as e: - raise ValidationError(f"Invalid result data: {e}") from e except Exception as e: logger.error("Failed to queue pipeline results for job %s: %s", job.pk, e) From a4e1961330c0123f700f674f3c0e0cd4af4a3f12 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Wed, 1 Apr 2026 22:51:34 -0700 Subject: [PATCH 07/12] refactor: rename batch to batch_size in TasksRequestSerializer "batch" sounds like content, "batch_size" clearly communicates a requested quantity. Matches the ADC worker's field name. Co-Authored-By: Claude --- ami/jobs/schemas.py | 2 +- ami/jobs/tests/test_jobs.py | 10 ++++------ ami/jobs/views.py | 6 +++--- 3 files changed, 8 insertions(+), 10 deletions(-) diff --git a/ami/jobs/schemas.py b/ami/jobs/schemas.py index cf225fcbe..10697b044 100644 --- a/ami/jobs/schemas.py +++ b/ami/jobs/schemas.py @@ -22,7 +22,7 @@ class TasksRequestSerializer(serializers.Serializer): """POST /jobs/{id}/tasks/ request body. Fetch tasks from the job queue.""" - batch = serializers.IntegerField(min_value=1, required=True) + batch_size = serializers.IntegerField(min_value=1, required=True) client_info = SchemaField(schema=ProcessingServiceClientInfo, required=False, default=None) diff --git a/ami/jobs/tests/test_jobs.py b/ami/jobs/tests/test_jobs.py index e385ebb0f..7241b0a57 100644 --- a/ami/jobs/tests/test_jobs.py +++ b/ami/jobs/tests/test_jobs.py @@ -490,7 +490,7 @@ def _task_batch_helper(self, value: Any, expected_status: int): self.client.force_authenticate(user=self.user) tasks_url = reverse_with_params("api:job-tasks", args=[job.pk], params={"project_id": self.project.pk}) - resp = self.client.post(tasks_url, {"batch": value}, format="json") + resp = self.client.post(tasks_url, {"batch_size": value}, format="json") self.assertEqual(resp.status_code, expected_status) return resp.json() @@ -522,7 +522,7 @@ def test_tasks_endpoint_without_pipeline(self): self.client.force_authenticate(user=self.user) tasks_url = reverse_with_params("api:job-tasks", args=[job.pk], params={"project_id": self.project.pk}) - resp = self.client.post(tasks_url, {"batch": 1}, format="json") + resp = self.client.post(tasks_url, {"batch_size": 1}, format="json") self.assertEqual(resp.status_code, 400) self.assertIn("pipeline", resp.json()[0].lower()) @@ -533,9 +533,7 @@ def test_result_endpoint_stub(self): job = self._create_ml_job("Job for results test", pipeline) self.client.force_authenticate(user=self.user) - result_url = reverse_with_params( - "api:job-result", args=[job.pk], params={"project_id": self.project.pk, "batch": 1} - ) + result_url = reverse_with_params("api:job-result", args=[job.pk], params={"project_id": self.project.pk}) result_data = { "results": [ @@ -724,6 +722,6 @@ def test_tasks_endpoint_rejects_non_async_jobs(self): self.client.force_authenticate(user=self.user) tasks_url = reverse_with_params("api:job-tasks", args=[sync_job.pk], params={"project_id": self.project.pk}) - resp = self.client.post(tasks_url, {"batch": 1}, format="json") + resp = self.client.post(tasks_url, {"batch_size": 1}, format="json") self.assertEqual(resp.status_code, 400) self.assertIn("async_api", resp.json()[0].lower()) diff --git a/ami/jobs/views.py b/ami/jobs/views.py index a5e384bd0..4f173abac 100644 --- a/ami/jobs/views.py +++ b/ami/jobs/views.py @@ -253,13 +253,13 @@ def tasks(self, request, pk=None): Fetch tasks from the job queue (POST). Returns task data with reply_subject for acknowledgment. External workers should: - 1. POST to this endpoint with {"batch": N} + 1. POST to this endpoint with {"batch_size": N} 2. Process the tasks 3. POST to /jobs/{id}/result/ with the results """ serializer = TasksRequestSerializer(data=request.data) serializer.is_valid(raise_exception=True) - batch = serializer.validated_data["batch"] + batch_size = serializer.validated_data["batch_size"] job: Job = self.get_object() @@ -283,7 +283,7 @@ def tasks(self, request, pk=None): async def get_tasks(): async with TaskQueueManager() as manager: - return [task.dict() for task in await manager.reserve_tasks(job.pk, count=batch, timeout=0.5)] + return [task.dict() for task in await manager.reserve_tasks(job.pk, count=batch_size, timeout=0.5)] try: tasks = async_to_sync(get_tasks)() From f07aa8501adf306a82f2bd8bd352738bf6711719 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Thu, 2 Apr 2026 00:17:33 -0700 Subject: [PATCH 08/12] refactor: rename serializers to MLJob* pattern, improve docstrings MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Rename DRF serializers to MLJob{Action}{Direction}Serializer pattern: - TasksRequestSerializer → MLJobTasksRequestSerializer - TasksResponseSerializer → MLJobTasksResponseSerializer - PipelineResultsRequestSerializer → MLJobResultsRequestSerializer - PipelineResultsResponseSerializer → MLJobResultsResponseSerializer ProcessingServiceClientInfo: remove specific fields, use extra="allow" so processing services can send any useful identity key-value pairs. Add detailed docstrings clarifying each serializer's role in the processing service ↔ Antenna flow. Fix test_result_endpoint_with_error_result to use wrapped payload format. Co-Authored-By: Claude --- ami/jobs/schemas.py | 37 ++++++++++++++++++++++++++++-------- ami/jobs/tests/test_tasks.py | 22 +++++++++++---------- ami/jobs/views.py | 20 +++++++++---------- ami/ml/schemas.py | 20 +++++++++++-------- 4 files changed, 63 insertions(+), 36 deletions(-) diff --git a/ami/jobs/schemas.py b/ami/jobs/schemas.py index 10697b044..efff1df75 100644 --- a/ami/jobs/schemas.py +++ b/ami/jobs/schemas.py @@ -19,28 +19,49 @@ ) -class TasksRequestSerializer(serializers.Serializer): - """POST /jobs/{id}/tasks/ request body. Fetch tasks from the job queue.""" +class MLJobTasksRequestSerializer(serializers.Serializer): + """POST /jobs/{id}/tasks/ — request body sent by a processing service to fetch work. + + The processing service polls this endpoint to get tasks (images) to process. + Each task is a PipelineProcessingTask with an image URL and a NATS reply subject. + """ batch_size = serializers.IntegerField(min_value=1, required=True) client_info = SchemaField(schema=ProcessingServiceClientInfo, required=False, default=None) -class TasksResponseSerializer(serializers.Serializer): - """POST /jobs/{id}/tasks/ response body. Tasks returned to the processing service.""" +class MLJobTasksResponseSerializer(serializers.Serializer): + """POST /jobs/{id}/tasks/ — response body returned to the processing service. + + Contains a list of tasks (PipelineProcessingTask dicts) for the worker to process. + Each task includes an image URL, task ID, and reply_subject for result correlation. + Returns an empty list when no tasks are available or the job is not active. + """ tasks = serializers.ListField(child=serializers.DictField(), default=[]) -class PipelineResultsRequestSerializer(serializers.Serializer): - """POST /jobs/{id}/result/ request body. Submit pipeline results for processing.""" +class MLJobResultsRequestSerializer(serializers.Serializer): + """POST /jobs/{id}/result/ — request body sent by a processing service to deliver results. + + "Request" here refers to the HTTP request to Antenna, not a request for work. + The processing service has finished processing tasks and is posting its results + (successes or errors) back. Each PipelineTaskResult contains a reply_subject + (correlating back to the original task) and a result payload that is either a + PipelineResultsResponse (success) or PipelineResultsError (failure). + """ results = SchemaField(schema=list[PipelineTaskResult]) client_info = SchemaField(schema=ProcessingServiceClientInfo, required=False, default=None) -class PipelineResultsResponseSerializer(serializers.Serializer): - """POST /jobs/{id}/result/ response body. Acknowledgment of queued results.""" +class MLJobResultsResponseSerializer(serializers.Serializer): + """POST /jobs/{id}/result/ — acknowledgment returned to the processing service. + + Confirms receipt and indicates how many results were queued for background + processing via Celery. Individual task entries include their Celery task_id + for traceability. + """ status = serializers.CharField() job_id = serializers.IntegerField() diff --git a/ami/jobs/tests/test_tasks.py b/ami/jobs/tests/test_tasks.py index daf1b6ae6..d183dfb3c 100644 --- a/ami/jobs/tests/test_tasks.py +++ b/ami/jobs/tests/test_tasks.py @@ -384,16 +384,18 @@ def test_result_endpoint_with_error_result(self, mock_apply_async): self.client.force_authenticate(user=self.user) result_url = reverse_with_params("api:job-result", args=[self.job.pk], params={"project_id": self.project.pk}) - # Create error result data - result_data = [ - { - "reply_subject": "test.reply.error.1", - "result": { - "error": "Image processing timeout", - "image_id": str(self.image.pk), - }, - } - ] + # Create error result data (wrapped format) + result_data = { + "results": [ + { + "reply_subject": "test.reply.error.1", + "result": { + "error": "Image processing timeout", + "image_id": str(self.image.pk), + }, + } + ] + } # POST error result to API resp = self.client.post(result_url, result_data, format="json") diff --git a/ami/jobs/views.py b/ami/jobs/views.py index 4f173abac..fcfb078a1 100644 --- a/ami/jobs/views.py +++ b/ami/jobs/views.py @@ -17,10 +17,10 @@ from ami.base.permissions import ObjectPermission from ami.base.views import ProjectMixin from ami.jobs.schemas import ( - PipelineResultsRequestSerializer, - PipelineResultsResponseSerializer, - TasksRequestSerializer, - TasksResponseSerializer, + MLJobResultsRequestSerializer, + MLJobResultsResponseSerializer, + MLJobTasksRequestSerializer, + MLJobTasksResponseSerializer, ids_only_param, incomplete_only_param, ) @@ -243,8 +243,8 @@ def list(self, request, *args, **kwargs): return super().list(request, *args, **kwargs) @extend_schema( - request=TasksRequestSerializer, - responses={200: TasksResponseSerializer}, + request=MLJobTasksRequestSerializer, + responses={200: MLJobTasksResponseSerializer}, parameters=[project_id_doc_param], ) @action(detail=True, methods=["post"], name="tasks") @@ -257,7 +257,7 @@ def tasks(self, request, pk=None): 2. Process the tasks 3. POST to /jobs/{id}/result/ with the results """ - serializer = TasksRequestSerializer(data=request.data) + serializer = MLJobTasksRequestSerializer(data=request.data) serializer.is_valid(raise_exception=True) batch_size = serializer.validated_data["batch_size"] @@ -294,8 +294,8 @@ async def get_tasks(): return Response({"tasks": tasks}) @extend_schema( - request=PipelineResultsRequestSerializer, - responses={200: PipelineResultsResponseSerializer}, + request=MLJobResultsRequestSerializer, + responses={200: MLJobResultsResponseSerializer}, parameters=[project_id_doc_param], ) @action(detail=True, methods=["post"], name="result") @@ -313,7 +313,7 @@ def result(self, request, pk=None): # Record heartbeat for async processing services on this pipeline _mark_pipeline_pull_services_seen(job) - serializer = PipelineResultsRequestSerializer(data=request.data) + serializer = MLJobResultsRequestSerializer(data=request.data) serializer.is_valid(raise_exception=True) validated_results = serializer.validated_data["results"] diff --git a/ami/ml/schemas.py b/ami/ml/schemas.py index 77bf4875f..9322e4116 100644 --- a/ami/ml/schemas.py +++ b/ami/ml/schemas.py @@ -263,17 +263,21 @@ class PipelineProcessingTask(pydantic.BaseModel): class ProcessingServiceClientInfo(pydantic.BaseModel): - """Identity metadata for a specific processing service instance. + """Identity metadata sent by a processing service worker. - A single ProcessingService may have multiple workers/pods. - This identifies which one is making the request. + A single ProcessingService record in the database may have multiple + physical workers, pods, or machines running simultaneously. This model + lets the server distinguish between them for logging, debugging, and + eventually for per-worker health tracking. + + Fields are intentionally left open for now. Processing services can + send any key-value pairs they find useful (e.g. hostname, pod_name, + software version). The schema will be tightened once real-world usage + patterns emerge. """ - hostname: str = "" - software: str = "" - version: str = "" - platform: str = "" - pod_name: str = "" + class Config: + extra = "allow" class PipelineTaskResult(pydantic.BaseModel): From 77a246f93d6eb6f2b58c78cad6d453c3b4e95598 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Thu, 2 Apr 2026 11:47:55 -0700 Subject: [PATCH 09/12] fix(jobs): narrow result endpoint exception handler to broker errors Replace broad `except Exception` with specific `(OSError, KombuError)` to catch only Celery broker connection failures. Returns 503 with a descriptive message instead of a generic 500 that would swallow bugs. Co-Authored-By: Claude --- ami/jobs/views.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ami/jobs/views.py b/ami/jobs/views.py index fcfb078a1..45055ef2c 100644 --- a/ami/jobs/views.py +++ b/ami/jobs/views.py @@ -1,6 +1,7 @@ import asyncio import logging +import kombu.exceptions import nats.errors from asgiref.sync import async_to_sync from django.db.models import Q @@ -354,12 +355,13 @@ def result(self, request, pk=None): } ) - except Exception as e: + except (OSError, kombu.exceptions.KombuError) as e: logger.error("Failed to queue pipeline results for job %s: %s", job.pk, e) return Response( { "status": "error", "job_id": job.pk, + "detail": "Task queue temporarily unavailable", }, - status=500, + status=503, ) From 5ee0de0b75b0ef460a9c3fb9a57b6f2ddf0104e9 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Fri, 3 Apr 2026 15:40:09 -0700 Subject: [PATCH 10/12] refactor(jobs): move MLJob serializers from schemas.py to serializers.py Follows the codebase convention: schemas.py holds OpenApiParameter definitions and Pydantic models, serializers.py holds DRF serializers. Co-Authored-By: Claude --- ami/jobs/schemas.py | 54 ----------------------------------------- ami/jobs/serializers.py | 51 ++++++++++++++++++++++++++++++++++++++ ami/jobs/views.py | 5 ++-- 3 files changed, 53 insertions(+), 57 deletions(-) diff --git a/ami/jobs/schemas.py b/ami/jobs/schemas.py index efff1df75..5343dd844 100644 --- a/ami/jobs/schemas.py +++ b/ami/jobs/schemas.py @@ -1,8 +1,4 @@ -from django_pydantic_field.rest_framework import SchemaField from drf_spectacular.utils import OpenApiParameter -from rest_framework import serializers - -from ami.ml.schemas import PipelineTaskResult, ProcessingServiceClientInfo ids_only_param = OpenApiParameter( name="ids_only", @@ -17,53 +13,3 @@ required=False, type=bool, ) - - -class MLJobTasksRequestSerializer(serializers.Serializer): - """POST /jobs/{id}/tasks/ — request body sent by a processing service to fetch work. - - The processing service polls this endpoint to get tasks (images) to process. - Each task is a PipelineProcessingTask with an image URL and a NATS reply subject. - """ - - batch_size = serializers.IntegerField(min_value=1, required=True) - client_info = SchemaField(schema=ProcessingServiceClientInfo, required=False, default=None) - - -class MLJobTasksResponseSerializer(serializers.Serializer): - """POST /jobs/{id}/tasks/ — response body returned to the processing service. - - Contains a list of tasks (PipelineProcessingTask dicts) for the worker to process. - Each task includes an image URL, task ID, and reply_subject for result correlation. - Returns an empty list when no tasks are available or the job is not active. - """ - - tasks = serializers.ListField(child=serializers.DictField(), default=[]) - - -class MLJobResultsRequestSerializer(serializers.Serializer): - """POST /jobs/{id}/result/ — request body sent by a processing service to deliver results. - - "Request" here refers to the HTTP request to Antenna, not a request for work. - The processing service has finished processing tasks and is posting its results - (successes or errors) back. Each PipelineTaskResult contains a reply_subject - (correlating back to the original task) and a result payload that is either a - PipelineResultsResponse (success) or PipelineResultsError (failure). - """ - - results = SchemaField(schema=list[PipelineTaskResult]) - client_info = SchemaField(schema=ProcessingServiceClientInfo, required=False, default=None) - - -class MLJobResultsResponseSerializer(serializers.Serializer): - """POST /jobs/{id}/result/ — acknowledgment returned to the processing service. - - Confirms receipt and indicates how many results were queued for background - processing via Celery. Individual task entries include their Celery task_id - for traceability. - """ - - status = serializers.CharField() - job_id = serializers.IntegerField() - results_queued = serializers.IntegerField() - tasks = serializers.ListField(child=serializers.DictField(), default=[]) diff --git a/ami/jobs/serializers.py b/ami/jobs/serializers.py index d903b0812..2608ce298 100644 --- a/ami/jobs/serializers.py +++ b/ami/jobs/serializers.py @@ -10,6 +10,7 @@ ) from ami.main.models import Deployment, Project, SourceImage, SourceImageCollection from ami.ml.models import Pipeline +from ami.ml.schemas import PipelineTaskResult, ProcessingServiceClientInfo from ami.ml.serializers import PipelineNestedSerializer from .models import Job, JobLogs, JobProgress, MLJob @@ -163,3 +164,53 @@ class MinimalJobSerializer(DefaultSerializer): class Meta: model = Job fields = ["id", "pipeline_slug"] + + +class MLJobTasksRequestSerializer(serializers.Serializer): + """POST /jobs/{id}/tasks/ — request body sent by a processing service to fetch work. + + The processing service polls this endpoint to get tasks (images) to process. + Each task is a PipelineProcessingTask with an image URL and a NATS reply subject. + """ + + batch_size = serializers.IntegerField(min_value=1, required=True) + client_info = SchemaField(schema=ProcessingServiceClientInfo, required=False, default=None) + + +class MLJobTasksResponseSerializer(serializers.Serializer): + """POST /jobs/{id}/tasks/ — response body returned to the processing service. + + Contains a list of tasks (PipelineProcessingTask dicts) for the worker to process. + Each task includes an image URL, task ID, and reply_subject for result correlation. + Returns an empty list when no tasks are available or the job is not active. + """ + + tasks = serializers.ListField(child=serializers.DictField(), default=[]) + + +class MLJobResultsRequestSerializer(serializers.Serializer): + """POST /jobs/{id}/result/ — request body sent by a processing service to deliver results. + + "Request" here refers to the HTTP request to Antenna, not a request for work. + The processing service has finished processing tasks and is posting its results + (successes or errors) back. Each PipelineTaskResult contains a reply_subject + (correlating back to the original task) and a result payload that is either a + PipelineResultsResponse (success) or PipelineResultsError (failure). + """ + + results = SchemaField(schema=list[PipelineTaskResult]) + client_info = SchemaField(schema=ProcessingServiceClientInfo, required=False, default=None) + + +class MLJobResultsResponseSerializer(serializers.Serializer): + """POST /jobs/{id}/result/ — acknowledgment returned to the processing service. + + Confirms receipt and indicates how many results were queued for background + processing via Celery. Individual task entries include their Celery task_id + for traceability. + """ + + status = serializers.CharField() + job_id = serializers.IntegerField() + results_queued = serializers.IntegerField() + tasks = serializers.ListField(child=serializers.DictField(), default=[]) diff --git a/ami/jobs/views.py b/ami/jobs/views.py index 45055ef2c..625fb8b47 100644 --- a/ami/jobs/views.py +++ b/ami/jobs/views.py @@ -17,13 +17,12 @@ from ami.base.permissions import ObjectPermission from ami.base.views import ProjectMixin -from ami.jobs.schemas import ( +from ami.jobs.schemas import ids_only_param, incomplete_only_param +from ami.jobs.serializers import ( MLJobResultsRequestSerializer, MLJobResultsResponseSerializer, MLJobTasksRequestSerializer, MLJobTasksResponseSerializer, - ids_only_param, - incomplete_only_param, ) from ami.jobs.tasks import process_nats_pipeline_result from ami.main.api.schemas import project_id_doc_param From 797d900713b3e64e3cb404d07c42d22bbd78f003 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Fri, 3 Apr 2026 15:42:07 -0700 Subject: [PATCH 11/12] refactor(jobs): use typed SchemaFields instead of generic list/dict MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace ListField(child=DictField()) with SchemaField using real Pydantic models for proper OpenAPI schema generation and validation: - MLJobTasksResponseSerializer.tasks → list[PipelineProcessingTask] - MLJobResultsResponseSerializer.tasks → list[QueuedTaskAcknowledgment] Co-Authored-By: Claude --- ami/jobs/serializers.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/ami/jobs/serializers.py b/ami/jobs/serializers.py index 2608ce298..0ada7f41b 100644 --- a/ami/jobs/serializers.py +++ b/ami/jobs/serializers.py @@ -1,3 +1,4 @@ +import pydantic from django_pydantic_field.rest_framework import SchemaField from rest_framework import serializers @@ -10,7 +11,7 @@ ) from ami.main.models import Deployment, Project, SourceImage, SourceImageCollection from ami.ml.models import Pipeline -from ami.ml.schemas import PipelineTaskResult, ProcessingServiceClientInfo +from ami.ml.schemas import PipelineProcessingTask, PipelineTaskResult, ProcessingServiceClientInfo from ami.ml.serializers import PipelineNestedSerializer from .models import Job, JobLogs, JobProgress, MLJob @@ -185,7 +186,7 @@ class MLJobTasksResponseSerializer(serializers.Serializer): Returns an empty list when no tasks are available or the job is not active. """ - tasks = serializers.ListField(child=serializers.DictField(), default=[]) + tasks = SchemaField(schema=list[PipelineProcessingTask], default=[]) class MLJobResultsRequestSerializer(serializers.Serializer): @@ -202,6 +203,14 @@ class MLJobResultsRequestSerializer(serializers.Serializer): client_info = SchemaField(schema=ProcessingServiceClientInfo, required=False, default=None) +class QueuedTaskAcknowledgment(pydantic.BaseModel): + """Acknowledgment for a single result that was queued for background processing.""" + + reply_subject: str + status: str + task_id: str + + class MLJobResultsResponseSerializer(serializers.Serializer): """POST /jobs/{id}/result/ — acknowledgment returned to the processing service. @@ -213,4 +222,4 @@ class MLJobResultsResponseSerializer(serializers.Serializer): status = serializers.CharField() job_id = serializers.IntegerField() results_queued = serializers.IntegerField() - tasks = serializers.ListField(child=serializers.DictField(), default=[]) + tasks = SchemaField(schema=list[QueuedTaskAcknowledgment], default=[]) From e79987162cd8122dcef239bdfdfff7a1eb02d301 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Fri, 3 Apr 2026 15:56:01 -0700 Subject: [PATCH 12/12] refactor(jobs): move QueuedTaskAcknowledgment to schemas.py Pydantic model belongs in the schema file, not serializers. Co-Authored-By: Claude --- ami/jobs/schemas.py | 10 ++++++++++ ami/jobs/serializers.py | 10 +--------- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/ami/jobs/schemas.py b/ami/jobs/schemas.py index 5343dd844..74af39ce9 100644 --- a/ami/jobs/schemas.py +++ b/ami/jobs/schemas.py @@ -1,5 +1,15 @@ +import pydantic from drf_spectacular.utils import OpenApiParameter + +class QueuedTaskAcknowledgment(pydantic.BaseModel): + """Acknowledgment for a single result that was queued for background processing.""" + + reply_subject: str + status: str + task_id: str + + ids_only_param = OpenApiParameter( name="ids_only", description="Return only job IDs instead of full objects", diff --git a/ami/jobs/serializers.py b/ami/jobs/serializers.py index 0ada7f41b..fc2fcf8be 100644 --- a/ami/jobs/serializers.py +++ b/ami/jobs/serializers.py @@ -1,4 +1,3 @@ -import pydantic from django_pydantic_field.rest_framework import SchemaField from rest_framework import serializers @@ -15,6 +14,7 @@ from ami.ml.serializers import PipelineNestedSerializer from .models import Job, JobLogs, JobProgress, MLJob +from .schemas import QueuedTaskAcknowledgment class JobProjectNestedSerializer(DefaultSerializer): @@ -203,14 +203,6 @@ class MLJobResultsRequestSerializer(serializers.Serializer): client_info = SchemaField(schema=ProcessingServiceClientInfo, required=False, default=None) -class QueuedTaskAcknowledgment(pydantic.BaseModel): - """Acknowledgment for a single result that was queued for background processing.""" - - reply_subject: str - status: str - task_id: str - - class MLJobResultsResponseSerializer(serializers.Serializer): """POST /jobs/{id}/result/ — acknowledgment returned to the processing service.