Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions ami/jobs/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from ami.jobs.tasks import process_nats_pipeline_result
from ami.main.api.schemas import project_id_doc_param
from ami.main.api.views import DefaultViewSet
from ami.ml.auth import HasProcessingServiceAPIKey
from ami.utils.fields import url_boolean_param

from .models import Job, JobDispatchMode, JobState
Expand Down Expand Up @@ -146,6 +147,14 @@ class JobViewSet(DefaultViewSet, ProjectMixin):

permission_classes = [ObjectPermission]

def _update_processing_service_heartbeat(self, request):
"""Update heartbeat for the specific PS identified by API key auth."""
from ami.ml.models.processing_service import ProcessingService
from ami.ml.schemas import get_client_info

if isinstance(request.auth, ProcessingService):
request.auth.mark_seen(client_info=get_client_info(request))

def get_serializer_class(self):
"""
Return different serializers for list and detail views.
Expand Down Expand Up @@ -247,7 +256,12 @@ def list(self, request, *args, **kwargs):
responses={200: MLJobTasksResponseSerializer},
parameters=[project_id_doc_param],
)
@action(detail=True, methods=["post"], name="tasks")
@action(
detail=True,
methods=["post"],
name="tasks",
permission_classes=[ObjectPermission | HasProcessingServiceAPIKey],
)
def tasks(self, request, pk=None):
"""
Fetch tasks from the job queue (POST).
Expand Down Expand Up @@ -278,6 +292,9 @@ def tasks(self, request, pk=None):
# Record heartbeat for async processing services on this pipeline
_mark_pipeline_pull_services_seen(job)

# Per-PS heartbeat via API key auth
self._update_processing_service_heartbeat(request)

# Get tasks from NATS JetStream
from ami.ml.orchestration.nats_queue import TaskQueueManager

Expand All @@ -298,7 +315,12 @@ async def get_tasks():
responses={200: MLJobResultsResponseSerializer},
parameters=[project_id_doc_param],
)
@action(detail=True, methods=["post"], name="result")
@action(
detail=True,
methods=["post"],
name="result",
permission_classes=[ObjectPermission | HasProcessingServiceAPIKey],
)
def result(self, request, pk=None):
"""
Submit pipeline results.
Expand All @@ -313,6 +335,9 @@ def result(self, request, pk=None):
# Record heartbeat for async processing services on this pipeline
_mark_pipeline_pull_services_seen(job)

# Per-PS heartbeat via API key auth
self._update_processing_service_heartbeat(request)

serializer = MLJobResultsRequestSerializer(data=request.data)
serializer.is_valid(raise_exception=True)
validated_results = serializer.validated_data["results"]
Expand Down
26 changes: 25 additions & 1 deletion ami/ml/admin.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from django.contrib import admin
from rest_framework_api_key.admin import APIKeyModelAdmin

from ami.main.admin import AdminBase, ProjectPipelineConfigInline

from .models.algorithm import Algorithm, AlgorithmCategoryMap
from .models.pipeline import Pipeline
from .models.processing_service import ProcessingService
from .models.processing_service import ProcessingService, ProcessingServiceAPIKey


@admin.register(Algorithm)
Expand Down Expand Up @@ -70,8 +71,31 @@ class ProcessingServiceAdmin(AdminBase):
"id",
"name",
"endpoint_url",
"last_seen_live",
"created_at",
]
readonly_fields = ["last_seen_client_info"]

@admin.action(description="Generate API key for selected processing services")
def generate_api_key(self, request, queryset):
for ps in queryset:
api_key_obj, plaintext_key = ProcessingServiceAPIKey.objects.create_key(
name=f"{ps.name} key",
processing_service=ps,
)
self.message_user(
request,
f"{ps.name}: {plaintext_key} (copy now — it won't be shown again)",
)

actions = [generate_api_key]


@admin.register(ProcessingServiceAPIKey)
class ProcessingServiceAPIKeyAdmin(APIKeyModelAdmin):
list_display = [*APIKeyModelAdmin.list_display, "processing_service"]
list_filter = ["processing_service"]
search_fields = [*APIKeyModelAdmin.search_fields, "processing_service__name"]


@admin.register(AlgorithmCategoryMap)
Expand Down
102 changes: 102 additions & 0 deletions ami/ml/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
"""
API key authentication for processing services.

Uses djangorestframework-api-key to provide key-based auth. Each ProcessingService
can have one or more API keys. When a request arrives with `Authorization: Api-Key <key>`,
the authentication class identifies the ProcessingService and sets request.auth to it.

Contains:
- ProcessingServiceAPIKeyAuthentication: DRF auth backend
- HasProcessingServiceAPIKey: DRF permission class

The ProcessingServiceAPIKey model lives in ami.ml.models.processing_service.
"""

import logging

from rest_framework import authentication, exceptions, permissions
from rest_framework_api_key.permissions import KeyParser

from ami.ml.models.processing_service import ProcessingServiceAPIKey

logger = logging.getLogger(__name__)


class ProcessingServiceAPIKeyAuthentication(authentication.BaseAuthentication):
"""
DRF authentication class that identifies a ProcessingService from an API key.

Sets:
request.user = AnonymousUser (required by django-guardian/ObjectPermission)
request.auth = ProcessingService instance

