Skip to content

Commit

Permalink
Merge pull request #244 from arnaudbore/add_resample_nb_points_stream…
Browse files Browse the repository at this point in the history
…lines_hdf5

[ENH] Add resample nb points + dps
  • Loading branch information
EmmaRenauld authored Sep 27, 2024
2 parents 93cec1e + f94453f commit 3ebc194
Show file tree
Hide file tree
Showing 10 changed files with 111 additions and 45 deletions.
1 change: 1 addition & 0 deletions .github/workflows/test_package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ jobs:

- name: Install dependencies
run: |
export SETUPTOOLS_USE_DISTUTILS=stdlib
pip install --upgrade pip
pip install pytest
pip install -e .
Expand Down
77 changes: 54 additions & 23 deletions dwi_ml/data/hdf5/hdf5_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,17 @@ class HDF5Creator:
def __init__(self, root_folder: Path, out_hdf_filename: Path,
training_subjs: List[str], validation_subjs: List[str],
testing_subjs: List[str], groups_config: dict,
step_size: float = None, compress: float = None,
dps_keys: List[str] = [],
step_size: float = None,
nb_points: int = None,
compress: float = None,
remove_invalid: bool = False,
enforce_files_presence: bool = True,
save_intermediate: bool = False,
intermediate_folder: Path = None):
"""
Params step_size, nb_points and compress are mutually exclusive.
Params
------
root_folder: Path
Expand All @@ -154,10 +160,16 @@ def __init__(self, root_folder: Path, out_hdf_filename: Path,
List of subject names for each data set.
groups_config: dict
Information from json file loaded as a dict.
dps_keys: List[str]
List of keys to keep in data_per_streamline. Default: [].
step_size: float
Step size to resample streamlines. Default: None.
nb_points: int
Number of points per streamline. Default: None.
compress: float
Compress streamlines. Default: None.
remove_invalid: bool
Remove invalid streamline. Default: False
enforce_files_presence: bool
If true, will stop if some files are not available for a subject.
Default: True.
Expand All @@ -174,8 +186,11 @@ def __init__(self, root_folder: Path, out_hdf_filename: Path,
self.validation_subjs = validation_subjs
self.testing_subjs = testing_subjs
self.groups_config = groups_config
self.dps_keys = dps_keys
self.step_size = step_size
self.nb_points = nb_points
self.compress = compress
self.remove_invalid = remove_invalid

# Optional
self.save_intermediate = save_intermediate
Expand All @@ -188,7 +203,7 @@ def __init__(self, root_folder: Path, out_hdf_filename: Path,
self._analyse_config_file()

# -------- Performing checks

self._check_streamlines_operations()
# Check that all subjects exist.
logging.debug("Preparing hdf5 creator for \n"
" training subjs {}, \n"
Expand Down Expand Up @@ -340,6 +355,19 @@ def flatten_list(a_list):
self.enforce_files_presence,
folder=subj_input_dir)

def _check_streamlines_operations(self):
valid = True
if self.step_size and self.nb_points:
valid = False
elif self.step_size and self.compress:
valid = False
elif self.nb_points and self.compress:
valid = False
if not valid:
raise ValueError(
"Only one option can be chosen: either resampling to "
"step_size, nb_points or compressing, not both.")

def create_database(self):
"""
Generates a hdf5 dataset from a group of subjects. Hdf5 dataset will
Expand All @@ -359,6 +387,8 @@ def create_database(self):
hdf_handle.attrs['testing_subjs'] = self.testing_subjs
hdf_handle.attrs['step_size'] = self.step_size if \
self.step_size is not None else 'Not defined by user'
hdf_handle.attrs['nb_points'] = self.nb_points if \
self.nb_points is not None else 'Not defined by user'
hdf_handle.attrs['compress'] = self.compress if \
self.compress is not None else 'Not defined by user'

