From cc9642f6cafd2da4fd836d85c177f8004bccbfee Mon Sep 17 00:00:00 2001 From: Ammar Ammar <43293485+ammar257ammar@users.noreply.github.com> Date: Thu, 12 Dec 2024 12:17:09 +0100 Subject: [PATCH] Fix Hanging Protocols json field raised exception (#3727) closes #3724 --- app/grandchallenge/hanging_protocols/forms.py | 18 ++-- .../hanging_protocols_tests/test_forms.py | 82 +++++++++++++++++++ 2 files changed, 95 insertions(+), 5 deletions(-) diff --git a/app/grandchallenge/hanging_protocols/forms.py b/app/grandchallenge/hanging_protocols/forms.py index 5ed519698..ad032fd3c 100644 --- a/app/grandchallenge/hanging_protocols/forms.py +++ b/app/grandchallenge/hanging_protocols/forms.py @@ -58,13 +58,21 @@ def __init__(self, *args, **kwargs): ) def clean_json(self): - json = self.cleaned_data["json"] - viewport_names = [x["viewport_name"] for x in json] + hanging_protocol_json = self.cleaned_data["json"] + + try: + viewport_names = [ + viewport["viewport_name"] for viewport in hanging_protocol_json + ] + except (KeyError, TypeError): + raise ValidationError( + "Hanging protocol definition is invalid. Have a look at the example in the helptext." + ) self._validate_viewport_uniqueness(viewport_names=viewport_names) - self._validate_dimensions(value=json) + self._validate_dimensions(value=hanging_protocol_json) - for viewport in json: + for viewport in hanging_protocol_json: if "parent_id" in viewport: self._validate_parent_id( viewport=viewport, viewport_names=viewport_names @@ -74,7 +82,7 @@ def clean_json(self): viewport=viewport, viewport_names=viewport_names ) - return json + return hanging_protocol_json def _validate_viewport_uniqueness(self, *, viewport_names): if len(set(viewport_names)) != len(viewport_names): diff --git a/app/tests/hanging_protocols_tests/test_forms.py b/app/tests/hanging_protocols_tests/test_forms.py index 45b018093..81fa7d73c 100644 --- a/app/tests/hanging_protocols_tests/test_forms.py +++ b/app/tests/hanging_protocols_tests/test_forms.py @@ -316,6 +316,88 @@ def test_hanging_protocol_clientside(): assert form.is_valid() +@pytest.mark.django_db +@pytest.mark.parametrize( + "hanging_protocol_json, form_is_valid, expected_json_error", + ( + ("[]", False, "This field is required."), + ("{}", False, "This field is required."), + ( + '[{"viewport_name": "main"}]', + True, + None, + ), + ( + 12345, + False, + "Hanging protocol definition is invalid. Have a look at the example in the helptext.", + ), + ("main", False, "Enter a valid JSON."), + ( + "[1,2,3,4,5]", + False, + "Hanging protocol definition is invalid. Have a look at the example in the helptext.", + ), + ( + '["test1", "test2", "test3"]', + False, + "Hanging protocol definition is invalid. Have a look at the example in the helptext.", + ), + ( + "[[],[],[]]", + False, + "Hanging protocol definition is invalid. Have a look at the example in the helptext.", + ), + ( + "[{},{},{}]", + False, + "Hanging protocol definition is invalid. Have a look at the example in the helptext.", + ), + ( + "true", + False, + "Hanging protocol definition is invalid. Have a look at the example in the helptext.", + ), + ( + "false", + False, + "Hanging protocol definition is invalid. Have a look at the example in the helptext.", + ), + ( + '{"viewport_name": "main"}', + False, + "Hanging protocol definition is invalid. Have a look at the example in the helptext.", + ), + ( + "[{}]", + False, + "Hanging protocol definition is invalid. Have a look at the example in the helptext.", + ), + ( + '[{"test":1}]', + False, + "Hanging protocol definition is invalid. Have a look at the example in the helptext.", + ), + ( + '[{"test1":"main"},{"test2":"secondary"}]', + False, + "Hanging protocol definition is invalid. Have a look at the example in the helptext.", + ), + ), +) +def test_hanging_protocol_form_json_validation( + hanging_protocol_json, form_is_valid, expected_json_error +): + form = HangingProtocolForm( + { + "title": "main", + "json": hanging_protocol_json, + } + ) + assert form.is_valid() is form_is_valid + assert form.errors.get("json", [None])[0] == expected_json_error + + def make_ci_list( *, number_of_images,