diff --git a/dwi_ml/training/projects/learn2track_trainer.py b/dwi_ml/training/projects/learn2track_trainer.py index 5267b3ce..37dcc5e3 100644 --- a/dwi_ml/training/projects/learn2track_trainer.py +++ b/dwi_ml/training/projects/learn2track_trainer.py @@ -2,11 +2,13 @@ import logging from typing import List +import h5py import numpy as np import torch from dwi_ml.models.projects.learn2track_model import Learn2TrackModel from dwi_ml.tracking.propagation import propagate_multiple_lines +from dwi_ml.tracking.utils import prepare_tracking_mask from dwi_ml.training.with_generation.trainer import \ DWIMLTrainerForTrackingOneInput @@ -92,10 +94,24 @@ def get_dirs_at_last_pos(_lines: List[torch.Tensor], n_last_pos): theta = 2 * np.pi # theta = 360 degrees max_nbr_pts = int(200 / self.model.step_size) - results = propagate_multiple_lines( - lines, update_memory_after_removing_lines, get_dirs_at_last_pos, - theta=theta, step_size=self.model.step_size, - verify_opposite_direction=False, mask=self.tracking_mask, - max_nbr_pts=max_nbr_pts, append_last_point=False, - normalize_directions=True) - return results + + final_lines = [] + for subj_idx, line_idx in ids_per_subj.items(): + + with h5py.File(self.batch_loader.dataset.hdf5_file, 'r') as hdf_handle: + subj_id = self.batch_loader.context_subset.subjects[subj_idx] + logging.debug("Loading subj {} ({})'s tracking mask." + .format(subj_idx, subj_id)) + tracking_mask, _ = prepare_tracking_mask( + hdf_handle, self.tracking_mask_group, subj_id=subj_id, + mask_interp='nearest') + tracking_mask.move_to(self.device) + + final_lines.extend(propagate_multiple_lines( + lines[line_idx], update_memory_after_removing_lines, + get_dirs_at_last_pos, theta=theta, + step_size=self.model.step_size, verify_opposite_direction=False, + mask=tracking_mask, max_nbr_pts=max_nbr_pts, + append_last_point=False, normalize_directions=True)) + + return final_lines diff --git a/dwi_ml/training/with_generation/trainer.py b/dwi_ml/training/with_generation/trainer.py index 086c0965..ef28eaa0 100644 --- a/dwi_ml/training/with_generation/trainer.py +++ b/dwi_ml/training/with_generation/trainer.py @@ -100,26 +100,7 @@ def __init__(self, add_a_tracking_validation_phase: bool = False, self.tracking_phase_nb_steps_init = tracking_phase_nb_steps_init self.tracking_mask_group = tracking_phase_mask_group - self.tracking_mask = None - if add_a_tracking_validation_phase: - # Right now, using any subject's, and supposing that they are all - # in the same space. Else, code would need refactoring to allow - # tracking on multiple subjects. Or we can loop on each subject. - logging.warning("***************\n" - "CODE NEEDS REFACTORING. USING THE SAME TRACKING " - "MASK FOR ALL SUBJECTS.\n" - "***************\n") - any_subj = self.batch_loader.dataset.training_set.subjects[0] - if tracking_phase_mask_group is not None: - with h5py.File(self.batch_loader.dataset.hdf5_file, 'r') \ - as hdf_handle: - logging.info("Loading tracking mask.") - self.tracking_mask, _ = prepare_tracking_mask( - hdf_handle, tracking_phase_mask_group, subj_id=any_subj, - mask_interp='nearest') - self.tracking_mask.move_to(self.device) - - self.compute_connectivity = self.batch_loader.data_contains_connectivity + self.compute_connectivity = self.batch_loader.data_contains_connectivity # -------- Monitors # At training time: only the one metric used for training. @@ -387,9 +368,28 @@ def get_dirs_at_last_pos(_lines: List[torch.Tensor], n_last_pos): theta = 2 * np.pi # theta = 360 degrees max_nbr_pts = int(200 / self.model.step_size) - return propagate_multiple_lines( - lines, update_memory_after_removing_lines, get_dirs_at_last_pos, - theta=theta, step_size=self.model.step_size, - verify_opposite_direction=False, mask=self.tracking_mask, - max_nbr_pts=max_nbr_pts, append_last_point=False, - normalize_directions=True) + + # Looping on subjects because current implementation requires a single + # tracking mask. But all the rest (get_dirs_at_last_pos, particularly) + # work on multiple subjects because the batch loader loads input + # according to subject id. Could refactor "propagate_multiple_line" to + # accept multiple masks or manage it differently. + final_lines = [] + for subj_idx, line_idx in ids_per_subj.items(): + with h5py.File(self.batch_loader.dataset.hdf5_file, 'r') as hdf_handle: + subj_id = self.batch_loader.context_subset.subjects[subj_idx] + logging.debug("Loading subj {} ({})'s tracking mask." + .format(subj_idx, subj_id)) + tracking_mask, _ = prepare_tracking_mask( + hdf_handle, self.tracking_mask_group, subj_id=subj_id, + mask_interp='nearest') + tracking_mask.move_to(self.device) + + final_lines.extend(propagate_multiple_lines( + lines[line_idx], update_memory_after_removing_lines, + get_dirs_at_last_pos, theta=theta, + step_size=self.model.step_size, verify_opposite_direction=False, + mask=tracking_mask, max_nbr_pts=max_nbr_pts, + append_last_point=False, normalize_directions=True)) + + return final_lines