This allows views to check `request.auth` to get the calling service,
and permission classes to verify project access.
"""

key_parser = KeyParser()

def authenticate(self, request):
key = self.key_parser.get(request)
if not key:
return None # No Api-Key header; fall through to next auth class

try:
api_key = ProcessingServiceAPIKey.objects.get_from_key(key)
except ProcessingServiceAPIKey.DoesNotExist:
raise exceptions.AuthenticationFailed("Invalid API key.")

if not api_key.is_valid:
raise exceptions.AuthenticationFailed("API key has been revoked or expired.")

from django.contrib.auth.models import AnonymousUser

return (AnonymousUser(), api_key.processing_service)

def authenticate_header(self, request):
return "Api-Key"
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this class and any methods already part of the DRF api_key package?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude says: BaseHasAPIKey from the library handles key validation (is_valid() check), but our HasProcessingServiceAPIKey serves a different purpose — it checks authorization (project membership) after the auth backend has already validated the key. The auth backend (ProcessingServiceAPIKeyAuthentication) handles authentication and puts the PS on request.auth. The permission class then checks that the PS belongs to the right project. So they complement each other rather than overlap — we can't inherit from BaseHasAPIKey because we need the isinstance(request.auth, ProcessingService) check rather than re-validating the key.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude says: The library provides BaseHasAPIKey (permission class) and KeyParser (header extraction), but not a DRF authentication backend. Our ProcessingServiceAPIKeyAuthentication serves a different role — it's an auth backend that sets request.auth = ProcessingService, which DRF's auth pipeline requires. The library's BaseHasAPIKey is a permission class that checks key validity but doesn't identify the caller.

We use KeyParser from the library (for header parsing) and its get_from_key() manager method (for hashed lookup). The auth backend and permission class are ours because they need to: (1) place the PS on request.auth, and (2) check project membership — both are app-specific concerns.



class HasProcessingServiceAPIKey(permissions.BasePermission):
"""
Allow access for requests authenticated with a ProcessingService API key.

The auth backend places the ProcessingService on request.auth.
This permission verifies project membership.

Compose with ObjectPermission for endpoints used by both users and services:
permission_classes = [ObjectPermission | HasProcessingServiceAPIKey]
"""

def has_permission(self, request, view):
from ami.ml.models.processing_service import ProcessingService

if not isinstance(request.auth, ProcessingService):
return False

# For detail views (e.g. /jobs/{pk}/tasks/), defer project scoping
# to has_object_permission where we can derive it from the object.
if view.kwargs.get("pk"):
return True

get_active_project = getattr(view, "get_active_project", None)
if not callable(get_active_project):
return False

project = get_active_project()
if not project:
return False

return request.auth.projects.filter(pk=project.pk).exists()

def has_object_permission(self, request, view, obj):
from ami.ml.models.processing_service import ProcessingService

if not isinstance(request.auth, ProcessingService):
return False

ps = request.auth
project = obj.get_project() if hasattr(obj, "get_project") else None
if not project:
return False
return ps.projects.filter(pk=project.pk).exists()
67 changes: 67 additions & 0 deletions ami/ml/migrations/0029_api_key_and_client_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Generated by Django 4.2.10 on 2026-03-29 05:36

from django.db import migrations, models
import django.db.models.deletion


class Migration(migrations.Migration):
dependencies = [
("ml", "0028_normalize_empty_endpoint_url_to_null"),
]

operations = [
migrations.AddField(
model_name="processingservice",
name="last_seen_client_info",
field=models.JSONField(blank=True, null=True),
),
migrations.CreateModel(
name="ProcessingServiceAPIKey",
fields=[
(
"id",
models.CharField(editable=False, max_length=150, primary_key=True, serialize=False, unique=True),
),
("prefix", models.CharField(editable=False, max_length=8, unique=True)),
("hashed_key", models.CharField(editable=False, max_length=150)),
("created", models.DateTimeField(auto_now_add=True, db_index=True)),
(
"name",
models.CharField(
default="",
help_text="A free-form name for the API key. Need not be unique. 50 characters max.",
max_length=50,
),
),
(
"revoked",
models.BooleanField(
blank=True,
default=False,
help_text="If the API key is revoked, clients cannot use it anymore. (This cannot be undone.)",
),
),
(
"expiry_date",
models.DateTimeField(
blank=True,
help_text="Once API key expires, clients cannot use it anymore.",
null=True,
verbose_name="Expires",
),
),
(
"processing_service",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE, related_name="api_keys", to="ml.processingservice"
),
),
],
options={
"verbose_name": "Processing Service API Key",
"verbose_name_plural": "Processing Service API Keys",
"ordering": ("-created",),
"abstract": False,
},
),
]
3 changes: 2 additions & 1 deletion ami/ml/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from ami.ml.models.algorithm import Algorithm, AlgorithmCategoryMap
from ami.ml.models.pipeline import Pipeline
from ami.ml.models.processing_service import ProcessingService
from ami.ml.models.processing_service import ProcessingService, ProcessingServiceAPIKey
from ami.ml.models.project_pipeline_config import ProjectPipelineConfig

__all__ = [
"Algorithm",
"AlgorithmCategoryMap",
"Pipeline",
"ProcessingService",
"ProcessingServiceAPIKey",
"ProjectPipelineConfig",
]
Loading
Loading