Skip to content

Commit

Permalink
switch to embarassingly parallel functions; RAM may be an issue, but …
Browse files Browse the repository at this point in the history
…it should be faster
  • Loading branch information
trislett committed Oct 7, 2019
1 parent 7091fea commit d254f93
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 42 deletions.
27 changes: 24 additions & 3 deletions ants_tbss/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def get_wildcard(searchstring, printarray = False): # super dirty
print (outstring)
return outstring

def antsLinearRegCmd(numthreads, reference, mov, out_basename, outdir = None):
def antsLinearRegCmd(numthreads, reference, mov, out_basename, outdir = None, use_float = False):
"""
Wrapper for ANTs linear registration with some recommended parameters.
Rigid transfomration: gradient step = 0.1
Expand Down Expand Up @@ -90,9 +90,11 @@ def antsLinearRegCmd(numthreads, reference, mov, out_basename, outdir = None):
out_basename,
outdir,
out_basename))
if use_float:
ants_cmd += ' --float'
return ants_cmd

def antsNonLinearRegCmd(numthreads, reference, mov, out_basename, outdir = None):
def antsNonLinearRegCmd(numthreads, reference, mov, out_basename, outdir = None, use_float = False):
"""
Wrapper for ANTs non-linear registration with some recommended parameters. I recommmend first using antsLinearRegCmd.
SyN transformation: [0.1,3,0]
Expand Down Expand Up @@ -140,6 +142,8 @@ def antsNonLinearRegCmd(numthreads, reference, mov, out_basename, outdir = None)
out_basename,
outdir,
out_basename))
if use_float:
ants_cmd += ' --float'
return ants_cmd

def antsApplyTransformCmd(reference, mov, warps, outname, outdir = None):
Expand Down Expand Up @@ -215,8 +219,25 @@ def antsBetCmd(numthreads, input_image, output_image_brain):
output_image_brain))
return ants_cmd

def round_mask_transform(mask_image):
"""
Binarize a mask using numpy round and overwrites it.
Parameters
----------
mask_image : str
/path/to/mask_image
Returns
-------
None
"""
img = nib.load(mask_image)
img_data = img.get_data()
img_data = np.round(img_data)
nib.save(nib.Nifti1Image(img_data,img.affine), mask_image)


