Skip to content

Commit 320834d

Browse files
Reformat code
1 parent 9d6e9ad commit 320834d

File tree

4 files changed

+29
-35
lines changed

4 files changed

+29
-35
lines changed

DLMUSE/__main__.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import argparse
2-
import pkg_resources
2+
3+
import pkg_resources # type: ignore
34

45
from DLMUSE.dlmuse_pipeline import run_dlmuse_pipeline
56

67
VERSION = pkg_resources.require("DLMUSE")[0].version
78

9+
810
def main() -> None:
9-
prog="DLMUSE"
11+
prog = "DLMUSE"
1012
parser = argparse.ArgumentParser(
1113
prog=prog,
1214
description="DLMUSE - MUlti-atlas region Segmentation utilizing Ensembles of registration algorithms and parameters.",
@@ -26,8 +28,10 @@ def main() -> None:
2628
-o /path/to/output \
2729
-device cpu|cuda|mps
2830
29-
""".format(VERSION=VERSION),
30-
add_help=False
31+
""".format(
32+
VERSION=VERSION
33+
),
34+
add_help=False,
3135
)
3236

3337
# Required Arguments
@@ -91,7 +95,7 @@ def main() -> None:
9195
action="store_true",
9296
required=False,
9397
default=False,
94-
help="Set this flag to clear any cached models before running. This is recommended if a previous download failed."
98+
help="Set this flag to clear any cached models before running. This is recommended if a previous download failed.",
9599
)
96100
parser.add_argument(
97101
"--disable_tta",
@@ -101,7 +105,7 @@ def main() -> None:
101105
help="[nnUnet Arg] Set this flag to disable test time data augmentation in the form of mirroring. "
102106
"Faster, but less accurate inference. Not recommended.",
103107
)
104-
### DEPRECIATED ####
108+
# DEPRECIATED ##
105109
# parser.add_argument(
106110
# "-m",
107111
# type=str,
@@ -227,8 +231,9 @@ def main() -> None:
227231
args.nps,
228232
args.prev_stage_predictions,
229233
args.num_parts,
230-
args.part_id
234+
args.part_id,
231235
)
232236

237+
233238
if __name__ == "__main__":
234239
main()

DLMUSE/dlmuse_pipeline.py

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
from typing import Optional
2-
import logging
31
import json
2+
import logging
43
import os
5-
from pathlib import Path
64
import shutil
75
import sys
86
import warnings
7+
from pathlib import Path
8+
from typing import Optional
99

1010
import torch
1111

@@ -14,6 +14,7 @@
1414
warnings.simplefilter(action="ignore", category=FutureWarning)
1515
warnings.simplefilter(action="ignore", category=UserWarning)
1616

17+
1718
def run_dlmuse_pipeline(
1819
in_dir: str,
1920
out_dir: str,
@@ -36,17 +37,11 @@ def run_dlmuse_pipeline(
3637
prev_stage_predictions: Optional[str] = None,
3738
num_parts: int = 1,
3839
part_id: int = 0,
39-
):
40+
) -> None:
4041

