Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More supervision of loss #180

Merged
merged 13 commits into from
Jun 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 37 additions & 20 deletions dwi_ml/data/dataset/checks_for_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import logging
from typing import Tuple

import numpy as np


def _find_groups_info_for_subj(hdf_file, subj_id: str):
"""
Expand All @@ -25,10 +27,14 @@ def _find_groups_info_for_subj(hdf_file, subj_id: str):
volume's last dimension).
streamline_groups: List[str]
The list of streamline groups for this subject.
contains_connectivity: np.ndarray
A list of boolean for each streamline_group stating if it contains the
pre-computed connectivity matrices for that subject.
"""
volume_groups = []
nb_features = []
streamline_groups = []
contains_connectivity = []

hdf_groups = hdf_file[subj_id]
for hdf_group in hdf_groups:
Expand All @@ -39,53 +45,64 @@ def _find_groups_info_for_subj(hdf_file, subj_id: str):
hdf_file[subj_id][hdf_group].attrs['nb_features'])
elif group_type == 'streamlines':
streamline_groups.append(hdf_group)
found_matrix = 'connectivity_matrix' in hdf_file[subj_id][hdf_group]
contains_connectivity.append(found_matrix)
else:
raise NotImplementedError(
"So far, you can only add 'volume' or 'streamline' groups in "
"the groups_config.json. Please see the doc for a json file "
"example. Your hdf5 contained group of type {} for subj {}"
.format(group_type, subj_id))

return volume_groups, nb_features, streamline_groups
contains_connectivity = np.asarray(contains_connectivity, dtype=bool)
return volume_groups, nb_features, streamline_groups, contains_connectivity


def _compare_groups_info(volume_groups, nb_features, streamline_groups,
group_info: Tuple):
def _compare_groups_info(subject_group_info, ref_group_info: Tuple):
"""
Compares the three lists (volume_groups, nb_features, streamline_groups) of
one subject to the expected list for this database, included in group_info.
Compares the two tuple (volume_groups, nb_features, streamline_groups,
contains_connectivity) between one subject to the expected list for this
database, included in group_info.
"""
v, f, s = group_info
if volume_groups != v:
sv, sf, ss, sc = subject_group_info
rv, rf, rs, rc = ref_group_info
if not set(rv).issubset(set(sv)):
logging.warning("Subject's hdf5 groups with attributes 'type' set as "
"'volume' are not the same as expected with this "
"dataset! Expected: {}. Found: {}"
.format(v, volume_groups))
if nb_features != f:
.format(rv, sv))

if not set(rf).issubset(set(sf)): # not a good verification but ok for now.
logging.warning("Among subject's hdf5 groups with attributes 'type' "
"set as 'volume', some data to not have the same "
"number of features as expected for this dataset! "
"Expected: {}. Found: {}".format(f, nb_features))
if streamline_groups != s:
"Expected: {}. Found: {}".format(rf, sf))

if not set(rs).issubset(set(ss)):
logging.warning("Subject's hdf5 groups with attributes 'type' set as "
"'streamlines' are not the same as expected with this "
"dataset! Expected: {}. Found: {}"
.format(s, streamline_groups))
.format(rs, ss))


def prepare_groups_info(subject_id: str, hdf_file, group_info=None):
def prepare_groups_info(subject_id: str, hdf_file, ref_group_info=None):
"""
Read the hdf5 file for this subject and get the groups information
(volume and streamlines groups names, number of features for volumes).

If group_info is given, compare subject's information with database
expected information.
expected information. If subject has more information than the reference,
(ex, non-useful volume groups), they will be ignored.

Returns
-------
subject_group_info = (volume_groups, nb_features,
streamline_groups, contains_connectivity)
"""
volume_groups, nb_features, streamline_groups = \
_find_groups_info_for_subj(hdf_file, subject_id)
subject_group_info = _find_groups_info_for_subj(hdf_file, subject_id)

if group_info is not None:
_compare_groups_info(volume_groups, nb_features, streamline_groups,
group_info)
if ref_group_info is not None:
_compare_groups_info(subject_group_info, ref_group_info)
return ref_group_info

