diff --git a/dwi_ml/data/dataset/checks_for_groups.py b/dwi_ml/data/dataset/checks_for_groups.py index 20adc1f8..d63fa430 100644 --- a/dwi_ml/data/dataset/checks_for_groups.py +++ b/dwi_ml/data/dataset/checks_for_groups.py @@ -2,6 +2,8 @@ import logging from typing import Tuple +import numpy as np + def _find_groups_info_for_subj(hdf_file, subj_id: str): """ @@ -25,10 +27,14 @@ def _find_groups_info_for_subj(hdf_file, subj_id: str): volume's last dimension). streamline_groups: List[str] The list of streamline groups for this subject. + contains_connectivity: np.ndarray + A list of boolean for each streamline_group stating if it contains the + pre-computed connectivity matrices for that subject. """ volume_groups = [] nb_features = [] streamline_groups = [] + contains_connectivity = [] hdf_groups = hdf_file[subj_id] for hdf_group in hdf_groups: @@ -39,6 +45,8 @@ def _find_groups_info_for_subj(hdf_file, subj_id: str): hdf_file[subj_id][hdf_group].attrs['nb_features']) elif group_type == 'streamlines': streamline_groups.append(hdf_group) + found_matrix = 'connectivity_matrix' in hdf_file[subj_id][hdf_group] + contains_connectivity.append(found_matrix) else: raise NotImplementedError( "So far, you can only add 'volume' or 'streamline' groups in " @@ -46,46 +54,55 @@ def _find_groups_info_for_subj(hdf_file, subj_id: str): "example. Your hdf5 contained group of type {} for subj {}" .format(group_type, subj_id)) - return volume_groups, nb_features, streamline_groups + contains_connectivity = np.asarray(contains_connectivity, dtype=bool) + return volume_groups, nb_features, streamline_groups, contains_connectivity -def _compare_groups_info(volume_groups, nb_features, streamline_groups, - group_info: Tuple): +def _compare_groups_info(subject_group_info, ref_group_info: Tuple): """ - Compares the three lists (volume_groups, nb_features, streamline_groups) of - one subject to the expected list for this database, included in group_info. + Compares the two tuple (volume_groups, nb_features, streamline_groups, + contains_connectivity) between one subject to the expected list for this + database, included in group_info. """ - v, f, s = group_info - if volume_groups != v: + sv, sf, ss, sc = subject_group_info + rv, rf, rs, rc = ref_group_info + if not set(rv).issubset(set(sv)): logging.warning("Subject's hdf5 groups with attributes 'type' set as " "'volume' are not the same as expected with this " "dataset! Expected: {}. Found: {}" - .format(v, volume_groups)) - if nb_features != f: + .format(rv, sv)) + + if not set(rf).issubset(set(sf)): # not a good verification but ok for now. logging.warning("Among subject's hdf5 groups with attributes 'type' " "set as 'volume', some data to not have the same " "number of features as expected for this dataset! " - "Expected: {}. Found: {}".format(f, nb_features)) - if streamline_groups != s: + "Expected: {}. Found: {}".format(rf, sf)) + + if not set(rs).issubset(set(ss)): logging.warning("Subject's hdf5 groups with attributes 'type' set as " "'streamlines' are not the same as expected with this " "dataset! Expected: {}. Found: {}" - .format(s, streamline_groups)) + .format(rs, ss)) -def prepare_groups_info(subject_id: str, hdf_file, group_info=None): +def prepare_groups_info(subject_id: str, hdf_file, ref_group_info=None): """ Read the hdf5 file for this subject and get the groups information (volume and streamlines groups names, number of features for volumes). If group_info is given, compare subject's information with database - expected information. + expected information. If subject has more information than the reference, + (ex, non-useful volume groups), they will be ignored. + + Returns + ------- + subject_group_info = (volume_groups, nb_features, + streamline_groups, contains_connectivity) """ - volume_groups, nb_features, streamline_groups = \ - _find_groups_info_for_subj(hdf_file, subject_id) + subject_group_info = _find_groups_info_for_subj(hdf_file, subject_id) - if group_info is not None: - _compare_groups_info(volume_groups, nb_features, streamline_groups, - group_info) + if ref_group_info is not None: + _compare_groups_info(subject_group_info, ref_group_info) + return ref_group_info - return volume_groups, nb_features, streamline_groups + return subject_group_info diff --git a/dwi_ml/data/dataset/mri_data_containers.py b/dwi_ml/data/dataset/mri_data_containers.py index 378bd7e8..43b292a3 100644 --- a/dwi_ml/data/dataset/mri_data_containers.py +++ b/dwi_ml/data/dataset/mri_data_containers.py @@ -44,7 +44,7 @@ def __init__(self, data: Union[torch.Tensor, h5py.Group], self._data = data @classmethod - def init_from_hdf_info(cls, hdf_group: h5py.Group): + def init_mri_data_from_hdf_info(cls, hdf_group: h5py.Group): """ Allows initiating an instance of this class by sending only the hdf handle. This method will define how to load the data from it @@ -74,7 +74,7 @@ def __init__(self, data: torch.Tensor, voxres: np.ndarray, super().__init__(data, voxres, affine) @classmethod - def init_from_hdf_info(cls, hdf_group: h5py.Group): + def init_mri_data_from_hdf_info(cls, hdf_group: h5py.Group): """ Creating class instance from the hdf in cases where data is not loaded yet. Non-lazy = loading the data here. @@ -106,7 +106,7 @@ def __init__(self, data: Union[h5py.Group, None], voxres: np.ndarray, super().__init__(data, voxres, affine) @classmethod - def init_from_hdf_info(cls, hdf_group: h5py.Group): + def init_mri_data_from_hdf_info(cls, hdf_group: h5py.Group): """ Creating class instance from the hdf in cases where data is not loaded yet. Not loading the data, but loading the voxres. diff --git a/dwi_ml/data/dataset/multi_subject_containers.py b/dwi_ml/data/dataset/multi_subject_containers.py index 860bb7a4..ccd236d3 100644 --- a/dwi_ml/data/dataset/multi_subject_containers.py +++ b/dwi_ml/data/dataset/multi_subject_containers.py @@ -46,6 +46,7 @@ def __init__(self, set_name: str, hdf5_file: str, lazy: bool, self.volume_groups = [] # type: List[str] self.nb_features = [] # type: List[int] self.streamline_groups = [] # type: List[str] + self.contains_connectivity = [] # type: np.ndarray # The subjects data list will be either a SubjectsDataList or a # LazySubjectsDataList depending on MultisubjectDataset.is_lazy. @@ -90,10 +91,11 @@ def close_all_handles(self): s.hdf_handle = None def set_subset_info(self, volume_groups, nb_features, streamline_groups, - step_size, compress): + contains_connectivity, step_size, compress): self.volume_groups = volume_groups - self.streamline_groups = streamline_groups self.nb_features = nb_features + self.streamline_groups = streamline_groups + self.contains_connectivity = contains_connectivity self.step_size = step_size self.compress = compress @@ -224,7 +226,6 @@ def load(self, hdf_handle: h5py.File, subj_id=None): Load all subjects for this subjset (either training, validation or testing). """ - # Checking if there are any subjects to load subject_keys = sorted(hdf_handle.attrs[self.set_name + '_subjs']) if subj_id is not None: @@ -256,6 +257,9 @@ def load(self, hdf_handle: h5py.File, subj_id=None): lengths = [[] for _ in self.streamline_groups] lengths_mm = [[] for _ in self.streamline_groups] + ref_group_info = (self.volume_groups, self.nb_features, + self.streamline_groups, self.contains_connectivity) + # Using tqdm progress bar, load all subjects from hdf_file with logging_redirect_tqdm(loggers=[logging.root], tqdm_class=tqdm): for subj_id in tqdm(subject_keys, ncols=100, total=self.nb_subjects): @@ -264,8 +268,7 @@ def load(self, hdf_handle: h5py.File, subj_id=None): # calling this method. logger.debug(" Creating subject '{}'.".format(subj_id)) subj_data = self._init_subj_from_hdf( - hdf_handle, subj_id, self.volume_groups, self.nb_features, - self.streamline_groups) + hdf_handle, subj_id, ref_group_info) # Add subject to the list logger.debug(" Adding it to the list of subjects.") @@ -323,16 +326,13 @@ def _build_empty_data_list(self): else: return SubjectsDataList(self.hdf5_file, logger) - def _init_subj_from_hdf(self, hdf_handle, subject_id, volume_groups, - nb_features, streamline_groups): + def _init_subj_from_hdf(self, hdf_handle, subject_id, ref_group_info): if self.is_lazy: - return LazySubjectData.init_from_hdf( - subject_id, hdf_handle, - (volume_groups, nb_features, streamline_groups)) + return LazySubjectData.init_single_subject_from_hdf( + subject_id, hdf_handle, ref_group_info) else: - return SubjectData.init_from_hdf( - subject_id, hdf_handle, - (volume_groups, nb_features, streamline_groups)) + return SubjectData.init_single_subject_from_hdf( + subject_id, hdf_handle, ref_group_info) class MultiSubjectDataset: @@ -348,7 +348,7 @@ class MultiSubjectDataset: 'streamlines/lengths', 'streamlines/euclidean_lengths'. """ def __init__(self, hdf5_file: str, lazy: bool, - cache_size: int = 0, log_level=logging.root.level): + cache_size: int = 0, log_level=None): """ Params ------ @@ -367,11 +367,13 @@ def __init__(self, hdf5_file: str, lazy: bool, # Dataset info self.hdf5_file = hdf5_file - logger.setLevel(log_level) + if log_level is not None: + logger.setLevel(log_level) self.volume_groups = [] # type: List[str] self.nb_features = [] # type: List[int] self.streamline_groups = [] # type: List[str] + self.streamlines_contain_connectivity = [] self.is_lazy = lazy self.subset_cache_size = cache_size @@ -445,15 +447,17 @@ def load_data(self, load_training=True, load_validation=True, # Loading the first training subject's group information. # Others should fit. one_subj = hdf_handle.attrs['training_subjs'][0] - group_info = \ - prepare_groups_info(one_subj, hdf_handle, group_info=None) - (poss_volume_groups, nb_features, poss_strea_groups) = group_info - logger.info(" Possible volume groups are: {}" - .format(poss_volume_groups)) - logger.info(" Number of features in each of these groups: " - "{}".format(nb_features)) - logger.info(" Possible streamline groups are: {}" - .format(poss_strea_groups)) + (poss_volume_groups, nb_features, poss_strea_groups, + contains_connectivity) = prepare_groups_info( + one_subj, hdf_handle, ref_group_info=None) + logger.debug("Possible volume groups are: {}" + .format(poss_volume_groups)) + logger.debug("Number of features in each of these groups: {}" + .format(nb_features)) + logger.debug("Possible streamline groups are: {}" + .format(poss_strea_groups)) + logger.debug("Streamline groups containing a connectivity matrix: " + "{}".format(contains_connectivity)) # Verifying groups of interest if volume_groups is not None: @@ -464,12 +468,12 @@ def load_data(self, load_training=True, load_validation=True, .format(missing_vol)) vol, indv, indposs = np.intersect1d( volume_groups, poss_volume_groups, return_indices=True) - self.volume_groups = vol + self.volume_groups = list(vol) self.nb_features = [nb_features[i] for i in indposs] - logger.info("Chosen volume groups are: {}" + logger.info("--> Chosen volume groups are: {}" .format(self.volume_groups)) else: - logger.info("Using all volume groups.") + logger.info("--> Using all volume groups.") self.volume_groups = poss_volume_groups self.nb_features = nb_features @@ -479,14 +483,19 @@ def load_data(self, load_training=True, load_validation=True, raise ValueError("Streamlines {} were not found in the " "first subject of your hdf5 file." .format(missing_str)) - self.streamline_groups = np.intersect1d(streamline_groups, - poss_strea_groups) - logger.info("Chosen streamline groups are: {}" + self.streamline_groups, _, ind = np.intersect1d( + streamline_groups, poss_strea_groups, return_indices=True) + logger.info("--> Chosen streamline groups are: {}" .format(self.streamline_groups)) + self.streamlines_contain_connectivity = contains_connectivity[ind] else: - logger.info("Using all streamline groups.") + logger.info("--> Using all streamline groups.") self.streamline_groups = poss_strea_groups + self.streamlines_contain_connectivity = contains_connectivity + group_info = (self.volume_groups, self.nb_features, + self.streamline_groups, + self.streamlines_contain_connectivity) self.training_set.set_subset_info(*group_info, step_size, compress) self.validation_set.set_subset_info(*group_info, step_size, compress) self.testing_set.set_subset_info(*group_info, step_size, compress) diff --git a/dwi_ml/data/dataset/single_subject_containers.py b/dwi_ml/data/dataset/single_subject_containers.py index e0439b6d..b4d17096 100644 --- a/dwi_ml/data/dataset/single_subject_containers.py +++ b/dwi_ml/data/dataset/single_subject_containers.py @@ -48,7 +48,8 @@ def sft_data_list(self): raise NotImplementedError @classmethod - def init_from_hdf(cls, subject_id: str, hdf_file, group_info=None): + def init_single_subject_from_hdf( + cls, subject_id: str, hdf_file, group_info=None): """Returns an instance of this class, initiated by sending only the hdf handle. The child class's method will define how to load the data based on the child data management.""" @@ -88,12 +89,13 @@ def sft_data_list(self): return self._sft_data_list @classmethod - def init_from_hdf(cls, subject_id: str, hdf_file, group_info=None): + def init_single_subject_from_hdf( + cls, subject_id: str, hdf_file, group_info=None): """ Instantiating a single subject data: load info and use __init__ """ - volume_groups, nb_features, streamline_groups = prepare_groups_info( - subject_id, hdf_file, group_info) + (volume_groups, nb_features, streamline_groups, _) = \ + prepare_groups_info(subject_id, hdf_file, group_info) subject_mri_data_list = [] subject_sft_data_list = [] @@ -102,13 +104,14 @@ def init_from_hdf(cls, subject_id: str, hdf_file, group_info=None): logger.debug(' Loading volume group "{}": '.format(group)) # Creating a SubjectMRIData or a LazySubjectMRIData based on # lazy or non-lazy version. - subject_mri_group_data = MRIData.init_from_hdf_info( + subject_mri_group_data = MRIData.init_mri_data_from_hdf_info( hdf_file[subject_id][group]) subject_mri_data_list.append(subject_mri_group_data) for group in streamline_groups: logger.debug(" Loading subject's streamlines") - sft_data = SFTData.init_from_hdf_info(hdf_file[subject_id][group]) + sft_data = SFTData.init_sft_data_from_hdf_info( + hdf_file[subject_id][group]) subject_sft_data_list.append(sft_data) subj_data = cls(subject_id, @@ -140,7 +143,8 @@ def __init__(self, volume_groups: List[str], nb_features: List[int], self.is_lazy = True @classmethod - def init_from_hdf(cls, subject_id: str, hdf_file, group_info=None): + def init_single_subject_from_hdf( + cls, subject_id: str, hdf_file, group_info=None): """ Instantiating a single subject data: NOT LOADING info and use __init__ (so in short: this does basically nothing, the lazy data is kept @@ -156,8 +160,8 @@ def init_from_hdf(cls, subject_id: str, hdf_file, group_info=None): Tuple containing (volume_groups, nb_features, streamline_groups) for this subject. """ - volume_groups, nb_features, streamline_groups = prepare_groups_info( - subject_id, hdf_file, group_info) + volume_groups, nb_features, streamline_groups, _ = \ + prepare_groups_info(subject_id, hdf_file, group_info) logger.debug(' Lazy: not loading data.') @@ -168,7 +172,6 @@ def init_from_hdf(cls, subject_id: str, hdf_file, group_info=None): def mri_data_list(self) -> Union[List[LazyMRIData], None]: """As a property, this is only computed if called by the user. Returns a List[LazyMRIData]""" - if self.hdf_handle is not None: if not self.hdf_handle.id.valid: logger.warning("Tried to access subject's volumes but its " @@ -176,7 +179,8 @@ def mri_data_list(self) -> Union[List[LazyMRIData], None]: mri_data_list = [] for group in self.volume_groups: hdf_group = self.hdf_handle[self.subject_id][group] - mri_data_list.append(LazyMRIData.init_from_hdf_info(hdf_group)) + mri_data_list.append( + LazyMRIData.init_mri_data_from_hdf_info(hdf_group)) return mri_data_list else: @@ -187,11 +191,16 @@ def mri_data_list(self) -> Union[List[LazyMRIData], None]: def sft_data_list(self) -> Union[List[LazySFTData], None]: """As a property, this is only computed if called by the user. Returns a List[LazyMRIData]""" + # toDo. Reloads the basic information (ex: origin, corner, etc) + # everytime we acces a subject. They are lazy subjects! Why can't + # we keep this list of lazysftdata in memory? + if self.hdf_handle is not None: sft_data_list = [] for group in self.streamline_groups: hdf_group = self.hdf_handle[self.subject_id][group] - sft_data_list.append(LazySFTData.init_from_hdf_info(hdf_group)) + sft_data_list.append( + LazySFTData.init_sft_data_from_hdf_info(hdf_group)) return sft_data_list else: diff --git a/dwi_ml/data/dataset/streamline_containers.py b/dwi_ml/data/dataset/streamline_containers.py index 262755d2..9f66bfe7 100644 --- a/dwi_ml/data/dataset/streamline_containers.py +++ b/dwi_ml/data/dataset/streamline_containers.py @@ -1,8 +1,6 @@ # -*- coding: utf-8 -*- """ We expect the classes here to be used in single_subject_containers - - """ import logging from typing import Tuple, Union, List @@ -135,6 +133,14 @@ def lengths_mm(self): return lengths + def connectivity_matrix(self, indxyz: Tuple = None): + if indxyz: + indx, indy, indz = indxyz + return np.asarray( + self.hdf_group['connectivity_matrix'][indx, indy, indy], + dtype=int) + return np.asarray(self.hdf_group['connectivity_matrix'], dtype=int) + def __len__(self): return len(self.hdf_group['offsets']) @@ -157,7 +163,9 @@ class SFTDataAbstract(object): """ def __init__(self, streamlines: Union[ArraySequence, _LazyStreamlinesGetter], - space_attributes: Tuple, space: Space, origin: Origin): + space_attributes: Tuple, space: Space, origin: Origin, + contains_connectivity: bool, + downsampled_size_for_connectivity: List): """ Params ------ @@ -181,6 +189,8 @@ def __init__(self, self.origin = origin self.streamlines = streamlines self.is_lazy = None + self.contains_connectivity = contains_connectivity + self.downsampled_size_for_connectivity = downsampled_size_for_connectivity @property def lengths(self): @@ -198,8 +208,25 @@ def lengths_mm(self): streamlines.""" raise NotImplementedError + def connectivity_matrix_and_info(self, ind=None): + """New method compared to SFTs: access pre-computed connectivity + matrix. Returns the subject's connectivity matrix associated with + current tractogram, together with information required to recompute + a similar matrix: reference volume's shape and downsampled shape.""" + if not self.contains_connectivity: + raise ValueError("No pre-computed connectivity matrix found for " + "this subject.") + + (_, ref_volume_shape, _, _) = self.space_attributes + + return (self._access_connectivity_matrix(ind), ref_volume_shape, + self.downsampled_size_for_connectivity) + + def _access_connectivity_matrix(self, ind): + raise NotImplementedError + @classmethod - def init_from_hdf_info(cls, hdf_group: h5py.Group): + def init_sft_data_from_hdf_info(cls, hdf_group: h5py.Group): """Create an instance of this class by sending directly the hdf5 file. The child class's method will define how to load the data according to the class's data management.""" @@ -230,18 +257,27 @@ def as_sft(self, streamline_ids: List = None): class SFTData(SFTDataAbstract): - def __init__(self, streamlines: ArraySequence, space_attributes: Tuple, - space: Space, origin: Origin, lengths_mm: List): - super().__init__(streamlines, space_attributes, space, origin) + streamlines: ArraySequence + + def __init__(self, lengths_mm: List, connectivity_matrix: np.ndarray, + **kwargs): + super().__init__(**kwargs) self._lengths_mm = lengths_mm + self._connectivity_matrix = connectivity_matrix self.is_lazy = False @property def lengths_mm(self): return np.array(self._lengths_mm) + def _access_connectivity_matrix(self, indxyz: Tuple = None): + if indxyz: + indx, indy, indz = indxyz + return self._connectivity_matrix[indx, indy, indy] + return self._connectivity_matrix + @classmethod - def init_from_hdf_info(cls, hdf_group: h5py.Group): + def init_sft_data_from_hdf_info(cls, hdf_group: h5py.Group): """ Creating class instance from the hdf in cases where data is not loaded yet. Non-lazy = loading the data here. @@ -249,12 +285,25 @@ def init_from_hdf_info(cls, hdf_group: h5py.Group): streamlines = _load_streamlines_from_hdf(hdf_group) # Adding non-hidden parameters for nicer later access lengths_mm = hdf_group['euclidean_lengths'] + if 'connectivity_matrix' in hdf_group: + contains_connectivity = True + connectivity_matrix = np.asarray(hdf_group['connectivity_matrix'], + dtype=int) + downsampled_size = hdf_group.attrs['downsampled_size'] + else: + contains_connectivity = False + connectivity_matrix = None + downsampled_size = None space_attributes, space, origin = _load_space_from_hdf(hdf_group) # Return an instance of SubjectMRIData instantiated through __init__ # with this loaded data: - return cls(streamlines, space_attributes, space, origin, lengths_mm) + return cls(lengths_mm, connectivity_matrix, + streamlines=streamlines, space_attributes=space_attributes, + space=space, origin=origin, + contains_connectivity=contains_connectivity, + downsampled_size_for_connectivity=downsampled_size) def _subset_streamlines(self, streamline_ids): if streamline_ids is not None: @@ -265,9 +314,10 @@ def _subset_streamlines(self, streamline_ids): class LazySFTData(SFTDataAbstract): - def __init__(self, streamlines: _LazyStreamlinesGetter, - space_attributes: Tuple, space: Space, origin: Origin): - super().__init__(streamlines, space_attributes, space, origin) + streamlines: _LazyStreamlinesGetter + + def __init__(self, **kwargs): + super().__init__(**kwargs) self.is_lazy = True @property @@ -275,13 +325,26 @@ def lengths_mm(self): # Fetching from the lazy streamline getter return np.array(self.streamlines.lengths_mm) + def _access_connectivity_matrix(self, indxyz: Tuple = None): + # Fetching in a lazy way + return self.streamlines.connectivity_matrix(indxyz) + @classmethod - def init_from_hdf_info(cls, hdf_group: h5py.Group): + def init_sft_data_from_hdf_info(cls, hdf_group: h5py.Group): space_attributes, space, origin = _load_space_from_hdf(hdf_group) + if 'connectivity_matrix' in hdf_group: + contains_connectivity = True + downsampled_size = hdf_group.attrs['downsampled_size'] + else: + contains_connectivity = False + downsampled_size = None streamlines = _LazyStreamlinesGetter(hdf_group) - return cls(streamlines, space_attributes, space, origin) + return cls(streamlines=streamlines, space_attributes=space_attributes, + space=space, origin=origin, + contains_connectivity=contains_connectivity, + downsampled_size_for_connectivity=downsampled_size) def _subset_streamlines(self, streamline_ids): streamlines = self.streamlines.get_array_sequence(streamline_ids) diff --git a/dwi_ml/data/hdf5/hdf5_creation.py b/dwi_ml/data/hdf5/hdf5_creation.py index f255ef6a..728bada9 100644 --- a/dwi_ml/data/hdf5/hdf5_creation.py +++ b/dwi_ml/data/hdf5/hdf5_creation.py @@ -3,13 +3,14 @@ import logging import os from pathlib import Path -from typing import List +from typing import List, Union from dipy.io.stateful_tractogram import set_sft_logger_level, Space from dipy.io.streamline import load_tractogram, save_tractogram from dipy.io.utils import is_header_compatible from dipy.tracking.utils import length import h5py + from dwi_ml.data.processing.streamlines.data_augmentation import \ resample_or_compress from nested_lookup import nested_lookup @@ -20,6 +21,8 @@ from dwi_ml.data.io import load_file_to4d from dwi_ml.data.processing.dwi.dwi import standardize_data +from dwi_ml.data.processing.streamlines.post_processing import \ + compute_triu_connectivity def _load_and_verify_file(filename: str, subj_input_path, group_name: str, @@ -106,7 +109,10 @@ def __init__(self, root_folder: Path, out_hdf_filename: Path, training_subjs: List[str], validation_subjs: List[str], testing_subjs: List[str], groups_config: dict, std_mask: str, step_size: float = None, - compress: float = None, enforce_files_presence: bool = True, + compress: float = None, + compute_connectivity_matrix: bool = False, + downsampled_size_for_connectivity: Union[int, list] = 20, + enforce_files_presence: bool = True, save_intermediate: bool = False, intermediate_folder: Path = None): """ @@ -129,6 +135,11 @@ def __init__(self, root_folder: Path, out_hdf_filename: Path, Step size to resample streamlines. Default: None. compress: float Compress streamlines. Default: None. + compute_connectivity_matrix: bool + Compute connectivity matrix for each streamline group. + Default: False. + downsampled_size_for_connectivity: int or List + See compute_connectivity_matrix's doc. enforce_files_presence: bool If true, will stop if some files are not available for a subject. Default: True. @@ -147,6 +158,20 @@ def __init__(self, root_folder: Path, out_hdf_filename: Path, self.groups_config = groups_config self.step_size = step_size self.compress = compress + self.compute_connectivity = compute_connectivity_matrix + if self.compute_connectivity: + if isinstance(downsampled_size_for_connectivity, List): + assert len(downsampled_size_for_connectivity) == 3, \ + "Expecting to work with 3D volumes. Expecting " \ + "connectivity downsample size to be a list of 3 values, " \ + "but got {}.".format(downsampled_size_for_connectivity) + self.connectivity_downsample_size = downsampled_size_for_connectivity + else: + assert isinstance(downsampled_size_for_connectivity, int), \ + "Expecting the connectivity matrix size to be either " \ + "a 3D list or an integer, but got {}" \ + .format(downsampled_size_for_connectivity) + self.connectivity_downsample_size = [downsampled_size_for_connectivity] * 3 # Optional self.std_mask = std_mask # (could be None) @@ -556,12 +581,14 @@ def _create_streamline_groups(self, ref, subj_input_dir, subj_id, streamlines_group.attrs['dimensions'] = d streamlines_group.attrs['voxel_sizes'] = vs streamlines_group.attrs['voxel_order'] = vo + if self.compute_connectivity: + streamlines_group.attrs['downsampled_size'] = \ + self.connectivity_downsample_size if len(sft.data_per_point) > 0: logging.debug('sft contained data_per_point. Data not kept.') if len(sft.data_per_streamline) > 0: - logging.debug('sft contained data_per_streamlines. Data not ' - 'kept.') + logging.debug('sft contained data_per_streamlines. Data not kept.') # Accessing private Dipy values, but necessary. # We need to deconstruct the streamlines into arrays with @@ -574,6 +601,18 @@ def _create_streamline_groups(self, ref, subj_input_dir, subj_id, data=sft.streamlines._lengths) streamlines_group.create_dataset('euclidean_lengths', data=lengths) + if self.compute_connectivity: + # Can be reduced using sparse tensors notation... always + # minimum 50% of zeros! Then we could save separately + # the indices, values, size of the tensor. But unclear how + # much sparse they need to be to actually save memory. + # Skipping for now. + streamlines_group.create_dataset( + 'connectivity_matrix', + data=compute_triu_connectivity( + sft.streamlines, d, self.connectivity_downsample_size, + binary=True, to_sparse_tensor=False)) + def _process_one_streamline_group( self, subj_dir: Path, group: str, subj_id: str, header: nib.Nifti1Header): @@ -602,7 +641,7 @@ def _process_one_streamline_group( final_tractogram : StatefulTractogram All streamlines in voxel space. output_lengths : List[float] - The euclidean length of each streamline + The Euclidean length of each streamline """ tractograms = self.groups_config[group]['files'] @@ -621,7 +660,7 @@ def _process_one_streamline_group( for instructions in tractograms: if instructions.endswith('/ALL'): - # instructions is to get all tractograms in given folder. + # instructions are to get all tractograms in given folder. tractograms_dir = instructions.split('/ALL') tractograms_dir = ''.join(tractograms_dir[:-1]) tractograms_sublist = [ diff --git a/dwi_ml/data/hdf5/utils.py b/dwi_ml/data/hdf5/utils.py index 488b5cc4..16d795ff 100644 --- a/dwi_ml/data/hdf5/utils.py +++ b/dwi_ml/data/hdf5/utils.py @@ -11,7 +11,7 @@ from dwi_ml.io_utils import add_resample_or_compress_arg -def add_basic_args(p: ArgumentParser): +def add_hdf5_creation_args(p: ArgumentParser): # Positional arguments p.add_argument('dwi_ml_ready_folder', @@ -73,6 +73,18 @@ def add_mri_processing_args(p: ArgumentParser): def add_streamline_processing_args(p: ArgumentParser): g = p.add_argument_group('Streamlines processing options:') add_resample_or_compress_arg(g) + g.add_argument( + '--compute_connectivity_matrix', action='store_true', + help="If set, computes the 3D connectivity matrix for each streamline " + "group. \nDefined from downsampled image, not from anatomy! \n" + "Ex: can be used at validation time with our trainer's " + "'generation-validation' step.") + g.add_argument( + '--connectivity_downsample_size', metavar='m', type=int, nargs='+', + help="Number of 3D blocks (m x m x m) for the connectivity matrix. \n" + "(The matrix will be m^3 x m^3). If more than one values are " + "provided, expected to be one per dimension. \n" + "Default: 20x20x20.") def _initialize_intermediate_subdir(hdf5_file, save_intermediate): @@ -127,7 +139,9 @@ def prepare_hdf5_creator(args): creator = HDF5Creator(Path(args.dwi_ml_ready_folder), args.out_hdf5_file, training_subjs, validation_subjs, testing_subjs, groups_config, args.std_mask, args.step_size, - args.compress, args.enforce_files_presence, + args.compress, args.compute_connectivity_matrix, + args.connectivity_downsample_size, + args.enforce_files_presence, args.save_intermediate, intermediate_subdir) return creator diff --git a/dwi_ml/data/processing/streamlines/post_processing.py b/dwi_ml/data/processing/streamlines/post_processing.py index f750222e..c31f6e14 100644 --- a/dwi_ml/data/processing/streamlines/post_processing.py +++ b/dwi_ml/data/processing/streamlines/post_processing.py @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- +import logging from typing import List import numpy as np @@ -253,13 +254,92 @@ def weight_value_with_angle(values: List, streamlines: List = None, # Mult choice: # We don't want to multiply by 0. Multiplying by angles + 1. - # values[i] = values[i] * (angles + 1.0) - values[i] = values[i] * (angles + 1.0)**2 - - # Pow choice: - # loss^0 = 1. loss^1 = loss. Also adding 1. - # But if values are < 1, pow becomes smaller. - # Our losses tend toward 0. Adding 1 before. - #values[i] = torch.pow(1.0 + values[i], angles + 1.0) - 1.0 + values[i] = values[i] * (angles + 1.0) return values + + +def compute_triu_connectivity( + streamlines, volume_size, downsampled_volume_size, + binary: bool = False, to_sparse_tensor: bool = False, device=None): + """ + Compute a connectivity matrix. + + Parameters + ---------- + streamlines: list of np arrays or list of tensors. + Streamlines, in vox space, corner origin. + volume_size: list + The 3D dimension of the reference volume. + downsampled_volume_size: + The m1 x m2 x m3 = mm downsampled volume size for the connectivity matrix. + This means that the matrix will be a mm x mm triangular matrix. + In 3D, with 20x20x20, this is an 8000 x 8000 matrix (triangular). It + probably contains a lot of zeros with the background being included. + Can be saved as sparse. + binary: bool + If true, return a binary matrix. + to_sparse_tensor: + If true, return the sparse matrix. + device: + If true and to_sparse_tensor, the matrix will be hosted on device. + """ + # Getting endpoint coordinates + # + Fix types + volume_size = np.asarray(volume_size) + downsampled_volume_size = np.asarray(downsampled_volume_size) + if isinstance(streamlines[0], list): + start_values = [s[0] for s in streamlines] + end_values = [s[-1] for s in streamlines] + elif isinstance(streamlines[0], torch.Tensor): + start_values = [s[0, :].cpu().numpy() for s in streamlines] + end_values = [s[-1, :].cpu().numpy() for s in streamlines] + else: # expecting numpy arrays + start_values = [s[0, :] for s in streamlines] + end_values = [s[-1, :] for s in streamlines] + + assert len(downsampled_volume_size) == len(volume_size) + nb_dims = len(downsampled_volume_size) + nb_voxels_pre = np.prod(volume_size) + nb_voxels_post = np.prod(downsampled_volume_size) + logging.debug("Preparing connectivity matrix of downsampled volume: from " + "{} to {}. Gives a matrix of size {} x {} rather than {} " + "voxels)." + .format(volume_size, downsampled_volume_size, + nb_voxels_post, nb_voxels_post, nb_voxels_pre)) + + # Downsampling + mult_factor = downsampled_volume_size / volume_size + start_values = np.clip((start_values * mult_factor).astype(int), + a_min=0, a_max=downsampled_volume_size - 1) + end_values = np.clip((end_values * mult_factor).astype(int), + a_min=0, a_max=downsampled_volume_size - 1) + + # Blocs go from 0 to m1*m2*m3. + start_block = np.ravel_multi_index( + [start_values[:, d] for d in range(nb_dims)], downsampled_volume_size) + end_block = np.ravel_multi_index( + [end_values[:, d] for d in range(nb_dims)], downsampled_volume_size) + + total_size = np.prod(downsampled_volume_size) + matrix = np.zeros((total_size, total_size), dtype=int) + for s_start, s_end in zip(start_block, end_block): + matrix[s_start, s_end] += 1 + + # Either, at the end, sum lower triangular + upper triangular (except + # diagonal), or: + if s_end != s_start: + matrix[s_end, s_start] += 1 + + matrix = np.triu(matrix) + assert matrix.sum() == len(streamlines) + + if binary: + matrix = matrix.astype(bool) + + if to_sparse_tensor: + logging.debug("Converting matrix to sparse. Contained {}% of zeros." + .format((1 - np.count_nonzero(matrix) / total_size) * 100)) + matrix = torch.as_tensor(matrix, device=device).to_sparse() + + return matrix diff --git a/dwi_ml/models/direction_getter_models.py b/dwi_ml/models/direction_getter_models.py index 41865982..61c3d648 100644 --- a/dwi_ml/models/direction_getter_models.py +++ b/dwi_ml/models/direction_getter_models.py @@ -291,9 +291,8 @@ def get_tracking_directions(self, outputs, algo: str, Returns ------- - next_dirs: list - A list of numpy arrays (one per streamline), each of size (1, 3): - the three coordinates of the next direction's vector. + next_dirs: torch.Tensor + A tensor of shape [n, 3] with the next direction for each output. """ if algo == 'det': next_dirs = self._get_tracking_direction_det( @@ -301,7 +300,7 @@ def get_tracking_directions(self, outputs, algo: str, else: next_dirs = self._sample_tracking_direction_prob( outputs, eos_stopping_thresh) - return next_dirs.detach() + return next_dirs class AbstractRegressionDG(AbstractDirectionGetterModel): diff --git a/dwi_ml/models/main_models.py b/dwi_ml/models/main_models.py index 9cc864d8..0d6e6f7b 100644 --- a/dwi_ml/models/main_models.py +++ b/dwi_ml/models/main_models.py @@ -10,6 +10,7 @@ import torch from torch import Tensor +from dwi_ml.data.dataset.multi_subject_containers import MultisubjectSubset from dwi_ml.data.processing.volume.interpolation import \ interpolate_volume_in_neighborhood from dwi_ml.data.processing.space.neighborhood import prepare_neighborhood_vectors @@ -98,6 +99,10 @@ def set_context(self, context): assert context in ['training', 'tracking'] self._context = context + @property + def context(self): + return self._context + def move_to(self, device): """ Careful. Calling model.to(a_device) does not influence the self.device. @@ -373,8 +378,8 @@ def forward(self, inputs, target_streamlines: List[torch.tensor], **kw): class MainModelOneInput(MainModelAbstract): - def prepare_batch_one_input(self, streamlines, subset, subj, - input_group_idx, prepare_mask=False): + def prepare_batch_one_input(self, streamlines, subset: MultisubjectSubset, + subj, input_group_idx, prepare_mask=False): """ These params are passed by either the batch loader or the propagator, which manage the data. @@ -514,11 +519,12 @@ def get_tracking_directions(self, model_outputs: Tensor, algo: str, Returns ------- - next_dir: list[array(3,)] - Numpy arrays with x,y,z value, one per streamline data point. + next_dir: torch.Tensor + A tensor of shape [n, 3] with the next direction for each output. """ - return self.direction_getter.get_tracking_directions( + dirs = self.direction_getter.get_tracking_directions( model_outputs, algo, eos_stopping_thresh) + return dirs def compute_loss(self, model_outputs: List[Tensor], target_streamlines, average_results=True, **kw): diff --git a/dwi_ml/models/projects/learn2track_model.py b/dwi_ml/models/projects/learn2track_model.py index d7c1e8c4..1062a366 100644 --- a/dwi_ml/models/projects/learn2track_model.py +++ b/dwi_ml/models/projects/learn2track_model.py @@ -234,6 +234,13 @@ def forward(self, inputs: List[torch.tensor], if self._context is None: raise ValueError("Please set context before usage.") + # Verifying the first input + assert inputs[0].shape[-1] == self.input_size, \ + "Not the expected input size! Should be {} (i.e. {} features for " \ + "each {} neighbor), but got {} (input shape {})." \ + .format(self.input_size, self.nb_features, self.nb_neighbors + 1, + inputs[0].shape[-1], inputs[0].shape) + # Making sure we can use default 'enforce_sorted=True' with packed # sequences. unsorted_indices = None @@ -386,7 +393,12 @@ def copy_prev_dir(self, dirs, n_prev_dirs): return copy_prev_dir - def update_hidden_state(self, hidden_recurrent_states, lines_to_keep): + def remove_lines_in_hidden_state( + self, hidden_recurrent_states, lines_to_keep): + """ + Utilitary method to remove a few streamlines from the hidden + state. + """ if self.rnn_model.rnn_torch_key == 'lstm': # LSTM: For each layer, states are tuples; (h_t, C_t) # Size of tensors are each [1, nb_streamlines, nb_neurons] diff --git a/dwi_ml/models/projects/transforming_tractography.py b/dwi_ml/models/projects/transforming_tractography.py index 62ed8250..dfff723c 100644 --- a/dwi_ml/models/projects/transforming_tractography.py +++ b/dwi_ml/models/projects/transforming_tractography.py @@ -469,7 +469,8 @@ def forward(self, inputs: List[torch.tensor], # restack when computing loss. [Chosen here. See if we can improve] # b) loop on direction getter. Stack when computing loss. if self._context == 'tracking': - outputs = outputs.detach() + # If needs to detach: error? Should be using witch torch.no_grad. + outputs = outputs # No need to actually unpad, we only take the last (unpadded) # point, newly created. (-1 for python indexing) if use_padding: # Not all the same length (backward tracking) @@ -488,7 +489,8 @@ def forward(self, inputs: List[torch.tensor], # Outputs will be all streamlines merged. # To compute loss = ok. During tracking, we will need to split back. outputs = self.direction_getter(outputs) - outputs = copy_prev_dir + outputs + if self.start_from_copy_prev: + outputs = copy_prev_dir + outputs if self._context != 'tracking': outputs = list(torch.split(outputs, list(unpad_lengths))) diff --git a/dwi_ml/testing/testers.py b/dwi_ml/testing/testers.py index 47604d07..5d3d37a7 100644 --- a/dwi_ml/testing/testers.py +++ b/dwi_ml/testing/testers.py @@ -81,8 +81,9 @@ def load_and_format_data(self, subj_id, hdf5_file, subset_name): # we don't verify streamline ids (loading all), and we don't split / # reverse streamlines. But we resample / compress. logging.info("Loading its streamlines as SFT.") - streamline_group_idx = self.subset.streamline_groups.index( - self.streamlines_group) + streamline_group_idx, = np.where(self.subset.streamline_groups == + self.streamlines_group) + streamline_group_idx = streamline_group_idx[0] subj_data = self.subset.subjs_data_list.get_subj_with_handle(self.subj_idx) subj_sft_data = subj_data.sft_data_list[streamline_group_idx] sft = subj_sft_data.as_sft() diff --git a/dwi_ml/testing/visu_loss.py b/dwi_ml/testing/visu_loss.py index eef186b2..18b420ff 100644 --- a/dwi_ml/testing/visu_loss.py +++ b/dwi_ml/testing/visu_loss.py @@ -263,9 +263,10 @@ def run_visu_save_colored_displacement( # Either concat, run, split or (chosen:) loop # Use eos_thresh of 1 to be sure we don't output a NaN - out_dirs = [model.get_tracking_directions( - s_output, algo='det', eos_stopping_thresh=1.0).numpy() - for s_output in outputs] + with torch.no_grad(): + out_dirs = [model.get_tracking_directions( + s_output, algo='det', eos_stopping_thresh=1.0).numpy() + for s_output in outputs] # Save error together with ref sft = combine_displacement_with_ref(out_dirs, sft, model.step_size) diff --git a/dwi_ml/tracking/projects/learn2track_tracker.py b/dwi_ml/tracking/projects/learn2track_tracker.py index ee877362..06ec3216 100644 --- a/dwi_ml/tracking/projects/learn2track_tracker.py +++ b/dwi_ml/tracking/projects/learn2track_tracker.py @@ -89,5 +89,5 @@ def update_memory_after_removing_lines(self, can_continue: np.ndarray, _): Indexes of lines that are kept. """ # Hidden states: list[states] (One value per layer). - self.hidden_recurrent_states = self.model.update_hidden_state( + self.hidden_recurrent_states = self.model.remove_lines_in_hidden_state( self.hidden_recurrent_states, can_continue) diff --git a/dwi_ml/tracking/propagation.py b/dwi_ml/tracking/propagation.py index ecf95786..edd2ee8e 100644 --- a/dwi_ml/tracking/propagation.py +++ b/dwi_ml/tracking/propagation.py @@ -147,7 +147,7 @@ def _take_one_step_or_go_straight( next_dirs = torch.vstack(next_dirs) if normalize_directions: - next_dirs /= torch.linalg.norm(next_dirs, dim=-1)[:, None] + next_dirs = next_dirs / torch.linalg.norm(next_dirs, dim=-1)[:, None] if previous_dirs is not None: # Verify angle diff --git a/dwi_ml/tracking/tracker.py b/dwi_ml/tracking/tracker.py index 68d7b958..9f0bd582 100644 --- a/dwi_ml/tracking/tracker.py +++ b/dwi_ml/tracking/tracker.py @@ -449,12 +449,13 @@ def _get_multiple_lines_both_directions(self, seeds: List[np.ndarray]): return clean_lines, clean_seeds def _propagate_multiple_lines(self, lines: List[Tensor]): - return propagate_multiple_lines( - lines, self.update_memory_after_removing_lines, - self.get_next_dirs, self.theta, self.step_size, - self.verify_opposite_direction, self.mask, self.max_nbr_pts, - append_last_point=self.append_last_point, - normalize_directions=self.normalize_directions) + with torch.no_grad(): + return propagate_multiple_lines( + lines, self.update_memory_after_removing_lines, + self.get_next_dirs, self.theta, self.step_size, + self.verify_opposite_direction, self.mask, self.max_nbr_pts, + append_last_point=self.append_last_point, + normalize_directions=self.normalize_directions) def get_next_dirs(self, lines: List[Tensor], n_last_pos: List[Tensor]): """ diff --git a/dwi_ml/training/batch_loaders.py b/dwi_ml/training/batch_loaders.py index 9a00343a..80ea704a 100644 --- a/dwi_ml/training/batch_loaders.py +++ b/dwi_ml/training/batch_loaders.py @@ -5,9 +5,9 @@ These classes define how to sample the streamlines available in the MultiSubjectData. -AbstractBatchSampler: +AbstractBatchLoader: -- Define the load_batch method: +- Defines the load_batch method: - Loads the streamlines associated to sampled ids. Can resample them. - Performs data augmentation (on-the-fly to avoid having to multiply data @@ -21,9 +21,9 @@ ---------- Implemented child classes -BatchStreamlinesSamplerOneInput: +BatchLoaderOneInput: -- Redefines the load_batch method: +- Defines the load_batch_inputs method: - Now also loads the input data under each point of the streamline (and possibly its neighborhood), for one input volume. diff --git a/dwi_ml/training/projects/learn2track_trainer.py b/dwi_ml/training/projects/learn2track_trainer.py index c37dedcb..2ef066d9 100644 --- a/dwi_ml/training/projects/learn2track_trainer.py +++ b/dwi_ml/training/projects/learn2track_trainer.py @@ -2,12 +2,14 @@ import logging from typing import List +import h5py import numpy as np import torch from dwi_ml.models.projects.learn2track_model import Learn2TrackModel +from dwi_ml.tracking.projects.utils import prepare_tracking_mask from dwi_ml.tracking.propagation import propagate_multiple_lines -from dwi_ml.training.projects.trainers_for_generation import \ +from dwi_ml.training.with_generation.trainer import \ DWIMLTrainerForTrackingOneInput logger = logging.getLogger('trainer_logger') @@ -72,8 +74,8 @@ def propagate_multiple_lines(self, lines: List[torch.Tensor], ids_per_subj): def update_memory_after_removing_lines(can_continue: np.ndarray, _): nonlocal hidden_states - hidden_states = self.model.update_hidden_state(hidden_states, - can_continue) + hidden_states = self.model.remove_lines_in_hidden_state( + hidden_states, can_continue) def get_dirs_at_last_pos(_lines: List[torch.Tensor], n_last_pos): nonlocal hidden_states @@ -82,21 +84,34 @@ def get_dirs_at_last_pos(_lines: List[torch.Tensor], n_last_pos): batch_inputs = self.batch_loader.load_batch_inputs(n_last_pos, ids_per_subj) - _model_outputs, hidden_states = self.model( - batch_inputs, _lines, return_hidden=True, point_idx=-1) + model_outputs, hidden_states = self.model( + batch_inputs, _lines, hidden_recurrent_states=hidden_states, + return_hidden=True, point_idx=-1) next_dirs = self.model.get_tracking_directions( - _model_outputs, algo='det', eos_stopping_thresh=0.5) + model_outputs, algo='det', eos_stopping_thresh=0.5) return next_dirs - self.model.set_context('tracking') theta = 2 * np.pi # theta = 360 degrees max_nbr_pts = int(200 / self.model.step_size) - results = propagate_multiple_lines( - lines, update_memory_after_removing_lines, get_dirs_at_last_pos, - theta=theta, step_size=self.model.step_size, - verify_opposite_direction=False, mask=self.tracking_mask, - max_nbr_pts=max_nbr_pts, append_last_point=False, - normalize_directions=True) - self.model.set_context('training') - return results + + final_lines = [] + for subj_idx, line_idx in ids_per_subj.items(): + + with h5py.File(self.batch_loader.dataset.hdf5_file, 'r') as hdf_handle: + subj_id = self.batch_loader.context_subset.subjects[subj_idx] + logging.debug("Loading subj {} ({})'s tracking mask." + .format(subj_idx, subj_id)) + tracking_mask, _ = prepare_tracking_mask( + hdf_handle, self.tracking_mask_group, subj_id=subj_id, + mask_interp='nearest') + tracking_mask.move_to(self.device) + + final_lines.extend(propagate_multiple_lines( + lines[line_idx], update_memory_after_removing_lines, + get_dirs_at_last_pos, theta=theta, + step_size=self.model.step_size, verify_opposite_direction=False, + mask=tracking_mask, max_nbr_pts=max_nbr_pts, + append_last_point=False, normalize_directions=True)) + + return final_lines diff --git a/dwi_ml/training/projects/trainers_for_generation.py b/dwi_ml/training/projects/trainers_for_generation.py deleted file mode 100644 index ae6c3936..00000000 --- a/dwi_ml/training/projects/trainers_for_generation.py +++ /dev/null @@ -1,234 +0,0 @@ -# -*- coding: utf-8 -*- -import logging -from typing import List - -import h5py -import numpy as np -import torch -from torch.nn import PairwiseDistance -from tqdm import tqdm - -from dwi_ml.experiment_utils.tqdm_logging import tqdm_logging_redirect -from dwi_ml.models.main_models import ModelWithDirectionGetter -from dwi_ml.tracking.propagation import propagate_multiple_lines -from dwi_ml.tracking.projects.utils import prepare_tracking_mask -from dwi_ml.training.trainers import DWIMLTrainerOneInput -from dwi_ml.training.utils.monitoring import BatchHistoryMonitor, TimeMonitor - -logger = logging.getLogger('train_logger') - -# Emma tests in ISMRM: a box of 30x30x30 mm sounds good. -# So half of it, max distance = sqrt( 3 * 15^2) = -IS_THRESHOLD = 25.98 - - -class DWIMLTrainerForTrackingOneInput(DWIMLTrainerOneInput): - model: ModelWithDirectionGetter - - def __init__(self, add_a_tracking_validation_phase: bool = False, - tracking_phase_frequency: int = 5, - tracking_phase_nb_steps_init: int = 5, - tracking_phase_mask_group: str = None, - *args, **kw): - super().__init__(*args, **kw) - - self.add_a_tracking_validation_phase = add_a_tracking_validation_phase - self.tracking_phase_frequency = tracking_phase_frequency - self.tracking_phase_nb_steps_init = tracking_phase_nb_steps_init - self.tracking_valid_time_monitor = TimeMonitor() - self.tracking_valid_IS_monitor = BatchHistoryMonitor(weighted=True) - self.tracking_valid_loss_monitor = BatchHistoryMonitor(weighted=True) - self.tracking_mask_group = tracking_phase_mask_group - - self.tracking_mask = None - if add_a_tracking_validation_phase: - # Right now, using any subject's, and supposing that they are all - # in the same space. Else, code would need refactoring to allow - # tracking on multiple subjects. Or we can loop on each subject. - logging.warning("***************\n" - "CODE NEEDS REFACTORING. USING THE SAME TRACKING " - "MASK FOR ALL SUBJECTS.\n" - "***************\n") - any_subj = self.batch_loader.dataset.training_set.subjects[0] - if tracking_phase_mask_group is not None: - with h5py.File(self.batch_loader.dataset.hdf5_file, 'r') \ - as hdf_handle: - logging.info("Loading tracking mask.") - self.tracking_mask, _ = prepare_tracking_mask( - hdf_handle, tracking_phase_mask_group, subj_id=any_subj, - mask_interp='nearest') - self.tracking_mask.move_to(self.device) - - @property - def params_for_checkpoint(self): - p = super().params_for_checkpoint - p.update({ - 'add_a_tracking_validation_phase': self.add_a_tracking_validation_phase, - 'tracking_phase_frequency': self.tracking_phase_frequency, - 'tracking_phase_nb_steps_init': self.tracking_phase_nb_steps_init, - 'tracking_phase_mask_group': self.tracking_mask_group - }) - - return p - - def _update_states_from_checkpoint(self, current_states): - super()._update_states_from_checkpoint(current_states) - self.tracking_valid_loss_monitor.set_state( - current_states['tracking_valid_loss_monitor_state']) - self.tracking_valid_IS_monitor.set_state( - current_states['tracking_valid_IS_monitor_state']) - - def _prepare_checkpoint_info(self) -> dict: - checkpoint_info = super()._prepare_checkpoint_info() - checkpoint_info['current_states'].update({ - 'tracking_valid_loss_monitor_state': - self.tracking_valid_loss_monitor.get_state(), - 'tracking_valid_IS_monitor_state': - self.tracking_valid_IS_monitor.get_state(), - }) - return checkpoint_info - - def save_local_logs(self): - super().save_local_logs() - self._save_log_locally( - self.tracking_valid_loss_monitor.average_per_epoch, - "tracking_validation_loss_per_epoch.npy") - self._save_log_locally( - self.tracking_valid_IS_monitor.average_per_epoch, - "tracking_validation_IS_per_epoch.npy") - - def validate_one_epoch(self, epoch): - if self.add_a_tracking_validation_phase: - self.tracking_valid_loss_monitor.start_new_epoch() - self.tracking_valid_IS_monitor.start_new_epoch() - self.tracking_valid_time_monitor.start_new_epoch() - - super().validate_one_epoch(epoch) - - if self.add_a_tracking_validation_phase: - self.tracking_valid_loss_monitor.end_epoch() - self.tracking_valid_IS_monitor.end_epoch() - self.tracking_valid_time_monitor.end_epoch() - - # Save info - if self.comet_exp: - self._update_comet_after_epoch(self.comet_exp.validate, epoch, - tracking_phase=True) - - def _get_latest_loss_to_supervise_best(self): - if self.use_validation: - if self.add_a_tracking_validation_phase: - # Compared to super, replacing by tracking_valid loss. - mean_epoch_loss = self.tracking_valid_loss_monitor.average_per_epoch[-1] - - # Could use IS instead. Not implemented. - else: - mean_epoch_loss = self.valid_loss_monitor.average_per_epoch[-1] - else: - mean_epoch_loss = self.train_loss_monitor.average_per_epoch[-1] - - return mean_epoch_loss - - def validate_one_batch(self, data, epoch): - mean_loss, n = super().validate_one_batch(data, epoch) - - if (epoch + 1) % self.tracking_phase_frequency == 0: - logger.info("Additional tracking-like generation validation " - "from batch.") - gen_mean_loss, gen_n, percent_inv = self.generate_from_one_batch(data) - gen_mean_loss = gen_mean_loss.cpu().item() - self.tracking_valid_loss_monitor.update(gen_mean_loss, weight=n) - self.tracking_valid_IS_monitor.update(percent_inv, weight=n) - else: - self.tracking_valid_loss_monitor.update( - self.tracking_valid_loss_monitor.average_per_epoch[-1]) - self.tracking_valid_IS_monitor.update( - self.tracking_valid_IS_monitor.average_per_epoch[-1]) - - return mean_loss, n - - def _update_comet_after_epoch(self, context: str, epoch: int, - tracking_phase=False): - if tracking_phase: - loss = self.tracking_valid_loss_monitor.average_per_epoch[-1] - logger.info(" Mean tracking loss for this epoch: {}".format(loss)) - - percent_inv = self.tracking_valid_IS_monitor.average_per_epoch[-1] - logger.info(" Mean simili-IS ratio for this epoch: {}" - " (threshold {})".format(percent_inv, IS_THRESHOLD)) - - if self.comet_exp: - comet_context = self.comet_exp.validate - with comet_context(): - self.comet_exp.log_metric( - "generation_loss_per_epoch", loss, step=epoch) - self.comet_exp.log_metric( - "generation_IS_ratio_per_epoch", percent_inv, step=epoch) - - else: - super()._update_comet_after_epoch(context, epoch) - - def generate_from_one_batch(self, data): - # Data interpolation has not been done yet. GPU computations are done - # here in the main thread. - torch.set_printoptions(precision=4) - np.set_printoptions(precision=4) - - lines, ids_per_subj = data - lines = [line.to(self.device, non_blocking=True, dtype=torch.float) - for line in lines] - last_pos = torch.vstack([line[-1, :] for line in lines]) - mean_length = np.mean([len(s) for s in lines]) - - # Dataloader always works on CPU. Sending to right device. - # (model is already moved). Using only the n first points - lines = [s[0:min(len(s), self.tracking_phase_nb_steps_init), :] - for s in lines] - lines = self.propagate_multiple_lines(lines, ids_per_subj) - - # Verify "loss", i.e. the differences in coordinates - computed_last_pos = torch.vstack([line[-1, :] for line in lines]) - compute_mean_length = np.mean([len(s) for s in lines]) - - logging.debug(" Average streamline length (nb pts) in this batch: {} \n" - " Average recovered streamline length: {}" - .format(mean_length, compute_mean_length)) - l2_loss = PairwiseDistance(p=2) - loss = l2_loss(computed_last_pos, last_pos) - - logging.info(" Best / Worst loss: {} / {}" - .format(torch.max(loss), torch.min(loss))) - - IS_ratio = torch.sum(loss > IS_THRESHOLD).cpu() / len(lines) * 100 - - return torch.mean(loss), len(lines), IS_ratio - - def propagate_multiple_lines(self, lines: List[torch.Tensor], ids_per_subj): - assert self.model.step_size is not None, \ - "We can't propagate compressed streamlines." - - def update_memory_after_removing_lines(_, __): - pass - - def get_dirs_at_last_pos(_lines: List[torch.Tensor], n_last_pos): - n_last_pos = [pos[None, :] for pos in n_last_pos] - batch_inputs = self.batch_loader.load_batch_inputs( - n_last_pos, ids_per_subj) - - if self.model.forward_uses_streamlines: - model_outputs = self.model(batch_inputs, n_last_pos) - else: - model_outputs = self.model(batch_inputs) - - next_dirs = self.model.get_tracking_directions( - model_outputs, algo='det', eos_stopping_thresh=0.5) - return next_dirs - - theta = 2 * np.pi # theta = 360 degrees - max_nbr_pts = int(200 / self.model.step_size) - return propagate_multiple_lines( - lines, update_memory_after_removing_lines, get_dirs_at_last_pos, - theta=theta, step_size=self.model.step_size, - verify_opposite_direction=False, mask=self.tracking_mask, - max_nbr_pts=max_nbr_pts, append_last_point=False, - normalize_directions=True) diff --git a/dwi_ml/training/projects/transformer_trainer.py b/dwi_ml/training/projects/transformer_trainer.py index 8ae99c0d..4bb0918e 100644 --- a/dwi_ml/training/projects/transformer_trainer.py +++ b/dwi_ml/training/projects/transformer_trainer.py @@ -2,69 +2,71 @@ import logging from typing import List +import h5py +import numpy as np import torch -from dwi_ml.models.projects.transforming_tractography import AbstractTransformerModel -from dwi_ml.training.batch_samplers import DWIMLBatchIDSampler -from dwi_ml.training.batch_loaders import DWIMLBatchLoaderOneInput -from dwi_ml.training.trainers import DWIMLTrainerOneInput +from dwi_ml.tracking.projects.utils import prepare_tracking_mask +from dwi_ml.tracking.propagation import propagate_multiple_lines +from dwi_ml.training.with_generation.trainer import \ + DWIMLTrainerForTrackingOneInput -class TransformerTrainer(DWIMLTrainerOneInput): - def __init__(self, - model: AbstractTransformerModel, experiments_path: str, - experiment_name: str, - batch_sampler: DWIMLBatchIDSampler, - batch_loader: DWIMLBatchLoaderOneInput, - learning_rates: List = None, weight_decay: float = 0.01, - optimizer='Adam', max_epochs: int = 10, - max_batches_per_epoch_training: int = 1000, - max_batches_per_epoch_validation: int = 1000, - patience: int = None, patience_delta: float = 1e-6, - nb_cpu_processes: int = 0, use_gpu: bool = False, - comet_workspace: str = None, comet_project: str = None, - from_checkpoint: bool = False, log_level=logging.root.level): + +class TransformerTrainer(DWIMLTrainerForTrackingOneInput): + def __init__(self, **kwargs): """ See Super for parameter description. No additional parameters here. """ - super().__init__(model, experiments_path, experiment_name, - batch_sampler, batch_loader, - learning_rates, weight_decay, - optimizer, max_epochs, - max_batches_per_epoch_training, - max_batches_per_epoch_validation, - patience, patience_delta, nb_cpu_processes, use_gpu, - comet_workspace, comet_project, - from_checkpoint, log_level) + super().__init__(**kwargs) + + def propagate_multiple_lines(self, lines: List[torch.Tensor], ids_per_subj): + assert self.model.step_size is not None, \ + "We can't propagate compressed streamlines." + + # Getting the first inputs + tmp_lines = [line[:-1, :] for line in lines] + batch_inputs = self.batch_loader.load_batch_inputs(tmp_lines, ids_per_subj) + del tmp_lines + + def update_memory_after_removing_lines(can_continue: np.ndarray, __): + nonlocal batch_inputs + batch_inputs = [inp for i, inp in enumerate(batch_inputs) if + can_continue[i]] + + def get_dirs_at_last_pos(_lines: List[torch.Tensor], n_last_pos): + nonlocal batch_inputs + n_last_pos = [pos[None, :] for pos in n_last_pos] + latest_inputs = self.batch_loader.load_batch_inputs( + n_last_pos, ids_per_subj) + batch_inputs = [torch.vstack((first, last)) for first, last in + zip(batch_inputs, latest_inputs)] + + model_outputs = self.model(batch_inputs, _lines) + next_dirs = self.model.get_tracking_directions( + model_outputs, algo='det', eos_stopping_thresh=0.5) + return next_dirs - def run_model(self, batch_inputs, batch_streamlines): - dirs = self.model.format_directions(batch_streamlines) + theta = 2 * np.pi # theta = 360 degrees + max_nbr_pts = int(200 / self.model.step_size) - # Formatting the previous dirs for all points. - n_prev_dirs = self.model.format_previous_dirs(dirs, self.device) + final_lines = [] + for subj_idx, line_idx in ids_per_subj.items(): - # Not keeping the last point: only useful to get the last direction - # (last target), but won't be used as an input. - if n_prev_dirs is not None: - n_prev_dirs = [s[:-1] for s in n_prev_dirs] + with h5py.File(self.batch_loader.dataset.hdf5_file, 'r') as hdf_handle: + subj_id = self.batch_loader.context_subset.subjects[subj_idx] + logging.debug("Loading subj {} ({})'s tracking mask." + .format(subj_idx, subj_id)) + tracking_mask, _ = prepare_tracking_mask( + hdf_handle, self.tracking_mask_group, subj_id=subj_id, + mask_interp='nearest') + tracking_mask.move_to(self.device) - try: - # Apply model. This calls our model's forward function - # (the hidden states are not used here, neither as input nor - # outputs. We need them only during tracking). - model_outputs, _ = self.model(batch_inputs, n_prev_dirs, - self.device) - except RuntimeError: - # Training RNNs with variable-length sequences on the GPU can - # cause memory fragmentation in the pytorch-managed cache, - # possibly leading to "random" OOM RuntimeError during - # training. Emptying the GPU cache seems to fix the problem for - # now. We don't do it every update because it can be time - # consuming. - torch.cuda.empty_cache() - model_outputs, _ = self.model(batch_inputs, n_prev_dirs, - self.device) + final_lines.extend(propagate_multiple_lines( + lines[line_idx], update_memory_after_removing_lines, + get_dirs_at_last_pos, theta=theta, + step_size=self.model.step_size, verify_opposite_direction=False, + mask=tracking_mask, max_nbr_pts=max_nbr_pts, + append_last_point=False, normalize_directions=True)) - # Returning the directions too, to be re-used in compute_loss - # later instead of computing them twice. - return model_outputs, dirs + return final_lines diff --git a/dwi_ml/training/trainers.py b/dwi_ml/training/trainers.py index e4f9c59e..a4010608 100644 --- a/dwi_ml/training/trainers.py +++ b/dwi_ml/training/trainers.py @@ -40,19 +40,15 @@ class DWIMLAbstractTrainer: """ This Trainer class's train_and_validate() method: - Creates DataLoaders from the data_loaders. Collate_fn will be the - loader.load_batch() method, and the dataset will be - sampler.source_data. - - Trains each epoch by using compute_batch_loss, which should be - implemented in each project's child class. + found in the batch loader, and the dataset will be found in the data + sampler. + - Trains each epoch by using the model's loss computation method. Comet is used to save training information, but some logs will also be saved locally in the saving_path. - NOTE: TRAINER USES STREAMLINES COORDINATES IN VOXEL SPACE, TO CORNER. + NOTE: TRAINER USES STREAMLINES COORDINATES IN VOXEL SPACE, CORNER ORIGIN. """ - # For now, this is ugly... But the option is there if you want. - save_logs_per_batch = False - def __init__(self, model: MainModelAbstract, experiments_path: str, experiment_name: str, batch_sampler: DWIMLBatchIDSampler, @@ -267,17 +263,31 @@ def __init__(self, # D. Monitors # grad_norm = The total norm (sqrt(sum(params**2))) of parameters # before gradient clipping, if any. - self.train_loss_monitor = BatchHistoryMonitor(weighted=True) - self.valid_loss_monitor = BatchHistoryMonitor(weighted=True) - self.grad_norm_monitor = BatchHistoryMonitor(weighted=False) - self.training_time_monitor = TimeMonitor() - self.validation_time_monitor = TimeMonitor() - if patience: - self.best_epoch_monitor = BestEpochMonitor(patience, patience_delta) - else: - # We won't use early stopping to stop the epoch, but we will use - # it as monitor of the best epochs. - self.best_epoch_monitor = BestEpochMonitor(patience=np.inf) + + # Training: only one monitor. + self.train_loss_monitor = BatchHistoryMonitor( + 'train_loss_monitor', weighted=True) + + # Validation: As many supervision losses as we want. + self.valid_local_loss_monitor = BatchHistoryMonitor( + 'valid_local_loss_monitor', weighted=True) + self.grad_norm_monitor = BatchHistoryMonitor( + 'grad_norm_monitor', weighted=False) + self.training_time_monitor = TimeMonitor('training_time_monitor') + self.validation_time_monitor = TimeMonitor('validation_time_monitor') + if not patience: + patience = np.inf + self.best_epoch_monitor = BestEpochMonitor( + 'best_epoch_monitor', patience, patience_delta) + self.monitors = [self.train_loss_monitor, + self.valid_local_loss_monitor, + self.grad_norm_monitor, self.training_time_monitor, + self.validation_time_monitor, self.best_epoch_monitor] + self.training_monitors = [self.train_loss_monitor, + self.grad_norm_monitor, + self.training_time_monitor] + self.validation_monitors = [self.valid_local_loss_monitor, + self.validation_time_monitor] # E. Comet Experiment # Values will be instantiated in train(). @@ -309,14 +319,16 @@ def __init__(self, @property def params_for_checkpoint(self): - # These are the parameters necessary to use _init_, together with - # instantiated classes (model, batch loader, batch sampler). - + """ + Returns the parameters necessary to initialize an identical Trainer. + However, the trainer's state could need to be updated (see checkpoint + management). + """ # Not saving experiment_path and experiment_name. Allowing user to # move the experiment on his computer between training sessions. - # Patience and patience delta will be taken from the best epoch monitor. - + # Patience is not saved here: we manage it separately to allow the + # user to increase the patience when running again. params = { 'learning_rates': self.learning_rates, 'weight_decay': self.weight_decay, @@ -333,9 +345,8 @@ def params_for_checkpoint(self): def save_params_to_json(self): """ - Utility method to save the parameters to a json file in the same - folder as the experiment. Suggestion, call this after instantiating - your trainer. + Save the trainer's parameters to a json file in the same folder as the + experiment. """ now = datetime.now() json_filename = os.path.join(self.saving_path, "parameters_{}.json" @@ -357,14 +368,12 @@ def save_params_to_json(self): def save_checkpoint(self): """ - Save an experiment checkpoint that can be resumed from. + Saves an experiment checkpoint, with parameters and states. """ logger.debug("Saving checkpoint...") - - # Make checkpoint directory checkpoint_dir = os.path.join(self.saving_path, "checkpoint") - # Backup old checkpoint before saving, and erase it afterwards + # Backup old checkpoint before saving, and erase it afterward. to_remove = None if os.path.exists(checkpoint_dir): to_remove = os.path.join(self.saving_path, "checkpoint_old") @@ -386,34 +395,34 @@ def save_checkpoint(self): shutil.rmtree(to_remove) def _prepare_checkpoint_info(self) -> dict: - # These are parameters that should be updated after instantiating cls. - + """ + To instantiate a Trainer, we need the initialization parameters + (self.params_for_checkpoint), and the states. This method returns + the dictionary of required states. + """ # Note. batch sampler's rng state and batch loader's are the same. current_states = { - # A. Rng value. + # Rng value. 'torch_rng_state': torch.random.get_rng_state(), 'torch_cuda_state': torch.cuda.get_rng_state() if self.use_gpu else None, 'sampler_np_rng_state': self.batch_sampler.np_rng.get_state(), 'loader_np_rng_state': self.batch_loader.np_rng.get_state(), - # B. Current epoch. + # Current epoch. 'current_epoch': self.current_epoch, - # C. Nb of batches per epoch. + # Nb of batches per epoch. 'nb_batches_train': self.nb_batches_train, 'nb_batches_valid': self.nb_batches_valid, - # D. Monitors - 'best_epoch_monitoring_state': self.best_epoch_monitor.get_state(), - 'train_loss_monitor_state': self.train_loss_monitor.get_state(), - 'valid_loss_monitor_state': self.valid_loss_monitor.get_state(), - 'grad_norm_monitor_state': self.grad_norm_monitor.get_state(), - 'training_time_monitor_state': self.training_time_monitor.get_state(), - 'validation_time_monitor_state': self.validation_time_monitor.get_state(), - # E. Comet Experiment + # Comet Experiment 'comet_key': self.comet_key, - # F. Optimizer + # Optimizer 'optimizer_state': self.optimizer.state_dict(), } + # Monitors + for monitor in self.monitors: + current_states[monitor.name + '_state'] = monitor.get_state() + # Additional params are the parameters necessary to load data, batch # samplers/loaders (see the example script dwiml_train_model.py). checkpoint_info = { @@ -433,16 +442,13 @@ def init_from_checkpoint( batch_loader: DWIMLAbstractBatchLoader, checkpoint_state: dict, new_patience, new_max_epochs, log_level): """ - During save_checkpoint(), checkpoint_state.pkl is saved. Loading it - back offers a dict that can be used to instantiate an experiment and - set it at the same state as previously. (Current_epoch is updated +1). - - Hint: If you want to use this in your child class, use: - experiment, checkpoint_state = super(cls, cls).init_from_checkpoint(... + Loads checkpoint information (parameters and states) to instantiate + a Trainer. Current_epoch is updated +1. """ trainer_params = checkpoint_state['params_for_init'] trainer = cls(model=model, experiments_path=experiments_path, - experiment_name=experiment_name, batch_sampler=batch_sampler, + experiment_name=experiment_name, + batch_sampler=batch_sampler, batch_loader=batch_loader, from_checkpoint=True, log_level=log_level, **trainer_params) @@ -450,6 +456,9 @@ def init_from_checkpoint( trainer.max_epochs = new_max_epochs # Save params to json to help user remember. + current_states = checkpoint_state['current_states'] + trainer._update_states_from_checkpoint(current_states) + if new_patience: trainer.best_epoch_monitor.patience = new_patience logger.info("Starting from checkpoint! Starting from epoch #{}.\n" @@ -460,11 +469,24 @@ def init_from_checkpoint( trainer.best_epoch_monitor.best_epoch, trainer.best_epoch_monitor.n_bad_epochs)) - current_states = checkpoint_state['current_states'] - trainer._update_states_from_checkpoint(current_states) return trainer + @staticmethod + def load_params_from_checkpoint(experiments_path: str, experiment_name: str): + total_path = os.path.join( + experiments_path, experiment_name, "checkpoint", + "checkpoint_state.pkl") + if not os.path.isfile(total_path): + raise FileNotFoundError('Checkpoint was not found! ({})' + .format(total_path)) + checkpoint_state = torch.load(total_path) + + return checkpoint_state + def _update_states_from_checkpoint(self, current_states): + """ + Updates all states from the checkpoint dictionary of states. + """ # A. Rng value. # RNG: # - numpy @@ -483,24 +505,21 @@ def _update_states_from_checkpoint(self, current_states): self.nb_batches_train = current_states['nb_batches_train'] self.nb_batches_valid = current_states['nb_batches_valid'] - # D. Monitors - self.best_epoch_monitor.set_state(current_states['best_epoch_monitoring_state']) - self.train_loss_monitor.set_state(current_states['train_loss_monitor_state']) - self.valid_loss_monitor.set_state(current_states['valid_loss_monitor_state']) - self.grad_norm_monitor.set_state(current_states['grad_norm_monitor_state']) - self.training_time_monitor.set_state(current_states['training_time_monitor_state']) - self.validation_time_monitor.set_state(current_states['validation_time_monitor_state']) - - # E. Comet Experiment + # D. Comet Experiment # Experiment will be instantiated in train(). self.comet_key = current_states['comet_key'] - # F. Optimizer + # E. Optimizer self.optimizer.load_state_dict(current_states['optimizer_state']) + # F. Monitors + for monitor in self.monitors: + monitor.set_state(current_states[monitor.name + '_state']) + def _init_comet(self): """ - For more information on comet, see our doc/Getting Started + Initialize comet's experiment. User's account and workspace must be + already set. """ try: if self.comet_key: @@ -543,11 +562,10 @@ def _init_comet(self): def estimate_nb_batches_per_epoch(self): """ - Please override in your child class if you have a better way to - define the epochs sizes. + Counts the number of training / validation batches required to see all + the data (up to the maximum number of allowed batches). - Returns: - (nb_training_batches_per_epoch, nb_validation_batches_per_epoch) + Data must be already loaded to access the information. """ streamline_group = self.batch_sampler.streamline_group_idx train_set = self.batch_sampler.dataset.training_set @@ -578,16 +596,16 @@ def estimate_nb_batches_per_epoch(self): def train_and_validate(self): """ - Train + validates the model (+ computes loss) + Trains + validates the model. Computes the training loss at each + training loop, and many validation metrics at each validation loop. - Starts comet, - Creates DataLoaders from the BatchSamplers, - For each epoch - uses _train_one_epoch and _validate_one_epoch, - - checks for earlyStopping if the loss is bad, + - saves a checkpoint, + - checks for earlyStopping if the loss is bad or patience is reached, - saves the model if the loss is good. - - Checks if allowed training time is exceeded. - """ logger.debug("Trainer {}: \nRunning the model {}.\n\n" .format(type(self), type(self.model))) @@ -616,7 +634,7 @@ def train_and_validate(self): if self.comet_exp: self.comet_exp.set_epoch(epoch) - logger.info("******* STARTING : Epoch {} (i.e. #{}) *******" + logger.info("\n\n******* STARTING : Epoch {} (i.e. #{}) *******" .format(epoch, epoch + 1)) # Set learning rate to either current value or last value @@ -670,26 +688,38 @@ def train_and_validate(self): break def _get_latest_loss_to_supervise_best(self): + """ + Defines the metric to be used to define the best model. Override if + you have other validation metrics. + """ if self.use_validation: - mean_epoch_loss = self.valid_loss_monitor.average_per_epoch[-1] + mean_epoch_loss = self.valid_local_loss_monitor.average_per_epoch[-1] else: mean_epoch_loss = self.train_loss_monitor.average_per_epoch[-1] return mean_epoch_loss def save_local_logs(self): - self._save_log_locally(self.train_loss_monitor.average_per_epoch, - "training_loss_per_epoch.npy") - self._save_log_locally(self.valid_loss_monitor.average_per_epoch, - "validation_loss_per_epoch.npy") - self._save_log_locally(self.grad_norm_monitor.average_per_epoch, - "gradient_norm.npy") - self._save_log_locally(self.training_time_monitor.epoch_durations, - "training_epochs_duration") - self._save_log_locally(self.validation_time_monitor.epoch_durations, - "validation_epochs_duration") + """ + Save logs locally as numpy arrays. + """ + for monitor in self.monitors: + if isinstance(monitor, BatchHistoryMonitor): + self._save_log_locally(monitor.average_per_epoch, + monitor.name + '_per_epoch.npy') + elif isinstance(monitor, TimeMonitor): + self._save_log_locally(monitor.epoch_durations, + monitor.name + '_duration.npy') + + def _save_log_locally(self, array: np.ndarray, fname: str): + np.save(os.path.join(self.log_dir, fname), array) def _clear_handles(self): + """ + Trying to improve the handles management. + Todo. Improve again. CPU multiprocessing fails because of handles + management. + """ # Make sure there are no existing HDF handles if using parallel workers if (self.nb_cpu_processes > 0 and self.batch_sampler.context_subset.is_lazy): @@ -700,11 +730,15 @@ def back_propagation(self, loss): logger.debug('*** Computing back propagation') loss.backward() - self.fix_parameters() # Ex: clip gradients + # Any other steps. Ex: clip gradients. Not implemented here. + # See Learn2track's Trainer for an example. + self.fix_parameters() + + # Supervizing the gradient's norm. grad_norm = compute_gradient_norm(self.model.parameters()) # Update parameters - # toDo. We could update only every n steps. + # Future work: We could update only every n steps. # Effective batch size is n time bigger. # See here https://towardsdatascience.com/optimize-pytorch-performance-for-speed-and-memory-efficiency-2022-84f453916ea6 self.optimizer.step() @@ -717,11 +751,11 @@ def back_propagation(self, loss): def train_one_epoch(self, epoch): """ - Train one epoch of the model: loop on all batches (forward + backward). + Trains one epoch of the model: loops on all batches + (forward + backpropagation). """ - self.training_time_monitor.start_new_epoch() - self.train_loss_monitor.start_new_epoch() - self.grad_norm_monitor.start_new_epoch() + for monitor in self.training_monitors: + monitor.start_new_epoch() # Setting contexts self.batch_loader.set_context('training') @@ -754,13 +788,9 @@ def train_one_epoch(self, epoch): # Enable gradients for backpropagation. Uses torch's module # train(), which "turns on" the training mode. with grad_context(): - mean_loss, n = self.run_one_batch(data) + mean_loss = self.train_one_batch(data, epoch) grad_norm = self.back_propagation(mean_loss) - - # Update information and logs - mean_loss = mean_loss.cpu().item() - self.train_loss_monitor.update(mean_loss, weight=n) self.grad_norm_monitor.update(grad_norm) # Explicitly delete iterator to kill threads and free memory before @@ -768,9 +798,8 @@ def train_one_epoch(self, epoch): del train_iterator # Saving epoch's information - self.train_loss_monitor.end_epoch() - self.grad_norm_monitor.end_epoch() - self.training_time_monitor.end_epoch() + for monitor in self.training_monitors: + monitor.end_epoch() self._update_comet_after_epoch('training', epoch) all_n = self.train_loss_monitor.current_epoch_batch_weights @@ -779,10 +808,10 @@ def train_one_epoch(self, epoch): def validate_one_epoch(self, epoch): """ - Validate one epoch of the model: loop on all batches. + Validates one epoch of the model: loops on all batches. """ - self.validation_time_monitor.start_new_epoch() - self.valid_loss_monitor.start_new_epoch() + for monitor in self.validation_monitors: + monitor.start_new_epoch() # Setting contexts # Turn gradients off (no back-propagation) @@ -791,7 +820,6 @@ def validate_one_epoch(self, epoch): self.batch_sampler.set_context('validation') self.model.set_context('validation') self.model.eval() - grad_context = torch.no_grad # Make sure there are no existing HDF handles if using parallel workers if (self.nb_cpu_processes > 0 and @@ -813,49 +841,66 @@ def validate_one_epoch(self, epoch): break # Validate this batch: forward propagation + loss - with grad_context(): - mean_loss, n = self.validate_one_batch(data, epoch) - - mean_loss = mean_loss.cpu().item() - - self.valid_loss_monitor.update(mean_loss, weight=n) + with torch.no_grad(): + self.validate_one_batch(data, epoch) # Explicitly delete iterator to kill threads and free memory before # running training again del valid_iterator # Save info - self.valid_loss_monitor.end_epoch() - self.validation_time_monitor.end_epoch() + for monitor in self.validation_monitors: + monitor.end_epoch() self._update_comet_after_epoch('validation', epoch) - def validate_one_batch(self, data, epoch): + def train_one_batch(self, data, epoch): + """ + Computes the loss for the current batch and updates monitors. + Returns the loss to be used for backpropagation. + """ # Encapsulated for easier management of child classes. - mean_loss, n = self.run_one_batch(data) - return mean_loss, n + mean_local_loss, n = self.run_one_batch(data) + self.train_loss_monitor.update(mean_local_loss.cpu().item(), + weight=n) + return mean_local_loss - def _update_comet_after_epoch(self, context: str, epoch: int): + def validate_one_batch(self, data, epoch): + """ + Computes the loss(es) for the current batch and updates monitors. """ - Update logs: - - logging to user - - get values from monitors and save final log locally. - - send mean data to comet + mean_local_loss, n = self.run_one_batch(data) + self.valid_local_loss_monitor.update(mean_local_loss.cpu().item(), + weight=n) - local_context: prefix when saving log. Training_ or Validate_ for - instance. + def _update_comet_after_epoch(self, context: str, epoch: int): + """ + Sends monitors information to comet. """ if context == 'training': - loss = self.train_loss_monitor.average_per_epoch[-1] + monitors = self.training_monitors elif context == 'validation': - loss = self.valid_loss_monitor.average_per_epoch[-1] + monitors = self.validation_monitors else: - raise ValueError("Unexpected context.") - logger.info(" Mean loss for this epoch: {}".format(loss)) + raise ValueError("Unexpected context ({}). Expecting " + "training or validation.") + + logs = [] + for monitor in monitors: + if isinstance(monitor, BatchHistoryMonitor): + value = monitor.average_per_epoch[-1] + elif isinstance(monitor, TimeMonitor): + value = monitor.epoch_durations[-1] + else: + continue + logger.info(" Mean {} for this epoch: {}" + .format(monitor.name, value)) + logs.append((value, monitor.name)) if self.comet_exp: + # Comet context: will add train_(loss) or valid_(loss) to the + # monitors name in comet. if context == 'training': comet_context = self.comet_exp.train - self._update_gradnorm_logs_after_epoch(comet_context, epoch) else: # context == 'validation': comet_context = self.comet_exp.validate @@ -865,18 +910,15 @@ def _update_comet_after_epoch(self, context: str, epoch: int): # Cheating. To have a correct plotting per epoch (no step) # using step = epoch. In comet_ml, it is intended to be # step = batch. - self.comet_exp.log_metric("loss_per_epoch", loss, epoch=0, - step=epoch) - - def _update_gradnorm_logs_after_epoch(self, comet_context, epoch: int): - if self.comet_exp: - with comet_context(): - self.comet_exp.log_metric( - "mean_gradient_norm_per_epoch", - self.grad_norm_monitor.average_per_epoch[epoch], - epoch=None, step=epoch) + for log in logs: + self.comet_exp.log_metric( + log[1], log[0], epoch=0, step=epoch) def _save_best_model(self): + """ + Saves the current state of the model in the best_model folder. + Saves the loss to a json folder. + """ logger.info(" Best epoch yet! Saving model and loss history.") # Save model @@ -897,25 +939,18 @@ def _save_best_model(self): def run_one_batch(self, data): """ - Run a batch of data through the model (calling its forward method) - and return the mean loss. If training, run the backward method too. + Runs a batch of data through the model (calling its forward method) + and returns the mean loss. Parameters ---------- data : tuple of (List[StatefulTractogram], dict) - This is the output of the AbstractBatchLoader's - load_batch_streamlines() method. data is a tuple + Output of the batch loader's collate_fn. + With our basic BatchLoader class, data is a tuple - batch_sfts: one sft per subject - final_streamline_ids_per_subj: the dict of streamlines ids from the list of all streamlines (if we concatenate all sfts' streamlines) - - Returns - ------- - mean_loss : float - The mean loss of the provided batch. - n: int - Total number of points for this batch. """ raise NotImplementedError @@ -930,31 +965,19 @@ def fix_parameters(self): """ pass - - def _save_log_locally(self, array: np.ndarray, fname: str): - np.save(os.path.join(self.log_dir, fname), array) - - @staticmethod - def load_params_from_checkpoint(experiments_path: str, experiment_name: str): - total_path = os.path.join( - experiments_path, experiment_name, "checkpoint", - "checkpoint_state.pkl") - if not os.path.isfile(total_path): - raise FileNotFoundError('Checkpoint was not found! ({})' - .format(total_path)) - checkpoint_state = torch.load(total_path) - - return checkpoint_state - @staticmethod def check_stopping_cause(checkpoint_state, new_patience=None, new_max_epochs=None): - + """ + This method should be used before starting the training. Verifies that + it makes sense to continue training based on number of epochs and + patience. + """ current_epoch = checkpoint_state['current_states']['current_epoch'] # 1. Check if early stopping had been triggered. best_monitor_state = \ - checkpoint_state['current_states']['best_epoch_monitoring_state'] + checkpoint_state['current_states']['best_epoch_monitor_state'] bad_epochs = best_monitor_state['n_bad_epochs'] if new_patience is None: # No new patience: checking if early stopping had been triggered. @@ -994,7 +1017,7 @@ def check_stopping_cause(checkpoint_state, new_patience=None, class DWIMLTrainerOneInput(DWIMLAbstractTrainer): batch_loader: DWIMLBatchLoaderOneInput - def run_one_batch(self, data): + def run_one_batch(self, data, average_results=True): """ Run a batch of data through the model (calling its forward method) and return the mean loss. If training, run the backward method too. @@ -1008,7 +1031,9 @@ def run_one_batch(self, data): - batch_sfts: one sft per subject - final_streamline_ids_per_subj: the dict of streamlines ids from the list of all streamlines (if we concatenate all sfts' - streamlines) + streamlines). + average_results: bool + If true, returns the averaged loss (as defined by the model). Returns ------- @@ -1057,12 +1082,14 @@ def run_one_batch(self, data): logger.debug('*** Computing loss') if self.model.loss_uses_streamlines: - mean_loss, n = self.model.compute_loss(model_outputs, targets) + results = self.model.compute_loss(model_outputs, targets, + average_results=average_results) else: - mean_loss, n = self.model.compute_loss(model_outputs) + results = self.model.compute_loss(model_outputs, + average_results=average_results) if self.use_gpu: log_gpu_memory_usage(logger) # The mean tensor is a single value. Converting to float using item(). - return mean_loss, n + return results diff --git a/dwi_ml/training/utils/batch_loaders.py b/dwi_ml/training/utils/batch_loaders.py index 3f0d8e7d..4a9652b3 100644 --- a/dwi_ml/training/utils/batch_loaders.py +++ b/dwi_ml/training/utils/batch_loaders.py @@ -4,7 +4,8 @@ from dwi_ml.experiment_utils.prints import format_dict_to_str from dwi_ml.experiment_utils.timer import Timer -from dwi_ml.training.batch_loaders import DWIMLBatchLoaderOneInput +from dwi_ml.training.with_generation.batch_loader import \ + DWIMLBatchLoaderWithConnectivity def add_args_batch_loader(p: argparse.ArgumentParser): @@ -37,7 +38,7 @@ def add_args_batch_loader(p: argparse.ArgumentParser): def prepare_batch_loader(dataset, model, args, sub_loggers_level): # Preparing the batch loader. with Timer("\nPreparing batch loader...", newline=True, color='pink'): - batch_loader = DWIMLBatchLoaderOneInput( + batch_loader = DWIMLBatchLoaderWithConnectivity( dataset=dataset, model=model, input_group_name=args.input_group_name, streamline_group_name=args.streamline_group_name, diff --git a/dwi_ml/training/utils/monitoring.py b/dwi_ml/training/utils/monitoring.py index 44c6cf2a..89fd9eac 100644 --- a/dwi_ml/training/utils/monitoring.py +++ b/dwi_ml/training/utils/monitoring.py @@ -8,7 +8,8 @@ class TimeMonitor(object): - def __init__(self): + def __init__(self, name): + self.name = name self.epoch_durations = [] self._start_time = None @@ -49,7 +50,8 @@ class BatchHistoryMonitor(object): loss_monitor.epochs_means # returns the loss curve as a list """ - def __init__(self, weighted: bool = False): + def __init__(self, name, weighted: bool = False): + self.name = name self.is_weighted = weighted # State: @@ -58,7 +60,7 @@ def __init__(self, weighted: bool = False): self.average_per_epoch = [] self.current_epoch = -1 - def update(self, value, weight=None): + def update(self, value, weight=1): """ Note. Does not save the update if value is inf. @@ -127,7 +129,7 @@ class BestEpochMonitor(object): number of epochs ("patience"). """ - def __init__(self, patience: int, patience_delta: float = 1e-6): + def __init__(self, name, patience: int, patience_delta: float = 1e-6): """ Parameters ----------- @@ -137,6 +139,7 @@ def __init__(self, patience: int, patience_delta: float = 1e-6): Precision term to define what we consider as "improving": when the loss is at least min_eps smaller than the previous best loss. """ + self.name = name self.patience = patience self.min_eps = patience_delta diff --git a/dwi_ml/training/utils/trainer.py b/dwi_ml/training/utils/trainer.py index 1eb30b67..78ed94a0 100644 --- a/dwi_ml/training/utils/trainer.py +++ b/dwi_ml/training/utils/trainer.py @@ -54,7 +54,13 @@ def add_training_args(p: argparse.ArgumentParser, training_group.add_argument( '--tracking_phase_frequency', type=int, default=5) training_group.add_argument( - '--tracking_mask') + '--tracking_mask', + help="Volume group to use as tracking mask during the generation " + "phase.") + training_group.add_argument( + '--tracking_phase_nb_steps_init', type=int, default=5, + help="Number of segments copied from the 'real' streamlines " + "before starting propagation during generation phases.") comet_g = p.add_argument_group("Comet") comet_g.add_argument( diff --git a/dwi_ml/training/with_generation/__init__.py b/dwi_ml/training/with_generation/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/dwi_ml/training/with_generation/batch_loader.py b/dwi_ml/training/with_generation/batch_loader.py new file mode 100644 index 00000000..b1e69b5a --- /dev/null +++ b/dwi_ml/training/with_generation/batch_loader.py @@ -0,0 +1,42 @@ +# -*- coding: utf-8 -*- +from typing import List, Dict + +import torch + +from dwi_ml.training.batch_loaders import DWIMLBatchLoaderOneInput + + +class DWIMLBatchLoaderWithConnectivity(DWIMLBatchLoaderOneInput): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.data_contains_connectivity = \ + self.dataset.streamlines_contain_connectivity[self.streamline_group_idx] + + def load_batch_connectivity_matrices( + self, streamline_ids_per_subj: Dict[int, slice]): + if not self.data_contains_connectivity: + raise ValueError("No connectivity matrix in this dataset.") + + # The batch's streamline ids will change throughout processing because + # of data augmentation, so we need to do it subject by subject to + # keep track of the streamline ids. These final ids will correspond to + # the loaded, processed streamlines, not to the ids in the hdf5 file. + subjs = list(streamline_ids_per_subj.keys()) + nb_subjs = len(subjs) + matrices = [None] * nb_subjs + volume_sizes = [None] * nb_subjs + downsampled_sizes = [None] * nb_subjs + for i, subj in enumerate(subjs): + # No cache for the sft data. Accessing it directly. + # Note: If this is used through the dataloader, multiprocessing + # is used. Each process will open a handle. + subj_data = \ + self.context_subset.subjs_data_list.get_subj_with_handle(subj) + subj_sft_data = subj_data.sft_data_list[self.streamline_group_idx] + + # We could access it only at required index, maybe. Loading the + # whole matrix here. + matrices[i], volume_sizes[i], downsampled_sizes[i] = \ + subj_sft_data.connectivity_matrix_and_info() + + return matrices, volume_sizes, downsampled_sizes diff --git a/dwi_ml/training/with_generation/trainer.py b/dwi_ml/training/with_generation/trainer.py new file mode 100644 index 00000000..7f3d5e80 --- /dev/null +++ b/dwi_ml/training/with_generation/trainer.py @@ -0,0 +1,395 @@ +# -*- coding: utf-8 -*- +""" +Adds a tracking step to verify the generation process. Metrics on the +streamlines are: + +- Very good / acceptable / very far IS threshold: + Percentage of streamlines ending inside a radius of 15 / 25 / 40 voxels of + the expected endpoint. This metric has the drawback that streamlines + following a correct path different from the "true" validation streamline + contribute negatively to the metric. +- 'diverg': + The point where the streamline becomes significantly far (i.e. > 25 voxels) + from the "true" path. Values range between 100 (100% bad, i.e. diverging + from the start) to 0 (0% bad; ended correclty). If the generated streamline + is longer than the "true" one, values range between 0 (0% bad) and infinit + (ex: 100% = went 100% too far before becoming far from the expected point. + I.e. the generated streamline is at least twice as long as expected). Same + drawback as above. +- Mean distance from expected endpoint: + In voxel space. Same drawback as above. Also, a single bad streamline may + contribute intensively to the score. +- Idem, clipped. + Distances are clipped at 25. We consider that bad streamlines are bad, no + matter if they end up near or far. +- Connectivity fit: + Percentage of streamlines ending in a block of the volume indeed connected + in the validation subject. Real connectivity matrices must be saved in the + hdf5. Right now, volumes are simply downsampled (the same way as in the + hdf5, ex, to 10x10x10 volumes for a total of 1000 blocks), not based on + anatomical ROIs. It has the advantage that it does not rely on the quality + of segmentation. It had the drawback that a generated streamline ending + very close to the "true" streamline, but in another block, if the + expected endpoint is close to the border of the block, contributes + negatively to the metric. It does not however have the drawback of other + metrics stated before. +""" +import logging +from typing import List + +import h5py +import numpy as np +import torch +from torch.nn import PairwiseDistance + +from dwi_ml.data.processing.streamlines.post_processing import \ + compute_triu_connectivity +from dwi_ml.models.main_models import ModelWithDirectionGetter +from dwi_ml.tracking.propagation import propagate_multiple_lines +from dwi_ml.tracking.projects.utils import prepare_tracking_mask +from dwi_ml.training.trainers import DWIMLTrainerOneInput +from dwi_ml.training.utils.monitoring import BatchHistoryMonitor +from dwi_ml.training.with_generation.batch_loader import \ + DWIMLBatchLoaderWithConnectivity + +logger = logging.getLogger('train_logger') + +# Emma tests in ISMRM: sphere of 50 mm of diameter (in MI-Brain, imagining a +# sphere encapsulated in a cube box of 50x50x50) around any point seems to +# englobe mostly acceptable streamlines. +# So half of it, a ray of 25 mm seems ok. +VERY_CLOSE_THRESHOLD = 15.0 +ACCEPTABLE_THRESHOLD = 25.0 +VERY_FAR_THRESHOLD = 40.0 + + +class DWIMLTrainerForTrackingOneInput(DWIMLTrainerOneInput): + model: ModelWithDirectionGetter + batch_loader: DWIMLBatchLoaderWithConnectivity + + def __init__(self, add_a_tracking_validation_phase: bool = False, + tracking_phase_frequency: int = 1, + tracking_phase_nb_steps_init: int = 5, + tracking_phase_mask_group: str = None, *args, **kw): + """ + Parameters + ---------- + add_a_tracking_validation_phase: bool + If true, the validation phase is extended with a generation (i.e. + tracking) step: the first N points of the validation streamlines + are kept as is, and streamlines are propagated through tractography + until they get out of the mask, or until the EOS criteria is + reached (if any) (threshold = 0.5). + In current implementation, the metric defining the best model is + the connectivity metric. + tracking_phase_frequency: int + There is the possibility to compute this additional step only every + X epochs. + tracking_phase_nb_steps_init: int + Number of initial points to keep in the validation step. Adding + enough should ensure that the generated streamlines go in the same + direction as the "true" validation streamline to generate good + metrics. + tracking_phase_mask_group: str + Name of the volume group to use as tracking mask. + """ + super().__init__(*args, **kw) + + self.add_a_tracking_validation_phase = add_a_tracking_validation_phase + self.tracking_phase_frequency = tracking_phase_frequency + self.tracking_phase_nb_steps_init = tracking_phase_nb_steps_init + self.tracking_mask_group = tracking_phase_mask_group + + self.compute_connectivity = self.batch_loader.data_contains_connectivity + + # -------- Monitors + # At training time: only the one metric used for training. + # At validation time: A lot of exploratory metrics monitors. + + # Percentage of streamlines inside a radius + self.tracking_very_good_IS_monitor = BatchHistoryMonitor( + 'tracking_very_good_IS_monitor', weighted=True) + self.tracking_acceptable_IS_monitor = BatchHistoryMonitor( + 'tracking_acceptable_IS_monitor', weighted=True) + self.tracking_very_far_IS_monitor = BatchHistoryMonitor( + 'tracking_very_far_IS_monitor', weighted=True) + + # Point where the streamline starts diverging from "acceptable" + self.tracking_valid_diverg_monitor = BatchHistoryMonitor( + 'tracking_valid_diverg_monitor', weighted=True) + + # Final distance from expected point + self.tracking_mean_final_distance_monitor = BatchHistoryMonitor( + 'tracking_mean_final_distance_monitor', weighted=True) + self.tracking_clipped_final_distance_monitor = BatchHistoryMonitor( + 'tracking_clipped_final_distance_monitor', weighted=True) + + # Connectivity matrix accordance + self.tracking_connectivity_score_monitor = BatchHistoryMonitor( + 'tracking_connectivity_score_monitor', weighted=True) + + if self.add_a_tracking_validation_phase: + new_monitors = [self.tracking_very_good_IS_monitor, + self.tracking_acceptable_IS_monitor, + self.tracking_very_far_IS_monitor, + self.tracking_valid_diverg_monitor, + self.tracking_mean_final_distance_monitor, + self.tracking_clipped_final_distance_monitor, + self.tracking_connectivity_score_monitor] + self.monitors += new_monitors + self.validation_monitors += new_monitors + + @property + def params_for_checkpoint(self): + p = super().params_for_checkpoint + p.update({ + 'add_a_tracking_validation_phase': self.add_a_tracking_validation_phase, + 'tracking_phase_frequency': self.tracking_phase_frequency, + 'tracking_phase_nb_steps_init': self.tracking_phase_nb_steps_init, + 'tracking_phase_mask_group': self.tracking_mask_group, + }) + + return p + + def _get_latest_loss_to_supervise_best(self): + """Using the connectivity score, if available.""" + if (self.use_validation and self.add_a_tracking_validation_phase and + self.compute_connectivity): + # Choosing connectivity. + mean_epoch_loss = \ + self.tracking_connectivity_score_monitor.average_per_epoch[-1] + return mean_epoch_loss + else: + return super()._get_latest_loss_to_supervise_best() + + def validate_one_batch(self, data, epoch): + # 1. Compute the local loss as usual. + super().validate_one_batch(data, epoch) + + # 2. Compute generation losses. + if self.add_a_tracking_validation_phase: + if (epoch + 1) % self.tracking_phase_frequency == 0: + logger.debug("Additional tracking-like generation validation " + "from batch.") + (gen_n, mean_final_dist, mean_clipped_final_dist, + percent_IS_very_good, percent_IS_acceptable, percent_IS_very_far, + diverging_pnt, connectivity) = self.validation_generation_one_batch( + data, compute_all_scores=True) + + self.tracking_very_good_IS_monitor.update( + percent_IS_very_good, weight=gen_n) + self.tracking_acceptable_IS_monitor.update( + percent_IS_acceptable, weight=gen_n) + self.tracking_very_far_IS_monitor.update( + percent_IS_very_far, weight=gen_n) + + self.tracking_mean_final_distance_monitor.update( + mean_final_dist, weight=gen_n) + self.tracking_clipped_final_distance_monitor.update( + mean_clipped_final_dist, weight=gen_n) + self.tracking_valid_diverg_monitor.update( + diverging_pnt, weight=gen_n) + + self.tracking_connectivity_score_monitor.update( + connectivity, weight=gen_n) + elif len(self.tracking_mean_final_distance_monitor.average_per_epoch) == 0: + logger.info("Skipping tracking-like generation validation from " + "batch. No values yet: adding fake initial values.") + # Fake values at the beginning + # Bad IS = 100% + self.tracking_very_good_IS_monitor.update(100.0) + self.tracking_acceptable_IS_monitor.update(100.0) + self.tracking_very_far_IS_monitor.update(100.0) + + # Bad diverging = very far from 0. Either 100% (if diverged at + # first point) or anything >0 if diverged further than expected + # point. + self.tracking_valid_diverg_monitor.update(100.0) + + # Bad mean dist = very far. ex, 100, or clipped. + self.tracking_mean_final_distance_monitor.update(100.0) + self.tracking_clipped_final_distance_monitor.update( + ACCEPTABLE_THRESHOLD) + + self.tracking_connectivity_score_monitor.update(1) + else: + logger.info("Skipping tracking-like generation validation from " + "batch. Copying previous epoch's values.") + # Copy previous value + for monitor in [self.tracking_very_good_IS_monitor, + self.tracking_acceptable_IS_monitor, + self.tracking_very_far_IS_monitor, + self.tracking_valid_diverg_monitor, + self.tracking_mean_final_distance_monitor, + self.tracking_clipped_final_distance_monitor, + self.tracking_connectivity_score_monitor]: + monitor.update(monitor.average_per_epoch[-1]) + + def validation_generation_one_batch(self, data, compute_all_scores=False): + """ + Use tractography to generate streamlines starting from the "true" + seeds and first few segments. Expected results are the batch's + validation streamlines. + """ + real_lines, ids_per_subj = data + + # Possibly sending again to GPU even if done in the local loss + # computation, but easier with current implementation. + real_lines = [line.to(self.device, non_blocking=True, dtype=torch.float) + for line in real_lines] + last_pos = torch.vstack([line[-1, :] for line in real_lines]) + + # Starting from the n first points + lines = [s[0:min(len(s), self.tracking_phase_nb_steps_init), :] + for s in real_lines] + + # Propagation: no backward tracking. + previous_context = self.model.context + self.model.set_context('tracking') + lines = self.propagate_multiple_lines(lines, ids_per_subj) + self.model.set_context(previous_context) + + # 1. Final distance compared to expected point. + computed_last_pos = torch.vstack([line[-1, :] for line in lines]) + l2_loss = PairwiseDistance(p=2) + final_dist = l2_loss(computed_last_pos, last_pos) + + if not compute_all_scores: + return final_dist + else: + # 1. (bis) Also clipping final dist + final_dist_clipped = torch.clip(final_dist, min=None, + max=ACCEPTABLE_THRESHOLD) + final_dist_clipped = torch.mean(final_dist_clipped) + + # 2. Connectivity scores, if available (else None) + connectivity_score = self._compare_connectivity(lines, ids_per_subj) + + # 3. "IS ratio", i.e. percentage of streamlines ending inside a + # predefined radius. + invalid_ratio_severe = torch.sum( + final_dist > VERY_CLOSE_THRESHOLD) / len(lines) * 100 + invalid_ratio_acceptable = torch.sum( + final_dist > ACCEPTABLE_THRESHOLD) / len(lines) * 100 + invalid_ratio_loose = torch.sum( + final_dist > VERY_FAR_THRESHOLD) / len(lines) * 100 + final_dist = torch.mean(final_dist) + + # 4. Verify point where streamline starts diverging. + # abs(100 - score): 0 = good. 100 = bad (either abs(100) -> diverged + # at first point or abs(-100) = diverged after twice the expected + # length. + total_point = 0 + for line, real_line in zip(lines, real_lines): + expected_nb = len(real_line) + diff_nb = abs(len(real_line) - len(line)) + if len(line) < expected_nb: + diff_nb = len(real_line) - len(line) + line = torch.vstack((line, line[-1, :].repeat(diff_nb, 1))) + elif len(line) > expected_nb: + real_line = torch.vstack( + (real_line, real_line[-1, :].repeat(diff_nb, 1))) + dist = l2_loss(line, real_line).cpu().numpy() + point, = np.where(dist > ACCEPTABLE_THRESHOLD) + if len(point) > 0: # (else: score = 0. Never out of range). + div_point = point[0] / expected_nb * 100.0 + total_point += abs(100 - div_point) + diverging_point = total_point / len(lines) + + invalid_ratio_severe = invalid_ratio_severe.cpu().numpy().astype(np.float32) + invalid_ratio_acceptable = invalid_ratio_acceptable.cpu().numpy().astype(np.float32) + invalid_ratio_loose = invalid_ratio_loose.cpu().numpy().astype(np.float32) + final_dist = final_dist.cpu().numpy().astype(np.float32) + final_dist_clipped = final_dist_clipped.cpu().numpy().astype(np.float32) + diverging_point = np.asarray(diverging_point, dtype=np.float32) + return (len(lines), final_dist, final_dist_clipped, + invalid_ratio_severe, invalid_ratio_acceptable, + invalid_ratio_loose, diverging_point, + connectivity_score) + + def _compare_connectivity(self, lines, ids_per_subj): + """ + If available, computes connectivity matrices for the batch and + compares with expected values for the subject. + """ + if self.compute_connectivity: + connectivity_matrices, volume_sizes, downsampled_sizes = \ + self.batch_loader.load_batch_connectivity_matrices(ids_per_subj) + + score = 0.0 + for i, subj in enumerate(ids_per_subj.keys()): + real_matrix = connectivity_matrices[i] + volume_size = volume_sizes[i] + downsampled_size = downsampled_sizes[i] + _lines = lines[ids_per_subj[subj]] + + batch_matrix = compute_triu_connectivity( + _lines, volume_size, downsampled_size, + binary=False, to_sparse_tensor=False, device=self.device) + + # Where our batch has a 1, if there was really a one: score should + # be 0. Else, score should be 1. + # If two streamlines in a voxel, score is 0 or 2. + + # Real matrices are saved as binary in create_hdf5. + where_one = np.where(batch_matrix > 0) + score += np.sum(batch_matrix[where_one] * + (1.0 - real_matrix[where_one])) + + # Average for batch + score = score / len(lines) + else: + score = None + return score + + def propagate_multiple_lines(self, lines: List[torch.Tensor], ids_per_subj): + """ + Tractography propagation of 'lines'. + """ + assert self.model.step_size is not None, \ + "We can't propagate compressed streamlines." + + def update_memory_after_removing_lines(_, __): + pass + + def get_dirs_at_last_pos(_lines: List[torch.Tensor], n_last_pos): + n_last_pos = [pos[None, :] for pos in n_last_pos] + batch_inputs = self.batch_loader.load_batch_inputs( + n_last_pos, ids_per_subj) + + if self.model.forward_uses_streamlines: + model_outputs = self.model(batch_inputs, n_last_pos) + else: + model_outputs = self.model(batch_inputs) + + next_dirs = self.model.get_tracking_directions( + model_outputs, algo='det', eos_stopping_thresh=0.5) + return next_dirs + + theta = 2 * np.pi # theta = 360 degrees + max_nbr_pts = int(200 / self.model.step_size) + + # Looping on subjects because current implementation requires a single + # tracking mask. But all the rest (get_dirs_at_last_pos, particularly) + # work on multiple subjects because the batch loader loads input + # according to subject id. Could refactor "propagate_multiple_line" to + # accept multiple masks or manage it differently. + final_lines = [] + for subj_idx, line_idx in ids_per_subj.items(): + with h5py.File(self.batch_loader.dataset.hdf5_file, 'r') as hdf_handle: + subj_id = self.batch_loader.context_subset.subjects[subj_idx] + logging.debug("Loading subj {} ({})'s tracking mask." + .format(subj_idx, subj_id)) + tracking_mask, _ = prepare_tracking_mask( + hdf_handle, self.tracking_mask_group, subj_id=subj_id, + mask_interp='nearest') + tracking_mask.move_to(self.device) + + final_lines.extend(propagate_multiple_lines( + lines[line_idx], update_memory_after_removing_lines, + get_dirs_at_last_pos, theta=theta, + step_size=self.model.step_size, verify_opposite_direction=False, + mask=tracking_mask, max_nbr_pts=max_nbr_pts, + append_last_point=False, normalize_directions=True)) + + return final_lines diff --git a/dwi_ml/unit_tests/test_connectivity_matrix.py b/dwi_ml/unit_tests/test_connectivity_matrix.py new file mode 100644 index 00000000..9fbac217 --- /dev/null +++ b/dwi_ml/unit_tests/test_connectivity_matrix.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +import numpy as np + +from dwi_ml.data.processing.streamlines.post_processing import compute_triu_connectivity + + +def test_connectivity(): + # Ex: Volume is 16 x 16 + + # Streamline starting at the lowest left side to the highest right side. + streamline = [[0, 0], [15.9, 15.9]] + streamlines = [streamline, streamline] + + # to a 4x4 matrix, should have two values from "ROI" 0 to "ROI" 15. + expected_m = np.zeros((16, 16), dtype=int) + expected_m[0, 15] = 2 + # expected_m[15, 0] = 2 ----> but triu + print("Expected connectivity matrix: {}".format(expected_m)) + + m = compute_triu_connectivity(streamlines, (16, 16), (4, 4)) + print("Got {}".format(m)) + assert np.array_equal(m, expected_m) + + m = compute_triu_connectivity(streamlines, (16, 16), (4, 4), + to_sparse_tensor=True) + m2 = m.to_dense().numpy().astype(int) + print("Converting to sparse and back to dense: {}".format(m2)) + assert np.array_equal(m2, expected_m) diff --git a/dwi_ml/unit_tests/utils/data_and_models_for_tests.py b/dwi_ml/unit_tests/utils/data_and_models_for_tests.py index 28b2cad0..bdfac4ba 100644 --- a/dwi_ml/unit_tests/utils/data_and_models_for_tests.py +++ b/dwi_ml/unit_tests/utils/data_and_models_for_tests.py @@ -140,7 +140,7 @@ def compute_loss(self, model_outputs: List[torch.Tensor], def get_tracking_directions(self, regressed_dirs, algo): if algo == 'det': - return regressed_dirs.detach() + return regressed_dirs elif algo == 'prob': raise NotImplementedError( "Our test model uses (fake) regression and does not allow " diff --git a/scripts_python/dwiml_create_hdf5_dataset.py b/scripts_python/dwiml_create_hdf5_dataset.py index 94759b27..1902548f 100644 --- a/scripts_python/dwiml_create_hdf5_dataset.py +++ b/scripts_python/dwiml_create_hdf5_dataset.py @@ -29,7 +29,7 @@ from dipy.io.stateful_tractogram import set_sft_logger_level from dwi_ml.data.hdf5.utils import ( - add_basic_args, add_mri_processing_args, add_streamline_processing_args, + add_hdf5_creation_args, add_mri_processing_args, add_streamline_processing_args, prepare_hdf5_creator) from dwi_ml.experiment_utils.timer import Timer @@ -38,7 +38,7 @@ def _parse_args(): p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawTextHelpFormatter) - add_basic_args(p) + add_hdf5_creation_args(p) add_mri_processing_args(p) add_streamline_processing_args(p) add_overwrite_arg(p) @@ -74,6 +74,13 @@ def main(): "received {}".format(ext)) assert_outputs_exist(p, args, args.out_hdf5_file) + # Default value with arparser '+' not possible. Setting manually. + if args.compute_connectivity_matrix: + if args.connectivity_downsample_size is None: + args.connectivity_downsample_size = 20 + elif len(args.connectivity_downsample_size) == 1: + args.connectivity_downsample_size = args.connectivity_downsample_size[0] + # Prepare creator and load config file. creator = prepare_hdf5_creator(args) diff --git a/scripts_python/dwiml_divide_volume_into_blocs.py b/scripts_python/dwiml_divide_volume_into_blocs.py new file mode 100644 index 00000000..c5b3b472 --- /dev/null +++ b/scripts_python/dwiml_divide_volume_into_blocs.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +import argparse + +import nibabel as nib +import numpy as np + +from scilpy.io.utils import assert_inputs_exist, assert_outputs_exist, \ + add_overwrite_arg + + +def _build_arg_parser(): + p = argparse.ArgumentParser( + description=__doc__, formatter_class=argparse.RawTextHelpFormatter) + p.add_argument('in_image', metavar='IN_FILE', + help='Input file name, in nifti format.') + + p.add_argument( + 'out', metavar='OUT_FILE', dest='out_filename', + help='name of the output file, which will be saved as a text file.') + + add_overwrite_arg(p) + + return p + + +def color_mri_connectivity_blocs(downsampled_volume_size, volume_size): + + # For tracking coordinates: we can work with float. + # Here, dividing as ints. + volume_size = np.asarray(volume_size) + downsampled_volume_size = np.asarray(downsampled_volume_size) + sizex, sizey, sizez = (volume_size / downsampled_volume_size).astype(int) + print("Coloring into blocs of size: ", sizex, sizey, sizez) + + final_volume = np.zeros(volume_size) + for i in range(downsampled_volume_size[0]): + for j in range(downsampled_volume_size[1]): + for k in range(downsampled_volume_size[2]): + final_volume[i*sizex: (i+1)*sizex, + j*sizey: (j+1)*sizey, + k*sizez: (k+1)*sizez] = i + 10*j + 100*k + + return final_volume + + +def main(): + parser = _build_arg_parser() + args = parser.parse_args() + + assert_inputs_exist(parser, args.in_image) + assert_outputs_exist(parser, args, required=args.out_filename) + + volume = nib.load(args.in_image) + final_volume = color_mri_connectivity_blocs([6, 6, 6], volume.shape) + img = nib.Nifti1Image(final_volume, volume.affine) + nib.save(img, args.out_filename) + + +if __name__ == '__main__': + main() diff --git a/scripts_python/l2t_train_from_pretrained.py b/scripts_python/l2t_train_from_pretrained.py new file mode 100644 index 00000000..5fe5bc16 --- /dev/null +++ b/scripts_python/l2t_train_from_pretrained.py @@ -0,0 +1,141 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +""" +Train a model for Learn2Track +""" +import argparse +import logging +import os + +# comet_ml not used, but comet_ml requires to be imported before torch. +# See bug report here https://github.com/Lightning-AI/lightning/issues/5829 +# Importing now to solve issues later. +import comet_ml +import torch + +from scilpy.io.utils import assert_inputs_exist, assert_outputs_exist + +from dwi_ml.data.dataset.utils import prepare_multisubjectdataset +from dwi_ml.experiment_utils.prints import format_dict_to_str +from dwi_ml.experiment_utils.timer import Timer +from dwi_ml.io_utils import add_logging_arg, add_memory_args +from dwi_ml.models.projects.learn2track_model import Learn2TrackModel +from dwi_ml.training.projects.learn2track_trainer import Learn2TrackTrainer +from dwi_ml.training.utils.batch_samplers import (add_args_batch_sampler, + prepare_batch_sampler) +from dwi_ml.training.utils.batch_loaders import (add_args_batch_loader, + prepare_batch_loader) +from dwi_ml.training.utils.experiment import ( + add_mandatory_args_training_experiment) +from dwi_ml.training.utils.trainer import run_experiment, add_training_args, \ + format_lr + + +def prepare_arg_parser(): + p = argparse.ArgumentParser(description=__doc__, + formatter_class=argparse.RawTextHelpFormatter) + add_mandatory_args_training_experiment(p) + p.add_argument('pretrained_model', + help="Name of the pretrained experiment (from the same " + "experiments path) from which to load the model. " + "Should contain a 'best_model' folder with pickle " + "information to load the model") + add_args_batch_sampler(p) + add_args_batch_loader(p) + training_group = add_training_args(p, add_a_tracking_validation_phase=True) + add_memory_args(p, add_lazy_options=True, add_rng=True) + add_logging_arg(p) + + # Additional arg for projects + training_group.add_argument( + '--clip_grad', type=float, default=None, + help="Value to which the gradient norms to avoid exploding gradients." + "\nDefault = None (not clipping).") + + return p + + +def init_from_args(args, sub_loggers_level): + torch.manual_seed(args.rng) # Set torch seed + + # Prepare the dataset + dataset = prepare_multisubjectdataset(args, load_testing=False, + log_level=sub_loggers_level) + + # Loading an existing model + logging.info("Loading existing model") + best_model_path = os.path.join(args.experiments_path, + args.pretrained_model, 'best_model') + model = Learn2TrackModel.load_params_and_state( + best_model_path, sub_loggers_level) + + # Preparing the batch samplers + batch_sampler = prepare_batch_sampler(dataset, args, sub_loggers_level) + batch_loader = prepare_batch_loader(dataset, model, args, sub_loggers_level) + + # Instantiate trainer + with Timer("\n\nPreparing trainer", newline=True, color='red'): + lr = format_lr(args.learning_rate) + trainer = Learn2TrackTrainer( + model=model, experiments_path=args.experiments_path, + experiment_name=args.experiment_name, batch_sampler=batch_sampler, + batch_loader=batch_loader, + # COMET + comet_project=args.comet_project, + comet_workspace=args.comet_workspace, + # TRAINING + learning_rates=lr, weight_decay=args.weight_decay, + optimizer=args.optimizer, max_epochs=args.max_epochs, + max_batches_per_epoch_training=args.max_batches_per_epoch_training, + max_batches_per_epoch_validation=args.max_batches_per_epoch_validation, + patience=args.patience, patience_delta=args.patience_delta, + from_checkpoint=False, clip_grad=args.clip_grad, + # (generation validation:) + add_a_tracking_validation_phase=args.add_a_tracking_validation_phase, + tracking_phase_frequency=args.tracking_phase_frequency, + tracking_phase_nb_steps_init=args.tracking_phase_nb_steps_init, + tracking_phase_mask_group=args.tracking_mask, + # MEMORY + nb_cpu_processes=args.nbr_processes, use_gpu=args.use_gpu, + log_level=args.logging) + logging.info("Trainer params : " + + format_dict_to_str(trainer.params_for_checkpoint)) + + return trainer + + +def main(): + p = prepare_arg_parser() + args = p.parse_args() + + # Setting log level to INFO maximum for sub-loggers, else it becomes ugly, + # but we will set trainer to user-defined level. + sub_loggers_level = args.logging + if args.logging == 'DEBUG': + sub_loggers_level = 'INFO' + + logging.getLogger().setLevel(level=logging.INFO) + + # Check that all files exist + assert_inputs_exist(p, [args.hdf5_file]) + assert_outputs_exist(p, args, args.experiments_path) + + # Verify if a checkpoint has been saved. Else create an experiment. + if os.path.exists(os.path.join(args.experiments_path, args.experiment_name, + "checkpoint")): + raise FileExistsError("This experiment already exists. Delete or use " + "script l2t_resume_training_from_checkpoint.py.") + + trainer = init_from_args(args, sub_loggers_level) + + # Supervising that we loaded everything correctly. + print("Validation 0 = Initial verification: pre-trained results!") + trainer.validate_one_epoch(-1) + + print("Now starting training") + run_experiment(trainer) + + +if __name__ == '__main__': + main() diff --git a/scripts_python/l2t_train_model.py b/scripts_python/l2t_train_model.py index 38cc1215..184f28a6 100755 --- a/scripts_python/l2t_train_model.py +++ b/scripts_python/l2t_train_model.py @@ -125,7 +125,7 @@ def init_from_args(args, sub_loggers_level): # (generation validation:) add_a_tracking_validation_phase=args.add_a_tracking_validation_phase, tracking_phase_frequency=args.tracking_phase_frequency, - tracking_phase_nb_steps_init=5, # args.tracking_phase_nb_steps_init + tracking_phase_nb_steps_init=args.tracking_phase_nb_steps_init, tracking_phase_mask_group=args.tracking_mask, # MEMORY nb_cpu_processes=args.nbr_processes, use_gpu=args.use_gpu, diff --git a/scripts_python/tto_train_model.py b/scripts_python/tto_train_model.py index b3fdb266..96d1eee8 100755 --- a/scripts_python/tto_train_model.py +++ b/scripts_python/tto_train_model.py @@ -43,7 +43,7 @@ def prepare_arg_parser(): add_logging_arg(p) add_args_batch_sampler(p) add_args_batch_loader(p) - add_training_args(p) + add_training_args(p, add_a_tracking_validation_phase=True) # Specific to Transformers: gt = add_abstract_model_args(p) @@ -102,8 +102,9 @@ def init_from_args(args, sub_loggers_level): with Timer("\n\nPreparing trainer", newline=True, color='red'): lr = format_lr(args.learning_rate) trainer = TransformerTrainer( - model, args.experiments_path, args.experiment_name, - batch_sampler, batch_loader, + model=model, experiments_path=args.experiments_path, + experiment_name=args.experiment_name, batch_sampler=batch_sampler, + batch_loader=batch_loader, # COMET comet_project=args.comet_project, comet_workspace=args.comet_workspace, @@ -114,6 +115,11 @@ def init_from_args(args, sub_loggers_level): max_batches_per_epoch_validation=args.max_batches_per_epoch_validation, patience=args.patience, patience_delta=args.patience_delta, from_checkpoint=False, + # (generation validation:) + add_a_tracking_validation_phase=args.add_a_tracking_validation_phase, + tracking_phase_frequency=args.tracking_phase_frequency, + tracking_phase_nb_steps_init=args.tracking_phase_nb_steps_init, + tracking_phase_mask_group=args.tracking_mask, # MEMORY nb_cpu_processes=args.nbr_processes, use_gpu=args.use_gpu, log_level=args.logging) diff --git a/scripts_python/ttst_train_model.py b/scripts_python/ttst_train_model.py index d917e837..8a8a9227 100755 --- a/scripts_python/ttst_train_model.py +++ b/scripts_python/ttst_train_model.py @@ -48,7 +48,7 @@ def prepare_arg_parser(): add_logging_arg(p) add_args_batch_sampler(p) add_args_batch_loader(p) - add_training_args(p) + add_training_args(p, add_a_tracking_validation_phase=True) # Specific to Transformers: gt = add_abstract_model_args(p) @@ -110,8 +110,9 @@ def init_from_args(args, sub_loggers_level): with Timer("\n\nPreparing trainer", newline=True, color='red'): lr = format_lr(args.learning_rate) trainer = TransformerTrainer( - model, args.experiments_path, args.experiment_name, - batch_sampler, batch_loader, + model=model, experiments_path=args.experiments_path, + experiment_name=args.experiment_name, batch_sampler=batch_sampler, + batch_loader=batch_loader, # COMET comet_project=args.comet_project, comet_workspace=args.comet_workspace, @@ -122,6 +123,11 @@ def init_from_args(args, sub_loggers_level): max_batches_per_epoch_validation=args.max_batches_per_epoch_validation, patience=args.patience, patience_delta=args.patience_delta, from_checkpoint=False, + # (generation validation:) + add_a_tracking_validation_phase=args.add_a_tracking_validation_phase, + tracking_phase_frequency=args.tracking_phase_frequency, + tracking_phase_nb_steps_init=args.tracking_phase_nb_steps_init, + tracking_phase_mask_group=args.tracking_mask, # MEMORY nb_cpu_processes=args.nbr_processes, use_gpu=args.use_gpu, log_level=args.logging)