From ab828dd8e642cc26bed8151ec3771395a2aa0ecf Mon Sep 17 00:00:00 2001 From: EmmaRenauld Date: Thu, 29 Jun 2023 14:30:35 -0400 Subject: [PATCH 1/3] Refactoring while searching for error --- dwi_ml/tracking/projects/__init__.py | 0 dwi_ml/tracking/projects/utils.py | 196 ++++++++++++++++++ dwi_ml/tracking/propagation.py | 2 +- dwi_ml/tracking/tracker.py | 30 +-- dwi_ml/tracking/tracking_mask.py | 2 +- dwi_ml/tracking/utils.py | 181 ---------------- .../projects/trainers_for_generation.py | 2 +- .../analyze_batch_loader_visually.py | 7 +- scripts_python/l2t_track_from_model.py | 8 +- scripts_python/l2t_train_model.py | 2 - scripts_python/tto_track_from_model.py | 6 +- scripts_python/ttst_track_from_model.py | 6 +- 12 files changed, 224 insertions(+), 218 deletions(-) create mode 100644 dwi_ml/tracking/projects/__init__.py create mode 100644 dwi_ml/tracking/projects/utils.py diff --git a/dwi_ml/tracking/projects/__init__.py b/dwi_ml/tracking/projects/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/dwi_ml/tracking/projects/utils.py b/dwi_ml/tracking/projects/utils.py new file mode 100644 index 00000000..0c2749dd --- /dev/null +++ b/dwi_ml/tracking/projects/utils.py @@ -0,0 +1,196 @@ +# -*- coding: utf-8 -*- +import logging +import os + +from dipy.io.stateful_tractogram import (Space, Origin, set_sft_logger_level, + StatefulTractogram) +from dipy.io.streamline import save_tractogram +import nibabel as nib +import numpy as np + +from scilpy.tracking.seed import SeedGenerator + +from dwi_ml.experiment_utils.timer import Timer +from dwi_ml.io_utils import add_arg_existing_experiment_path, add_memory_args +from dwi_ml.testing.utils import add_args_testing_subj_hdf5 +from dwi_ml.tracking.tracking_mask import TrackingMask +from dwi_ml.tracking.tracker import DWIMLAbstractTracker + + +ALWAYS_VOX_SPACE = Space.VOX +ALWAYS_CORNER = Origin('corner') + + +def add_tracking_options(p): + + add_arg_existing_experiment_path(p) + add_args_testing_subj_hdf5(p, ask_input_group=True) + + p.add_argument('out_tractogram', + help='Tractogram output file (must be .trk or .tck).') + p.add_argument('seeding_mask_group', + help="Seeding mask's volume group in the hdf5.") + + track_g = p.add_argument_group(' Tracking options') + track_g.add_argument('--algo', choices=['det', 'prob'], default='det', + help="Tracking algorithm (det or prob). Must be " + "implemented in the chosen model. [det]") + track_g.add_argument('--step_size', type=float, + help='Step size in mm. Default: using the step size ' + 'saved in the model parameters.') + track_g.add_argument('--track_forward_only', action='store_true', + help="If set, tracks in one direction only (forward) " + "given the initial \nseed. The direction is " + "randomly drawn from the ODF.") + track_g.add_argument('--mask_interp', default='nearest', + choices=['nearest', 'trilinear'], + help="Mask interpolation: nearest-neighbor or " + "trilinear. [%(default)s]") + track_g.add_argument('--data_interp', default='trilinear', + choices=['nearest', 'trilinear'], + help="Input data interpolation: nearest-neighbor or " + "trilinear. [%(default)s]") + + stop_g = p.add_argument_group("Stopping criteria") + stop_g.add_argument('--min_length', type=float, default=10., + metavar='m', + help='Minimum length of a streamline in mm. ' + '[%(default)s]') + stop_g.add_argument('--max_length', type=float, default=300., + metavar='M', + help='Maximum length of a streamline in mm. ' + '[%(default)s]') + stop_g.add_argument('--tracking_mask_group', + help="Tracking mask's volume group in the hdf5.") + stop_g.add_argument('--theta', metavar='t', type=float, + default=90, + help="The tracking direction at each step being " + "defined by the model, \ntheta arg can't define " + "allowed directions in the tracking field.\n" + "Rather, this new equivalent angle, is used as " + "\na stopping criterion during propagation: " + "tracking \nis stopped when a direction is more " + "than an angle t from preceding direction") + stop_g.add_argument('--eos_stop', metavar='prob', + help="Stopping criterion if a EOS value was learned " + "during training. \nCan either be a probability " + "(default 0.5) or the string 'max', which will " + "\nstop the propagation if the EOS class's " + "probability is the class with maximal " + "probability, no mather its value.") + + r_g = p.add_argument_group(' Random seeding options') + r_g.add_argument('--rng_seed', type=int, + help='Initial value for the random number generator. ' + '[%(default)s]') + r_g.add_argument('--skip', type=int, default=0, + help="Skip the first N random numbers. \n" + "Useful if you want to create new streamlines to " + "add to \na previously created tractogram with a " + "fixed --rng_seed.\nEx: If tractogram_1 was created " + "with -nt 1,000,000, \nyou can create tractogram_2 " + "with \n--skip 1,000,000.") + + # Memory options: + m_g = add_memory_args(p, add_lazy_options=True, + add_multiprocessing_option=True, + add_rng=True) + m_g.add_argument('--simultaneous_tracking', type=int, default=1, + help='Track n streamlines at the same time. Intended for ' + 'GPU usage. Default = 1 (no simultaneous tracking).') + + return track_g + + +def prepare_seed_generator(parser, args, hdf_handle): + """ + Prepares a SeedGenerator from scilpy's library. Returns also some header + information to allow verifications. + """ + seeding_group = hdf_handle[args.subj_id][args.seeding_mask_group] + seed_data = np.array(seeding_group['data'], dtype=np.float32) + seed_res = np.array(seeding_group.attrs['voxres'], dtype=np.float32) + affine = np.array(seeding_group.attrs['affine'], dtype=np.float32) + ref = nib.Nifti1Image(seed_data, affine) + + seed_generator = SeedGenerator(seed_data, seed_res, space=ALWAYS_VOX_SPACE, + origin=ALWAYS_CORNER) + + if len(seed_generator.seeds_vox) == 0: + parser.error('Seed mask "{}" does not have any voxel with value > 0.' + .format(args.in_seed)) + + if args.npv: + # Note. Not really nb seed per voxel, just in average. + nbr_seeds = len(seed_generator.seeds_vox) * args.npv + elif args.nt: + nbr_seeds = args.nt + else: + # Setting npv = 1. + nbr_seeds = len(seed_generator.seeds_vox) + + seed_header = nib.Nifti1Image(seed_data, affine).header + + return seed_generator, nbr_seeds, seed_header, ref + + +def prepare_tracking_mask(hdf_handle, tracking_mask_group, subj_id, mask_interp): + """ + Prepare the tracking mask as a DataVolume from scilpy's library. Returns + also some header information to allow verifications. + """ + if subj_id not in hdf_handle: + raise KeyError("Subject {} not found in {}. Possible subjects are: {}" + .format(subj_id, hdf_handle, list(hdf_handle.keys()))) + if tracking_mask_group not in hdf_handle[subj_id]: + raise KeyError("HDF group '{}' not found for subject {} in hdf file {}" + .format(tracking_mask_group, subj_id, hdf_handle)) + tm_group = hdf_handle[subj_id][tracking_mask_group] + mask_data = np.array(tm_group['data'], dtype=np.float64).squeeze() + # mask_res = np.array(tm_group.attrs['voxres'], dtype=np.float32) + affine = np.array(tm_group.attrs['affine'], dtype=np.float32) + ref = nib.Nifti1Image(mask_data, affine) + + mask = TrackingMask(mask_data.shape, mask_data, mask_interp) + + return mask, ref + + +def track_and_save(tracker: DWIMLAbstractTracker, args, ref): + if args.save_seeds: + name, ext = os.path.splitext(args.out_tractogram) + if ext != '.trk': + raise ValueError("Cannot save seeds! (data per streamline not " + "saved with extension {}). Please change out " + "filename to .trk".format(ext)) + + with Timer("\nTracking...", newline=True, color='blue'): + streamlines, seeds = tracker.track() + + logging.debug("Tracked {} streamlines (out of {} seeds). Now saving..." + .format(len(streamlines), tracker.nbr_seeds)) + + if len(streamlines) == 0: + logging.warning("No streamlines created! Not saving tractogram!") + return + + # save seeds if args.save_seeds is given + # Seeds must be saved in voxel space (ok!), but origin: center, if we want + # to use scripts such as scil_compute_seed_density_map. + if args.save_seeds: + print("Saving seeds in data_per_streamline.") + seeds = [np.asarray(seed) - 0.5 for seed in seeds] # to_center + data_per_streamline = {'seeds': seeds} + else: + data_per_streamline = {} + + # Silencing SFT's logger if our logging is in DEBUG mode, because it + # typically produces a lot of outputs! + set_sft_logger_level('WARNING') + + logging.info("Saving resulting tractogram to {}" + .format(args.out_tractogram)) + sft = StatefulTractogram(streamlines, ref, space=ALWAYS_VOX_SPACE, + origin=ALWAYS_CORNER, + data_per_streamline=data_per_streamline) + save_tractogram(sft, args.out_tractogram, bbox_valid_check=False) diff --git a/dwi_ml/tracking/propagation.py b/dwi_ml/tracking/propagation.py index 902c56f6..ecf95786 100644 --- a/dwi_ml/tracking/propagation.py +++ b/dwi_ml/tracking/propagation.py @@ -199,7 +199,7 @@ def _verify_stopping_criteria(n_last_pos, lines, mask=None, max_nbr_pts=None): # continue. still_on = ~stopping - out_of_mask = ~mask.is_in_mask(n_last_pos[still_on]).cpu().numpy() + out_of_mask = ~mask.is_vox_corner_in_mask(n_last_pos[still_on]).cpu().numpy() if sum(out_of_mask) > 0: logger.debug("{} streamlines stopping out of mask." .format(sum(out_of_mask))) diff --git a/dwi_ml/tracking/tracker.py b/dwi_ml/tracking/tracker.py index 209ced21..68d7b958 100644 --- a/dwi_ml/tracking/tracker.py +++ b/dwi_ml/tracking/tracker.py @@ -7,10 +7,10 @@ import traceback from typing import List -from dipy.io.stateful_tractogram import Space, Origin from dipy.tracking.streamlinespeed import compress_streamlines import numpy as np import torch +from dwi_ml.tracking.utils import prepare_step_size_vox from torch import Tensor from tqdm.contrib.logging import tqdm_logging_redirect @@ -23,7 +23,6 @@ MainModelOneInput from dwi_ml.tracking.propagation import propagate_multiple_lines from dwi_ml.tracking.tracking_mask import TrackingMask -from dwi_ml.tracking.utils import prepare_step_size_vox logger = logging.getLogger('tracker_logger') @@ -180,11 +179,6 @@ def __init__(self, dataset: MultisubjectSubset, subj_idx: int, self.model.eval() self.grad_context = torch.no_grad() - # Space and origin - # torch trilinear interpolation uses origin='corner', space=vox. - self.origin = Origin('corner') - self.space = Space.VOX - # Nb points if self.min_nbr_pts <= 0: logger.warning("Minimum number of points cannot be 0. Changed to " @@ -374,15 +368,13 @@ def _cpu_tracking(self, chunk_id): streamline = np.array(line, dtype='float32') if self.compression_th and self.compression_th > 0: - # Compressing. Threshold is in mm. Verifying space. - if self.space == Space.VOX: - # Equivalent of sft.to_voxmm: - streamline *= self.seed_generator.voxres - compress_streamlines(streamline, self.compression_th) - # Equivalent of sft.to_vox: - streamline /= self.seed_generator.voxres - else: - compress_streamlines(streamline, self.compression_th) + # Compressing. Threshold is in mm. Considering that we work + # in vox space, changing: + # Equivalent of sft.to_voxmm: + streamline *= self.seed_generator.voxres + compress_streamlines(streamline, self.compression_th) + # Equivalent of sft.to_vox: + streamline /= self.seed_generator.voxres streamlines.append(streamline) @@ -438,14 +430,14 @@ def _get_multiple_lines_both_directions(self, seeds: List[np.ndarray]): logger.debug("Starting forward") self.prepare_forward(seeds) - lines = self.propagate_multiple_lines(lines) + lines = self._propagate_multiple_lines(lines) if not self.track_forward_only: logger.debug("Starting backward") lines, rej_idx = self.prepare_backward(lines) if rej_idx is not None and len(rej_idx) > 0: seeds = [s for i, s in enumerate(seeds) if i not in rej_idx] - lines = self.propagate_multiple_lines(lines) + lines = self._propagate_multiple_lines(lines) # Clean streamlines # Max is already checked as stopping criteria. @@ -456,7 +448,7 @@ def _get_multiple_lines_both_directions(self, seeds: List[np.ndarray]): return clean_lines, clean_seeds - def propagate_multiple_lines(self, lines: List[Tensor]): + def _propagate_multiple_lines(self, lines: List[Tensor]): return propagate_multiple_lines( lines, self.update_memory_after_removing_lines, self.get_next_dirs, self.theta, self.step_size, diff --git a/dwi_ml/tracking/tracking_mask.py b/dwi_ml/tracking/tracking_mask.py index 1424f253..7bd2b11a 100644 --- a/dwi_ml/tracking/tracking_mask.py +++ b/dwi_ml/tracking/tracking_mask.py @@ -48,7 +48,7 @@ def get_value_at_vox_corner_coordinate(self, xyz, interpolation): else: return torch_trilinear_interpolation(self.data, xyz) - def is_in_mask(self, xyz): + def is_vox_corner_in_mask(self, xyz): # Clipping to bound. xyz = torch.maximum(xyz, self.lower_bound) xyz = torch.minimum(xyz, self.higher_bound - eps) diff --git a/dwi_ml/tracking/utils.py b/dwi_ml/tracking/utils.py index bb9c1186..65220a84 100644 --- a/dwi_ml/tracking/utils.py +++ b/dwi_ml/tracking/utils.py @@ -1,154 +1,5 @@ # -*- coding: utf-8 -*- import logging -from typing import Union - -from dipy.io.stateful_tractogram import (Space, Origin, set_sft_logger_level, - StatefulTractogram) -from dipy.io.streamline import save_tractogram -import nibabel as nib -import numpy as np - -from scilpy.tracking.seed import SeedGenerator - -from dwi_ml.experiment_utils.timer import Timer -from dwi_ml.io_utils import add_arg_existing_experiment_path, add_memory_args -from dwi_ml.testing.utils import add_args_testing_subj_hdf5 -from dwi_ml.tracking.tracking_mask import TrackingMask - - -def add_tracking_options(p): - - add_arg_existing_experiment_path(p) - add_args_testing_subj_hdf5(p, ask_input_group=True) - - p.add_argument('out_tractogram', - help='Tractogram output file (must be .trk or .tck).') - p.add_argument('seeding_mask_group', - help="Seeding mask's volume group in the hdf5.") - - track_g = p.add_argument_group(' Tracking options') - track_g.add_argument('--algo', choices=['det', 'prob'], default='det', - help="Tracking algorithm (det or prob). Must be " - "implemented in the chosen model. [det]") - track_g.add_argument('--step_size', type=float, - help='Step size in mm. Default: using the step size ' - 'saved in the model parameters.') - track_g.add_argument('--track_forward_only', action='store_true', - help="If set, tracks in one direction only (forward) " - "given the initial \nseed. The direction is " - "randomly drawn from the ODF.") - track_g.add_argument('--mask_interp', default='trilinear', - choices=['nearest', 'trilinear'], - help="Mask interpolation: nearest-neighbor or " - "trilinear. [%(default)s]") - track_g.add_argument('--data_interp', default='trilinear', - choices=['nearest', 'trilinear'], - help="Input data interpolation: nearest-neighbor or " - "trilinear. [%(default)s]") - - stop_g = p.add_argument_group("Stopping criteria") - stop_g.add_argument('--min_length', type=float, default=10., - metavar='m', - help='Minimum length of a streamline in mm. ' - '[%(default)s]') - stop_g.add_argument('--max_length', type=float, default=300., - metavar='M', - help='Maximum length of a streamline in mm. ' - '[%(default)s]') - stop_g.add_argument('--tracking_mask_group', - help="Tracking mask's volume group in the hdf5.") - stop_g.add_argument('--theta', metavar='t', type=float, - default=90, - help="The tracking direction at each step being " - "defined by the model, \ntheta arg can't define " - "allowed directions in the tracking field.\n" - "Rather, this new equivalent angle, is used as " - "\na stopping criterion during propagation: " - "tracking \nis stopped when a direction is more " - "than an angle t from preceding direction") - stop_g.add_argument('--eos_stop', metavar='prob', - help="Stopping criterion if a EOS value was learned " - "during training. \nCan either be a probability " - "(default 0.5) or the string 'max', which will " - "\nstop the propagation if the EOS class's " - "probability is the class with maximal " - "probability, no mather its value.") - - r_g = p.add_argument_group(' Random seeding options') - r_g.add_argument('--rng_seed', type=int, - help='Initial value for the random number generator. ' - '[%(default)s]') - r_g.add_argument('--skip', type=int, default=0, - help="Skip the first N random numbers. \n" - "Useful if you want to create new streamlines to " - "add to \na previously created tractogram with a " - "fixed --rng_seed.\nEx: If tractogram_1 was created " - "with -nt 1,000,000, \nyou can create tractogram_2 " - "with \n--skip 1,000,000.") - - # Memory options: - m_g = add_memory_args(p, add_lazy_options=True, - add_multiprocessing_option=True, - add_rng=True) - m_g.add_argument('--simultaneous_tracking', type=int, default=1, - help='Track n streamlines at the same time. Intended for ' - 'GPU usage. Default = 1 (no simultaneous tracking).') - - return track_g - - -def prepare_seed_generator(parser, args, hdf_handle): - """ - Prepares a SeedGenerator from scilpy's library. Returns also some header - information to allow verifications. - """ - seeding_group = hdf_handle[args.subj_id][args.seeding_mask_group] - seed_data = np.array(seeding_group['data'], dtype=np.float32) - seed_res = np.array(seeding_group.attrs['voxres'], dtype=np.float32) - affine = np.array(seeding_group.attrs['affine'], dtype=np.float32) - ref = nib.Nifti1Image(seed_data, affine) - - seed_generator = SeedGenerator(seed_data, seed_res, space=Space.VOX, - origin=Origin('corner')) - - if len(seed_generator.seeds_vox) == 0: - parser.error('Seed mask "{}" does not have any voxel with value > 0.' - .format(args.in_seed)) - - if args.npv: - # Note. Not really nb seed per voxel, just in average. - nbr_seeds = len(seed_generator.seeds_vox) * args.npv - elif args.nt: - nbr_seeds = args.nt - else: - # Setting npv = 1. - nbr_seeds = len(seed_generator.seeds_vox) - - seed_header = nib.Nifti1Image(seed_data, affine).header - - return seed_generator, nbr_seeds, seed_header, ref - - -def prepare_tracking_mask(hdf_handle, tracking_mask_group, subj_id, mask_interp): - """ - Prepare the tracking mask as a DataVolume from scilpy's library. Returns - also some header information to allow verifications. - """ - if subj_id not in hdf_handle: - raise KeyError("Subject {} not found in {}. Possible subjects are: {}" - .format(subj_id, hdf_handle, list(hdf_handle.keys()))) - if tracking_mask_group not in hdf_handle[subj_id]: - raise KeyError("HDF group '{}' not found for subject {} in hdf file {}" - .format(tracking_mask_group, subj_id, hdf_handle)) - tm_group = hdf_handle[subj_id][tracking_mask_group] - mask_data = np.array(tm_group['data'], dtype=np.float64).squeeze() - # mask_res = np.array(tm_group.attrs['voxres'], dtype=np.float32) - affine = np.array(tm_group.attrs['affine'], dtype=np.float32) - ref = nib.Nifti1Image(mask_data, affine) - - mask = TrackingMask(mask_data.shape, mask_data, mask_interp) - - return mask, ref def prepare_step_size_vox(step_size, res): @@ -169,35 +20,3 @@ def prepare_step_size_vox(step_size, res): return step_size_vox_space, normalize_directions - -def track_and_save(tracker, args, ref): - with Timer("\nTracking...", newline=True, color='blue'): - streamlines, seeds = tracker.track() - - logging.debug("Tracked {} streamlines (out of {} seeds). Now saving..." - .format(len(streamlines), tracker.nbr_seeds)) - - if len(streamlines) == 0: - logging.warning("No streamlines created! Not saving tractogram!") - return - - # save seeds if args.save_seeds is given - # Seeds must be saved in voxel space (ok!), but origin: center, if we want - # to use scripts such as scil_compute_seed_density_map. - if args.save_seeds: - print("Saving seeds in data_per_streamline.") - seeds = [np.asarray(seed) - 0.5 for seed in seeds] # to_center - data_per_streamline = {'seeds': seeds} - else: - data_per_streamline = {} - - # Silencing SFT's logger if our logging is in DEBUG mode, because it - # typically produces a lot of outputs! - set_sft_logger_level('WARNING') - - logging.info("Saving resulting tractogram to {}" - .format(args.out_tractogram)) - sft = StatefulTractogram(streamlines, ref, space=Space.VOX, - origin=Origin('corner'), - data_per_streamline=data_per_streamline) - save_tractogram(sft, args.out_tractogram, bbox_valid_check=False) diff --git a/dwi_ml/training/projects/trainers_for_generation.py b/dwi_ml/training/projects/trainers_for_generation.py index 559678a6..4e7e4888 100644 --- a/dwi_ml/training/projects/trainers_for_generation.py +++ b/dwi_ml/training/projects/trainers_for_generation.py @@ -11,7 +11,7 @@ from dwi_ml.experiment_utils.tqdm_logging import tqdm_logging_redirect from dwi_ml.models.main_models import ModelWithDirectionGetter from dwi_ml.tracking.propagation import propagate_multiple_lines -from dwi_ml.tracking.utils import prepare_tracking_mask +from dwi_ml.tracking.projects.utils import prepare_tracking_mask from dwi_ml.training.trainers import DWIMLTrainerOneInput from dwi_ml.training.utils.monitoring import BatchHistoryMonitor, TimeMonitor diff --git a/dwi_ml/unit_tests/visual_tests/analyze_batch_loader_visually.py b/dwi_ml/unit_tests/visual_tests/analyze_batch_loader_visually.py index 5d7f8af1..eae3faf8 100755 --- a/dwi_ml/unit_tests/visual_tests/analyze_batch_loader_visually.py +++ b/dwi_ml/unit_tests/visual_tests/analyze_batch_loader_visually.py @@ -8,11 +8,12 @@ import nibabel as nib import numpy as np -from dipy.io.stateful_tractogram import StatefulTractogram, Space, Origin +from dipy.io.stateful_tractogram import StatefulTractogram from dipy.io.streamline import save_tractogram from dwi_ml.data.dataset.multi_subject_containers import MultiSubjectDataset from dwi_ml.models.main_models import MainModelOneInput +from dwi_ml.tracking.projects.utils import ALWAYS_VOX_SPACE, ALWAYS_CORNER from dwi_ml.unit_tests.utils.data_and_models_for_tests import ( create_test_batch_sampler, create_batch_loader, fetch_testing_data, ModelForTest) @@ -92,8 +93,8 @@ def _load_directly_and_verify(batch_loader, batch_idx_tuples, ref, suffix): # Save the last batch's SFT. logging.info("Saving subj's tractogram {}".format('test_batch1_' + suffix)) - sft = StatefulTractogram(batch_streamlines, reference=ref, space=Space.VOX, - origin=Origin('corner')) + sft = StatefulTractogram(batch_streamlines, reference=ref, + space=ALWAYS_VOX_SPACE, origin=ALWAYS_CORNER) sft.data_per_streamline = {"seeds": [s[0] for s in batch_streamlines]} filename = os.path.join(results_folder, 'test_batch_' + suffix + '.trk') diff --git a/scripts_python/l2t_track_from_model.py b/scripts_python/l2t_track_from_model.py index 24088e25..e4523736 100644 --- a/scripts_python/l2t_track_from_model.py +++ b/scripts_python/l2t_track_from_model.py @@ -28,14 +28,14 @@ from dwi_ml.testing.utils import prepare_dataset_one_subj from dwi_ml.tracking.projects.learn2track_tracker import RecurrentTracker from dwi_ml.tracking.tracking_mask import TrackingMask -from dwi_ml.tracking.utils import (add_tracking_options, - prepare_seed_generator, - prepare_tracking_mask, track_and_save) +from dwi_ml.tracking.projects.utils import (add_tracking_options, + prepare_seed_generator, + prepare_tracking_mask, track_and_save) # A decision should be made as if we should keep the last point (out of the # tracking mask). Currently keeping this as in Dipy, i.e. True. Could be # an option for the user. -APPEND_LAST_POINT = True +APPEND_LAST_POINT = False def build_argparser(): diff --git a/scripts_python/l2t_train_model.py b/scripts_python/l2t_train_model.py index 6e68653f..38cc1215 100755 --- a/scripts_python/l2t_train_model.py +++ b/scripts_python/l2t_train_model.py @@ -11,7 +11,6 @@ # comet_ml not used, but comet_ml requires to be imported before torch. # See bug report here https://github.com/Lightning-AI/lightning/issues/5829 # Importing now to solve issues later. -import comet_ml import torch from scilpy.io.utils import assert_inputs_exist, assert_outputs_exist @@ -23,7 +22,6 @@ from dwi_ml.models.projects.learn2track_model import Learn2TrackModel from dwi_ml.models.projects.learn2track_utils import add_model_args from dwi_ml.models.utils.direction_getters import check_args_direction_getter -from dwi_ml.tracking.utils import prepare_tracking_mask from dwi_ml.training.projects.learn2track_trainer import Learn2TrackTrainer from dwi_ml.training.utils.batch_samplers import (add_args_batch_sampler, prepare_batch_sampler) diff --git a/scripts_python/tto_track_from_model.py b/scripts_python/tto_track_from_model.py index b6b8ed09..7acb268f 100644 --- a/scripts_python/tto_track_from_model.py +++ b/scripts_python/tto_track_from_model.py @@ -30,9 +30,9 @@ from dwi_ml.tracking.projects.transformer_tracker import \ TransformerTracker from dwi_ml.tracking.tracking_mask import TrackingMask -from dwi_ml.tracking.utils import (add_tracking_options, - prepare_seed_generator, - prepare_tracking_mask, track_and_save) +from dwi_ml.tracking.projects.utils import (add_tracking_options, + prepare_seed_generator, + prepare_tracking_mask, track_and_save) # A decision should be made as if we should keep the last point (out of the # tracking mask). Currently keeping this as in Dipy, i.e. True. Could be diff --git a/scripts_python/ttst_track_from_model.py b/scripts_python/ttst_track_from_model.py index f4e06ca9..734eef0d 100644 --- a/scripts_python/ttst_track_from_model.py +++ b/scripts_python/ttst_track_from_model.py @@ -30,9 +30,9 @@ from dwi_ml.tracking.projects.transformer_tracker import \ TransformerTracker from dwi_ml.tracking.tracking_mask import TrackingMask -from dwi_ml.tracking.utils import (add_tracking_options, - prepare_seed_generator, - prepare_tracking_mask, track_and_save) +from dwi_ml.tracking.projects.utils import (add_tracking_options, + prepare_seed_generator, + prepare_tracking_mask, track_and_save) # A decision should be made as if we should keep the last point (out of the # tracking mask). Currently keeping this as in Dipy, i.e. True. Could be From 61ad86f61a87467c8423f9f9e1b0706cfc61fef6 Mon Sep 17 00:00:00 2001 From: EmmaRenauld Date: Thu, 29 Jun 2023 14:41:49 -0400 Subject: [PATCH 2/3] only one append_last_point variable --- dwi_ml/tracking/projects/utils.py | 5 +++++ scripts_python/l2t_track_from_model.py | 8 ++------ scripts_python/tto_track_from_model.py | 9 +++------ scripts_python/ttst_track_from_model.py | 9 +++------ 4 files changed, 13 insertions(+), 18 deletions(-) diff --git a/dwi_ml/tracking/projects/utils.py b/dwi_ml/tracking/projects/utils.py index 0c2749dd..f75a7d52 100644 --- a/dwi_ml/tracking/projects/utils.py +++ b/dwi_ml/tracking/projects/utils.py @@ -20,6 +20,11 @@ ALWAYS_VOX_SPACE = Space.VOX ALWAYS_CORNER = Origin('corner') +# A decision should be made as if we should keep the last point (out of the +# tracking mask). Currently keeping this as in Dipy, i.e. True. Could be +# an option for the user. +APPEND_LAST_POINT = True # See here: https://github.com/dipy/dipy/discussions/2764 + def add_tracking_options(p): diff --git a/scripts_python/l2t_track_from_model.py b/scripts_python/l2t_track_from_model.py index e4523736..bc2c9172 100644 --- a/scripts_python/l2t_track_from_model.py +++ b/scripts_python/l2t_track_from_model.py @@ -30,12 +30,8 @@ from dwi_ml.tracking.tracking_mask import TrackingMask from dwi_ml.tracking.projects.utils import (add_tracking_options, prepare_seed_generator, - prepare_tracking_mask, track_and_save) - -# A decision should be made as if we should keep the last point (out of the -# tracking mask). Currently keeping this as in Dipy, i.e. True. Could be -# an option for the user. -APPEND_LAST_POINT = False + prepare_tracking_mask, track_and_save, + APPEND_LAST_POINT) def build_argparser(): diff --git a/scripts_python/tto_track_from_model.py b/scripts_python/tto_track_from_model.py index 7acb268f..35fde512 100644 --- a/scripts_python/tto_track_from_model.py +++ b/scripts_python/tto_track_from_model.py @@ -32,12 +32,9 @@ from dwi_ml.tracking.tracking_mask import TrackingMask from dwi_ml.tracking.projects.utils import (add_tracking_options, prepare_seed_generator, - prepare_tracking_mask, track_and_save) - -# A decision should be made as if we should keep the last point (out of the -# tracking mask). Currently keeping this as in Dipy, i.e. True. Could be -# an option for the user. -APPEND_LAST_POINT = True + prepare_tracking_mask, + track_and_save, + APPEND_LAST_POINT) def build_argparser(): diff --git a/scripts_python/ttst_track_from_model.py b/scripts_python/ttst_track_from_model.py index 734eef0d..eb20ebc0 100644 --- a/scripts_python/ttst_track_from_model.py +++ b/scripts_python/ttst_track_from_model.py @@ -32,12 +32,9 @@ from dwi_ml.tracking.tracking_mask import TrackingMask from dwi_ml.tracking.projects.utils import (add_tracking_options, prepare_seed_generator, - prepare_tracking_mask, track_and_save) - -# A decision should be made as if we should keep the last point (out of the -# tracking mask). Currently keeping this as in Dipy, i.e. True. Could be -# an option for the user. -APPEND_LAST_POINT = True + prepare_tracking_mask, + track_and_save, + APPEND_LAST_POINT) def build_argparser(): From 796ac7ecca00764ffddaed605f48fd9ae880793f Mon Sep 17 00:00:00 2001 From: EmmaRenauld Date: Thu, 29 Jun 2023 14:43:30 -0400 Subject: [PATCH 3/3] update github's python version --- .github/workflows/test_package.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test_package.yml b/.github/workflows/test_package.yml index f1702318..46842849 100644 --- a/.github/workflows/test_package.yml +++ b/.github/workflows/test_package.yml @@ -11,7 +11,7 @@ jobs: # max-parallel: 6 matrix: os: [ubuntu-latest] - python-version: [3.10.11] + python-version: [3.10.12] requires: ['latest'] steps: