Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
EmmaRenauld committed Mar 12, 2024
1 parent 95b115e commit 6106e84
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 7 deletions.
11 changes: 6 additions & 5 deletions dwi_ml/models/direction_getter_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1176,9 +1176,7 @@ def _prepare_dirs_for_loss(self, target_dirs: List[Tensor]):
Returns: list[Tensors], the directions.
"""
# Need to normalize before adding EOS labels (dir = 0,0,0)
if self.normalize_targets is not None:
target_dirs = normalize_directions(target_dirs,
new_norm=self.normalize_targets)
target_dirs = normalize_directions(target_dirs)
return add_label_as_last_dim(target_dirs, add_sos=False,
add_eos=self.add_eos)

Expand Down Expand Up @@ -1209,7 +1207,7 @@ def forward(self, inputs: Tensor) -> Tuple[Tensor, Tensor]:
def stack_batch(outputs, target_dirs):
target_dirs = torch.vstack(target_dirs)
mu = torch.vstack(outputs[0])
kappa = torch.vstack(outputs[1])
kappa = torch.hstack(outputs[1]) # Not vstack: they are vectors
return (mu, kappa), target_dirs

def _compute_loss(self, learned_fisher_params: Tuple[Tensor, Tensor],
Expand All @@ -1227,13 +1225,16 @@ def _compute_loss(self, learned_fisher_params: Tuple[Tensor, Tensor],
mu = mu[:, 0:3]

# 1. Main loss
log_prob = fisher_von_mises_log_prob(mu, kappa, target_dirs)
log_prob = fisher_von_mises_log_prob(mu, kappa, target_dirs[:, 0:3])
nll_loss = -log_prob

n = 1
if average_results:
nll_loss, n = _mean_and_weight(nll_loss)

#logging.warning("Batch Prob: {}. Log: {}. NLL: {}"
# .format(torch.mean(torch.exp(log_prob)), torch.mean(log_prob), nll_loss))

# 2. EOS loss:
if self.add_eos:
# Binary cross-entropy
Expand Down
3 changes: 1 addition & 2 deletions dwi_ml/models/utils/fisher_von_mises.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ def fisher_von_mises_log_prob(mus, kappa, targets, eps=1e-5):

eps = torch.as_tensor(eps, device=kappa.device, dtype=torch.float32)

# Add an epsilon in case kappa is too small (i.e. a uniform
# distribution)
# Add an epsilon in case kappa is too small (i.e. a uniform distribution)
log_diff_exp_kappa = torch.log(
torch.maximum(eps, torch.exp(kappa) - torch.exp(-kappa)))

Expand Down
1 change: 1 addition & 0 deletions dwi_ml/training/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,7 @@ def back_propagation(self, loss):
# Any other steps. Ex: clip gradients. Not implemented here.
# See Learn2track's Trainer for an example.
unclipped_grad_norm = self.fix_parameters()
# logging.warning(" Unclipped grad norm: {}".format(unclipped_grad_norm))

# Supervizing the gradient's norm.
grad_norm = compute_gradient_norm(self.model.parameters())
Expand Down

0 comments on commit 6106e84

Please sign in to comment.