From 6106e84184fe7ae9325244d2f5cfbb52b290b9e5 Mon Sep 17 00:00:00 2001 From: EmmaRenauld Date: Tue, 12 Mar 2024 14:12:11 -0400 Subject: [PATCH] WIP --- dwi_ml/models/direction_getter_models.py | 11 ++++++----- dwi_ml/models/utils/fisher_von_mises.py | 3 +-- dwi_ml/training/trainers.py | 1 + 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/dwi_ml/models/direction_getter_models.py b/dwi_ml/models/direction_getter_models.py index de41cdb5..5c6bb500 100644 --- a/dwi_ml/models/direction_getter_models.py +++ b/dwi_ml/models/direction_getter_models.py @@ -1176,9 +1176,7 @@ def _prepare_dirs_for_loss(self, target_dirs: List[Tensor]): Returns: list[Tensors], the directions. """ # Need to normalize before adding EOS labels (dir = 0,0,0) - if self.normalize_targets is not None: - target_dirs = normalize_directions(target_dirs, - new_norm=self.normalize_targets) + target_dirs = normalize_directions(target_dirs) return add_label_as_last_dim(target_dirs, add_sos=False, add_eos=self.add_eos) @@ -1209,7 +1207,7 @@ def forward(self, inputs: Tensor) -> Tuple[Tensor, Tensor]: def stack_batch(outputs, target_dirs): target_dirs = torch.vstack(target_dirs) mu = torch.vstack(outputs[0]) - kappa = torch.vstack(outputs[1]) + kappa = torch.hstack(outputs[1]) # Not vstack: they are vectors return (mu, kappa), target_dirs def _compute_loss(self, learned_fisher_params: Tuple[Tensor, Tensor], @@ -1227,13 +1225,16 @@ def _compute_loss(self, learned_fisher_params: Tuple[Tensor, Tensor], mu = mu[:, 0:3] # 1. Main loss - log_prob = fisher_von_mises_log_prob(mu, kappa, target_dirs) + log_prob = fisher_von_mises_log_prob(mu, kappa, target_dirs[:, 0:3]) nll_loss = -log_prob n = 1 if average_results: nll_loss, n = _mean_and_weight(nll_loss) + #logging.warning("Batch Prob: {}. Log: {}. NLL: {}" + # .format(torch.mean(torch.exp(log_prob)), torch.mean(log_prob), nll_loss)) + # 2. EOS loss: if self.add_eos: # Binary cross-entropy diff --git a/dwi_ml/models/utils/fisher_von_mises.py b/dwi_ml/models/utils/fisher_von_mises.py index 16a14ad6..75daccea 100644 --- a/dwi_ml/models/utils/fisher_von_mises.py +++ b/dwi_ml/models/utils/fisher_von_mises.py @@ -22,8 +22,7 @@ def fisher_von_mises_log_prob(mus, kappa, targets, eps=1e-5): eps = torch.as_tensor(eps, device=kappa.device, dtype=torch.float32) - # Add an epsilon in case kappa is too small (i.e. a uniform - # distribution) + # Add an epsilon in case kappa is too small (i.e. a uniform distribution) log_diff_exp_kappa = torch.log( torch.maximum(eps, torch.exp(kappa) - torch.exp(-kappa))) diff --git a/dwi_ml/training/trainers.py b/dwi_ml/training/trainers.py index ed28fb47..c2ef6300 100644 --- a/dwi_ml/training/trainers.py +++ b/dwi_ml/training/trainers.py @@ -761,6 +761,7 @@ def back_propagation(self, loss): # Any other steps. Ex: clip gradients. Not implemented here. # See Learn2track's Trainer for an example. unclipped_grad_norm = self.fix_parameters() + # logging.warning(" Unclipped grad norm: {}".format(unclipped_grad_norm)) # Supervizing the gradient's norm. grad_norm = compute_gradient_norm(self.model.parameters())