Skip to content

Commit

Permalink
Merge pull request #222 from EmmaRenauld/changes_biggest_database
Browse files Browse the repository at this point in the history
Update and fixes in dwiml_create_hdf5 and associated doc
  • Loading branch information
EmmaRenauld authored Mar 4, 2024
2 parents de48e0a + c1ec189 commit bf8d2aa
Show file tree
Hide file tree
Showing 30 changed files with 746 additions and 705 deletions.
4 changes: 2 additions & 2 deletions bash_utilities/scil_score_ismrm_Renauld2023.sh
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ then
fi

echo '------------- FINAL SCORING ------------'
scil_score_bundles.py -v $config_file_tractometry $out_dir \
--gt_dir $scoring_data --reference $ref --no_bbox_check
scil_bundle_score_many_bundles_one_tractogram.py $config_file_tractometry $out_dir \
--gt_dir $scoring_data --reference $ref --no_bbox_check -v

cat $out_dir/results.json
84 changes: 58 additions & 26 deletions dwi_ml/data/dataset/streamline_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,27 @@ def _load_all_streamlines_from_hdf(hdf_group: h5py.Group):
return streamlines


def _load_connectivity_info(hdf_group: h5py.Group):
connectivity_nb_blocs = None
connectivity_labels = None
if 'connectivity_matrix' in hdf_group:
contains_connectivity = True
if 'connectivity_nb_blocs' in hdf_group.attrs:
connectivity_nb_blocs = hdf_group.attrs['connectivity_nb_blocs']
elif 'connectivity_label_volume' in hdf_group:
connectivity_labels = np.asarray(
hdf_group['connectivity_label_volume'], dtype=int)
else:
raise ValueError(
"Information stored in the hdf5 is that it contains a "
"connectivity matrix, but we don't know how it was "
"created. Either 'connectivity_nb_blocs' or "
"'connectivity_labels' should be set.")
else:
contains_connectivity = False
return contains_connectivity, connectivity_nb_blocs, connectivity_labels


class _LazyStreamlinesGetter(object):
def __init__(self, hdf_group):
self.hdf_group = hdf_group
Expand Down Expand Up @@ -141,27 +162,38 @@ class SFTDataAbstract(object):
"""
def __init__(self, space_attributes: Tuple, space: Space, origin: Origin,
contains_connectivity: bool,
connectivity_nb_blocs: List):
connectivity_nb_blocs: List = None,
connectivity_labels: np.ndarray = None):
"""
Params
------
group: str
The current streamlines group id, as loaded in the hdf5 file (it
had type "streamlines"). Probabaly 'streamlines'.
The lazy/non-lazy versions will have more parameters, such as the
streamlines, the connectivity_matrix. In the case of the lazy version,
through the LazyStreamlinesGetter.
Parameters
----------
space_attributes: Tuple
The space attributes consist of a tuple:
(affine, dimensions, voxel_sizes, voxel_order)
space: Space
The space from dipy's Space format.
subject_id: str:
The subject's name
origin: Origin
The origin from dipy's Origin format.
contains_connectivity: bool
If true, will search for either the connectivity_nb_blocs or the
connectivity_from_labels information.
connectivity_nb_blocs: List
The information how to recreate the connectivity matrix.
connectivity_labels: np.ndarray
The 3D volume stating how to recreate the labels.
(toDo: Could be managed to be lazy)
"""
self.space_attributes = space_attributes
self.space = space
self.origin = origin
self.is_lazy = None
self.contains_connectivity = contains_connectivity
self.connectivity_nb_blocs = connectivity_nb_blocs
self.connectivity_labels = connectivity_labels

def __len__(self):
raise NotImplementedError
Expand Down Expand Up @@ -195,7 +227,7 @@ def get_connectivity_matrix_and_info(self, ind=None):
(_, ref_volume_shape, _, _) = self.space_attributes

return (self._access_connectivity_matrix(ind), ref_volume_shape,
self.connectivity_nb_blocs)
self.connectivity_nb_blocs, self.connectivity_labels)

def _access_connectivity_matrix(self, ind):
raise NotImplementedError
Expand Down Expand Up @@ -277,15 +309,14 @@ def init_sft_data_from_hdf_info(cls, hdf_group: h5py.Group):
streamlines = _load_all_streamlines_from_hdf(hdf_group)
# Adding non-hidden parameters for nicer later access
lengths_mm = hdf_group['euclidean_lengths']
if 'connectivity_matrix' in hdf_group:
contains_connectivity = True
connectivity_matrix = np.asarray(hdf_group['connectivity_matrix'],
dtype=int)
connectivity_nb_blocs = hdf_group.attrs['connectivity_nb_blocs']

contains_connectivity, connectivity_nb_blocs, connectivity_labels = \
_load_connectivity_info(hdf_group)
if contains_connectivity:
connectivity_matrix = np.asarray(
hdf_group['connectivity_matrix'], dtype=int) # int or bool?
else:
contains_connectivity = False
connectivity_matrix = None
connectivity_nb_blocs = None

space_attributes, space, origin = _load_space_attributes_from_hdf(hdf_group)

Expand All @@ -296,7 +327,8 @@ def init_sft_data_from_hdf_info(cls, hdf_group: h5py.Group):
space_attributes=space_attributes,
space=space, origin=origin,
contains_connectivity=contains_connectivity,
connectivity_nb_blocs=connectivity_nb_blocs)
connectivity_nb_blocs=connectivity_nb_blocs,
connectivity_labels=connectivity_labels)

def _get_streamlines_as_list(self, streamline_ids):
if streamline_ids is not None:
Expand Down Expand Up @@ -336,22 +368,22 @@ def _access_connectivity_matrix(self, indxyz: Tuple = None):

@classmethod
def init_sft_data_from_hdf_info(cls, hdf_group: h5py.Group):
space_attributes, space, origin = _load_space_attributes_from_hdf(hdf_group)
if 'connectivity_matrix' in hdf_group:
contains_connectivity = True
connectivity_nb_blocs = hdf_group.attrs['connectivity_nb_blocs']
else:
contains_connectivity = False
connectivity_nb_blocs = None
space_attributes, space, origin = _load_space_attributes_from_hdf(
hdf_group)

contains_connectivity, connectivity_nb_blocs, connectivity_labels = \
_load_connectivity_info(hdf_group)

streamlines = _LazyStreamlinesGetter(hdf_group)

return cls(streamlines_getter=streamlines,
space_attributes=space_attributes,
space=space, origin=origin,
contains_connectivity=contains_connectivity,
connectivity_nb_blocs=connectivity_nb_blocs)
connectivity_nb_blocs=connectivity_nb_blocs,
connectivity_labels=connectivity_labels)

def _get_streamlines_as_list(self, streamline_ids):
streamlines = self.streamlines_getter.get_array_sequence(streamline_ids)
streamlines = self.streamlines_getter.get_array_sequence(
streamline_ids)
return streamlines
Loading

0 comments on commit bf8d2aa

Please sign in to comment.