From ad3489c62d200986a60bd7a46822ace79c538062 Mon Sep 17 00:00:00 2001 From: amickan Date: Fri, 17 Jan 2025 13:24:38 +0100 Subject: [PATCH] Add test for input validation on serializer --- app/grandchallenge/algorithms/serializers.py | 53 +++++++----- .../algorithms_tests/test_serializers.py | 82 +++++++++++++++++++ 2 files changed, 113 insertions(+), 22 deletions(-) diff --git a/app/grandchallenge/algorithms/serializers.py b/app/grandchallenge/algorithms/serializers.py index b5861a821..8c21448e0 100644 --- a/app/grandchallenge/algorithms/serializers.py +++ b/app/grandchallenge/algorithms/serializers.py @@ -31,6 +31,7 @@ HyperlinkedComponentInterfaceValueSerializer, ) from grandchallenge.core.guardian import filter_by_permission +from grandchallenge.core.templatetags.remove_whitespace import oxford_comma from grandchallenge.hanging_protocols.serializers import ( HangingProtocolSerializer, ) @@ -218,29 +219,10 @@ def validate(self, data): "You have run out of algorithm credits" ) - # validate that the provided inputs match one of the configured interfaces - # and add the matching interface to the data - provided_inputs = {i["interface"] for i in data["inputs"]} - try: - data["algorithm_interface"] = self._algorithm.interfaces.annotate( - input_count=Count("inputs", distinct=True), - relevant_input_count=Count( - "inputs", - filter=Q(inputs__in=provided_inputs), - distinct=True, - ), - ).get( - relevant_input_count=len(provided_inputs), - input_count=len(provided_inputs), - ) - except ObjectDoesNotExist: - raise serializers.ValidationError( - f"The set of inputs provided does not match " - f"any of the algorithm's interfaces. This algorithm supports the " - f"following sets of inputs: {[interface.inputs for interface in self._algorithm.interfaces.all()]}" - ) - inputs = data.pop("inputs") + data["algorithm_interface"] = ( + self.validate_inputs_and_return_matching_interface(inputs=inputs) + ) self.inputs = self.reformat_inputs(serialized_civs=inputs) if Job.objects.get_jobs_with_same_inputs( @@ -285,6 +267,33 @@ def create(self, validated_data): return job + def validate_inputs_and_return_matching_interface(self, *, inputs): + """ + Validates that the provided inputs match one of the configured interfaces of + the algorithm and returns that AlgorithmInterface + """ + provided_inputs = {i["interface"] for i in inputs} + try: + interface = self._algorithm.interfaces.annotate( + input_count=Count("inputs", distinct=True), + relevant_input_count=Count( + "inputs", + filter=Q(inputs__in=provided_inputs), + distinct=True, + ), + ).get( + relevant_input_count=len(provided_inputs), + input_count=len(provided_inputs), + ) + return interface + except ObjectDoesNotExist: + raise serializers.ValidationError( + f"The set of inputs provided does not match " + f"any of the algorithm's interfaces. This algorithm supports the " + f"following input combinations: " + f"{oxford_comma([f'Interface {n}: {oxford_comma(interface.inputs.all())}' for n, interface in enumerate(self._algorithm.interfaces.all(), start=1)])}" + ) + @staticmethod def reformat_inputs(*, serialized_civs): """Takes serialized CIV data and returns list of CIVData objects.""" diff --git a/app/tests/algorithms_tests/test_serializers.py b/app/tests/algorithms_tests/test_serializers.py index 56ce7cfdb..4e9c90d7a 100644 --- a/app/tests/algorithms_tests/test_serializers.py +++ b/app/tests/algorithms_tests/test_serializers.py @@ -451,3 +451,85 @@ def test_algorithm_post_serializer_image_and_time_limit_fixed(rf): assert job.algorithm_image != different_ai assert not job.algorithm_model assert job.time_limit == 10 + + +@pytest.mark.parametrize( + "inputs, interface", + ( + ([1], 1), # matches interface 1 of algorithm + ([1, 2], 2), # matches interface 2 of algorithm + ([3, 4, 5], 3), # matches interface 3 of algorithm + ([4], None), # matches interface 4, but not configured for algorithm + ( + [1, 2, 3], + None, + ), # matches interface 5, but not configured for algorithm + ([2], None), # matches no interface (implements part of interface 2) + ( + [1, 3, 4], + None, + ), # matches no interface (implements interface 3 and an additional input) + ), +) +@pytest.mark.django_db +def test_validate_inputs_on_job_serializer(inputs, interface, rf): + user = UserFactory() + algorithm = AlgorithmFactory() + algorithm.add_editor(user) + AlgorithmImageFactory( + algorithm=algorithm, + is_desired_version=True, + is_manifest_valid=True, + is_in_registry=True, + ) + + io1, io2, io3, io4, io5 = AlgorithmInterfaceFactory.create_batch(5) + ci1, ci2, ci3, ci4, ci5, ci6 = ComponentInterfaceFactory.create_batch( + 6, kind=ComponentInterface.Kind.STRING + ) + + interfaces = [io1, io2, io3] + cis = [ci1, ci2, ci3, ci4, ci5, ci6] + + io1.inputs.set([ci1]) + io2.inputs.set([ci1, ci2]) + io3.inputs.set([ci3, ci4, ci5]) + io4.inputs.set([ci1, ci2, ci3]) + io5.inputs.set([ci4]) + io1.outputs.set([ci6]) + io2.outputs.set([ci3]) + io3.outputs.set([ci1]) + io4.outputs.set([ci1]) + io5.outputs.set([ci1]) + + algorithm.interfaces.add(io1, through_defaults={"is_default": True}) + algorithm.interfaces.add(io2) + algorithm.interfaces.add(io3) + + algorithm_interface = interfaces[interface - 1] if interface else None + inputs = [cis[i - 1] for i in inputs] + + job = { + "algorithm": algorithm.api_url, + "inputs": [ + {"interface": int.slug, "value": "dummy"} for int in inputs + ], + } + + request = rf.get("/foo") + request.user = user + serializer = JobPostSerializer(data=job, context={"request": request}) + + if interface: + assert serializer.is_valid() + assert ( + serializer.validated_data["algorithm_interface"] + == algorithm_interface + ) + else: + assert not serializer.is_valid() + assert ( + "The set of inputs provided does not match any of the algorithm's interfaces." + in str(serializer.errors) + ) + assert "algorithm_interface" not in serializer.validated_data