Skip to content

Commit

Permalink
Support for Apple silicon using MPS #6
Browse files Browse the repository at this point in the history
  • Loading branch information
mrunibe committed Oct 15, 2024
1 parent 9d991bd commit c750348
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 7 deletions.
4 changes: 3 additions & 1 deletion batch-dl+direct.sh
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ export MODEL_ARGS
export JOB_QUEUE=job_queue.txt
[[ -f ${JOB_QUEUE} ]] && rm ${JOB_QUEUE}

TAIL=tail
[[ "`uname -s`" == "Darwin" ]] && TAIL=gtail

run_dl() {
SUBJ=$1
Expand Down Expand Up @@ -127,5 +129,5 @@ true > ${JOB_QUEUE}
# create first N_PARALLEL_CPU dummy entries. Otherwise jobs will only start once N_PARALLEL_CPU jobs are queued
for i in `seq 1 ${N_PARALLEL_CPU}` ; do echo dummy >> ${JOB_QUEUE} ; done

tail -n+0 -f ${JOB_QUEUE} --pid ${PID_DL} | parallel -j ${N_PARALLEL_CPU} run_direct {}
${TAIL} -n+0 -f ${JOB_QUEUE} --pid ${PID_DL} | parallel -j ${N_PARALLEL_CPU} run_direct {}

3 changes: 2 additions & 1 deletion dl+direct.sh
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ echo
# convert into freesurfer space (resample to 1mm voxel, orient to LIA)
python ${SCRIPT_DIR}/conform.py "${T1}" "${DST}/T1w_norm.nii.gz"

HAS_GPU=`python -c 'import torch; print(torch.cuda.is_available())'`
HAS_GPU=`python -c 'import torch; print(torch.cuda.is_available() or (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()))'`
if [ ${HAS_GPU} != 'True' ] ; then
echo "WARNING: No GPU/CUDA device found. Running on CPU might take some time..."
fi
Expand All @@ -110,6 +110,7 @@ if [ ${DO_SKULLSTRIP} -gt 0 ] ; then
IN_VOLUME=${DST}/T1w_norm_noskull.nii.gz
BET_INPUT_VOLUME=${DST}/T1w_norm.nii.gz
MASK_VOLUME=${DST}/T1w_norm_noskull_mask.nii.gz
export PYTORCH_ENABLE_MPS_FALLBACK=1

python ${SCRIPT_DIR}/bet.py ${BET_OPTS} "${BET_INPUT_VOLUME}" "${IN_VOLUME}" || die "hd-bet failed"
else
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "DL-DiReCT"
version = "1.0.1"
version = "1.0.2"
description = "DL+DiReCT - Direct Cortical Thickness Estimation using Deep Learning-based Anatomy Segmentation and Cortex Parcellation"
readme = "README.md"
authors = [ {name = "Michael Rebsamen"} ]
Expand All @@ -20,7 +20,7 @@ classifiers = [

dependencies = [
"antspyx>=0.3.5",
"HD_BET @ https://github.com/MIC-DKFZ/HD-BET/archive/refs/heads/master.zip",
"HD_BET @ https://github.com/mrunibe/HD-BET/archive/refs/heads/master.zip",
"nibabel>=3.2.1",
"numpy<2.0.0",
"pandas>=0.25.3",
Expand Down
4 changes: 2 additions & 2 deletions src/DeepSCAN_Anatomy_Newnet_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ def load_checkpoint(checkpoint_file, device):
sys.exit(1)

print('loading checkpoint {}'.format(checkpoint_file)) if VERBOSE else False
return torch.load(checkpoint_file, map_location=device)
return torch.load(checkpoint_file, weights_only=True, map_location=device)


def validate_input(t1, t1_data):
Expand Down Expand Up @@ -435,7 +435,7 @@ def validate_input(t1, t1_data):
if not os.path.exists(output_dir):
os.makedirs(output_dir)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cuda' if torch.cuda.is_available() else ('mps' if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available() else 'cpu'))
checkpoint = load_checkpoint(model_file, device)
target_label_names = checkpoint['label_names']
# number of last labels to ignore for hard segmentation (argmax), e.g. left-hemi, right-hemi, brain
Expand Down
4 changes: 3 additions & 1 deletion src/bet.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@

print('Brain extraction using HD-BET [https://doi.org/10.1002/hbm.24750] ...')

if not torch.cuda.is_available():
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
hdbet_device = 'mps'
elif not torch.cuda.is_available():
print('No GPU found. Running hd-bet in fast mode, check results! Make sure you have enough memory.')
hdbet_mode = 'fast'
hdbet_device = 'cpu'
Expand Down

0 comments on commit c750348

Please sign in to comment.