Skip to content

Commit

Permalink
add the possibility to resample with nb points per streamline
Browse files Browse the repository at this point in the history
  • Loading branch information
arnaudbore committed Sep 25, 2024
1 parent 93cec1e commit 550ae33
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 33 deletions.
43 changes: 26 additions & 17 deletions dwi_ml/data/hdf5/hdf5_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,10 @@ class HDF5Creator:
def __init__(self, root_folder: Path, out_hdf_filename: Path,
training_subjs: List[str], validation_subjs: List[str],
testing_subjs: List[str], groups_config: dict,
step_size: float = None, compress: float = None,
step_size: float = None,
nb_points: int = None,
compress: float = None,
remove_invalid: bool = False,
enforce_files_presence: bool = True,
save_intermediate: bool = False,
intermediate_folder: Path = None):
Expand All @@ -156,8 +159,12 @@ def __init__(self, root_folder: Path, out_hdf_filename: Path,
Information from json file loaded as a dict.
step_size: float
Step size to resample streamlines. Default: None.
nb_points: int
Number of points per streamline. Default: None.
compress: float
Compress streamlines. Default: None.
remove_invalid: bool
Remove invalid streamline. Default: False
enforce_files_presence: bool
If true, will stop if some files are not available for a subject.
Default: True.
Expand All @@ -175,7 +182,9 @@ def __init__(self, root_folder: Path, out_hdf_filename: Path,
self.testing_subjs = testing_subjs
self.groups_config = groups_config
self.step_size = step_size
self.nb_points = nb_points
self.compress = compress
self.remove_invalid = remove_invalid

# Optional
self.save_intermediate = save_intermediate
Expand Down Expand Up @@ -359,6 +368,8 @@ def create_database(self):
hdf_handle.attrs['testing_subjs'] = self.testing_subjs
hdf_handle.attrs['step_size'] = self.step_size if \
self.step_size is not None else 'Not defined by user'
hdf_handle.attrs['nb_points'] = self.nb_points if \
self.nb_points is not None else 'Not defined by user'
hdf_handle.attrs['compress'] = self.compress if \
self.compress is not None else 'Not defined by user'

Expand Down Expand Up @@ -632,6 +643,8 @@ def _process_one_streamline_group(
Reference used to load and send the streamlines in voxel space and
to create final merged SFT. If the file is a .trk, 'same' is used
instead.
remove_invalid : bool
If True, invalid streamlines will be removed
Returns
-------
Expand All @@ -641,11 +654,10 @@ def _process_one_streamline_group(
The Euclidean length of each streamline
"""
tractograms = self.groups_config[group]['files']

if self.step_size and self.compress:
raise ValueError(
"Only one option can be chosen: either resampling to "
"step_size or compressing, not both.")
"step_size, nb_points or compressing, not both.")

# Silencing SFT's logger if our logging is in DEBUG mode, because it
# typically produces a lot of outputs!
Expand Down Expand Up @@ -679,19 +691,12 @@ def _process_one_streamline_group(
if self.save_intermediate:
output_fname = self.intermediate_folder.joinpath(
subj_id + '_' + group + '.trk')
logging.debug(' *Saving intermediate streamline group {} '
'into {}.'.format(group, output_fname))
logging.debug(" *Saving intermediate streamline group {} "
"into {}.".format(group, output_fname))
# Note. Do not remove the str below. Does not work well
# with Path.
save_tractogram(final_sft, str(output_fname))

# Removing invalid streamlines
logging.debug(' *Total: {:,.0f} streamlines. Now removing '
'invalid streamlines.'.format(len(final_sft)))
final_sft.remove_invalid_streamlines()
logging.info(" Final number of streamlines: {:,.0f}."
.format(len(final_sft)))

conn_matrix = None
conn_info = None
if 'connectivity_matrix' in self.groups_config[group]:
Expand Down Expand Up @@ -735,10 +740,11 @@ def _load_and_process_sft(self, tractogram_file, header):
"We do not support file's type: {}. We only support .trk "
"and .tck files.".format(tractogram_file))
if file_extension == '.trk':
if not is_header_compatible(str(tractogram_file), header):
raise ValueError("Streamlines group is not compatible with "
"volume groups\n ({})"
.format(tractogram_file))
if header:
if not is_header_compatible(str(tractogram_file), header):
raise ValueError("Streamlines group is not compatible "
"with volume groups\n ({})"
.format(tractogram_file))
# overriding given header.
header = 'same'

Expand All @@ -748,6 +754,9 @@ def _load_and_process_sft(self, tractogram_file, header):
sft = load_tractogram(str(tractogram_file), header)

# Resample or compress streamlines
sft = resample_or_compress(sft, self.step_size, self.compress)
sft = resample_or_compress(sft, self.step_size,
self.nb_points,
self.compress,
self.remove_invalid)

return sft
13 changes: 7 additions & 6 deletions dwi_ml/data/hdf5/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,11 @@ def add_hdf5_creation_args(p: ArgumentParser):
help="A txt file containing the list of subjects ids to "
"use for training. \n(Can be an empty file.)")
p.add_argument('validation_subjs',
help="A txt file containing the list of subjects ids to use "
"for validation. \n(Can be an empty file.)")
help="A txt file containing the list of subjects ids"
" to use for validation. \n(Can be an empty file.)")
p.add_argument('testing_subjs',
help="A txt file containing the list of subjects ids to use "
"for testing. \n(Can be an empty file.)")
help="A txt file containing the list of subjects ids"
" to use for testing. \n(Can be an empty file.)")

# Optional arguments
p.add_argument('--enforce_files_presence', type=bool, default=True,
Expand All @@ -76,9 +76,10 @@ def add_hdf5_creation_args(p: ArgumentParser):
"each subject inside the \nhdf5 folder, in sub-"
"folders named subjid_intermediate.\n"
"(Final concatenated standardized volumes and \n"
"final concatenated resampled/compressed streamlines.)")
"final concatenated resampled/compressed "
"streamlines.)")


def add_streamline_processing_args(p: ArgumentParser):
g = p.add_argument_group('Streamlines processing options:')
g = p.add_argument_group('Streamlines processing options')
add_resample_or_compress_arg(g)
25 changes: 20 additions & 5 deletions dwi_ml/data/processing/streamlines/data_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,33 @@
import numpy as np

from scilpy.tractograms.streamline_operations import \
resample_streamlines_step_size, compress_sft
resample_streamlines_num_points, resample_streamlines_step_size, \
compress_sft


def resample_or_compress(sft, step_size_mm: float = None,
compress: float = None):
nb_points: int = None,
compress: float = None,
remove_invalid: bool = False):
if step_size_mm is not None:
# Note. No matter the chosen space, resampling is done in mm.
logging.debug(" Resampling: {}".format(step_size_mm))
logging.debug(" Resampling (step size): {}mm".format(step_size_mm))
sft = resample_streamlines_step_size(sft, step_size=step_size_mm)
if compress is not None:
logging.debug(" Compressing: {}".format(compress))
elif nb_points is not None:
logging.debug(" Resampling: " +
"{} points per streamline".format(nb_points))
sft = resample_streamlines_num_points(sft, nb_points)
elif compress is not None:
logging.debug(" Compressing: {}".format(compress))
sft = compress_sft(sft, compress)

if remove_invalid:
logging.debug(" Total: {:,.0f} streamlines. Now removing "
"invalid streamlines.".format(len(sft)))
sft.remove_invalid_streamlines()
logging.info(" Final number of streamlines: {:,.0f}."
.format(len(sft)))

return sft


Expand Down
11 changes: 7 additions & 4 deletions dwi_ml/io_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,23 @@
import os
from argparse import ArgumentParser

from scilpy.io.utils import add_processes_arg


def add_resample_or_compress_arg(p: ArgumentParser):
p.add_argument("--remove_invalid", action='store_true',
help="If set, remove invalid streamlines.")
g = p.add_mutually_exclusive_group()
g.add_argument(
'--step_size', type=float, metavar='s',
help="Step size to resample the data (in mm). Default: None")
g.add_argument('--nb_points', type=int, metavar='n',
help='Number of points per streamline in the output.'
'Default: None')
g.add_argument(
'--compress', type=float, metavar='r', const=0.01, nargs='?',
dest='compress_th',
help="Compression ratio. Default: None. Default if set: 0.01.\n"
"If neither step_size nor compress are chosen, streamlines "
"will be kept \nas they are.")
"If neither step_size, nb_points nor compress "
"are chosen, \nstreamlines will be kept as they are.")


def add_arg_existing_experiment_path(p: ArgumentParser):
Expand Down
5 changes: 4 additions & 1 deletion scripts_python/dwiml_create_hdf5_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,10 @@ def prepare_hdf5_creator(args):
# Instantiate a creator and perform checks
creator = HDF5Creator(Path(args.dwi_ml_ready_folder), args.out_hdf5_file,
training_subjs, validation_subjs, testing_subjs,
groups_config, args.step_size, args.compress_th,
groups_config, args.step_size,
args.nb_points,
args.compress_th,
args.remove_invalid,
args.enforce_files_presence,
args.save_intermediate, intermediate_subdir)

Expand Down

0 comments on commit 550ae33

Please sign in to comment.