4142
if clear_cache:
42-
shutil.rmtree(os.path.join(
43-
Path(__file__).parent,
44-
"nnunet_results"
45-
))
46-
shutil.rmtree(os.path.join(
47-
Path(__file__).parent,
48-
".cache"
49-
))
43+
shutil.rmtree(os.path.join(Path(__file__).parent, "nnunet_results"))
44+
shutil.rmtree(os.path.join(Path(__file__).parent, ".cache"))
5045
if not in_dir or not out_dir:
5146
logging.error("Cache cleared and missing either -i / -o. Exiting.")
5247
sys.exit(0)
@@ -80,18 +75,17 @@ def run_dlmuse_pipeline(
8075
model_folder = os.path.join(
8176
Path(__file__).parent,
8277
"nnunet_results",
83-
"Dataset%s_Task%s_DLMUSEV2/nnUNetTrainer__nnUNetPlans__%s/"
84-
% (d, d, c),
78+
"Dataset%s_Task%s_DLMUSEV2/nnUNetTrainer__nnUNetPlans__%s/" % (d, d, c),
8579
)
8680

87-
8881
# Check if model exists. If not exist, download using HuggingFace
8982
logging.info(f"Using model folder: {model_folder}")
9083
if not os.path.exists(model_folder):
9184
# HF download model
9285
logging.info("DLMUSE model not found, downloading...")
9386

9487
from huggingface_hub import snapshot_download
88+
9589
local_src = Path(__file__).parent
9690
snapshot_download(repo_id="nichart/DLMUSE", local_dir=local_src)
9791

@@ -101,9 +95,7 @@ def run_dlmuse_pipeline(
10195

10296
prepare_data_folder(des_folder)
10397

104-
assert (
105-
part_id < num_parts
106-
), "part_id < num_parts. Please see nnUNetv2_predict -h."
98+
assert part_id < num_parts, "part_id < num_parts. Please see nnUNetv2_predict -h."
10799

108100
assert device in [
109101
"cpu",
@@ -113,6 +105,7 @@ def run_dlmuse_pipeline(
113105

114106
if device == "cpu":
115107
import multiprocessing
108+
116109
# use half of the available threads in the system.
117110
torch.set_num_threads(multiprocessing.cpu_count() // 2)
118111
device = torch.device("cpu")
@@ -149,9 +142,7 @@ def run_dlmuse_pipeline(
149142
)
150143

151144
# Retrieve the model and it's weight
152-
predictor.initialize_from_trained_model_folder(
153-
model_folder, f, checkpoint_name=chk
154-
)
145+
predictor.initialize_from_trained_model_folder(model_folder, f, checkpoint_name=chk)
155146

156147
# Final prediction
157148
predictor.predict_from_files(

DLMUSE/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def rename_and_copy_files(src_folder: str, des_folder: str) -> Tuple[dict, dict]
2626
2727
"""
2828
if not os.path.exists(src_folder):
29-
raise FileNotFoundError(f"Source folder '{src_folder}' does not exist.")
29+
raise FileNotFoundError(f"Source folder '{src_folder}' does not exist.")
3030
if not os.path.exists(des_folder):
3131
raise FileNotFoundError(f"Source folder '{des_folder}' does not exist.")
3232

@@ -36,7 +36,7 @@ def rename_and_copy_files(src_folder: str, des_folder: str) -> Tuple[dict, dict]
3636

3737
for idx, filename in enumerate(files):
3838
old_name = os.path.join(src_folder, filename)
39-
if not os.path.isfile(old_name): # We only want files!
39+
if not os.path.isfile(old_name): # We only want files!
4040
continue
4141
rename_file = f"case_{idx: 04d}_0000.nii.gz"
4242
rename_back = f"case_{idx: 04d}.nii.gz"
@@ -46,6 +46,6 @@ def rename_and_copy_files(src_folder: str, des_folder: str) -> Tuple[dict, dict]
4646
rename_dict[filename] = rename_file
4747
rename_back_dict[rename_back] = "DLMUSE_mask_" + filename
4848
except Exception as e:
49-
print(F"Error copying file '{filename}' to '{new_name}': {e}")
49+
print(f"Error copying file '{filename}' to '{new_name}': {e}")
5050

5151
return rename_dict, rename_back_dict

setup.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
this_directory = Path(__file__).parent
88
long_description = (this_directory / "README.md").read_text()
99

10-
with open('requirements.txt') as f:
10+
with open("requirements.txt") as f:
1111
required = f.read().splitlines()
1212

1313

@@ -46,7 +46,5 @@
4646
"nnU-Net",
4747
"nnunet",
4848
],
49-
package_data={
50-
"DLMUSE": ["VERSION"]
51-
},
49+
package_data={"DLMUSE": ["VERSION"]},
5250
)

0 commit comments

Comments
 (0)