Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Add mrds scripts part 3 #1046

Merged
merged 14 commits into from
Nov 13, 2024
130 changes: 130 additions & 0 deletions scilpy/tractanalysis/mrds_along_streamlines.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# -*- coding: utf-8 -*-

import numpy as np

from scilpy.tractanalysis.grid_intersections import grid_intersections


def mrds_metrics_along_streamlines(sft, mrds_pdds,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that this method is not mrds-related. It's just dividing data.

It could be, in the main script:

 mrds_sum, weights = \
        mrds_metric_sums_along_streamlines(sft, mrds_pdds,
                                           metrics, max_theta,
                                           length_weighting)
weighte_mrds = weight_values(mrds_sum, weights)

Where weight values is function that divides by the weight where non-zero, somewhere in volume tools or something.

metrics, max_theta,
length_weighting):
"""
Compute mean map for a given fixel-specific metric along streamlines.

Parameters
----------
sft : StatefulTractogram
StatefulTractogram containing the streamlines needed.
mrds_pdds : ndarray (X, Y, Z, 3*N_TENSORS)
MRDS principal diffusion directions of the tensors
metrics : list of ndarray
Array of shape (X, Y, Z, N_TENSORS) containing the fixel-specific
metric of interest.
max_theta : float
Maximum angle in degrees between the fiber direction and the
MRDS principal diffusion direction.
length_weighting : bool
If True, will weigh the metric values according to segment lengths.
"""

mrds_sum, weights = \
mrds_metric_sums_along_streamlines(sft, mrds_pdds,
metrics, max_theta,
length_weighting)

all_metric = mrds_sum[0]
for curr_metric in mrds_sum[1:]:
all_metric += np.abs(curr_metric)

non_zeros = np.nonzero(all_metric)
weights_nz = weights[non_zeros]
for metric_idx in range(len(metrics)):
mrds_sum[metric_idx][non_zeros] /= weights_nz

return mrds_sum


def mrds_metric_sums_along_streamlines(sft, mrds_pdds, metrics,
max_theta, length_weighting):
"""
Compute a sum map along a bundle for a given fixel-specific metric.

Parameters
----------
sft : StatefulTractogram
StatefulTractogram containing the streamlines needed.
mrds_pdds : ndarray (X, Y, Z, 3*N_TENSORS)
MRDS principal diffusion directions (PDDs) of the tensors
metrics : list of ndarray (X, Y, Z, N_TENSORS)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think metrics should be metric

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I aggree it's "a" list but reading it I feel like metrics make sense here.

Fixel-specific metrics.
max_theta : float
Maximum angle in degrees between the fiber direction and the
MRDS principal diffusion direction.
length_weighting : bool
If True, will weight the metric values according to segment lengths.

Returns
-------
metric_sum_map : np.array
fixel-specific metrics sum map.
weight_map : np.array
Segment lengths.
"""

sft.to_vox()
sft.to_corner()

X, Y, Z = metrics[0].shape[0:3]
metrics_sum_map = np.zeros((len(metrics), X, Y, Z))
weight_map = np.zeros(metrics[0].shape[:-1])
min_cos_theta = np.cos(np.radians(max_theta))

all_crossed_indices = grid_intersections(sft.streamlines)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am pretty sure that our "uncompress" method should be used somewhere here. @frheault probably knows this better than me.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe, this version is copy paste (kinda) of afd_along_streamlinespy

for crossed_indices in all_crossed_indices:
segments = crossed_indices[1:] - crossed_indices[:-1]
seg_lengths = np.linalg.norm(segments, axis=1)

# Remove points where the segment is zero.
# This removes numpy warnings of division by zero.
non_zero_lengths = np.nonzero(seg_lengths)[0]
segments = segments[non_zero_lengths]
seg_lengths = seg_lengths[non_zero_lengths]

# Those starting points are used for the segment vox_idx computations
seg_start = crossed_indices[non_zero_lengths]
vox_indices = (seg_start + (0.5 * segments)).astype(int)

normalization_weights = np.ones_like(seg_lengths)
if length_weighting:
normalization_weights = seg_lengths

normalized_seg = np.reshape(segments / seg_lengths[..., None], (-1, 3))

# Reshape MRDS PDDs
mrds_pdds = mrds_pdds.reshape(mrds_pdds.shape[0],
mrds_pdds.shape[1],
mrds_pdds.shape[2], -1, 3)

for vox_idx, seg_dir, norm_weight in zip(vox_indices,
normalized_seg,
normalization_weights):
vox_idx = tuple(vox_idx)

mrds_peak_dir = mrds_pdds[vox_idx]

cos_theta = np.abs(np.dot(seg_dir.reshape((-1, 3)),
mrds_peak_dir.T))

metric_val = [0.0]*len(metrics)
if (cos_theta > min_cos_theta).any():
fixel_idx = np.argmax(np.squeeze(cos_theta),
axis=0) # (n_segs)

for metric_idx, curr_metric in enumerate(metrics):
metric_val[metric_idx] = curr_metric[vox_idx][fixel_idx]

