Skip to content

Commit

Permalink
Load the tracking mask correctly
Browse files Browse the repository at this point in the history
  • Loading branch information
EmmaRenauld committed Jun 12, 2023
1 parent 8665675 commit db0345d
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 39 deletions.
30 changes: 23 additions & 7 deletions dwi_ml/training/projects/learn2track_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
import logging
from typing import List

import h5py
import numpy as np
import torch

from dwi_ml.models.projects.learn2track_model import Learn2TrackModel
from dwi_ml.tracking.propagation import propagate_multiple_lines
from dwi_ml.tracking.utils import prepare_tracking_mask
from dwi_ml.training.with_generation.trainer import \
DWIMLTrainerForTrackingOneInput

Expand Down Expand Up @@ -92,10 +94,24 @@ def get_dirs_at_last_pos(_lines: List[torch.Tensor], n_last_pos):

theta = 2 * np.pi # theta = 360 degrees
max_nbr_pts = int(200 / self.model.step_size)
results = propagate_multiple_lines(
lines, update_memory_after_removing_lines, get_dirs_at_last_pos,
theta=theta, step_size=self.model.step_size,
verify_opposite_direction=False, mask=self.tracking_mask,
max_nbr_pts=max_nbr_pts, append_last_point=False,
normalize_directions=True)
return results

final_lines = []
for subj_idx, line_idx in ids_per_subj.items():

with h5py.File(self.batch_loader.dataset.hdf5_file, 'r') as hdf_handle:
subj_id = self.batch_loader.context_subset.subjects[subj_idx]
logging.debug("Loading subj {} ({})'s tracking mask."
.format(subj_idx, subj_id))
tracking_mask, _ = prepare_tracking_mask(
hdf_handle, self.tracking_mask_group, subj_id=subj_id,
mask_interp='nearest')
tracking_mask.move_to(self.device)

final_lines.extend(propagate_multiple_lines(
lines[line_idx], update_memory_after_removing_lines,
get_dirs_at_last_pos, theta=theta,
step_size=self.model.step_size, verify_opposite_direction=False,
mask=tracking_mask, max_nbr_pts=max_nbr_pts,
append_last_point=False, normalize_directions=True))

return final_lines
30 changes: 24 additions & 6 deletions dwi_ml/training/projects/transformer_trainer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
# -*- coding: utf-8 -*-
import logging
from typing import List

import h5py
import numpy as np
import torch
from dwi_ml.tracking.propagation import propagate_multiple_lines
from dwi_ml.tracking.utils import prepare_tracking_mask

from dwi_ml.training.with_generation.trainer import \
DWIMLTrainerForTrackingOneInput
Expand Down Expand Up @@ -45,9 +48,24 @@ def get_dirs_at_last_pos(_lines: List[torch.Tensor], n_last_pos):

theta = 2 * np.pi # theta = 360 degrees
max_nbr_pts = int(200 / self.model.step_size)
return propagate_multiple_lines(
lines, update_memory_after_removing_lines, get_dirs_at_last_pos,
theta=theta, step_size=self.model.step_size,
verify_opposite_direction=False, mask=self.tracking_mask,
max_nbr_pts=max_nbr_pts, append_last_point=False,
normalize_directions=True)

final_lines = []
for subj_idx, line_idx in ids_per_subj.items():

with h5py.File(self.batch_loader.dataset.hdf5_file, 'r') as hdf_handle:
subj_id = self.batch_loader.context_subset.subjects[subj_idx]
logging.debug("Loading subj {} ({})'s tracking mask."
.format(subj_idx, subj_id))
tracking_mask, _ = prepare_tracking_mask(
hdf_handle, self.tracking_mask_group, subj_id=subj_id,
mask_interp='nearest')
tracking_mask.move_to(self.device)

final_lines.extend(propagate_multiple_lines(
lines[line_idx], update_memory_after_removing_lines,
get_dirs_at_last_pos, theta=theta,
step_size=self.model.step_size, verify_opposite_direction=False,
mask=tracking_mask, max_nbr_pts=max_nbr_pts,
append_last_point=False, normalize_directions=True))

return final_lines
52 changes: 26 additions & 26 deletions dwi_ml/training/with_generation/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,26 +100,7 @@ def __init__(self, add_a_tracking_validation_phase: bool = False,
self.tracking_phase_nb_steps_init = tracking_phase_nb_steps_init
self.tracking_mask_group = tracking_phase_mask_group

self.tracking_mask = None
if add_a_tracking_validation_phase:
# Right now, using any subject's, and supposing that they are all
# in the same space. Else, code would need refactoring to allow
# tracking on multiple subjects. Or we can loop on each subject.
logging.warning("***************\n"
"CODE NEEDS REFACTORING. USING THE SAME TRACKING "
"MASK FOR ALL SUBJECTS.\n"
"***************\n")
any_subj = self.batch_loader.dataset.training_set.subjects[0]
if tracking_phase_mask_group is not None:
with h5py.File(self.batch_loader.dataset.hdf5_file, 'r') \
as hdf_handle:
logging.info("Loading tracking mask.")
self.tracking_mask, _ = prepare_tracking_mask(
hdf_handle, tracking_phase_mask_group, subj_id=any_subj,
mask_interp='nearest')
self.tracking_mask.move_to(self.device)

self.compute_connectivity = self.batch_loader.data_contains_connectivity
self.compute_connectivity = self.batch_loader.data_contains_connectivity

# -------- Monitors
# At training time: only the one metric used for training.
Expand Down Expand Up @@ -387,9 +368,28 @@ def get_dirs_at_last_pos(_lines: List[torch.Tensor], n_last_pos):

theta = 2 * np.pi # theta = 360 degrees
max_nbr_pts = int(200 / self.model.step_size)
return propagate_multiple_lines(
lines, update_memory_after_removing_lines, get_dirs_at_last_pos,
theta=theta, step_size=self.model.step_size,
verify_opposite_direction=False, mask=self.tracking_mask,
max_nbr_pts=max_nbr_pts, append_last_point=False,
normalize_directions=True)

# Looping on subjects because current implementation requires a single
# tracking mask. But all the rest (get_dirs_at_last_pos, particularly)
# work on multiple subjects because the batch loader loads input
# according to subject id. Could refactor "propagate_multiple_line" to
# accept multiple masks or manage it differently.
final_lines = []
for subj_idx, line_idx in ids_per_subj.items():
with h5py.File(self.batch_loader.dataset.hdf5_file, 'r') as hdf_handle:
subj_id = self.batch_loader.context_subset.subjects[subj_idx]
logging.debug("Loading subj {} ({})'s tracking mask."
.format(subj_idx, subj_id))
tracking_mask, _ = prepare_tracking_mask(
hdf_handle, self.tracking_mask_group, subj_id=subj_id,
mask_interp='nearest')
tracking_mask.move_to(self.device)

final_lines.extend(propagate_multiple_lines(
lines[line_idx], update_memory_after_removing_lines,
get_dirs_at_last_pos, theta=theta,
step_size=self.model.step_size, verify_opposite_direction=False,
mask=tracking_mask, max_nbr_pts=max_nbr_pts,
append_last_point=False, normalize_directions=True))

return final_lines

0 comments on commit db0345d

Please sign in to comment.