Skip to content

Commit

Permalink
Merge pull request #210 from EmmaRenauld/manage_connectivity
Browse files Browse the repository at this point in the history
Manage connectivity
  • Loading branch information
EmmaRenauld committed Oct 27, 2023
2 parents aabb039 + ff5f594 commit 7120d67
Show file tree
Hide file tree
Showing 16 changed files with 703 additions and 221 deletions.
20 changes: 10 additions & 10 deletions dwi_ml/data/dataset/streamline_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def __init__(self,
streamlines: Union[ArraySequence, _LazyStreamlinesGetter],
space_attributes: Tuple, space: Space, origin: Origin,
contains_connectivity: bool,
downsampled_size_for_connectivity: List):
connectivity_nb_blocs: List):
"""
Params
------
Expand All @@ -190,7 +190,7 @@ def __init__(self,
self.streamlines = streamlines
self.is_lazy = None
self.contains_connectivity = contains_connectivity
self.downsampled_size_for_connectivity = downsampled_size_for_connectivity
self.connectivity_nb_blocs = connectivity_nb_blocs

@property
def lengths(self):
Expand All @@ -212,15 +212,15 @@ def connectivity_matrix_and_info(self, ind=None):
"""New method compared to SFTs: access pre-computed connectivity
matrix. Returns the subject's connectivity matrix associated with
current tractogram, together with information required to recompute
a similar matrix: reference volume's shape and downsampled shape."""
a similar matrix: reference volume's shape and number of blocs."""
if not self.contains_connectivity:
raise ValueError("No pre-computed connectivity matrix found for "
"this subject.")

(_, ref_volume_shape, _, _) = self.space_attributes

return (self._access_connectivity_matrix(ind), ref_volume_shape,
self.downsampled_size_for_connectivity)
self.connectivity_nb_blocs)

def _access_connectivity_matrix(self, ind):
raise NotImplementedError
Expand Down Expand Up @@ -289,11 +289,11 @@ def init_sft_data_from_hdf_info(cls, hdf_group: h5py.Group):
contains_connectivity = True
connectivity_matrix = np.asarray(hdf_group['connectivity_matrix'],
dtype=int)
downsampled_size = hdf_group.attrs['downsampled_size']
connectivity_nb_blocs = hdf_group.attrs['connectivity_nb_blocs']
else:
contains_connectivity = False
connectivity_matrix = None
downsampled_size = None
connectivity_nb_blocs = None

space_attributes, space, origin = _load_space_from_hdf(hdf_group)

Expand All @@ -303,7 +303,7 @@ def init_sft_data_from_hdf_info(cls, hdf_group: h5py.Group):
streamlines=streamlines, space_attributes=space_attributes,
space=space, origin=origin,
contains_connectivity=contains_connectivity,
downsampled_size_for_connectivity=downsampled_size)
connectivity_nb_blocs=connectivity_nb_blocs)

def _subset_streamlines(self, streamline_ids):
if streamline_ids is not None:
Expand Down Expand Up @@ -334,17 +334,17 @@ def init_sft_data_from_hdf_info(cls, hdf_group: h5py.Group):
space_attributes, space, origin = _load_space_from_hdf(hdf_group)
if 'connectivity_matrix' in hdf_group:
contains_connectivity = True
downsampled_size = hdf_group.attrs['downsampled_size']
connectivity_nb_blocs = hdf_group.attrs['connectivity_nb_blocs']
else:
contains_connectivity = False
downsampled_size = None
connectivity_nb_blocs = None

streamlines = _LazyStreamlinesGetter(hdf_group)

return cls(streamlines=streamlines, space_attributes=space_attributes,
space=space, origin=origin,
contains_connectivity=contains_connectivity,
downsampled_size_for_connectivity=downsampled_size)
connectivity_nb_blocs=connectivity_nb_blocs)

def _subset_streamlines(self, streamline_ids):
streamlines = self.streamlines.get_array_sequence(streamline_ids)
Expand Down
96 changes: 54 additions & 42 deletions dwi_ml/data/hdf5/hdf5_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
import logging
import os
from pathlib import Path
from typing import List, Union
from typing import List

