Skip to content

Commit

Permalink
Fixes in connectivity_from_labels
Browse files Browse the repository at this point in the history
  • Loading branch information
EmmaRenauld committed Feb 1, 2024
1 parent c24e984 commit 3d10144
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 30 deletions.
62 changes: 41 additions & 21 deletions dwi_ml/data/processing/streamlines/post_processing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# -*- coding: utf-8 -*-
import logging
from typing import List

import numpy as np
Expand Down Expand Up @@ -320,43 +321,62 @@ def compute_triu_connectivity_from_labels(streamlines, data_labels,
compressed streamlines.'
Else, uses simple computation from endpoints. Faster. Also, works with
incomplete parcellation.
Returns
-------
matrix: np.ndarray
With use_scilpy: shape (nb_labels + 1, nb_labels + 1)
(last label is "Not Found")
Else, shape (nb_labels, nb_labels)
labels: List
The list of labels
"""
real_labels = np.unique(data_labels)[1:]
real_labels = list(np.sort(np.unique(data_labels)))
nb_labels = len(real_labels)
matrix = np.zeros((nb_labels + 1, nb_labels + 1), dtype=int)
logging.debug("Computing connectivity matrix for {} labels."
.format(nb_labels))

start_blocs = []
end_blocs = []
if use_scilpy:
matrix = np.zeros((nb_labels + 1, nb_labels + 1), dtype=int)
else:
matrix = np.zeros((nb_labels, nb_labels), dtype=int)

start_labels = []
end_labels = []

if use_scilpy:
indices, points_to_idx = uncompress(streamlines, return_mapping=True)

for strl_vox_indices in indices:
segments_info = segmenting_func(strl_vox_indices, data_labels)
if len(segments_info) > 0:
start = segments_info[0]['start_label']
end = segments_info[0]['end_label']
start_blocs.append(start)
end_blocs.append(end)
start = real_labels.index(segments_info[0]['start_label'])
end = real_labels.index(segments_info[0]['end_label'])
else:
start = nb_labels
end = nb_labels

matrix[start, end] += 1
if start != end:
matrix[end, start] += 1
start_labels.append(start)
end_labels.append(end)

matrix[start, end] += 1
if start != end:
matrix[end, start] += 1

real_labels = real_labels + [np.NaN]

else:
# Putting it in 0,0, we will remember that this means 'other'
matrix[0, 0] += 1
start_blocs.append(0)
end_blocs.append(0)
else:
for s in streamlines:
# Vox space, corner origin
# = we can get the nearest neighbor easily.
# Coord 0 = voxel 0. Coord 0.9 = voxel 0. Coord 1 = voxel 1.
start = data_labels[tuple(np.floor(s[0, :]).astype(int))]
end = data_labels[tuple(np.floor(s[-1, :]).astype(int))]
start_blocs.append(start)
end_blocs.append(end)
start = real_labels.index(
data_labels[tuple(np.floor(s[0, :]).astype(int))])
end = real_labels.index(
data_labels[tuple(np.floor(s[-1, :]).astype(int))])

start_labels.append(start)
end_labels.append(end)
matrix[start, end] += 1
if start != end:
matrix[end, start] += 1
Expand All @@ -367,7 +387,7 @@ def compute_triu_connectivity_from_labels(streamlines, data_labels,
if binary:
matrix = matrix.astype(bool)

return matrix, start_blocs, end_blocs
return matrix, real_labels, start_labels, end_labels


def compute_triu_connectivity_from_blocs(streamlines, volume_size, nb_blocs,
Expand Down
47 changes: 38 additions & 9 deletions scripts_python/dwiml_compute_connectivity_matrix_from_labels.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
Computes the connectivity matrix.
Labels associated with each line / row will be printed.
"""

import argparse
import logging
import os.path
Expand Down Expand Up @@ -65,10 +71,16 @@ def main():
args = p.parse_args()

if args.verbose:
# Currenlty, with debug, matplotlib prints a lot of stuff. Why??
logging.getLogger().setLevel(logging.INFO)

tmp, ext = os.path.splitext(args.out_file)

if ext != '.npy':
p.error("--out_file should have a .npy extension.")

out_fig = tmp + '.png'
out_ordered_labels = tmp + '_labels.txt'
assert_inputs_exist(p, [args.in_labels, args.streamlines])
assert_outputs_exist(p, args, [args.out_file, out_fig],
[args.save_biggest, args.save_smallest])
Expand All @@ -80,26 +92,36 @@ def main():
p.error("Streamlines not compatible with chosen volume.")
else:
args.reference = args.in_labels

logging.info("Loading tractogram.")
in_sft = load_tractogram_with_reference(p, args, args.streamlines)
in_img = nib.load(args.in_labels)
data_labels = get_data_as_labels(in_img)

in_sft.to_vox()
in_sft.to_corner()
matrix, start_blocs, end_blocs = compute_triu_connectivity_from_labels(
in_sft.streamlines, data_labels,
use_scilpy=args.use_longest_segment)
matrix, ordered_labels, start_blocs, end_blocs = \
compute_triu_connectivity_from_labels(
in_sft.streamlines, data_labels,
use_scilpy=args.use_longest_segment)

if args.hide_background is not None:
matrix[args.hide_background, :] = 0
matrix[:, args.hide_background] = 0
idx = ordered_labels.idx(args.hide_background)
matrix[idx, :] = 0
matrix[:, idx] = 0
ordered_labels[idx] = ("Hidden background ({})"
.format(args.hide_background))

logging.info("Labels are, in order: {}".format(ordered_labels))

# Options to try to investigate the connectivity matrix:
# masking point (0,0) = streamline ending in wm.
if args.save_biggest is not None:
i, j = np.unravel_index(np.argmax(matrix, axis=None), matrix.shape)
print("Saving biggest bundle: {} streamlines. From label {} to label "
"{}".format(matrix[i, j], i, j))
"{} (line {}, column {} in the matrix)"
.format(matrix[i, j], ordered_labels[i], ordered_labels[j],
i, j))
biggest = find_streamlines_with_chosen_connectivity(
in_sft.streamlines, i, j, start_blocs, end_blocs)
sft = in_sft.from_sft(biggest, in_sft)
Expand All @@ -109,15 +131,22 @@ def main():
tmp_matrix = np.ma.masked_equal(matrix, 0)
i, j = np.unravel_index(tmp_matrix.argmin(axis=None), matrix.shape)
print("Saving smallest bundle: {} streamlines. From label {} to label "
"{}".format(matrix[i, j], i, j))
biggest = find_streamlines_with_chosen_connectivity(
"{} (line {}, column {} in the matrix)"
.format(matrix[i, j], ordered_labels[i], ordered_labels[j],
i, j))
smallest = find_streamlines_with_chosen_connectivity(
in_sft.streamlines, i, j, start_blocs, end_blocs)
sft = in_sft.from_sft(biggest, in_sft)
sft = in_sft.from_sft(smallest, in_sft)
save_tractogram(sft, args.save_smallest)

ordered_labels = str(ordered_labels)
with open(out_ordered_labels, "w") as text_file:
text_file.write(ordered_labels)

if args.show_now:
plt.imshow(matrix)
plt.colorbar()
plt.title("Raw streamline count")

plt.figure()
plt.imshow(matrix > 0)
Expand Down

0 comments on commit 3d10144

Please sign in to comment.