Expand Down Expand Up @@ -598,9 +628,16 @@ def _create_streamline_groups(self, ref, subj_input_dir, subj_id,

if len(sft.data_per_point) > 0:
logging.debug('sft contained data_per_point. Data not kept.')
if len(sft.data_per_streamline) > 0:
logging.debug('sft contained data_per_streamlines. Data not '
'kept.')

for dps_key in self.dps_keys:
if dps_key not in sft.data_per_streamline:
raise ValueError(
"The data_per_streamline key '{}' was not found in "
"the sft. Check your tractogram file.".format(dps_key))

logging.debug(" Include dps \"{}\" in the HDF5.".format(dps_key))
streamlines_group.create_dataset('dps_' + dps_key,
data=sft.data_per_streamline[dps_key])

# Accessing private Dipy values, but necessary.
# We need to deconstruct the streamlines into arrays with
Expand Down Expand Up @@ -632,6 +669,8 @@ def _process_one_streamline_group(
Reference used to load and send the streamlines in voxel space and
to create final merged SFT. If the file is a .trk, 'same' is used
instead.
remove_invalid : bool
If True, invalid streamlines will be removed
Returns
-------
Expand All @@ -642,11 +681,6 @@ def _process_one_streamline_group(
"""
tractograms = self.groups_config[group]['files']

if self.step_size and self.compress:
raise ValueError(
"Only one option can be chosen: either resampling to "
"step_size or compressing, not both.")

# Silencing SFT's logger if our logging is in DEBUG mode, because it
# typically produces a lot of outputs!
set_sft_logger_level('WARNING')
Expand Down Expand Up @@ -679,19 +713,12 @@ def _process_one_streamline_group(
if self.save_intermediate:
output_fname = self.intermediate_folder.joinpath(
subj_id + '_' + group + '.trk')
logging.debug(' *Saving intermediate streamline group {} '
'into {}.'.format(group, output_fname))
logging.debug(" *Saving intermediate streamline group {} "
"into {}.".format(group, output_fname))
# Note. Do not remove the str below. Does not work well
# with Path.
save_tractogram(final_sft, str(output_fname))

# Removing invalid streamlines
logging.debug(' *Total: {:,.0f} streamlines. Now removing '
'invalid streamlines.'.format(len(final_sft)))
final_sft.remove_invalid_streamlines()
logging.info(" Final number of streamlines: {:,.0f}."
.format(len(final_sft)))

conn_matrix = None
conn_info = None
if 'connectivity_matrix' in self.groups_config[group]:
Expand Down Expand Up @@ -735,9 +762,10 @@ def _load_and_process_sft(self, tractogram_file, header):
"We do not support file's type: {}. We only support .trk "
"and .tck files.".format(tractogram_file))
if file_extension == '.trk':
if not is_header_compatible(str(tractogram_file), header):
raise ValueError("Streamlines group is not compatible with "
"volume groups\n ({})"
if header and not is_header_compatible(str(tractogram_file),
header):
raise ValueError("Streamlines group is not compatible "
"with volume groups\n ({})"
.format(tractogram_file))
# overriding given header.
header = 'same'
Expand All @@ -748,6 +776,9 @@ def _load_and_process_sft(self, tractogram_file, header):
sft = load_tractogram(str(tractogram_file), header)

# Resample or compress streamlines
sft = resample_or_compress(sft, self.step_size, self.compress)
sft = resample_or_compress(sft, self.step_size,
self.nb_points,
self.compress,
self.remove_invalid)

return sft
16 changes: 10 additions & 6 deletions dwi_ml/data/hdf5/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,11 @@ def add_hdf5_creation_args(p: ArgumentParser):
help="A txt file containing the list of subjects ids to "
"use for training. \n(Can be an empty file.)")
p.add_argument('validation_subjs',
help="A txt file containing the list of subjects ids to use "
"for validation. \n(Can be an empty file.)")
help="A txt file containing the list of subjects ids"
" to use for validation. \n(Can be an empty file.)")
p.add_argument('testing_subjs',
help="A txt file containing the list of subjects ids to use "
"for testing. \n(Can be an empty file.)")
help="A txt file containing the list of subjects ids"
" to use for testing. \n(Can be an empty file.)")

# Optional arguments
p.add_argument('--enforce_files_presence', type=bool, default=True,
Expand All @@ -76,9 +76,13 @@ def add_hdf5_creation_args(p: ArgumentParser):
"each subject inside the \nhdf5 folder, in sub-"
"folders named subjid_intermediate.\n"
"(Final concatenated standardized volumes and \n"
"final concatenated resampled/compressed streamlines.)")
"final concatenated resampled/compressed "
"streamlines.)")
p.add_argument('--dps_keys', type=str, nargs='+', default=[],
help="List of keys to keep in data_per_streamline. "
"Default: Empty.")


def add_streamline_processing_args(p: ArgumentParser):
g = p.add_argument_group('Streamlines processing options:')
g = p.add_argument_group('Streamlines processing options')
add_resample_or_compress_arg(g)
25 changes: 20 additions & 5 deletions dwi_ml/data/processing/streamlines/data_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,33 @@
import numpy as np

from scilpy.tractograms.streamline_operations import \
resample_streamlines_step_size, compress_sft
resample_streamlines_num_points, resample_streamlines_step_size, \
compress_sft


def resample_or_compress(sft, step_size_mm: float = None,
compress: float = None):
nb_points: int = None,
compress: float = None,
remove_invalid: bool = False):
if step_size_mm is not None:
# Note. No matter the chosen space, resampling is done in mm.
logging.debug(" Resampling: {}".format(step_size_mm))
logging.debug(" Resampling (step size): {}mm".format(step_size_mm))
sft = resample_streamlines_step_size(sft, step_size=step_size_mm)
if compress is not None:
logging.debug(" Compressing: {}".format(compress))
elif nb_points is not None:
logging.debug(" Resampling: " +
"{} points per streamline".format(nb_points))
sft = resample_streamlines_num_points(sft, nb_points)
elif compress is not None:
logging.debug(" Compressing: {}".format(compress))
sft = compress_sft(sft, compress)

if remove_invalid:
logging.debug(" Total: {:,.0f} streamlines. Now removing "
"invalid streamlines.".format(len(sft)))
sft.remove_invalid_streamlines()
logging.info(" Final number of streamlines: {:,.0f}."
.format(len(sft)))

return sft


Expand Down
11 changes: 7 additions & 4 deletions dwi_ml/io_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,23 @@
import os
from argparse import ArgumentParser

from scilpy.io.utils import add_processes_arg


def add_resample_or_compress_arg(p: ArgumentParser):
p.add_argument("--remove_invalid", action='store_true',
help="If set, remove invalid streamlines.")
g = p.add_mutually_exclusive_group()
g.add_argument(
'--step_size', type=float, metavar='s',
help="Step size to resample the data (in mm). Default: None")
g.add_argument('--nb_points', type=int, metavar='n',
help='Number of points per streamline in the output.'
'Default: None')
g.add_argument(
'--compress', type=float, metavar='r', const=0.01, nargs='?',
dest='compress_th',
help="Compression ratio. Default: None. Default if set: 0.01.\n"
"If neither step_size nor compress are chosen, streamlines "
"will be kept \nas they are.")
"If neither step_size, nb_points nor compress "
"are chosen, \nstreamlines will be kept as they are.")


def add_arg_existing_experiment_path(p: ArgumentParser):
Expand Down
13 changes: 8 additions & 5 deletions dwi_ml/models/main_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
prepare_neighborhood_vectors
from dwi_ml.experiment_utils.prints import format_dict_to_str
from dwi_ml.io_utils import add_resample_or_compress_arg
from dwi_ml.models.direction_getter_models import keys_to_direction_getters, \
AbstractDirectionGetterModel
from dwi_ml.models.direction_getter_models import keys_to_direction_getters
from dwi_ml.models.embeddings import (keys_to_embeddings, NNEmbedding,
NoEmbedding)
from dwi_ml.models.utils.direction_getters import add_direction_getter_args
Expand All @@ -37,6 +36,7 @@ class MainModelAbstract(torch.nn.Module):
def __init__(self, experiment_name: str,
# Target preprocessing params for the batch loader + tracker
step_size: float = None,
nb_points: int = None,
compress_lines: float = False,
# Other
log_level=logging.root.level):
Expand Down Expand Up @@ -74,17 +74,20 @@ def __init__(self, experiment_name: str,

# To tell our batch loader how to resample streamlines during training
# (should also be the step size during tractography).
if step_size and compress_lines:
raise ValueError("You may choose either resampling or compressing,"
"but not both.")
if (step_size and compress_lines) or (step_size and nb_points) or (nb_points and compress_lines):
raise ValueError("You may choose either resampling (step_size or nb_points)"
" or compressing, but not two of them or more.")
elif step_size and step_size <= 0:
raise ValueError("Step size can't be 0 or less!")
elif nb_points and nb_points <= 0:
raise ValueError("Number of points can't be 0 or less!")
# Note. When using
# scilpy.tracking.tools.resample_streamlines_step_size, a warning
# is shown if step_size < 0.1 or > np.max(sft.voxel_sizes), saying
# that the value is suspicious. Not raising the same warnings here
# as you may be wanting to test weird things to understand better
# your model.
self.nb_points = nb_points
self.step_size = step_size
self.compress_lines = compress_lines

Expand Down
1 change: 1 addition & 0 deletions dwi_ml/testing/testers.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def run_model_on_sft(self, sft, compute_loss=False):
The mean eos error per line.
"""
sft = resample_or_compress(sft, self.model.step_size,
self.model.nb_points,
self.model.compress_lines)
sft.to_vox()
sft.to_corner()
Expand Down
1 change: 1 addition & 0 deletions dwi_ml/training/batch_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ def _data_augmentation_sft(self, sft):
"the hdf5 dataset. Not compressing again.")
else:
sft = resample_or_compress(sft, self.model.step_size,
self.model.nb_points,
self.model.compress_lines)

# Splitting streamlines
Expand Down
7 changes: 6 additions & 1 deletion scripts_python/dwiml_create_hdf5_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,12 @@ def prepare_hdf5_creator(args):
# Instantiate a creator and perform checks
creator = HDF5Creator(Path(args.dwi_ml_ready_folder), args.out_hdf5_file,
training_subjs, validation_subjs, testing_subjs,
groups_config, args.step_size, args.compress_th,
groups_config,
args.dps_keys,
args.step_size,
args.nb_points,
args.compress_th,
args.remove_invalid,
args.enforce_files_presence,
args.save_intermediate, intermediate_subdir)

Expand Down
4 changes: 3 additions & 1 deletion scripts_python/dwiml_visualize_noise_on_streamlines.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ def main():
subj_sft_data = subj_data.sft_data_list[streamline_group_idx]
sft = subj_sft_data.as_sft()

sft = resample_or_compress(sft, args.step_size, args.compress_th)
sft = resample_or_compress(sft, args.step_size,
args.nb_points,
args.compress_th)
sft.to_vox()
sft.to_corner()

Expand Down

0 comments on commit 3ebc194

Please sign in to comment.