diff --git a/ci/unit_tests/test_vista3d.py b/ci/unit_tests/test_vista3d.py index 640acbaf..695a616d 100644 --- a/ci/unit_tests/test_vista3d.py +++ b/ci/unit_tests/test_vista3d.py @@ -61,7 +61,7 @@ TEST_CASE_INFER_MULTI_NEW_STR_PROMPT = [ { "bundle_root": "models/vista3d", - "input_dict": {"label_prompt": ["new class 1", "new class 2", "new class 3"]}, + "input_dict": {"label_prompt": ["new class 1"], "points": [[123, 212, 151]], "point_labels": [1]}, "patch_size": [32, 32, 32], "checkpointloader#_disabled_": True, # do not load weights" "initialize": ["$monai.utils.set_determinism(seed=123)"], @@ -223,6 +223,26 @@ "error": "Label prompt can only be a single object if provided with point prompts.", } ], + [ + { + "bundle_root": "models/vista3d", + "input_dict": {"label_prompt": [16, 25, 26]}, + "patch_size": [32, 32, 32], + "checkpointloader#_disabled_": True, # do not load weights" + "initialize": ["$monai.utils.set_determinism(seed=123)"], + "error": "Undefined label prompt detected. Provide point prompts for zero-shot.", + } + ], + [ + { + "bundle_root": "models/vista3d", + "input_dict": {"label_prompt": [136]}, + "patch_size": [32, 32, 32], + "checkpointloader#_disabled_": True, # do not load weights" + "initialize": ["$monai.utils.set_determinism(seed=123)"], + "error": "Undefined label prompt detected. Provide point prompts for zero-shot.", + } + ], ] diff --git a/models/vista3d/configs/metadata.json b/models/vista3d/configs/metadata.json index 285db4f2..923cc380 100644 --- a/models/vista3d/configs/metadata.json +++ b/models/vista3d/configs/metadata.json @@ -1,7 +1,8 @@ { "schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_20240725.json", - "version": "0.5.3", + "version": "0.5.4", "changelog": { + "0.5.4": "add undefined label prompt check", "0.5.3": "update readme", "0.5.2": "fix eval issue", "0.5.1": "add description for zero-shot and upate eval", diff --git a/models/vista3d/scripts/evaluator.py b/models/vista3d/scripts/evaluator.py index 8e26451b..f20261b3 100644 --- a/models/vista3d/scripts/evaluator.py +++ b/models/vista3d/scripts/evaluator.py @@ -166,6 +166,10 @@ def check_prompts_format(self, label_prompt, points, point_labels): raise ValueError("Label prompt must be a list of single scalar, [1,2,3,4,...,].") if not np.all([(x < 255).item() for x in label_prompt]): raise ValueError("Current bundle only supports label prompt smaller than 255.") + if points is None: + supported_list = list({i + 1 for i in range(132)} - {16, 18, 129, 130, 131}) + if not np.all([x in supported_list for x in label_prompt]): + raise ValueError("Undefined label prompt detected. Provide point prompts for zero-shot.") else: raise ValueError("Label prompt must be a list, [1,2,3,4,...,].") # check points