diff --git a/ami/jobs/models.py b/ami/jobs/models.py index b4df41a04..be797dd4f 100644 --- a/ami/jobs/models.py +++ b/ami/jobs/models.py @@ -461,9 +461,7 @@ def run(cls, job: "Job"): # End image collection stage job.save() - if job.project.feature_flags.async_pipeline_workers: - job.dispatch_mode = JobDispatchMode.ASYNC_API - job.save(update_fields=["dispatch_mode"]) + if job.dispatch_mode == JobDispatchMode.ASYNC_API: queued = queue_images_to_nats(job, images) if not queued: job.logger.error("Aborting job %s because images could not be queued to NATS", job.pk) @@ -473,8 +471,6 @@ def run(cls, job: "Job"): job.save() return else: - job.dispatch_mode = JobDispatchMode.SYNC_API - job.save(update_fields=["dispatch_mode"]) cls.process_images(job, images) @classmethod @@ -919,6 +915,15 @@ def setup(self, save=True): self.progress.add_stage_param(delay_stage.key, "Mood", "😴") if self.pipeline: + # Set dispatch mode based on project feature flags at creation time + # so the UI can show the correct mode before the job runs. + # Only override if still at the default (INTERNAL), to allow explicit overrides. + if self.dispatch_mode == JobDispatchMode.INTERNAL: + if self.project and self.project.feature_flags.async_pipeline_workers: + self.dispatch_mode = JobDispatchMode.ASYNC_API + else: + self.dispatch_mode = JobDispatchMode.SYNC_API + collect_stage = self.progress.add_stage("Collect") self.progress.add_stage_param(collect_stage.key, "Total Images", "") diff --git a/ami/jobs/tests.py b/ami/jobs/tests.py index 7902faeb1..033a08b5c 100644 --- a/ami/jobs/tests.py +++ b/ami/jobs/tests.py @@ -384,6 +384,36 @@ def test_filter_by_pipeline_slug(self): self.assertEqual(data["count"], 1) self.assertEqual(data["results"][0]["id"], job_with_pipeline.pk) + def test_filter_by_pipeline_slug_in(self): + """Test filtering jobs by pipeline__slug__in (multiple slugs).""" + pipeline_a = self._create_pipeline("Pipeline A", "pipeline-a") + pipeline_b = Pipeline.objects.create(name="Pipeline B", slug="pipeline-b", description="B") + pipeline_b.projects.add(self.project) + pipeline_c = Pipeline.objects.create(name="Pipeline C", slug="pipeline-c", description="C") + pipeline_c.projects.add(self.project) + + job_a = self._create_ml_job("Job A", pipeline_a) + job_b = self._create_ml_job("Job B", pipeline_b) + job_c = self._create_ml_job("Job C", pipeline_c) + + self.client.force_authenticate(user=self.user) + + # Filter for two of the three pipelines + jobs_list_url = reverse_with_params( + "api:job-list", + params={"project_id": self.project.pk, "pipeline__slug__in": "pipeline-a,pipeline-b"}, + ) + resp = self.client.get(jobs_list_url) + + self.assertEqual(resp.status_code, 200) + data = resp.json() + returned_ids = {job["id"] for job in data["results"]} + self.assertIn(job_a.pk, returned_ids) + self.assertIn(job_b.pk, returned_ids) + self.assertNotIn(job_c.pk, returned_ids) + # Original setUp job (no pipeline) should also be excluded + self.assertNotIn(self.job.pk, returned_ids) + def test_search_jobs(self): """Test searching jobs by name and pipeline name.""" pipeline = self._create_pipeline("SearchablePipeline", "searchable-pipeline") @@ -571,13 +601,11 @@ def test_dispatch_mode_filtering(self): dispatch_mode=JobDispatchMode.ASYNC_API, ) - # Create a job with default dispatch_mode (should be "internal") + # Create a non-ML job without a pipeline (dispatch_mode stays "internal") internal_job = Job.objects.create( - job_type_key=MLJob.key, + job_type_key="data_storage_sync", project=self.project, name="Internal Job", - pipeline=self.pipeline, - source_image_collection=self.source_image_collection, ) self.client.force_authenticate(user=self.user) @@ -614,6 +642,39 @@ def test_dispatch_mode_filtering(self): expected_ids = {sync_job.pk, async_job.pk, internal_job.pk} self.assertEqual(returned_ids, expected_ids) + def test_ml_job_dispatch_mode_set_on_creation(self): + """Test that ML jobs get dispatch_mode set based on project feature flags at creation time.""" + # Without async flag, ML job should default to sync_api + sync_job = Job.objects.create( + job_type_key=MLJob.key, + project=self.project, + name="Auto Sync Job", + pipeline=self.pipeline, + source_image_collection=self.source_image_collection, + ) + self.assertEqual(sync_job.dispatch_mode, JobDispatchMode.SYNC_API) + + # Enable async flag on project + self.project.feature_flags.async_pipeline_workers = True + self.project.save() + + async_job = Job.objects.create( + job_type_key=MLJob.key, + project=self.project, + name="Auto Async Job", + pipeline=self.pipeline, + source_image_collection=self.source_image_collection, + ) + self.assertEqual(async_job.dispatch_mode, JobDispatchMode.ASYNC_API) + + # Non-pipeline job should stay internal regardless of feature flag + internal_job = Job.objects.create( + job_type_key="data_storage_sync", + project=self.project, + name="Internal Job", + ) + self.assertEqual(internal_job.dispatch_mode, JobDispatchMode.INTERNAL) + def test_tasks_endpoint_rejects_non_async_jobs(self): """Test that /tasks endpoint returns 400 for non-async_api jobs.""" from ami.base.serializers import reverse_with_params diff --git a/ami/jobs/views.py b/ami/jobs/views.py index dd8da01b2..ddc1e57a7 100644 --- a/ami/jobs/views.py +++ b/ami/jobs/views.py @@ -1,5 +1,7 @@ +import asyncio import logging +import nats.errors import pydantic from asgiref.sync import async_to_sync from django.db.models import Q @@ -32,6 +34,7 @@ class JobFilterSet(filters.FilterSet): """Custom filterset to enable pipeline name filtering.""" pipeline__slug = filters.CharFilter(field_name="pipeline__slug", lookup_expr="exact") + pipeline__slug__in = filters.BaseInFilter(field_name="pipeline__slug", lookup_expr="in") class Meta: model = Job @@ -55,11 +58,12 @@ def filter_queryset(self, request, queryset, view): incomplete_only = url_boolean_param(request, "incomplete_only", default=False) # Filter to incomplete jobs if requested (checks "results" stage status) if incomplete_only: - # Create filters for each final state to exclude + # Exclude jobs with a terminal top-level status + queryset = queryset.exclude(status__in=JobState.final_states()) + + # Also exclude jobs where the "results" stage has a final state status final_states = JobState.final_states() exclude_conditions = Q() - - # Exclude jobs where the "results" stage has a final state status for state in final_states: # JSON path query to check if results stage status is in final states # @TODO move to a QuerySet method on Job model if/when this needs to be reused elsewhere @@ -233,6 +237,10 @@ def tasks(self, request, pk=None): if job.dispatch_mode != JobDispatchMode.ASYNC_API: raise ValidationError("Only async_api jobs have fetchable tasks") + # Don't fetch tasks from completed/failed/revoked jobs + if job.status in JobState.final_states(): + return Response({"tasks": []}) + # Validate that the job has a pipeline if not job.pipeline: raise ValidationError("This job does not have a pipeline configured") @@ -241,16 +249,14 @@ def tasks(self, request, pk=None): from ami.ml.orchestration.nats_queue import TaskQueueManager async def get_tasks(): - tasks = [] async with TaskQueueManager() as manager: - for _ in range(batch): - task = await manager.reserve_task(job.pk, timeout=0.1) - if task: - tasks.append(task.dict()) - return tasks - - # Use async_to_sync to properly handle the async call - tasks = async_to_sync(get_tasks)() + return [task.dict() for task in await manager.reserve_tasks(job.pk, count=batch, timeout=0.5)] + + try: + tasks = async_to_sync(get_tasks)() + except (asyncio.TimeoutError, OSError, nats.errors.Error) as e: + logger.warning("NATS unavailable while fetching tasks for job %s: %s", job.pk, e) + return Response({"error": "Task queue temporarily unavailable"}, status=503) return Response({"tasks": tasks}) diff --git a/ami/ml/orchestration/nats_queue.py b/ami/ml/orchestration/nats_queue.py index fa7188627..65b6f6f72 100644 --- a/ami/ml/orchestration/nats_queue.py +++ b/ami/ml/orchestration/nats_queue.py @@ -10,6 +10,7 @@ support the visibility timeout semantics we want or a disconnected mode of pulling and ACKing tasks. """ +import asyncio import json import logging @@ -22,9 +23,21 @@ logger = logging.getLogger(__name__) - -async def get_connection(nats_url: str): - nc = await nats.connect(nats_url) +# Timeout for individual JetStream metadata operations (create/check stream and consumer). +# These are lightweight NATS server operations that complete in milliseconds under normal +# conditions. stream_info() and add_stream() don't accept a native timeout parameter, so +# we use asyncio.wait_for() uniformly for all operations. Without these timeouts, a hung +# NATS connection blocks the caller's thread indefinitely — and when that caller is a +# Django worker (via async_to_sync), it makes the entire server unresponsive. +NATS_JETSTREAM_TIMEOUT = 10 # seconds + + +async def get_connection(nats_url: str) -> tuple[nats.NATS, JetStreamContext]: + nc = await nats.connect( + nats_url, + connect_timeout=5, + allow_reconnect=False, + ) js = nc.jetstream() return nc, js @@ -38,9 +51,9 @@ class TaskQueueManager: Use as an async context manager: async with TaskQueueManager() as manager: - await manager.publish_task('job123', {'data': 'value'}) - task = await manager.reserve_task('job123') - await manager.acknowledge_task(task['reply_subject']) + await manager.publish_task(123, {'data': 'value'}) + tasks = await manager.reserve_tasks(123, count=64) + await manager.acknowledge_task(tasks[0].reply_subject) """ def __init__(self, nats_url: str | None = None): @@ -83,15 +96,20 @@ async def _ensure_stream(self, job_id: int): subject = self._get_subject(job_id) try: - await self.js.stream_info(stream_name) + await asyncio.wait_for(self.js.stream_info(stream_name), timeout=NATS_JETSTREAM_TIMEOUT) logger.debug(f"Stream {stream_name} already exists") + except asyncio.TimeoutError: + raise # NATS unreachable — let caller handle it rather than creating a stream blindly except Exception as e: logger.warning(f"Stream {stream_name} does not exist: {e}") # Stream doesn't exist, create it - await self.js.add_stream( - name=stream_name, - subjects=[subject], - max_age=86400, # 24 hours retention + await asyncio.wait_for( + self.js.add_stream( + name=stream_name, + subjects=[subject], + max_age=86400, # 24 hours retention + ), + timeout=NATS_JETSTREAM_TIMEOUT, ) logger.info(f"Created stream {stream_name}") @@ -105,21 +123,29 @@ async def _ensure_consumer(self, job_id: int): subject = self._get_subject(job_id) try: - info = await self.js.consumer_info(stream_name, consumer_name) + info = await asyncio.wait_for( + self.js.consumer_info(stream_name, consumer_name), + timeout=NATS_JETSTREAM_TIMEOUT, + ) logger.debug(f"Consumer {consumer_name} already exists: {info}") + except asyncio.TimeoutError: + raise # NATS unreachable — let caller handle it except Exception: # Consumer doesn't exist, create it - await self.js.add_consumer( - stream=stream_name, - config=ConsumerConfig( - durable_name=consumer_name, - ack_policy=AckPolicy.EXPLICIT, - ack_wait=TASK_TTR, # Visibility timeout (TTR) - max_deliver=5, # Max retry attempts - deliver_policy=DeliverPolicy.ALL, - max_ack_pending=100, # Max unacked messages - filter_subject=subject, + await asyncio.wait_for( + self.js.add_consumer( + stream=stream_name, + config=ConsumerConfig( + durable_name=consumer_name, + ack_policy=AckPolicy.EXPLICIT, + ack_wait=TASK_TTR, # Visibility timeout (TTR) + max_deliver=5, # Max retry attempts + deliver_policy=DeliverPolicy.ALL, + max_ack_pending=100, # Max unacked messages + filter_subject=subject, + ), ), + timeout=NATS_JETSTREAM_TIMEOUT, ) logger.info(f"Created consumer {consumer_name}") @@ -147,7 +173,7 @@ async def publish_task(self, job_id: int, data: PipelineProcessingTask) -> bool: task_data = json.dumps(data.dict()) # Publish to JetStream - ack = await self.js.publish(subject, task_data.encode()) + ack = await self.js.publish(subject, task_data.encode(), timeout=NATS_JETSTREAM_TIMEOUT) logger.info(f"Published task to stream for job '{job_id}', sequence {ack.seq}") return True @@ -156,62 +182,57 @@ async def publish_task(self, job_id: int, data: PipelineProcessingTask) -> bool: logger.error(f"Failed to publish task to stream for job '{job_id}': {e}") return False - async def reserve_task(self, job_id: int, timeout: float | None = None) -> PipelineProcessingTask | None: + async def reserve_tasks(self, job_id: int, count: int, timeout: float = 5) -> list[PipelineProcessingTask]: """ - Reserve a task from the specified stream. + Reserve up to `count` tasks from the specified stream in a single NATS fetch. Args: job_id: The job ID (integer primary key) to pull tasks from - timeout: Timeout in seconds for reservation (default: 5 seconds) + count: Maximum number of tasks to reserve + timeout: Timeout in seconds waiting for messages (default: 5 seconds) Returns: - PipelineProcessingTask with reply_subject set for acknowledgment, or None if no task available + List of PipelineProcessingTask objects with reply_subject set for acknowledgment. + May return fewer than `count` if the queue has fewer messages available. """ if self.js is None: raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.") - if timeout is None: - timeout = 5 - try: - # Ensure stream and consumer exist await self._ensure_stream(job_id) await self._ensure_consumer(job_id) consumer_name = self._get_consumer_name(job_id) subject = self._get_subject(job_id) - # Create ephemeral subscription for this pull psub = await self.js.pull_subscribe(subject, consumer_name) try: - # Fetch a single message - msgs = await psub.fetch(1, timeout=timeout) - - if msgs: - msg = msgs[0] - task_data = json.loads(msg.data.decode()) - metadata = msg.metadata - - # Parse the task data into PipelineProcessingTask - task = PipelineProcessingTask(**task_data) - # Set the reply_subject for acknowledgment - task.reply_subject = msg.reply - - logger.debug(f"Reserved task from stream for job '{job_id}', sequence {metadata.sequence.stream}") - return task - + msgs = await psub.fetch(count, timeout=timeout) except nats.errors.TimeoutError: - # No messages available logger.debug(f"No tasks available in stream for job '{job_id}'") - return None + return [] finally: - # Always unsubscribe await psub.unsubscribe() + tasks = [] + for msg in msgs: + task_data = json.loads(msg.data.decode()) + task = PipelineProcessingTask(**task_data) + task.reply_subject = msg.reply + tasks.append(task) + + if tasks: + logger.info(f"Reserved {len(tasks)} tasks from stream for job '{job_id}'") + else: + logger.debug(f"No tasks reserved from stream for job '{job_id}'") + return tasks + + except asyncio.TimeoutError: + raise # NATS unreachable — propagate so the view can return an appropriate error except Exception as e: - logger.error(f"Failed to reserve task from stream for job '{job_id}': {e}") - return None + logger.error(f"Failed to reserve tasks from stream for job '{job_id}': {e}") + return [] async def acknowledge_task(self, reply_subject: str) -> bool: """ @@ -251,7 +272,10 @@ async def delete_consumer(self, job_id: int) -> bool: stream_name = self._get_stream_name(job_id) consumer_name = self._get_consumer_name(job_id) - await self.js.delete_consumer(stream_name, consumer_name) + await asyncio.wait_for( + self.js.delete_consumer(stream_name, consumer_name), + timeout=NATS_JETSTREAM_TIMEOUT, + ) logger.info(f"Deleted consumer {consumer_name} for job '{job_id}'") return True except Exception as e: @@ -274,7 +298,10 @@ async def delete_stream(self, job_id: int) -> bool: try: stream_name = self._get_stream_name(job_id) - await self.js.delete_stream(stream_name) + await asyncio.wait_for( + self.js.delete_stream(stream_name), + timeout=NATS_JETSTREAM_TIMEOUT, + ) logger.info(f"Deleted stream {stream_name} for job '{job_id}'") return True except Exception as e: diff --git a/ami/ml/orchestration/tests/test_nats_queue.py b/ami/ml/orchestration/tests/test_nats_queue.py index 0cd2c3bef..a7bd91b68 100644 --- a/ami/ml/orchestration/tests/test_nats_queue.py +++ b/ami/ml/orchestration/tests/test_nats_queue.py @@ -62,47 +62,74 @@ async def test_publish_task_creates_stream_and_consumer(self): self.assertIn("job_456", str(js.add_stream.call_args)) js.add_consumer.assert_called_once() - async def test_reserve_task_success(self): - """Test successful task reservation.""" + async def test_reserve_tasks_success(self): + """Test successful batch task reservation.""" nc, js = self._create_mock_nats_connection() sample_task = self._create_sample_task() - # Mock message with task data - mock_msg = MagicMock() - mock_msg.data = sample_task.json().encode() - mock_msg.reply = "reply.subject.123" - mock_msg.metadata = MagicMock(sequence=MagicMock(stream=1)) + # Mock messages with task data + mock_msg1 = MagicMock() + mock_msg1.data = sample_task.json().encode() + mock_msg1.reply = "reply.subject.1" + + mock_msg2 = MagicMock() + mock_msg2.data = sample_task.json().encode() + mock_msg2.reply = "reply.subject.2" mock_psub = MagicMock() - mock_psub.fetch = AsyncMock(return_value=[mock_msg]) + mock_psub.fetch = AsyncMock(return_value=[mock_msg1, mock_msg2]) mock_psub.unsubscribe = AsyncMock() js.pull_subscribe = AsyncMock(return_value=mock_psub) with patch("ami.ml.orchestration.nats_queue.get_connection", AsyncMock(return_value=(nc, js))): async with TaskQueueManager() as manager: - task = await manager.reserve_task(123) + tasks = await manager.reserve_tasks(123, count=5) - self.assertIsNotNone(task) - self.assertEqual(task.id, sample_task.id) - self.assertEqual(task.reply_subject, "reply.subject.123") + self.assertEqual(len(tasks), 2) + self.assertEqual(tasks[0].id, sample_task.id) + self.assertEqual(tasks[0].reply_subject, "reply.subject.1") + self.assertEqual(tasks[1].reply_subject, "reply.subject.2") + mock_psub.fetch.assert_called_once_with(5, timeout=5) mock_psub.unsubscribe.assert_called_once() - async def test_reserve_task_no_messages(self): - """Test reserve_task when no messages are available.""" + async def test_reserve_tasks_no_messages(self): + """Test reserve_tasks when no messages are available (timeout).""" nc, js = self._create_mock_nats_connection() + import nats.errors mock_psub = MagicMock() - mock_psub.fetch = AsyncMock(return_value=[]) + mock_psub.fetch = AsyncMock(side_effect=nats.errors.TimeoutError) mock_psub.unsubscribe = AsyncMock() js.pull_subscribe = AsyncMock(return_value=mock_psub) with patch("ami.ml.orchestration.nats_queue.get_connection", AsyncMock(return_value=(nc, js))): async with TaskQueueManager() as manager: - task = await manager.reserve_task(123) + tasks = await manager.reserve_tasks(123, count=5) - self.assertIsNone(task) + self.assertEqual(tasks, []) mock_psub.unsubscribe.assert_called_once() + async def test_reserve_tasks_single(self): + """Test reserving a single task.""" + nc, js = self._create_mock_nats_connection() + sample_task = self._create_sample_task() + + mock_msg = MagicMock() + mock_msg.data = sample_task.json().encode() + mock_msg.reply = "reply.subject.123" + + mock_psub = MagicMock() + mock_psub.fetch = AsyncMock(return_value=[mock_msg]) + mock_psub.unsubscribe = AsyncMock() + js.pull_subscribe = AsyncMock(return_value=mock_psub) + + with patch("ami.ml.orchestration.nats_queue.get_connection", AsyncMock(return_value=(nc, js))): + async with TaskQueueManager() as manager: + tasks = await manager.reserve_tasks(123, count=1) + + self.assertEqual(len(tasks), 1) + self.assertEqual(tasks[0].reply_subject, "reply.subject.123") + async def test_acknowledge_task_success(self): """Test successful task acknowledgment.""" nc, js = self._create_mock_nats_connection() @@ -144,7 +171,7 @@ async def test_operations_without_connection_raise_error(self): await manager.publish_task(123, sample_task) with self.assertRaisesRegex(RuntimeError, "Connection is not open"): - await manager.reserve_task(123) + await manager.reserve_tasks(123, count=1) with self.assertRaisesRegex(RuntimeError, "Connection is not open"): await manager.delete_stream(123)