# various methods for choosing thresholds automatically
def autothreshold(data, threshold_type = 'yen', z = 2.3264):
"""
Autothresholds data.
Expand Down
147 changes: 108 additions & 39 deletions bin/ants_tbss
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@ import numpy as np
import argparse
import nibabel as nib
import json
from time import time

from ants_tbss.functions import get_wildcard, antsLinearRegCmd, antsNonLinearRegCmd, antsApplyTransformCmd, antsBetCmd
from ants_tbss.functions import get_wildcard, antsLinearRegCmd, antsNonLinearRegCmd, antsApplyTransformCmd, antsBetCmd, round_mask_transform

DESCRIPTION = "TBSS (FSL) implementation with ANTs and T1w registration to template."

Expand All @@ -28,6 +29,12 @@ def getArgumentParser(parser = argparse.ArgumentParser(description = DESCRIPTION
parser.add_argument("-ab","--runantsbet",
action = 'store_true',
help="Run ANTs bet on the T1w images. Use this if your T1 inputs are not already brain extracted. Arguably, ANTs does a must better brain extraction than FSL's bet or freesurfer's watershed.")
parser.add_argument("-nlws","--nonlinearwithinsubject",
action = 'store_true',
help="Run an additional non-linear transformation of the linear transformed B0 image to the native space T1w image. This is useful if EPI distortions are present.")
parser.add_argument("-f","--usefloat",
action = 'store_true',
help="Run ANTs bet on the T1w images. Use this if your T1 inputs are not already brain extracted. Arguably, ANTs does a must better brain extraction than FSL's bet or freesurfer's watershed.")

# settings
parser.add_argument("-t","--threshold",
Expand All @@ -53,11 +60,16 @@ def getArgumentParser(parser = argparse.ArgumentParser(description = DESCRIPTION
return parser

def run(opts):
# get time stamp
currentTime=int(time())

thresh = float(opts.threshold[0])
num_threads = int(opts.numthreads[0])
std_brain = opts.templateimage[0]
current_dir = os.getcwd()
float_flag = False
if opts.usefloat:
float_flag = True

assert "ANTSPATH" in os.environ, "The environment variable ANTSPATH must be declared."
ANTSPATH = os.environ['ANTSPATH']
Expand All @@ -79,15 +91,21 @@ def run(opts):
for T1w in T1_native_list:
t1_name = os.path.basename(T1w)[:-7]
print("Running antsBrainExtraction.sh on:\t%s" % t1_name)
os.system(antsBetCmd(num_threads, T1w, 'T1w_Brain/%s_' % t1_name))
with open("cmd_ants_bet_%d" % currentTime, "a") as cmd_ants_bet:
cmd_ants_bet.write("%s\n" % antsBetCmd(numthreads = 1, input_image = T1w, output_image_brain = 'T1w_Brain/%s_' % t1_name))
Betted_T1.append('T1w_Brain/%s_BrainExtractionBrain.nii.gz' % t1_name)
os.system("cat cmd_ants_bet_%d | parallel -j %d; rm cmd_ants_bet_%d" % (currentTime, num_threads, currentTime))
T1_native_list = np.array(Betted_T1, dtype=str)


print("Running registration of B0 images -> T1w images -> template T1 image")
os.system("mkdir -p reg")

assert len(B0_list) == len(T1_native_list), "The image lists are not of equal length."


## this need to be separated better registration and transformation! Two loops needed.

for i, b0 in enumerate(B0_list):
temp_transformation = {}
b0_name = os.path.basename(b0)[:-7]
Expand All @@ -99,60 +117,45 @@ def run(opts):
t1_mask_name = "reg/%s_mask.nii.gz" % t1_name
t1_mask_name_std = "reg/%s_mask_to_stdT1.nii.gz" % t1_name


ref = T1_native_list[i]
temp_transformation['nativeT1'] = T1_native_list[i]

mov = b0
log = "reg/lin%s_to_natT1.log" % b0_name
out = "reg/%s_to_natT1.nii.gz" % b0_name
out = "reg/lin%s_to_natT1.nii.gz" % b0_name
os.system("fslmaths %s -bin %s" % (mov, mask_name))

# linear reg of B0 to T1w native image
os.system("%s > %s" % (antsLinearRegCmd(int(num_threads), ref, mov, out), log))
temp_transformation['linB0_to_natT1'] = ['%s/%s_0GenericAffine.mat' % (current_dir, out)]

# move the mask
os.system(antsApplyTransformCmd(reference = ref, mov = mask_name, warps = ['%s_0GenericAffine.mat' % out], outname = mask_name_t1, outdir = None))
img = nib.load(mask_name_t1)
img_data = img.get_data()
img_data = np.round(img_data)
nib.save(nib.Nifti1Image(img_data,img.affine),mask_name_t1)
with open("cmd_linB0_to_natT1_%d" % currentTime, "a") as cmd_linB0_to_natT1:
cmd_linB0_to_natT1.write("%s > %s\n" % (antsLinearRegCmd(int(1), ref, mov, out, use_float = float_flag), log))
if opts.nonlinearwithinsubject:
mov = "reg/lin%s_to_natT1.nii.gz" % b0_name
log = "reg/%s_to_natT1.log" % b0_name
nl_out = "reg/%s_to_natT1.nii.gz" % b0_name
with open("cmd_B0_to_natT1_%d" % currentTime, "a") as cmd_B0_to_natT1:
cmd_B0_to_natT1.write("%s > %s\n" % (antsNonLinearRegCmd(int(1), ref, mov, nl_out, use_float = float_flag), log))
temp_transformation['B0_to_natT1'] = ['%s_0Warp.nii.gz' % (nl_out), '%s_0GenericAffine.mat' % (out)]
else:
temp_transformation['B0_to_natT1'] = ['%s/%s_0GenericAffine.mat' % (current_dir, out)]

# linear/non-linear reg T1 to std
# linear reg T1 to std
ref = std_brain
mov = T1_native_list[i]
log = "reg/lin%s_to_stdT1.log" % t1_name
out = "reg/lin%s_to_stdT1.nii.gz" % t1_name
os.system("%s > %s" % (antsLinearRegCmd(int(num_threads), ref, mov, out), log))

# get mask of T1
os.system("fslmaths %s -bin %s" % (mov, t1_mask_name))
with open("cmd_linT1_to_stdT1_%d" % currentTime, "a") as cmd_linT1_to_stdT1:
cmd_linT1_to_stdT1.write("%s > %s\n" % (antsLinearRegCmd(int(1), ref, mov, out, use_float = float_flag), log))

# non-linear
# non-linear reg T1 to std
mov = out
log = "reg/%s_to_stdT1.log" % t1_name
out = "reg/%s_to_stdT1.nii.gz" % t1_name
os.system("%s > %s" % (antsNonLinearRegCmd(int(num_threads), ref, mov, out), log))
with open("cmd_T1_to_stdT1_%d" % currentTime, "a") as cmd_T1_to_stdT1:
cmd_T1_to_stdT1.write("%s > %s\n" % (antsNonLinearRegCmd(int(1), ref, mov, out, use_float = float_flag), log))

# move FA mask
warps = ['%s_0Warp.nii.gz' % (out), 'reg/lin%s_to_stdT1.nii.gz_0GenericAffine.mat' % (t1_name)]
temp_transformation['T1to_stdT1'] = ['%s/%s_0Warp.nii.gz' % (current_dir, out), '%s/reg/lin%s_to_stdT1.nii.gz_0GenericAffine.mat' % (current_dir, t1_name)]



os.system(antsApplyTransformCmd(reference = ref, mov = mask_name_t1, warps = warps, outname = mask_name_std, outdir = None))
img = nib.load(mask_name_std)
img_data = img.get_data()
img_data = np.round(img_data)
nib.save(nib.Nifti1Image(img_data,img.affine),mask_name_std)
# move T1 mask
os.system(antsApplyTransformCmd(reference = ref, mov = t1_mask_name, warps = warps, outname = t1_mask_name_std, outdir = None))
img = nib.load(t1_mask_name_std)
img_data = img.get_data()
img_data = np.round(img_data)
nib.save(nib.Nifti1Image(img_data,img.affine),t1_mask_name_std)

mask_list.append(mask_name_std)
mask_list.append(t1_mask_name_std)

Expand All @@ -165,27 +168,93 @@ def run(opts):
warp_list.append('reg/%s_warps.json' % b0_name)

warp_list = np.array(warp_list, dtype = str)
os.system("cat cmd_linB0_to_natT1_%d | parallel -j %d; rm cmd_linB0_to_natT1_%d" % (currentTime, num_threads, currentTime))
if opts.nonlinearwithinsubject:
os.system("cat cmd_B0_to_natT1_%d | parallel -j %d; rm cmd_B0_to_natT1_%d" % (currentTime, num_threads, currentTime))
os.system("cat cmd_linT1_to_stdT1_%d | parallel -j %d; rm cmd_linT1_to_stdT1_%d" % (currentTime, num_threads, currentTime))
os.system("cat cmd_T1_to_stdT1_%d | parallel -j %d; rm cmd_T1_to_stdT1_%d" % (currentTime, num_threads, currentTime))

# read the information for the warps
for i, b0 in enumerate(B0_list):
B0_to_natT1 = []
T1to_stdT1 = []
std_masks = []
T1_ref = []
for json_warp_file in warp_list:
with open(json_warp_file) as json_file:
transform_files = json.load(json_file)
B0_to_natT1.append(transform_files['B0_to_natT1'])
T1to_stdT1.append(transform_files['T1to_stdT1'])
std_masks.append(transform_files['mask_name_std'])
std_masks.append(transform_files['t1_mask_name_std'])
T1_ref.append(transform_files['nativeT1'])

# build the masks
for i, b0 in enumerate(B0_list):
# set names again ... fix this laziness later
b0_name = os.path.basename(b0)[:-7]
t1_name = os.path.basename(T1_native_list[i])[:-7]
mask_name = "reg/%s_mask.nii.gz" % b0_name
mask_name_t1 = "reg/%s_mask_to_natT1.nii.gz" % b0_name
mask_name_std = "reg/%s_mask_to_stdT1.nii.gz" % b0_name
t1_mask_name = "reg/%s_mask.nii.gz" % t1_name
t1_mask_name_std = "reg/%s_mask_to_stdT1.nii.gz" % t1_name

# move the B0 mask to ref
ref = T1_ref[i]
mov = mask_name
warps = B0_to_natT1[i]
outname = mask_name_t1
os.system(antsApplyTransformCmd(reference = ref,
mov = mov,
warps = warps,
outname = outname))
# round image
round_mask_transform(mask_name_t1)

# ref B0 mask to template
os.system("fslmaths %s -bin %s" % (mov, t1_mask_name))
ref = std_brain
mov = mask_name_t1
warps = T1to_stdT1[i]
outname = mask_name_std
os.system(antsApplyTransformCmd(reference = ref,
mov = mov,
warps = warps,
outname = outname))
# round image
round_mask_transform(mask_name_std)

# move T1 mask
ref = std_brain
mov = t1_mask_name
warps = T1to_stdT1[i]
outname = t1_mask_name_std
os.system(antsApplyTransformCmd(reference = ref,
mov = mov,
warps = warps,
outname = outname))
round_mask_transform(t1_mask_name_std)

if opts.jsontransformlist:
warp_list = np.genfromtxt(opts.jsontransformlist[0], dtype=str)

if opts.runtbss:

# Check runtbss inputs
assert len(opts.runtbss) % 2 == 0, "--runtbss must have an even number of inputs. e.g., -r FA_list FA MD_list MD."

# Check warp list length.
assert len(warp_list) == len(np.genfromtxt(opts.runtbss[0], dtype=str)), "The length the --runtbss inputs must match the number of warps."

linB0_to_natT1 = []
B0_to_natT1 = []
T1to_stdT1 = []
std_masks = []
T1_ref = []
# read the json files containing the warps and mask files
for json_warp_file in warp_list:
with open(json_warp_file) as json_file:
transform_files = json.load(json_file)
linB0_to_natT1.append(transform_files['linB0_to_natT1'])
B0_to_natT1.append(transform_files['B0_to_natT1'])
T1to_stdT1.append(transform_files['T1to_stdT1'])
std_masks.append(transform_files['mask_name_std'])
std_masks.append(transform_files['t1_mask_name_std'])
Expand All @@ -211,7 +280,7 @@ def run(opts):
# metric to native
ref = T1_ref[i]
mov = metric_img
warps = linB0_to_natT1[i]
warps = B0_to_natT1[i]
os.system(antsApplyTransformCmd(reference = ref, mov = mov, warps = warps, outname = metric_nat))

# metric to template
Expand Down

0 comments on commit d254f93

Please sign in to comment.