return volume_groups, nb_features, streamline_groups
return subject_group_info
6 changes: 3 additions & 3 deletions dwi_ml/data/dataset/mri_data_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(self, data: Union[torch.Tensor, h5py.Group],
self._data = data

@classmethod
def init_from_hdf_info(cls, hdf_group: h5py.Group):
def init_mri_data_from_hdf_info(cls, hdf_group: h5py.Group):
"""
Allows initiating an instance of this class by sending only the
hdf handle. This method will define how to load the data from it
Expand Down Expand Up @@ -74,7 +74,7 @@ def __init__(self, data: torch.Tensor, voxres: np.ndarray,
super().__init__(data, voxres, affine)

@classmethod
def init_from_hdf_info(cls, hdf_group: h5py.Group):
def init_mri_data_from_hdf_info(cls, hdf_group: h5py.Group):
"""
Creating class instance from the hdf in cases where data is not
loaded yet. Non-lazy = loading the data here.
Expand Down Expand Up @@ -106,7 +106,7 @@ def __init__(self, data: Union[h5py.Group, None], voxres: np.ndarray,
super().__init__(data, voxres, affine)

@classmethod
def init_from_hdf_info(cls, hdf_group: h5py.Group):
def init_mri_data_from_hdf_info(cls, hdf_group: h5py.Group):
"""
Creating class instance from the hdf in cases where data is not
loaded yet. Not loading the data, but loading the voxres.
Expand Down
71 changes: 40 additions & 31 deletions dwi_ml/data/dataset/multi_subject_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(self, set_name: str, hdf5_file: str, lazy: bool,
self.volume_groups = [] # type: List[str]
self.nb_features = [] # type: List[int]
self.streamline_groups = [] # type: List[str]
self.contains_connectivity = [] # type: np.ndarray

# The subjects data list will be either a SubjectsDataList or a
# LazySubjectsDataList depending on MultisubjectDataset.is_lazy.
Expand Down Expand Up @@ -90,10 +91,11 @@ def close_all_handles(self):
s.hdf_handle = None

def set_subset_info(self, volume_groups, nb_features, streamline_groups,
step_size, compress):
contains_connectivity, step_size, compress):
self.volume_groups = volume_groups
self.streamline_groups = streamline_groups
self.nb_features = nb_features
self.streamline_groups = streamline_groups
self.contains_connectivity = contains_connectivity
self.step_size = step_size
self.compress = compress

Expand Down Expand Up @@ -224,7 +226,6 @@ def load(self, hdf_handle: h5py.File, subj_id=None):
Load all subjects for this subjset (either training, validation or
testing).
"""

# Checking if there are any subjects to load
subject_keys = sorted(hdf_handle.attrs[self.set_name + '_subjs'])
if subj_id is not None:
Expand Down Expand Up @@ -256,6 +257,9 @@ def load(self, hdf_handle: h5py.File, subj_id=None):
lengths = [[] for _ in self.streamline_groups]
lengths_mm = [[] for _ in self.streamline_groups]

ref_group_info = (self.volume_groups, self.nb_features,
self.streamline_groups, self.contains_connectivity)

# Using tqdm progress bar, load all subjects from hdf_file
with logging_redirect_tqdm(loggers=[logging.root], tqdm_class=tqdm):
for subj_id in tqdm(subject_keys, ncols=100, total=self.nb_subjects):
Expand All @@ -264,8 +268,7 @@ def load(self, hdf_handle: h5py.File, subj_id=None):
# calling this method.
logger.debug(" Creating subject '{}'.".format(subj_id))
subj_data = self._init_subj_from_hdf(
hdf_handle, subj_id, self.volume_groups, self.nb_features,
self.streamline_groups)
hdf_handle, subj_id, ref_group_info)

# Add subject to the list
logger.debug(" Adding it to the list of subjects.")
Expand Down Expand Up @@ -323,16 +326,13 @@ def _build_empty_data_list(self):
else:
return SubjectsDataList(self.hdf5_file, logger)

def _init_subj_from_hdf(self, hdf_handle, subject_id, volume_groups,
nb_features, streamline_groups):
def _init_subj_from_hdf(self, hdf_handle, subject_id, ref_group_info):
if self.is_lazy:
return LazySubjectData.init_from_hdf(
subject_id, hdf_handle,
(volume_groups, nb_features, streamline_groups))
return LazySubjectData.init_single_subject_from_hdf(
subject_id, hdf_handle, ref_group_info)
else:
return SubjectData.init_from_hdf(
subject_id, hdf_handle,
(volume_groups, nb_features, streamline_groups))
return SubjectData.init_single_subject_from_hdf(
subject_id, hdf_handle, ref_group_info)


class MultiSubjectDataset:
Expand All @@ -348,7 +348,7 @@ class MultiSubjectDataset:
'streamlines/lengths', 'streamlines/euclidean_lengths'.
"""
def __init__(self, hdf5_file: str, lazy: bool,
cache_size: int = 0, log_level=logging.root.level):
cache_size: int = 0, log_level=None):
"""
Params
------
Expand All @@ -367,11 +367,13 @@ def __init__(self, hdf5_file: str, lazy: bool,
# Dataset info
self.hdf5_file = hdf5_file

logger.setLevel(log_level)
if log_level is not None:
logger.setLevel(log_level)

self.volume_groups = [] # type: List[str]
self.nb_features = [] # type: List[int]
self.streamline_groups = [] # type: List[str]
self.streamlines_contain_connectivity = []

self.is_lazy = lazy
self.subset_cache_size = cache_size
Expand Down Expand Up @@ -445,15 +447,17 @@ def load_data(self, load_training=True, load_validation=True,
# Loading the first training subject's group information.
# Others should fit.
one_subj = hdf_handle.attrs['training_subjs'][0]
group_info = \
prepare_groups_info(one_subj, hdf_handle, group_info=None)
(poss_volume_groups, nb_features, poss_strea_groups) = group_info
logger.info(" Possible volume groups are: {}"
.format(poss_volume_groups))
logger.info(" Number of features in each of these groups: "
"{}".format(nb_features))
logger.info(" Possible streamline groups are: {}"
.format(poss_strea_groups))
(poss_volume_groups, nb_features, poss_strea_groups,
contains_connectivity) = prepare_groups_info(
one_subj, hdf_handle, ref_group_info=None)
logger.debug("Possible volume groups are: {}"
.format(poss_volume_groups))
logger.debug("Number of features in each of these groups: {}"
.format(nb_features))
logger.debug("Possible streamline groups are: {}"
.format(poss_strea_groups))
logger.debug("Streamline groups containing a connectivity matrix: "
"{}".format(contains_connectivity))

# Verifying groups of interest
if volume_groups is not None:
Expand All @@ -464,12 +468,12 @@ def load_data(self, load_training=True, load_validation=True,
.format(missing_vol))
vol, indv, indposs = np.intersect1d(
volume_groups, poss_volume_groups, return_indices=True)
self.volume_groups = vol
self.volume_groups = list(vol)
self.nb_features = [nb_features[i] for i in indposs]
logger.info("Chosen volume groups are: {}"
logger.info("--> Chosen volume groups are: {}"
.format(self.volume_groups))
else:
logger.info("Using all volume groups.")
logger.info("--> Using all volume groups.")
self.volume_groups = poss_volume_groups
self.nb_features = nb_features

Expand All @@ -479,14 +483,19 @@ def load_data(self, load_training=True, load_validation=True,
raise ValueError("Streamlines {} were not found in the "
"first subject of your hdf5 file."
.format(missing_str))
self.streamline_groups = np.intersect1d(streamline_groups,
poss_strea_groups)
logger.info("Chosen streamline groups are: {}"
self.streamline_groups, _, ind = np.intersect1d(
streamline_groups, poss_strea_groups, return_indices=True)
logger.info("--> Chosen streamline groups are: {}"
.format(self.streamline_groups))
self.streamlines_contain_connectivity = contains_connectivity[ind]
else:
logger.info("Using all streamline groups.")
logger.info("--> Using all streamline groups.")
self.streamline_groups = poss_strea_groups
self.streamlines_contain_connectivity = contains_connectivity

group_info = (self.volume_groups, self.nb_features,
self.streamline_groups,
self.streamlines_contain_connectivity)
self.training_set.set_subset_info(*group_info, step_size, compress)
self.validation_set.set_subset_info(*group_info, step_size, compress)
self.testing_set.set_subset_info(*group_info, step_size, compress)
Expand Down
33 changes: 21 additions & 12 deletions dwi_ml/data/dataset/single_subject_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ def sft_data_list(self):
raise NotImplementedError

@classmethod
def init_from_hdf(cls, subject_id: str, hdf_file, group_info=None):
def init_single_subject_from_hdf(
cls, subject_id: str, hdf_file, group_info=None):
"""Returns an instance of this class, initiated by sending only the
hdf handle. The child class's method will define how to load the
data based on the child data management."""
Expand Down Expand Up @@ -88,12 +89,13 @@ def sft_data_list(self):
return self._sft_data_list

@classmethod
def init_from_hdf(cls, subject_id: str, hdf_file, group_info=None):
def init_single_subject_from_hdf(
cls, subject_id: str, hdf_file, group_info=None):
"""
Instantiating a single subject data: load info and use __init__
"""
volume_groups, nb_features, streamline_groups = prepare_groups_info(
subject_id, hdf_file, group_info)
(volume_groups, nb_features, streamline_groups, _) = \
prepare_groups_info(subject_id, hdf_file, group_info)

subject_mri_data_list = []
subject_sft_data_list = []
Expand All @@ -102,13 +104,14 @@ def init_from_hdf(cls, subject_id: str, hdf_file, group_info=None):
logger.debug(' Loading volume group "{}": '.format(group))
# Creating a SubjectMRIData or a LazySubjectMRIData based on
# lazy or non-lazy version.
subject_mri_group_data = MRIData.init_from_hdf_info(
subject_mri_group_data = MRIData.init_mri_data_from_hdf_info(
hdf_file[subject_id][group])
subject_mri_data_list.append(subject_mri_group_data)

for group in streamline_groups:
logger.debug(" Loading subject's streamlines")
sft_data = SFTData.init_from_hdf_info(hdf_file[subject_id][group])
sft_data = SFTData.init_sft_data_from_hdf_info(
hdf_file[subject_id][group])
subject_sft_data_list.append(sft_data)

subj_data = cls(subject_id,
Expand Down Expand Up @@ -140,7 +143,8 @@ def __init__(self, volume_groups: List[str], nb_features: List[int],
self.is_lazy = True

@classmethod
def init_from_hdf(cls, subject_id: str, hdf_file, group_info=None):
def init_single_subject_from_hdf(
cls, subject_id: str, hdf_file, group_info=None):
"""
Instantiating a single subject data: NOT LOADING info and use __init__
(so in short: this does basically nothing, the lazy data is kept
Expand All @@ -156,8 +160,8 @@ def init_from_hdf(cls, subject_id: str, hdf_file, group_info=None):
Tuple containing (volume_groups, nb_features, streamline_groups)
for this subject.
"""
volume_groups, nb_features, streamline_groups = prepare_groups_info(
subject_id, hdf_file, group_info)
volume_groups, nb_features, streamline_groups, _ = \
prepare_groups_info(subject_id, hdf_file, group_info)

logger.debug(' Lazy: not loading data.')

Expand All @@ -168,15 +172,15 @@ def init_from_hdf(cls, subject_id: str, hdf_file, group_info=None):
def mri_data_list(self) -> Union[List[LazyMRIData], None]:
"""As a property, this is only computed if called by the user.
Returns a List[LazyMRIData]"""

if self.hdf_handle is not None:
if not self.hdf_handle.id.valid:
logger.warning("Tried to access subject's volumes but its "
"hdf handle is not valid (closed file?)")
mri_data_list = []
for group in self.volume_groups:
hdf_group = self.hdf_handle[self.subject_id][group]
mri_data_list.append(LazyMRIData.init_from_hdf_info(hdf_group))
mri_data_list.append(
LazyMRIData.init_mri_data_from_hdf_info(hdf_group))

return mri_data_list
else:
Expand All @@ -187,11 +191,16 @@ def mri_data_list(self) -> Union[List[LazyMRIData], None]:
def sft_data_list(self) -> Union[List[LazySFTData], None]:
"""As a property, this is only computed if called by the user.
Returns a List[LazyMRIData]"""
# toDo. Reloads the basic information (ex: origin, corner, etc)
# everytime we acces a subject. They are lazy subjects! Why can't
# we keep this list of lazysftdata in memory?

if self.hdf_handle is not None:
sft_data_list = []
for group in self.streamline_groups:
hdf_group = self.hdf_handle[self.subject_id][group]
sft_data_list.append(LazySFTData.init_from_hdf_info(hdf_group))
sft_data_list.append(
LazySFTData.init_sft_data_from_hdf_info(hdf_group))

return sft_data_list
else:
Expand Down
Loading