Skip to content

Commit 9d35f9a

Browse files
heyufan1995pre-commit-ci[bot]yiheng-wang-nvbhashemian
authored
Add undefined label prompt check (Project-MONAI#693)
Fixes # . ### Description A few sentences describing the changes proposed in this pull request. ### Status **Ready/Work in progress/Hold** ### Please ensure all the checkboxes: <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Codeformat tests passed locally by running `./runtests.sh --codeformat`. - [ ] In-line docstrings updated. - [ ] Update `version` and `changelog` in `metadata.json` if changing an existing bundle. - [ ] Please ensure the naming rules in config files meet our requirements (please refer to: `CONTRIBUTING.md`). - [ ] Ensure versions of packages such as `monai`, `pytorch` and `numpy` are correct in `metadata.json`. - [ ] Descriptions should be consistent with the content, such as `eval_metrics` of the provided weights and TorchScript modules. - [ ] Files larger than 25MB are excluded and replaced by providing download links in `large_file.yml`. - [ ] Avoid using path that contains personal information within config files (such as use `/home/your_name/` for `"bundle_root"`). --------- Signed-off-by: heyufan1995 <heyufan1995@gmail.com> Signed-off-by: Yiheng Wang <vennw@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Yiheng Wang <68361391+yiheng-wang-nv@users.noreply.github.com> Co-authored-by: Dr. Bruce Hashemian <3968947+drbeh@users.noreply.github.com> Co-authored-by: Yiheng Wang <vennw@nvidia.com>
1 parent c094cda commit 9d35f9a

File tree

3 files changed

+27
-2
lines changed

3 files changed

+27
-2
lines changed

ci/unit_tests/test_vista3d.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161
TEST_CASE_INFER_MULTI_NEW_STR_PROMPT = [
6262
{
6363
"bundle_root": "models/vista3d",
64-
"input_dict": {"label_prompt": ["new class 1", "new class 2", "new class 3"]},
64+
"input_dict": {"label_prompt": ["new class 1"], "points": [[123, 212, 151]], "point_labels": [1]},
6565
"patch_size": [32, 32, 32],
6666
"checkpointloader#_disabled_": True, # do not load weights"
6767
"initialize": ["$monai.utils.set_determinism(seed=123)"],
@@ -223,6 +223,26 @@
223223
"error": "Label prompt can only be a single object if provided with point prompts.",
224224
}
225225
],
226+
[
227+
{
228+
"bundle_root": "models/vista3d",
229+
"input_dict": {"label_prompt": [16, 25, 26]},
230+
"patch_size": [32, 32, 32],
231+
"checkpointloader#_disabled_": True, # do not load weights"
232+
"initialize": ["$monai.utils.set_determinism(seed=123)"],
233+
"error": "Undefined label prompt detected. Provide point prompts for zero-shot.",
234+
}
235+
],
236+
[
237+
{
238+
"bundle_root": "models/vista3d",
239+
"input_dict": {"label_prompt": [136]},
240+
"patch_size": [32, 32, 32],
241+
"checkpointloader#_disabled_": True, # do not load weights"
242+
"initialize": ["$monai.utils.set_determinism(seed=123)"],
243+
"error": "Undefined label prompt detected. Provide point prompts for zero-shot.",
244+
}
245+
],
226246
]
227247

228248

models/vista3d/configs/metadata.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
{
22
"schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_20240725.json",
3-
"version": "0.5.3",
3+
"version": "0.5.4",
44
"changelog": {
5+
"0.5.4": "add undefined label prompt check",
56
"0.5.3": "update readme",
67
"0.5.2": "fix eval issue",
78
"0.5.1": "add description for zero-shot and upate eval",

models/vista3d/scripts/evaluator.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,10 @@ def check_prompts_format(self, label_prompt, points, point_labels):
166166
raise ValueError("Label prompt must be a list of single scalar, [1,2,3,4,...,].")
167167
if not np.all([(x < 255).item() for x in label_prompt]):
168168
raise ValueError("Current bundle only supports label prompt smaller than 255.")
169+
if points is None:
170+
supported_list = list({i + 1 for i in range(132)} - {16, 18, 129, 130, 131})
171+
if not np.all([x in supported_list for x in label_prompt]):
172+
raise ValueError("Undefined label prompt detected. Provide point prompts for zero-shot.")
169173
else:
170174
raise ValueError("Label prompt must be a list, [1,2,3,4,...,].")
171175
# check points

0 commit comments

Comments
 (0)