From c750348c569f9dc1c8321048c1ab374b2daca61d Mon Sep 17 00:00:00 2001 From: mrunibe Date: Tue, 15 Oct 2024 22:33:50 +0200 Subject: [PATCH] Support for Apple silicon using MPS #6 --- batch-dl+direct.sh | 4 +++- dl+direct.sh | 3 ++- pyproject.toml | 4 ++-- src/DeepSCAN_Anatomy_Newnet_apply.py | 4 ++-- src/bet.py | 4 +++- 5 files changed, 12 insertions(+), 7 deletions(-) diff --git a/batch-dl+direct.sh b/batch-dl+direct.sh index 9abf030..319b76e 100755 --- a/batch-dl+direct.sh +++ b/batch-dl+direct.sh @@ -72,6 +72,8 @@ export MODEL_ARGS export JOB_QUEUE=job_queue.txt [[ -f ${JOB_QUEUE} ]] && rm ${JOB_QUEUE} +TAIL=tail +[[ "`uname -s`" == "Darwin" ]] && TAIL=gtail run_dl() { SUBJ=$1 @@ -127,5 +129,5 @@ true > ${JOB_QUEUE} # create first N_PARALLEL_CPU dummy entries. Otherwise jobs will only start once N_PARALLEL_CPU jobs are queued for i in `seq 1 ${N_PARALLEL_CPU}` ; do echo dummy >> ${JOB_QUEUE} ; done -tail -n+0 -f ${JOB_QUEUE} --pid ${PID_DL} | parallel -j ${N_PARALLEL_CPU} run_direct {} +${TAIL} -n+0 -f ${JOB_QUEUE} --pid ${PID_DL} | parallel -j ${N_PARALLEL_CPU} run_direct {} diff --git a/dl+direct.sh b/dl+direct.sh index 8f228fd..30edb83 100755 --- a/dl+direct.sh +++ b/dl+direct.sh @@ -91,7 +91,7 @@ echo # convert into freesurfer space (resample to 1mm voxel, orient to LIA) python ${SCRIPT_DIR}/conform.py "${T1}" "${DST}/T1w_norm.nii.gz" -HAS_GPU=`python -c 'import torch; print(torch.cuda.is_available())'` +HAS_GPU=`python -c 'import torch; print(torch.cuda.is_available() or (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()))'` if [ ${HAS_GPU} != 'True' ] ; then echo "WARNING: No GPU/CUDA device found. Running on CPU might take some time..." fi @@ -110,6 +110,7 @@ if [ ${DO_SKULLSTRIP} -gt 0 ] ; then IN_VOLUME=${DST}/T1w_norm_noskull.nii.gz BET_INPUT_VOLUME=${DST}/T1w_norm.nii.gz MASK_VOLUME=${DST}/T1w_norm_noskull_mask.nii.gz + export PYTORCH_ENABLE_MPS_FALLBACK=1 python ${SCRIPT_DIR}/bet.py ${BET_OPTS} "${BET_INPUT_VOLUME}" "${IN_VOLUME}" || die "hd-bet failed" else diff --git a/pyproject.toml b/pyproject.toml index 77e9640..b5faefc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "DL-DiReCT" -version = "1.0.1" +version = "1.0.2" description = "DL+DiReCT - Direct Cortical Thickness Estimation using Deep Learning-based Anatomy Segmentation and Cortex Parcellation" readme = "README.md" authors = [ {name = "Michael Rebsamen"} ] @@ -20,7 +20,7 @@ classifiers = [ dependencies = [ "antspyx>=0.3.5", - "HD_BET @ https://github.com/MIC-DKFZ/HD-BET/archive/refs/heads/master.zip", + "HD_BET @ https://github.com/mrunibe/HD-BET/archive/refs/heads/master.zip", "nibabel>=3.2.1", "numpy<2.0.0", "pandas>=0.25.3", diff --git a/src/DeepSCAN_Anatomy_Newnet_apply.py b/src/DeepSCAN_Anatomy_Newnet_apply.py index 656e089..e954e06 100644 --- a/src/DeepSCAN_Anatomy_Newnet_apply.py +++ b/src/DeepSCAN_Anatomy_Newnet_apply.py @@ -399,7 +399,7 @@ def load_checkpoint(checkpoint_file, device): sys.exit(1) print('loading checkpoint {}'.format(checkpoint_file)) if VERBOSE else False - return torch.load(checkpoint_file, map_location=device) + return torch.load(checkpoint_file, weights_only=True, map_location=device) def validate_input(t1, t1_data): @@ -435,7 +435,7 @@ def validate_input(t1, t1_data): if not os.path.exists(output_dir): os.makedirs(output_dir) - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + device = torch.device('cuda' if torch.cuda.is_available() else ('mps' if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available() else 'cpu')) checkpoint = load_checkpoint(model_file, device) target_label_names = checkpoint['label_names'] # number of last labels to ignore for hard segmentation (argmax), e.g. left-hemi, right-hemi, brain diff --git a/src/bet.py b/src/bet.py index c32b4d6..f062d1c 100644 --- a/src/bet.py +++ b/src/bet.py @@ -46,7 +46,9 @@ print('Brain extraction using HD-BET [https://doi.org/10.1002/hbm.24750] ...') - if not torch.cuda.is_available(): + if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): + hdbet_device = 'mps' + elif not torch.cuda.is_available(): print('No GPU found. Running hd-bet in fast mode, check results! Make sure you have enough memory.') hdbet_mode = 'fast' hdbet_device = 'cpu'