Skip to content

Commit

Permalink
Refactor input output annotation
Browse files Browse the repository at this point in the history
  • Loading branch information
amickan committed Jan 23, 2025
1 parent 118fed8 commit 994c43b
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 17 deletions.
27 changes: 18 additions & 9 deletions app/grandchallenge/algorithms/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,22 @@ def create(
def delete(self):
raise NotImplementedError("Bulk delete is not allowed.")

def with_input_output_counts(self, inputs=None, outputs=None):
return self.annotate(
input_count=Count("inputs", distinct=True),
output_count=Count("outputs", distinct=True),
relevant_input_count=Count(
"inputs",
filter=Q(inputs__in=inputs) if inputs is not None else Q(),
distinct=True,
),
relevant_output_count=Count(
"outputs",
filter=Q(outputs__in=outputs) if outputs is not None else Q(),
distinct=True,
),
)


class AlgorithmInterface(UUIDModel):
inputs = models.ManyToManyField(
Expand Down Expand Up @@ -124,15 +140,8 @@ def get_existing_interface_for_inputs_and_outputs(
*, inputs, outputs, model=AlgorithmInterface
):
try:
return model.objects.annotate(
input_count=Count("inputs", distinct=True),
output_count=Count("outputs", distinct=True),
relevant_input_count=Count(
"inputs", filter=Q(inputs__in=inputs), distinct=True
),
relevant_output_count=Count(
"outputs", filter=Q(outputs__in=outputs), distinct=True
),
return model.objects.with_input_output_counts(
inputs=inputs, outputs=outputs
).get(
relevant_input_count=len(inputs),
relevant_output_count=len(outputs),
Expand Down
10 changes: 2 additions & 8 deletions app/grandchallenge/algorithms/serializers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import logging

from django.core.exceptions import ObjectDoesNotExist, ValidationError
from django.db.models import Count, Q
from rest_framework import serializers
from rest_framework.fields import (
CharField,
Expand Down Expand Up @@ -273,13 +272,8 @@ def validate_inputs_and_return_matching_interface(self, *, inputs):
"""
provided_inputs = {i["interface"] for i in inputs}
try:
interface = self._algorithm.interfaces.annotate(
input_count=Count("inputs", distinct=True),
relevant_input_count=Count(
"inputs",
filter=Q(inputs__in=provided_inputs),
distinct=True,
),
interface = self._algorithm.interfaces.with_input_output_counts(
inputs=provided_inputs
).get(
relevant_input_count=len(provided_inputs),
input_count=len(provided_inputs),
Expand Down

0 comments on commit 994c43b

Please sign in to comment.