Skip to content

Commit

Permalink
Add test for input validation on serializer
Browse files Browse the repository at this point in the history
  • Loading branch information
amickan committed Jan 17, 2025
1 parent d1b829b commit ad3489c
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 22 deletions.
53 changes: 31 additions & 22 deletions app/grandchallenge/algorithms/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
HyperlinkedComponentInterfaceValueSerializer,
)
from grandchallenge.core.guardian import filter_by_permission
from grandchallenge.core.templatetags.remove_whitespace import oxford_comma
from grandchallenge.hanging_protocols.serializers import (
HangingProtocolSerializer,
)
Expand Down Expand Up @@ -218,29 +219,10 @@ def validate(self, data):
"You have run out of algorithm credits"
)

# validate that the provided inputs match one of the configured interfaces
# and add the matching interface to the data
provided_inputs = {i["interface"] for i in data["inputs"]}
try:
data["algorithm_interface"] = self._algorithm.interfaces.annotate(
input_count=Count("inputs", distinct=True),
relevant_input_count=Count(
"inputs",
filter=Q(inputs__in=provided_inputs),
distinct=True,
),
).get(
relevant_input_count=len(provided_inputs),
input_count=len(provided_inputs),
)
except ObjectDoesNotExist:
raise serializers.ValidationError(
f"The set of inputs provided does not match "
f"any of the algorithm's interfaces. This algorithm supports the "
f"following sets of inputs: {[interface.inputs for interface in self._algorithm.interfaces.all()]}"
)

inputs = data.pop("inputs")
data["algorithm_interface"] = (
self.validate_inputs_and_return_matching_interface(inputs=inputs)
)
self.inputs = self.reformat_inputs(serialized_civs=inputs)

if Job.objects.get_jobs_with_same_inputs(
Expand Down Expand Up @@ -285,6 +267,33 @@ def create(self, validated_data):

return job

def validate_inputs_and_return_matching_interface(self, *, inputs):
"""
Validates that the provided inputs match one of the configured interfaces of
the algorithm and returns that AlgorithmInterface
"""
provided_inputs = {i["interface"] for i in inputs}
try:
interface = self._algorithm.interfaces.annotate(
input_count=Count("inputs", distinct=True),
relevant_input_count=Count(
"inputs",
filter=Q(inputs__in=provided_inputs),
distinct=True,
),
).get(
relevant_input_count=len(provided_inputs),
input_count=len(provided_inputs),
)
return interface
except ObjectDoesNotExist:
raise serializers.ValidationError(
f"The set of inputs provided does not match "
f"any of the algorithm's interfaces. This algorithm supports the "
f"following input combinations: "
f"{oxford_comma([f'Interface {n}: {oxford_comma(interface.inputs.all())}' for n, interface in enumerate(self._algorithm.interfaces.all(), start=1)])}"
)

@staticmethod
def reformat_inputs(*, serialized_civs):
"""Takes serialized CIV data and returns list of CIVData objects."""
Expand Down
82 changes: 82 additions & 0 deletions app/tests/algorithms_tests/test_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,3 +451,85 @@ def test_algorithm_post_serializer_image_and_time_limit_fixed(rf):
assert job.algorithm_image != different_ai
assert not job.algorithm_model
assert job.time_limit == 10


@pytest.mark.parametrize(
"inputs, interface",
(
([1], 1), # matches interface 1 of algorithm
([1, 2], 2), # matches interface 2 of algorithm
([3, 4, 5], 3), # matches interface 3 of algorithm
([4], None), # matches interface 4, but not configured for algorithm
(
[1, 2, 3],
None,
), # matches interface 5, but not configured for algorithm
([2], None), # matches no interface (implements part of interface 2)
(
[1, 3, 4],
None,
), # matches no interface (implements interface 3 and an additional input)
),
)
@pytest.mark.django_db
def test_validate_inputs_on_job_serializer(inputs, interface, rf):
user = UserFactory()
algorithm = AlgorithmFactory()
algorithm.add_editor(user)
AlgorithmImageFactory(
algorithm=algorithm,
is_desired_version=True,
is_manifest_valid=True,
is_in_registry=True,
)

io1, io2, io3, io4, io5 = AlgorithmInterfaceFactory.create_batch(5)
ci1, ci2, ci3, ci4, ci5, ci6 = ComponentInterfaceFactory.create_batch(
6, kind=ComponentInterface.Kind.STRING
)

interfaces = [io1, io2, io3]
cis = [ci1, ci2, ci3, ci4, ci5, ci6]

io1.inputs.set([ci1])
io2.inputs.set([ci1, ci2])
io3.inputs.set([ci3, ci4, ci5])
io4.inputs.set([ci1, ci2, ci3])
io5.inputs.set([ci4])
io1.outputs.set([ci6])
io2.outputs.set([ci3])
io3.outputs.set([ci1])
io4.outputs.set([ci1])
io5.outputs.set([ci1])

algorithm.interfaces.add(io1, through_defaults={"is_default": True})
algorithm.interfaces.add(io2)
algorithm.interfaces.add(io3)

algorithm_interface = interfaces[interface - 1] if interface else None
inputs = [cis[i - 1] for i in inputs]

job = {
"algorithm": algorithm.api_url,
"inputs": [
{"interface": int.slug, "value": "dummy"} for int in inputs
],
}

request = rf.get("/foo")
request.user = user
serializer = JobPostSerializer(data=job, context={"request": request})

if interface:
assert serializer.is_valid()
assert (
serializer.validated_data["algorithm_interface"]
== algorithm_interface
)
else:
assert not serializer.is_valid()
assert (
"The set of inputs provided does not match any of the algorithm's interfaces."
in str(serializer.errors)
)
assert "algorithm_interface" not in serializer.validated_data

0 comments on commit ad3489c

Please sign in to comment.