Skip to content

Commit 936085c

Browse files
Fixes regarding the part_id and the weights download
1 parent b11854d commit 936085c

File tree

2 files changed

+11
-6
lines changed

2 files changed

+11
-6
lines changed

DLMUSE/__main__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,14 @@ def main() -> None:
3939
# Required Arguments
4040
parser.add_argument(
4141
"-i",
42+
"--in_dir",
4243
type=str,
4344
required=True,
4445
help="[REQUIRED] Input folder with LPS oriented T1 sMRI Intra Cranial Volumes (ICV) in Nifti format (nii.gz).",
4546
)
4647
parser.add_argument(
4748
"-o",
49+
"--out_dir",
4850
type=str,
4951
required=True,
5052
help="[REQUIRED] Output folder for the segmentation results in Nifti format (nii.gz).",
@@ -217,6 +219,7 @@ def main() -> None:
217219
args.device,
218220
args.clear_cache,
219221
args.d,
222+
args.c,
220223
args.part_id,
221224
args.num_parts,
222225
args.step_size,

DLMUSE/dlmuse_pipeline.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ def run_pipeline(
1515
out_dir: str,
1616
device: str,
1717
clear_cache: bool = False,
18-
d: str = "901",
18+
d: str = "903",
19+
c: str = "3d_fullres",
1920
part_id: int = 0,
2021
num_parts: int = 1,
2122
step_size: float = 0.5,
@@ -75,7 +76,7 @@ def run_pipeline(
7576
model_folder = os.path.join(
7677
Path(__file__).parent,
7778
"nnunet_results",
78-
"Dataset%s_Task%s_dlicv/nnUNetTrainer__nnUNetPlans__3d_fullres/" % (d, d),
79+
"Dataset%s_Task%s_DLMUSEV2/nnUNetTrainer__nnUNetPlans__%s/" % (d, d, c),
7980
)
8081

8182
if clear_cache:
@@ -90,15 +91,16 @@ def run_pipeline(
9091
from huggingface_hub import snapshot_download
9192

9293
local_src = Path(__file__).parent
93-
snapshot_download(repo_id="nichart/DLICV", local_dir=local_src)
94-
print("DLICV model has been successfully downloaded!")
94+
snapshot_download(repo_id="nichart/DLMUSE", local_dir=local_src)
95+
print("DLMUSE model has been successfully downloaded!")
9596
else:
9697
print("Loading the model...")
9798

9899
prepare_data_folder(out_dir)
99100

100-
# Check for invalid arguments - advise users to see nnUNetv2 documentation
101-
assert part_id < num_parts, "See nnUNetv2_predict -h."
101+
assert (
102+
part_id < num_parts
103+
), "part_id < num_parts. Please see nnUNetv2_predict -h."
102104

103105
assert device in [
104106
"cpu",

0 commit comments

Comments
 (0)