diff --git a/dwi_ml/data/hdf5/hdf5_creation.py b/dwi_ml/data/hdf5/hdf5_creation.py index 82414c52..e8f3c663 100644 --- a/dwi_ml/data/hdf5/hdf5_creation.py +++ b/dwi_ml/data/hdf5/hdf5_creation.py @@ -136,6 +136,7 @@ class HDF5Creator: def __init__(self, root_folder: Path, out_hdf_filename: Path, training_subjs: List[str], validation_subjs: List[str], testing_subjs: List[str], groups_config: dict, + dps_keys: List[str] = [], step_size: float = None, nb_points: int = None, compress: float = None, @@ -157,6 +158,8 @@ def __init__(self, root_folder: Path, out_hdf_filename: Path, List of subject names for each data set. groups_config: dict Information from json file loaded as a dict. + dps_keys: List[str] + List of keys to keep in data_per_streamline. Default: None. step_size: float Step size to resample streamlines. Default: None. nb_points: int @@ -181,6 +184,7 @@ def __init__(self, root_folder: Path, out_hdf_filename: Path, self.validation_subjs = validation_subjs self.testing_subjs = testing_subjs self.groups_config = groups_config + self.dps_keys = dps_keys self.step_size = step_size self.nb_points = nb_points self.compress = compress @@ -609,9 +613,11 @@ def _create_streamline_groups(self, ref, subj_input_dir, subj_id, if len(sft.data_per_point) > 0: logging.debug('sft contained data_per_point. Data not kept.') - if len(sft.data_per_streamline) > 0: - logging.debug('sft contained data_per_streamlines. Data not ' - 'kept.') + + for dps_key in self.dps_keys: + logging.debug(" Include dps \"{}\" in the HDF5.".format(dps_key)) + streamlines_group.create_dataset('dps_' + dps_key, + data=sft.data_per_streamline[dps_key]) # Accessing private Dipy values, but necessary. # We need to deconstruct the streamlines into arrays with diff --git a/dwi_ml/data/hdf5/utils.py b/dwi_ml/data/hdf5/utils.py index d8a6d990..4d0aca12 100644 --- a/dwi_ml/data/hdf5/utils.py +++ b/dwi_ml/data/hdf5/utils.py @@ -78,6 +78,9 @@ def add_hdf5_creation_args(p: ArgumentParser): "(Final concatenated standardized volumes and \n" "final concatenated resampled/compressed " "streamlines.)") + p.add_argument('--dps_keys', type=str, nargs='+', default=[], + help="List of keys to keep in data_per_streamline. " + "Default: Empty.") def add_streamline_processing_args(p: ArgumentParser): diff --git a/scripts_python/dwiml_create_hdf5_dataset.py b/scripts_python/dwiml_create_hdf5_dataset.py index f0a82f9a..2719449c 100644 --- a/scripts_python/dwiml_create_hdf5_dataset.py +++ b/scripts_python/dwiml_create_hdf5_dataset.py @@ -87,7 +87,9 @@ def prepare_hdf5_creator(args): # Instantiate a creator and perform checks creator = HDF5Creator(Path(args.dwi_ml_ready_folder), args.out_hdf5_file, training_subjs, validation_subjs, testing_subjs, - groups_config, args.step_size, + groups_config, + args.dps_keys, + args.step_size, args.nb_points, args.compress_th, args.remove_invalid,