Skip to content

Commit

Permalink
Merge pull request #206 from EmmaRenauld/Test_Gaussian
Browse files Browse the repository at this point in the history
Finalize Gaussian model.
  • Loading branch information
EmmaRenauld authored Oct 18, 2023
2 parents b7ac03a + 36b1c00 commit b07544b
Show file tree
Hide file tree
Showing 22 changed files with 398 additions and 213 deletions.
5 changes: 3 additions & 2 deletions dwi_ml/data/processing/streamlines/post_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def compute_directions(streamlines):
return batch_directions


def normalize_directions(directions):
def normalize_directions(directions, new_norm=1.0):
"""
Params
------
Expand All @@ -136,8 +136,9 @@ def normalize_directions(directions):
# propagation will fail.
directions = directions / torch.linalg.norm(directions, dim=-1,
keepdim=True)
directions *= new_norm
else:
directions = [s / torch.linalg.norm(s, dim=-1, keepdim=True)
directions = [s / torch.linalg.norm(s, dim=-1, keepdim=True) * new_norm
for s in directions]

return directions
Expand Down
220 changes: 147 additions & 73 deletions dwi_ml/models/direction_getter_models.py

Large diffs are not rendered by default.

22 changes: 18 additions & 4 deletions dwi_ml/models/projects/learn2track_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,13 @@ def __init__(self, experiment_name,
if dropout < 0 or dropout > 1:
raise ValueError('The dropout rate must be between 0 and 1.')

if start_from_copy_prev and 'gaussian' in dg_key:
raise ValueError("Start_from_copy_prev makes no sense with "
"Gaussian direction getters.")
if start_from_copy_prev and 'fisher' in dg_key:
raise ValueError("Start_from_copy_prev makes no sense with "
"Fisher von Mises direction getters.")

# ---------- Instantiations
# 1. Previous dirs embedding: prepared by super.

Expand Down Expand Up @@ -366,18 +373,25 @@ def forward(self, x: List[torch.tensor],
x = self.direction_getter(x)

# Adding either prev_dir or 0.
x = x + copy_prev_dir
if self.start_from_copy_prev:
x = x + copy_prev_dir

# Unpacking.
if not self._context == 'tracking':
# (during tracking: keeping as one single tensor.)
if 'gaussian' in self.dg_key or 'fisher' in self.dg_key:
# Separating mean, sigmas (gaussian) or mean, kappa (fisher)
x, x2 = x

x2 = PackedSequence(x2, batch_sizes)
x2 = faster_unpack_sequence(x2)
x2 = [x2[i] for i in unsorted_indices]
x = PackedSequence(x, batch_sizes)
x = faster_unpack_sequence(x)
x = [x[i] for i in unsorted_indices]

assert x[0].shape[-1] == self.direction_getter.output_size, \
"Expecting output size of {}. Got {}" \
.format(self.direction_getter.output_size, x[0].shape[-1])
if 'gaussian' in self.dg_key or 'fisher' in self.dg_key:
x = (x, x2)

if return_hidden:
# Return the hidden states too. Necessary for the generative
Expand Down
50 changes: 35 additions & 15 deletions dwi_ml/models/utils/direction_getters.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,23 +35,29 @@ def add_direction_getter_args(p: ArgumentParser, gaussian_fisher_args=True):

# Gaussian models, Fisher-von-Mises models
if gaussian_fisher_args:
p.add_argument(
'--add_entropy_to_gauss', nargs='?', const=1.0, type=float,
metavar='f',
help="For GAUSSIAN models: If set, adds the entropy to the negative "
"log-likelihood \nloss. By defaut, weight is 1.0, but a "
"value >1 can be added \n to increase its influence.")
p.add_argument(
'--dg_nb_gaussians', type=int, metavar='n',
help="Number of gaussians in the case of a Gaussian Mixture model "
"for the direction \ngetter. [3]")
help="For GAUSSIAN models: Number of gaussians in the case of a "
"mixture model. [3]")
p.add_argument(
'--dg_nb_clusters', type=int,
help="Number of clusters in the case of a Fisher von Mises "
"Mixture model for the direction \ngetter. [3].")
help="For FISHER VON MISES models: Number of clusters in the case "
"of a mixture model for the direction \ngetter. [3]")
p.add_argument(
'--normalize_targets', const=1., nargs='?', type=float,
metavar='norm',
help="For REGRESSION models: If set, target directions will be "
help="For REGRESSION models: If set, target directions will be "
"normalized before \ncomputing the loss. Default norm: 1.")
p.add_argument(
'--normalize_outputs', const=1., nargs='?', type=float,
metavar='norm',
help="For REGRESSION models: If set, model outputs will be "
help="For REGRESSION models: If set, model outputs will be "
"normalized. Default norm: 1.")

# EOS
Expand Down Expand Up @@ -82,14 +88,22 @@ def check_args_direction_getter(args):
if args.dg_dropout < 0 or args.dg_dropout > 1:
raise ValueError('The dg dropout rate must be between 0 and 1.')

# Gaussian additional arg = nb_gaussians.
# Gaussian additional arg = nb_gaussians and entropy_weight.
if args.dg_key == 'gaussian-mixture':
if args.dg_nb_gaussians:
dg_args.update({'nb_gaussians': args.dg_nb_gaussians})
elif args.dg_nb_gaussians:
logging.warning("You have provided a value for --dg_nb_gaussians but "
"the chosen direction getter is not the gaussian "
"mixture. Ignored.")
if args.add_entropy_to_gauss:
dg_args.update({'entroy_weight': args.add_entropy_to_gauss})

else:
if args.dg_nb_gaussians:
logging.warning("You have provided a value for --dg_nb_gaussians "
"but the chosen direction getter is not the "
"gaussian mixture. Ignored.")
if args.add_entropy_to_gauss:
logging.warning("You have provided a value for --add_entropy_to_gauss "
"but the chosen direction getter is not the "
"gaussian mixture. Ignored.")

# Fisher additional arg = nb_clusters
if args.dg_key == 'fisher-von-mises-mixture':
Expand All @@ -101,10 +115,16 @@ def check_args_direction_getter(args):
"Mises mixture. Ignored.")

# Regression and normalisation
if 'regression' in args.dg_key or 'gaussian' in args.dg_key:
dg_args['normalize_targets'] = args.normalize_targets
elif args.normalize_targets:
raise ValueError("--normalize_targets is only an option for "
"regression and gaussian models.")

if 'regression' in args.dg_key:
dg_args.update({
'normalize_targets': args.normalize_targets,
'normalize_outputs': args.normalize_outputs,
})
dg_args['normalize_outputs'] = args.normalize_outputs
elif args.normalize_outputs is not None:
raise ValueError("--normalize_outputs is only an option for "
"regression models.")

return dg_args
20 changes: 17 additions & 3 deletions dwi_ml/testing/testers.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def run_model_on_sft(self, sft, add_zeros_if_no_eos=True,
force_compress_loss: bool
If true, compresses the loss even if that is not the model's
parameter.
change_weight_with_angle: bool
weight_with_angle: bool
If true, modify model's wieght_loss_with_angle param.
"""
if uncompress_loss and force_compress_loss:
Expand All @@ -141,9 +141,15 @@ def run_model_on_sft(self, sft, add_zeros_if_no_eos=True,
batch_size = self.batch_size or len(sft)
nb_batches = int(np.ceil(len(sft) / batch_size))

if 'gaussian' in self.model.direction_getter.key:
outputs = ([], [])
elif 'fisher' in self.model.direction_getter.key:
raise NotImplementedError
else:
outputs = []

losses = []
compressed_n = []
outputs = []
batch_start = 0
batch_end = batch_size
with torch.no_grad():
Expand Down Expand Up @@ -180,7 +186,15 @@ def run_model_on_sft(self, sft, add_zeros_if_no_eos=True,
losses.extend([line_loss.cpu() for line_loss in
tmp_losses])

outputs.extend([o.cpu() for o in tmp_outputs])
# ToDo. See if we can simplify to fit with all models
if 'gaussian' in self.model.direction_getter.key:
tmp_means, tmp_sigmas = tmp_outputs
outputs[0].extend([m.cpu() for m in tmp_means])
outputs[1].extend([s.cpu() for s in tmp_sigmas])
elif 'fisher' in self.model.direction_getter.key:
raise NotImplementedError
else:
outputs.extend([o.cpu() for o in tmp_outputs])

batch_start = batch_end
batch_end = min(batch_start + batch_size, len(sft))
Expand Down
28 changes: 23 additions & 5 deletions dwi_ml/testing/visu_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def combine_displacement_with_ref(out_dirs, sft, step_size_mm=None):
color_x = []
color_y = []
color_z = []

for i, s in enumerate(sft.streamlines):
this_s_len = len(s)

Expand Down Expand Up @@ -253,20 +254,37 @@ def run_visu_save_colored_displacement(
save_tractogram(worst_sft, worst_sft_name)

# Save displacement
args.pick_idx = list(range(10))
if args.out_displacement_sft:
if args.out_colored_sft:
# We have run model on all streamlines. Picking a few now.
sft, idx = pick_a_few(
sft, best_idx, worst_idx, args.pick_at_random,
args.pick_best_and_worst, args.pick_idx)
outputs = [outputs[i] for i in idx]

# Either concat, run, split or (chosen:) loop
# ToDo. See if we can simplify to fit with all models
if 'gaussian' in model.direction_getter.key:
means, sigmas = outputs
means = [means[i] for i in idx]
lengths = [len(line) for line in means]
outputs = (torch.vstack(means),
torch.vstack([sigmas[i] for i in idx]))

elif 'fisher' in model.direction_getter.key:
raise NotImplementedError
else:
outputs = [outputs[i] for i in idx]
lengths = [len(line) for line in outputs]
outputs = torch.vstack(outputs)

# Use eos_thresh of 1 to be sure we don't output a NaN
with torch.no_grad():
out_dirs = [model.get_tracking_directions(
s_output, algo='det', eos_stopping_thresh=1.0).numpy()
for s_output in outputs]
out_dirs = model.get_tracking_directions(
outputs, algo='det', eos_stopping_thresh=1.0)

out_dirs = torch.split(out_dirs, lengths)

out_dirs = [o.numpy() for o in out_dirs]

# Save error together with ref
sft = combine_displacement_with_ref(out_dirs, sft, model.step_size)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,12 @@ def prepare_seed_generator(parser, args, hdf_handle):
Prepares a SeedGenerator from scilpy's library. Returns also some header
information to allow verifications.
"""
if args.subj_id not in hdf_handle:
raise ValueError("Subject {} not found in the HDF5 file."
.format(args.subj_id))
if args.seeding_mask_group not in hdf_handle[args.subj_id]:
raise ValueError("Seeding mask {} not found the subject's HDF group."
.format(args.seeding_mask_group))
seeding_group = hdf_handle[args.subj_id][args.seeding_mask_group]
seed_data = np.array(seeding_group['data'], dtype=np.float32)
seed_res = np.array(seeding_group.attrs['voxres'], dtype=np.float32)
Expand Down
40 changes: 2 additions & 38 deletions dwi_ml/training/projects/learn2track_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch

from dwi_ml.models.projects.learn2track_model import Learn2TrackModel
from dwi_ml.tracking.projects.utils import prepare_tracking_mask
from dwi_ml.tracking.io_utils import prepare_tracking_mask
from dwi_ml.tracking.propagation import propagate_multiple_lines
from dwi_ml.training.with_generation.trainer import \
DWIMLTrainerForTrackingOneInput
Expand All @@ -22,45 +22,9 @@ class Learn2TrackTrainer(DWIMLTrainerForTrackingOneInput):
"""
model: Learn2TrackModel

def __init__(self, clip_grad: float = None, **kwargs):
"""
Init trainer.
Additional values compared to super:
clip_grad : float
The value to which to clip gradients after the backward pass.
There is no good value here. Default: 1000.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)

self.clip_grad = clip_grad

@property
def params_for_checkpoint(self):
params = super().params_for_checkpoint
params.update({
'clip_grad': self.clip_grad
})
return params

def _prepare_checkpoint_info(self) -> dict:
checkpoint_state = super()._prepare_checkpoint_info()
checkpoint_state['params_for_init'].update({
'clip_grad': self.clip_grad
})

return checkpoint_state

def fix_parameters(self):
"""
In our case, clipping gradients to avoid exploding gradients in RNN
"""
if self.clip_grad is not None:
total_norm = torch.nn.utils.clip_grad_norm_(
self.model.parameters(), self.clip_grad)
if torch.isnan(total_norm):
raise ValueError("Exploding gradients. Experiment failed.")

def propagate_multiple_lines(self, lines: List[torch.Tensor], ids_per_subj):
assert self.model.step_size is not None, \
"We can't propagate compressed streamlines."
Expand Down
2 changes: 1 addition & 1 deletion dwi_ml/training/projects/transformer_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
import torch

from dwi_ml.tracking.projects.utils import prepare_tracking_mask
from dwi_ml.tracking.io_utils import prepare_tracking_mask
from dwi_ml.tracking.propagation import propagate_multiple_lines

from dwi_ml.training.with_generation.trainer import \
Expand Down
Loading

0 comments on commit b07544b

Please sign in to comment.