From 550ae33905b08783cdf36fe66b9b8f0bbf26df2f Mon Sep 17 00:00:00 2001 From: arnaudbore Date: Wed, 25 Sep 2024 12:09:22 -0400 Subject: [PATCH] add the possibility to resample with nb points per streamline --- dwi_ml/data/hdf5/hdf5_creation.py | 43 +++++++++++-------- dwi_ml/data/hdf5/utils.py | 13 +++--- .../streamlines/data_augmentation.py | 25 ++++++++--- dwi_ml/io_utils.py | 11 +++-- scripts_python/dwiml_create_hdf5_dataset.py | 5 ++- 5 files changed, 64 insertions(+), 33 deletions(-) diff --git a/dwi_ml/data/hdf5/hdf5_creation.py b/dwi_ml/data/hdf5/hdf5_creation.py index 74715f1e..82414c52 100644 --- a/dwi_ml/data/hdf5/hdf5_creation.py +++ b/dwi_ml/data/hdf5/hdf5_creation.py @@ -136,7 +136,10 @@ class HDF5Creator: def __init__(self, root_folder: Path, out_hdf_filename: Path, training_subjs: List[str], validation_subjs: List[str], testing_subjs: List[str], groups_config: dict, - step_size: float = None, compress: float = None, + step_size: float = None, + nb_points: int = None, + compress: float = None, + remove_invalid: bool = False, enforce_files_presence: bool = True, save_intermediate: bool = False, intermediate_folder: Path = None): @@ -156,8 +159,12 @@ def __init__(self, root_folder: Path, out_hdf_filename: Path, Information from json file loaded as a dict. step_size: float Step size to resample streamlines. Default: None. + nb_points: int + Number of points per streamline. Default: None. compress: float Compress streamlines. Default: None. + remove_invalid: bool + Remove invalid streamline. Default: False enforce_files_presence: bool If true, will stop if some files are not available for a subject. Default: True. @@ -175,7 +182,9 @@ def __init__(self, root_folder: Path, out_hdf_filename: Path, self.testing_subjs = testing_subjs self.groups_config = groups_config self.step_size = step_size + self.nb_points = nb_points self.compress = compress + self.remove_invalid = remove_invalid # Optional self.save_intermediate = save_intermediate @@ -359,6 +368,8 @@ def create_database(self): hdf_handle.attrs['testing_subjs'] = self.testing_subjs hdf_handle.attrs['step_size'] = self.step_size if \ self.step_size is not None else 'Not defined by user' + hdf_handle.attrs['nb_points'] = self.nb_points if \ + self.nb_points is not None else 'Not defined by user' hdf_handle.attrs['compress'] = self.compress if \ self.compress is not None else 'Not defined by user' @@ -632,6 +643,8 @@ def _process_one_streamline_group( Reference used to load and send the streamlines in voxel space and to create final merged SFT. If the file is a .trk, 'same' is used instead. + remove_invalid : bool + If True, invalid streamlines will be removed Returns ------- @@ -641,11 +654,10 @@ def _process_one_streamline_group( The Euclidean length of each streamline """ tractograms = self.groups_config[group]['files'] - if self.step_size and self.compress: raise ValueError( "Only one option can be chosen: either resampling to " - "step_size or compressing, not both.") + "step_size, nb_points or compressing, not both.") # Silencing SFT's logger if our logging is in DEBUG mode, because it # typically produces a lot of outputs! @@ -679,19 +691,12 @@ def _process_one_streamline_group( if self.save_intermediate: output_fname = self.intermediate_folder.joinpath( subj_id + '_' + group + '.trk') - logging.debug(' *Saving intermediate streamline group {} ' - 'into {}.'.format(group, output_fname)) + logging.debug(" *Saving intermediate streamline group {} " + "into {}.".format(group, output_fname)) # Note. Do not remove the str below. Does not work well # with Path. save_tractogram(final_sft, str(output_fname)) - # Removing invalid streamlines - logging.debug(' *Total: {:,.0f} streamlines. Now removing ' - 'invalid streamlines.'.format(len(final_sft))) - final_sft.remove_invalid_streamlines() - logging.info(" Final number of streamlines: {:,.0f}." - .format(len(final_sft))) - conn_matrix = None conn_info = None if 'connectivity_matrix' in self.groups_config[group]: @@ -735,10 +740,11 @@ def _load_and_process_sft(self, tractogram_file, header): "We do not support file's type: {}. We only support .trk " "and .tck files.".format(tractogram_file)) if file_extension == '.trk': - if not is_header_compatible(str(tractogram_file), header): - raise ValueError("Streamlines group is not compatible with " - "volume groups\n ({})" - .format(tractogram_file)) + if header: + if not is_header_compatible(str(tractogram_file), header): + raise ValueError("Streamlines group is not compatible " + "with volume groups\n ({})" + .format(tractogram_file)) # overriding given header. header = 'same' @@ -748,6 +754,9 @@ def _load_and_process_sft(self, tractogram_file, header): sft = load_tractogram(str(tractogram_file), header) # Resample or compress streamlines - sft = resample_or_compress(sft, self.step_size, self.compress) + sft = resample_or_compress(sft, self.step_size, + self.nb_points, + self.compress, + self.remove_invalid) return sft diff --git a/dwi_ml/data/hdf5/utils.py b/dwi_ml/data/hdf5/utils.py index cbd7704b..d8a6d990 100644 --- a/dwi_ml/data/hdf5/utils.py +++ b/dwi_ml/data/hdf5/utils.py @@ -59,11 +59,11 @@ def add_hdf5_creation_args(p: ArgumentParser): help="A txt file containing the list of subjects ids to " "use for training. \n(Can be an empty file.)") p.add_argument('validation_subjs', - help="A txt file containing the list of subjects ids to use " - "for validation. \n(Can be an empty file.)") + help="A txt file containing the list of subjects ids" + " to use for validation. \n(Can be an empty file.)") p.add_argument('testing_subjs', - help="A txt file containing the list of subjects ids to use " - "for testing. \n(Can be an empty file.)") + help="A txt file containing the list of subjects ids" + " to use for testing. \n(Can be an empty file.)") # Optional arguments p.add_argument('--enforce_files_presence', type=bool, default=True, @@ -76,9 +76,10 @@ def add_hdf5_creation_args(p: ArgumentParser): "each subject inside the \nhdf5 folder, in sub-" "folders named subjid_intermediate.\n" "(Final concatenated standardized volumes and \n" - "final concatenated resampled/compressed streamlines.)") + "final concatenated resampled/compressed " + "streamlines.)") def add_streamline_processing_args(p: ArgumentParser): - g = p.add_argument_group('Streamlines processing options:') + g = p.add_argument_group('Streamlines processing options') add_resample_or_compress_arg(g) diff --git a/dwi_ml/data/processing/streamlines/data_augmentation.py b/dwi_ml/data/processing/streamlines/data_augmentation.py index 48683def..18428cb9 100644 --- a/dwi_ml/data/processing/streamlines/data_augmentation.py +++ b/dwi_ml/data/processing/streamlines/data_augmentation.py @@ -9,18 +9,33 @@ import numpy as np from scilpy.tractograms.streamline_operations import \ - resample_streamlines_step_size, compress_sft + resample_streamlines_num_points, resample_streamlines_step_size, \ + compress_sft def resample_or_compress(sft, step_size_mm: float = None, - compress: float = None): + nb_points: int = None, + compress: float = None, + remove_invalid: bool = False): if step_size_mm is not None: # Note. No matter the chosen space, resampling is done in mm. - logging.debug(" Resampling: {}".format(step_size_mm)) + logging.debug(" Resampling (step size): {}mm".format(step_size_mm)) sft = resample_streamlines_step_size(sft, step_size=step_size_mm) - if compress is not None: - logging.debug(" Compressing: {}".format(compress)) + elif nb_points is not None: + logging.debug(" Resampling: " + + "{} points per streamline".format(nb_points)) + sft = resample_streamlines_num_points(sft, nb_points) + elif compress is not None: + logging.debug(" Compressing: {}".format(compress)) sft = compress_sft(sft, compress) + + if remove_invalid: + logging.debug(" Total: {:,.0f} streamlines. Now removing " + "invalid streamlines.".format(len(sft))) + sft.remove_invalid_streamlines() + logging.info(" Final number of streamlines: {:,.0f}." + .format(len(sft))) + return sft diff --git a/dwi_ml/io_utils.py b/dwi_ml/io_utils.py index b16e5baf..12558acf 100644 --- a/dwi_ml/io_utils.py +++ b/dwi_ml/io_utils.py @@ -2,20 +2,23 @@ import os from argparse import ArgumentParser -from scilpy.io.utils import add_processes_arg - def add_resample_or_compress_arg(p: ArgumentParser): + p.add_argument("--remove_invalid", action='store_true', + help="If set, remove invalid streamlines.") g = p.add_mutually_exclusive_group() g.add_argument( '--step_size', type=float, metavar='s', help="Step size to resample the data (in mm). Default: None") + g.add_argument('--nb_points', type=int, metavar='n', + help='Number of points per streamline in the output.' + 'Default: None') g.add_argument( '--compress', type=float, metavar='r', const=0.01, nargs='?', dest='compress_th', help="Compression ratio. Default: None. Default if set: 0.01.\n" - "If neither step_size nor compress are chosen, streamlines " - "will be kept \nas they are.") + "If neither step_size, nb_points nor compress " + "are chosen, \nstreamlines will be kept as they are.") def add_arg_existing_experiment_path(p: ArgumentParser): diff --git a/scripts_python/dwiml_create_hdf5_dataset.py b/scripts_python/dwiml_create_hdf5_dataset.py index 7266cd15..f0a82f9a 100644 --- a/scripts_python/dwiml_create_hdf5_dataset.py +++ b/scripts_python/dwiml_create_hdf5_dataset.py @@ -87,7 +87,10 @@ def prepare_hdf5_creator(args): # Instantiate a creator and perform checks creator = HDF5Creator(Path(args.dwi_ml_ready_folder), args.out_hdf5_file, training_subjs, validation_subjs, testing_subjs, - groups_config, args.step_size, args.compress_th, + groups_config, args.step_size, + args.nb_points, + args.compress_th, + args.remove_invalid, args.enforce_files_presence, args.save_intermediate, intermediate_subdir)