for metric_idx, curr_metric in enumerate(metrics):
metrics_sum_map[metric_idx][vox_idx] += metric_val[metric_idx] * norm_weight
weight_map[vox_idx] += norm_weight

return metrics_sum_map, weight_map
127 changes: 127 additions & 0 deletions scripts/scil_bundle_mean_fixel_mrds_metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Given a bundle and MRDS metrics, compute the fixel-specific
metrics at each voxel intersected by the bundle. Intersected voxels are
found by computing the intersection between the voxel grid and each streamline
in the input tractogram.

This script behaves like scil_bundle_mean_fixel_afd.py for fODFs,
but here for MRDS metrics. These latest distributions add the unique
possibility to capture fixel-based fractional anisotropy (fixel-FA), mean
diffusivity (fixel-MD), radial diffusivity (fixel-RD) and
axial diffusivity (fixel-AD).

Fixel-specific metrics are metrics extracted from
Multi-Resolution Discrete-Search (MRDS) solutions.
There are as many values per voxel as there are fixels extracted. The
values chosen for a given voxel is the one belonging to the lobe better aligned
with the current streamline segment.

Input files come from scil_mrds_metrics.py command.

Output metrics will be named: [prefix]_mrds_[metric_name].nii.gz

arnaudbore marked this conversation as resolved.
Show resolved Hide resolved
Please use a bundle file rather than a whole tractogram.
"""

import argparse

import nibabel as nib
import numpy as np

from scilpy.io.streamlines import load_tractogram_with_reference
from scilpy.io.utils import (add_overwrite_arg,
add_reference_arg,
assert_headers_compatible,
assert_inputs_exist, assert_outputs_exist)
from scilpy.tractanalysis.mrds_along_streamlines \
import mrds_metrics_along_streamlines


def _build_arg_parser():
p = argparse.ArgumentParser(description=__doc__,
formatter_class=argparse.RawTextHelpFormatter)
p.add_argument('in_bundle',
help='Path of the bundle file.')
p.add_argument('in_pdds',
help='Path of the MRDS PDDs volume.')

g = p.add_argument_group(title='MRDS metrics input')
g.add_argument('--fa',
arnaudbore marked this conversation as resolved.
Show resolved Hide resolved
help='Path of the fixel-specific metric FA volume.')
g.add_argument('--md',
help='Path of the fixel-specific metric MD volume.')
g.add_argument('--rd',
help='Path of the fixel-specific metric RD volume.')
g.add_argument('--ad',
help='Path of the fixel-specific metric AD volume.')

p.add_argument('--prefix', default='result',
help='Prefix of the MRDS fixel results.')

p.add_argument('--length_weighting', action='store_true',
help='If set, will weight the values according to '
'segment lengths. [%(default)s]')

p.add_argument('--max_theta', default=60, type=float,
help='Maximum angle (in degrees) condition on fixel '
'alignment. [%(default)s]')

add_reference_arg(p)
add_overwrite_arg(p)
return p


def main():
parser = _build_arg_parser()
args = parser.parse_args()

in_metrics = []
out_metrics = []
if args.fa is not None:
in_metrics.append(args.fa)
out_metrics.append('{}_mrds_fFA.nii.gz'.format(args.prefix))
if args.ad is not None:
in_metrics.append(args.ad)
out_metrics.append('{}_mrds_fAD.nii.gz'.format(args.prefix))
if args.rd is not None:
in_metrics.append(args.rd)
out_metrics.append('{}_mrds_fRD.nii.gz'.format(args.prefix))
if args.md is not None:
in_metrics.append(args.md)
out_metrics.append('{}_mrds_fMD.nii.gz'.format(args.prefix))

if in_metrics == []:
parser.error('At least one metric is required.')

assert_inputs_exist(parser, [args.in_bundle,
args.in_pdds], in_metrics)
assert_headers_compatible(parser, [args.in_bundle, args.in_pdds],
in_metrics)

assert_outputs_exist(parser, args, out_metrics)

sft = load_tractogram_with_reference(parser, args, args.in_bundle)
pdds_img = nib.load(args.in_pdds)
affine = pdds_img.affine
header = pdds_img.header

in_metrics_data = [nib.load(metric).get_fdata(dtype=np.float32) for metric in in_metrics]
fixel_metrics =\
mrds_metrics_along_streamlines(sft,
pdds_img.get_fdata(dtype=np.float32),
in_metrics_data,
args.max_theta,
args.length_weighting)

for metric_id, curr_metric in enumerate(fixel_metrics):
nib.Nifti1Image(curr_metric.astype(np.float32),
affine=affine,
header=header,
dtype=np.float32).to_filename(out_metrics[metric_id])


if __name__ == '__main__':
main()
8 changes: 8 additions & 0 deletions scripts/tests/test_bundle_mean_fixel_mrds_metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

def test_help_option(script_runner):
ret = script_runner.run(
'scil_bundle_mean_fixel_mrds_metric.py', '--help')

assert ret.success
Empty file modified scripts/tests/test_gradients_apply_transform.py
100755 → 100644
Empty file.
Empty file modified scripts/tests/test_sh_to_sf.py
100755 → 100644
Empty file.