from dipy.io.stateful_tractogram import set_sft_logger_level, Space
from dipy.io.streamline import load_tractogram, save_tractogram
from dipy.io.utils import is_header_compatible
from dipy.tracking.utils import length
import h5py

from dwi_ml.data.hdf5.utils import format_nb_blocs_connectivity
from dwi_ml.data.processing.streamlines.data_augmentation import \
resample_or_compress
from nested_lookup import nested_lookup
Expand All @@ -21,8 +22,6 @@

from dwi_ml.data.io import load_file_to4d
from dwi_ml.data.processing.dwi.dwi import standardize_data
from dwi_ml.data.processing.streamlines.post_processing import \
compute_triu_connectivity


def _load_and_verify_file(filename: str, subj_input_path, group_name: str,
Expand Down Expand Up @@ -110,8 +109,6 @@ def __init__(self, root_folder: Path, out_hdf_filename: Path,
testing_subjs: List[str], groups_config: dict,
std_mask: str, step_size: float = None,
compress: float = None,
compute_connectivity_matrix: bool = False,
downsampled_size_for_connectivity: Union[int, list] = 20,
enforce_files_presence: bool = True,
save_intermediate: bool = False,
intermediate_folder: Path = None):
Expand All @@ -135,11 +132,6 @@ def __init__(self, root_folder: Path, out_hdf_filename: Path,
Step size to resample streamlines. Default: None.
compress: float
Compress streamlines. Default: None.
compute_connectivity_matrix: bool
Compute connectivity matrix for each streamline group.
Default: False.
downsampled_size_for_connectivity: int or List
See compute_connectivity_matrix's doc.
enforce_files_presence: bool
If true, will stop if some files are not available for a subject.
Default: True.
Expand All @@ -158,20 +150,6 @@ def __init__(self, root_folder: Path, out_hdf_filename: Path,
self.groups_config = groups_config
self.step_size = step_size
self.compress = compress
self.compute_connectivity = compute_connectivity_matrix
if self.compute_connectivity:
if isinstance(downsampled_size_for_connectivity, List):
assert len(downsampled_size_for_connectivity) == 3, \
"Expecting to work with 3D volumes. Expecting " \
"connectivity downsample size to be a list of 3 values, " \
"but got {}.".format(downsampled_size_for_connectivity)
self.connectivity_downsample_size = downsampled_size_for_connectivity
else:
assert isinstance(downsampled_size_for_connectivity, int), \
"Expecting the connectivity matrix size to be either " \
"a 3D list or an integer, but got {}" \
.format(downsampled_size_for_connectivity)
self.connectivity_downsample_size = [downsampled_size_for_connectivity] * 3

# Optional
self.std_mask = std_mask # (could be None)
Expand Down Expand Up @@ -313,7 +291,10 @@ def _check_files_presence(self):
logging.debug("Verifying files presence")

# concatenating files from all groups files:
# sum: concatenates list of sub-lists
config_file_list = sum(nested_lookup('files', self.groups_config), [])
config_file_list += nested_lookup(
'connectivity_matrix', self.groups_config)

for subj_id in self.all_subjs:
subj_input_dir = Path(self.root_folder).joinpath(subj_id)
Expand Down Expand Up @@ -564,8 +545,9 @@ def _create_streamline_groups(self, ref, subj_input_dir, subj_id,
"in the config_file. If all files are .trk, we can use "
"ref 'same' but if some files were .tck, we need a ref!"
"Hint: Create a volume group 'ref' in the config file.")
sft, lengths = self._process_one_streamline_group(
subj_input_dir, group, subj_id, ref)
sft, lengths, connectivity_matrix, conn_info = (
self._process_one_streamline_group(
subj_input_dir, group, subj_id, ref))

streamlines_group = subj_hdf_group.create_group(group)
streamlines_group.attrs['type'] = 'streamlines'
Expand All @@ -581,9 +563,17 @@ def _create_streamline_groups(self, ref, subj_input_dir, subj_id,
streamlines_group.attrs['dimensions'] = d
streamlines_group.attrs['voxel_sizes'] = vs
streamlines_group.attrs['voxel_order'] = vo
if self.compute_connectivity:
streamlines_group.attrs['downsampled_size'] = \
self.connectivity_downsample_size
if connectivity_matrix is not None:
streamlines_group.attrs[
'connectivity_matrix_type'] = conn_info[0]
streamlines_group.create_dataset(
'connectivity_matrix', data=connectivity_matrix)
if conn_info[0] == 'from_label':
streamlines_group.attrs['connectivity_labels_volume'] = \
conn_info[1]
else:
streamlines_group.attrs['connectivity_nb_blocs'] = \
conn_info[1]

if len(sft.data_per_point) > 0:
logging.debug('sft contained data_per_point. Data not kept.')
Expand All @@ -601,18 +591,6 @@ def _create_streamline_groups(self, ref, subj_input_dir, subj_id,
data=sft.streamlines._lengths)
streamlines_group.create_dataset('euclidean_lengths', data=lengths)

if self.compute_connectivity:
# Can be reduced using sparse tensors notation... always
# minimum 50% of zeros! Then we could save separately
# the indices, values, size of the tensor. But unclear how
# much sparse they need to be to actually save memory.
# Skipping for now.
streamlines_group.create_dataset(
'connectivity_matrix',
data=compute_triu_connectivity(
sft.streamlines, d, self.connectivity_downsample_size,
binary=True, to_sparse_tensor=False))

def _process_one_streamline_group(
self, subj_dir: Path, group: str, subj_id: str,
header: nib.Nifti1Header):
Expand Down Expand Up @@ -710,7 +688,41 @@ def _process_one_streamline_group(
logging.debug(" *Remaining: {:,.0f} streamlines."
"".format(len(final_sft)))

return final_sft, output_lengths
conn_matrix = None
conn_info = None
if 'connectivity_matrix' in self.groups_config[group]:
logging.info(" Now preparing connectivity matrix")
if not ("connectivty_nb_blocs" in self.groups_config[group] or
"connectivty_labels" in self.groups_config[group]):
raise ValueError(
"The config file must provide either the "
"connectivty_nb_blocs or the connectivty_labels information "
"associated with the streamline group '{}'"
.format(group))
elif ("connectivty_nb_blocs" in self.groups_config[group] and
"connectivty_labels" in self.groups_config[group]):
raise ValueError(
"The config file must only provide ONE of the "
"connectivty_nb_blocs or the connectivty_labels information "
"associated with the streamline group '{}'"
.format(group))
elif "connectivty_nb_blocs" in self.groups_config[group]:
nb_blocs = format_nb_blocs_connectivity(
self.groups_config[group]['connectivty_nb_blocs'])
conn_info = ['from_blocs', nb_blocs]
else:
labels = self.groups_config[group]['connectivty_labels']
if labels not in self.volume_groups:
raise ValueError("connectivity_labels_volume must be "
"an existing volume group.")
conn_info = ['from_labels', labels]

conn_file = subj_dir.joinpath(
self.groups_config[group]['connectivity_matrix'])
conn_matrix = np.load(conn_file)
conn_matrix = conn_matrix > 0

return final_sft, output_lengths, conn_matrix, conn_info

def _load_and_process_sft(self, tractogram_file, tractogram_name, header):
if not tractogram_file.is_file():
Expand Down
104 changes: 25 additions & 79 deletions dwi_ml/data/hdf5/utils.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,34 @@
# -*- coding: utf-8 -*-
import datetime
import shutil
from argparse import ArgumentParser
import json
import logging
import os
from pathlib import Path
from typing import List

from dwi_ml.data.hdf5.hdf5_creation import HDF5Creator
from dwi_ml.io_utils import add_resample_or_compress_arg


def format_nb_blocs_connectivity(connectivity_nb_blocs) -> List:
if connectivity_nb_blocs is None:
# Default/const value with argparser '+' not possible.
# Setting it manually.
connectivity_nb_blocs = 20
elif (isinstance(connectivity_nb_blocs, list) and
len(connectivity_nb_blocs) == 1):
connectivity_nb_blocs = connectivity_nb_blocs[0]

if isinstance(connectivity_nb_blocs, List):
assert len(connectivity_nb_blocs) == 3, \
"Expecting to work with 3D volumes. Expecting " \
"connectivity_nb_blocs to be a list of 3 values, " \
"but got {}.".format(connectivity_nb_blocs)
else:
assert isinstance(connectivity_nb_blocs, int), \
"Expecting the connectivity_nb_blocs to be either " \
"a 3D list or an integer, but got {}" \
.format(connectivity_nb_blocs)
connectivity_nb_blocs = [connectivity_nb_blocs] * 3

return connectivity_nb_blocs


def add_hdf5_creation_args(p: ArgumentParser):

# Positional arguments
Expand Down Expand Up @@ -73,75 +91,3 @@ def add_mri_processing_args(p: ArgumentParser):
def add_streamline_processing_args(p: ArgumentParser):
g = p.add_argument_group('Streamlines processing options:')
add_resample_or_compress_arg(g)
g.add_argument(
'--compute_connectivity_matrix', action='store_true',
help="If set, computes the 3D connectivity matrix for each streamline "
"group. \nDefined from downsampled image, not from anatomy! \n"
"Ex: can be used at validation time with our trainer's "
"'generation-validation' step.")
g.add_argument(
'--connectivity_downsample_size', metavar='m', type=int, nargs='+',
help="Number of 3D blocks (m x m x m) for the connectivity matrix. \n"
"(The matrix will be m^3 x m^3). If more than one values are "
"provided, expected to be one per dimension. \n"
"Default: 20x20x20.")


def _initialize_intermediate_subdir(hdf5_file, save_intermediate):
# Create hdf5 dir or clean existing one
hdf5_folder = os.path.dirname(hdf5_file)

# Preparing intermediate folder.
if save_intermediate:
now = datetime.datetime.now().strftime("%Y_%m_%d_%H%M%S")
intermediate_subdir = Path(hdf5_folder, "intermediate_" + now)
logging.debug(" Creating intermediate files directory")
intermediate_subdir.mkdir()

return intermediate_subdir
return None


def prepare_hdf5_creator(args):
"""
Reads the config file and subjects lists files and instantiate a class of
the HDF5Creator.
"""
# Read subjects lists
with open(args.training_subjs, 'r') as file:
training_subjs = file.read().split()
logging.debug(' Training subjs: {}'.format(training_subjs))
with open(args.validation_subjs, 'r') as file:
validation_subjs = file.read().split()
logging.debug(' Validation subjs: {}'.format(validation_subjs))
with open(args.testing_subjs, 'r') as file:
testing_subjs = file.read().split()
logging.debug(' Testing subjs: {}'.format(testing_subjs))

# Read group information from the json file (config file)
with open(args.config_file, 'r') as json_file:
groups_config = json.load(json_file)

# Delete existing hdf5, if -f
if args.overwrite and os.path.exists(args.out_hdf5_file):
os.remove(args.out_hdf5_file)

# Initialize intermediate subdir
intermediate_subdir = _initialize_intermediate_subdir(
args.out_hdf5_file, args.save_intermediate)

# Copy config file locally
config_copy_name = os.path.splitext(args.out_hdf5_file)[0] + '.json'
logging.info("Copying json config file to {}".format(config_copy_name))
shutil.copyfile(args.config_file, config_copy_name)

# Instantiate a creator and perform checks
creator = HDF5Creator(Path(args.dwi_ml_ready_folder), args.out_hdf5_file,
training_subjs, validation_subjs, testing_subjs,
groups_config, args.std_mask, args.step_size,
args.compress, args.compute_connectivity_matrix,
args.connectivity_downsample_size,
args.enforce_files_presence,
args.save_intermediate, intermediate_subdir)

return creator
Loading

0 comments on commit 7120d67

Please sign in to comment.