Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Finish Fisher von Mises loss #230

Merged
merged 6 commits into from
Apr 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 92 additions & 35 deletions dwi_ml/models/direction_getter_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,7 @@ def init_2layer_fully_connected(input_size: int, output_size: int):


def binary_cross_entropy_eos(learned_eos, target_eos, average_results=True):
reduction = 'none'
if average_results:
reduction = 'mean'
reduction = 'mean' if average_results else 'none'

learned_eos = torch.sigmoid(learned_eos)
losses_eos = torch.nn.functional.binary_cross_entropy(
Expand Down Expand Up @@ -80,7 +78,7 @@ class AbstractDirectionGetterModel(torch.nn.Module):
-----------------------
"""
def __init__(self, input_size: int, key: str,
supports_compressed_streamlines: bool, dropout: float = None,
supports_compressed_streamlines: bool, dropout: float = None,
compress_loss: bool = False, compress_eps: float = 1e-3,
weight_loss_with_angle: bool = False,
loss_description: str = '', add_eos: bool = False,
Expand Down Expand Up @@ -210,8 +208,8 @@ def _prepare_dirs_for_loss(self, target_dirs: List[Tensor]):
def compute_loss(self, outputs: List[Tensor],
target_streamlines: List[Tensor], average_results=True):
if self.compress_loss and not average_results:
raise ValueError("Current implementation of compress_loss does not "
"allow returning non-averaged loss.")
raise ValueError("Current implementation of compress_loss does "
"not allow returning non-averaged loss.")

# Compute directions
target_dirs = compute_directions(target_streamlines)
Expand Down Expand Up @@ -272,8 +270,9 @@ def stack_batch(outputs, target_dirs):
outputs = torch.vstack(outputs)
return outputs, target_dirs

def _compute_loss(self, outputs: Tensor, target_dirs: Tensor,
average_results=True) -> Union[Tuple[Tensor, int], Tensor]:
def _compute_loss(
self, outputs: Tensor, target_dirs: Tensor,
average_results=True) -> Union[Tuple[Tensor, int], Tensor]:
"""
Expecting a single tensor.

Expand Down Expand Up @@ -762,7 +761,7 @@ def _compute_loss(self, logits_per_class: Tensor, targets_probs: Tensor,
# buggy: reduction is supposed to be a str but if I send 'none', it
# says that it expects an int.)
# Gives the same result as above, but averaged instead of summed.
# The real definition is integral (i.e. sum). Typically for our
# The real definition is integral (i.e. sum). Typically, for our
# data (724 classes), that's a big difference: from values ~7 to values
# around 0.04. Nicer for visu with sum.
# So, avoiding torch's 'mean' reduction; reducing ourselves.
Expand All @@ -775,7 +774,8 @@ def _compute_loss(self, logits_per_class: Tensor, targets_probs: Tensor,

# Integral over classes per point.
kl_loss = KLDivLoss(reduction='none', log_target=False)
nll_losses = torch.sum(kl_loss(logits_per_class, targets_probs), dim=-1)
nll_losses = torch.sum(kl_loss(logits_per_class, targets_probs),
dim=-1)

if average_results:
return _mean_and_weight(nll_losses)
Expand Down Expand Up @@ -893,8 +893,10 @@ def _compute_loss(self, learned_gaussian_params: Tuple[Tensor, Tensor],
"""
# 1. Main loss
means, sigmas = learned_gaussian_params
learned_eos = means[:, -1]
means = means[:, 0:3]
learned_eos = None
if self.add_eos:
learned_eos = means[:, -1]
means = means[:, 0:3]

# Create an official function-probability distribution from the means
# and variances
Expand All @@ -905,7 +907,8 @@ def _compute_loss(self, learned_gaussian_params: Tuple[Tensor, Tensor],
if self.entropy_weight > 0:
# Trying to ensure that sigma values are not too small.
# Entropy values range between 0 and log(K). 0 = high probability.
# We want a high entropy / low certainty = we will minimize -entropy.
# We want a high entropy / low certainty = we will minimize
# -entropy.
entropy = distribution.entropy()
logging.info("Computing batch loss with sigma {}, entropy: {}"
.format(torch.mean(sigmas), torch.mean(entropy)))
Expand All @@ -918,11 +921,11 @@ def _compute_loss(self, learned_gaussian_params: Tuple[Tensor, Tensor],
# 2. EOS loss:
if self.add_eos:
# Binary cross-entropy
loss_eos = binary_cross_entropy_eos(learned_eos, target_dirs[:, -1],
loss_eos = binary_cross_entropy_eos(learned_eos,
target_dirs[:, -1],
average_results)
return nll_loss + self.eos_weight * loss_eos, n
else:
n = 1
return nll_loss, n

def _sample_tracking_direction_prob(
Expand Down Expand Up @@ -956,6 +959,7 @@ def _get_tracking_direction_det(self, learned_gaussian_params: Tensor,
Get the predicted class with highest logits (=probabilities).
"""
# Returns the direction of the max of the Gaussian = the mean.
# Not using sigma
means, sigmas = learned_gaussian_params
dirs = means[:, 0:3]

Expand Down Expand Up @@ -1156,43 +1160,68 @@ def __init__(self, **kwargs):
loss_description='negative log-likelihood',
**kwargs)

if self.add_eos:
raise NotImplementedError
self.layers_mean = init_2layer_fully_connected(self.input_size, 3)
# Layers
# 3 values as mean, 1 value as kappa
# If EOS: Adding it to the mean layer. Could be separated.
oneifeos = 1 if self.add_eos else 0
self.layers_mean = init_2layer_fully_connected(self.input_size,
3 + oneifeos)
self.layers_kappa = init_2layer_fully_connected(self.input_size, 1)

self.output_size = 4
# Loss will be defined in _compute_loss, using torch distribution

def _prepare_dirs_for_loss(self, target_dirs: List[Tensor]):
"""
Should be called before _compute_loss, before concatenating your
streamlines.

Returns: list[Tensors], the directions.
"""
# Need to normalize before adding EOS labels (dir = 0,0,0)
target_dirs = normalize_directions(target_dirs)
return add_label_as_last_dim(target_dirs, add_sos=False,
add_eos=self.add_eos)

def forward(self, inputs: Tensor) -> Tuple[Tensor, Tensor]:
"""Run the inputs through the fully-connected layer.

Returns
-------
means : torch.Tensor with shape [batch_size x 3]
?
mus : torch.Tensor with shape [batch_size x 3]
The 3D coordinate of the mean.
kappas : torch.Tensor with shape [batch_size x 1]
?
The kappa concentration parameter.
"""
means = self.loop_on_layers(inputs, self.layers_mean)
mu = self.loop_on_layers(inputs, self.layers_mean)
kappas = self.loop_on_layers(inputs, self.layers_kappa)

# mean should be a unit vector for Fisher Von-Mises distribution
means = torch.nn.functional.normalize(means, dim=-1)
# (Using [0:3] only; EOS value does not need to be normalized).
# Simple code line raises an error: inplace operation
# mu[0:3] = torch.nn.functional.normalize(mu[0:3], dim=-1)
learned_eos = None
if self.add_eos:
learned_eos = mu[:, 3][:, None]
mu = mu[:, 0:3]
mu = torch.nn.functional.normalize(mu, dim=-1)
if self.add_eos:
mu = torch.hstack((mu, learned_eos))

# Need to restrict kappa to a certain range, e.g. [0, 20]
unbound_kappa = self.loop_on_layers(inputs, self.layers_kappa)
kappas = torch.sigmoid(unbound_kappa) * 20
kappas = torch.sigmoid(kappas) * 20

# Squeeze the trailing dim, the kappa parameter is a scalar
kappas = kappas.squeeze(dim=-1)

return means, kappas
return mu, kappas

@staticmethod
def stack_batch(outputs, target_dirs):
target_dirs = torch.vstack(target_dirs)
mus = torch.vstack(outputs[0])
kappas = torch.vstack(outputs[1])
return (mus, kappas), target_dirs
mu = torch.vstack(outputs[0])
kappa = torch.hstack(outputs[1]) # Not vstack: they are vectors
return (mu, kappa), target_dirs

def _compute_loss(self, learned_fisher_params: Tuple[Tensor, Tensor],
target_dirs, average_results=True):
Expand All @@ -1202,16 +1231,31 @@ def _compute_loss(self, learned_fisher_params: Tuple[Tensor, Tensor],
See the doc for explanation on the formulas:
https://dwi-ml.readthedocs.io/en/latest/formulas.html
"""
# mu.shape : [flattened_sequences, 3]
# mu.shape : [all_point, 4]. 3 first values are x, y, z. Last is EOS.
mu, kappa = learned_fisher_params
learned_eos = None
if self.add_eos:
learned_eos = mu[:, 3]
mu = mu[:, 0:3]

log_prob = fisher_von_mises_log_prob(mu, kappa, target_dirs)
nll_losses = -log_prob
# 1. Main loss
# Note. Mu was already normalized through the forward method.
log_prob = fisher_von_mises_log_prob(mu, kappa, target_dirs[:, 0:3])
nll_loss = -log_prob

n = 1
if average_results:
return _mean_and_weight(nll_losses)
nll_loss, n = _mean_and_weight(nll_loss)

# 2. EOS loss:
if self.add_eos:
# Binary cross-entropy
loss_eos = binary_cross_entropy_eos(learned_eos,
target_dirs[:, -1],
average_results)
return nll_loss + self.eos_weight * loss_eos, n
else:
return nll_losses
return nll_loss, n

def _sample_tracking_direction_prob(
self, learned_fisher_params: Tuple[Tensor, Tensor],
Expand Down Expand Up @@ -1247,7 +1291,20 @@ def _sample_tracking_direction_prob(

def _get_tracking_direction_det(self, learned_fisher_params: Tensor,
eos_stopping_thresh):
raise NotImplementedError
"""
Get the predicted class with highest logits (=probabilities).
"""
# Returns the direction of the max of the Gaussian = the mean.
# Not using sigma
mus, kappas = learned_fisher_params
dirs = mus[:, 0:3]

if self.add_eos:
eos_prob = torch.sigmoid(mus[:, -1])
eos_prob = torch.gt(eos_prob, eos_stopping_thresh)
return torch.masked_fill(dirs, eos_prob[:, None], torch.nan)
else:
return dirs

@staticmethod
def _sample_weight(kappa):
Expand Down
29 changes: 27 additions & 2 deletions dwi_ml/models/utils/fisher_von_mises.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,19 @@


def fisher_von_mises_log_prob_vector(mus, kappa, targets, eps=1e-5):
"""
Same as below, but for a single vector.

Parameters
----------
mus: torch.Tensor
Shape: (3, )
kappa: torch.Tensor
Shape: (1, )
targets: torch.Tensor
Directions. Shape (3, )
eps: float
"""
log_diff_exp_kappa = np.log(
np.maximum(eps, np.exp(kappa) - np.exp(-kappa)))
log_c = np.log(kappa) - np.log(2 * np.pi) - log_diff_exp_kappa
Expand All @@ -18,12 +31,24 @@ def fisher_von_mises_log_prob_vector(mus, kappa, targets, eps=1e-5):


def fisher_von_mises_log_prob(mus, kappa, targets, eps=1e-5):
"""
Fisher von Mises loss for a batch.

Parameters
----------
mus: torch.Tensor
Shape: (n, 3)
kappa: torch.Tensor
Shape: (n, 1)
targets: torch.Tensor
Directions. Shape (n, 3)
eps: float
"""
log_2pi = np.log(2 * np.pi).astype(np.float32)

eps = torch.as_tensor(eps, device=kappa.device, dtype=torch.float32)

# Add an epsilon in case kappa is too small (i.e. a uniform
# distribution)
# Add an epsilon in case kappa is too small (i.e. a uniform distribution)
log_diff_exp_kappa = torch.log(
torch.maximum(eps, torch.exp(kappa) - torch.exp(-kappa)))

Expand Down
15 changes: 6 additions & 9 deletions dwi_ml/tracking/io_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,10 @@ def add_tracking_options(p):
help="Tracking mask's volume group in the hdf5.")
stop_g.add_argument('--theta', metavar='t', type=float,
default=90,
help="The tracking direction at each step being "
"defined by the model, \ntheta arg can't define "
"allowed directions in the tracking field.\n"
"Rather, this new equivalent angle, is used as "
"\na stopping criterion during propagation: "
help="Stopping criterion during propagation: "
"tracking \nis stopped when a direction is more "
"than an angle t from preceding direction")
"than an angle t from \npreceding direction."
"[%(default)s]")
stop_g.add_argument('--eos_stop', metavar='prob',
help="Stopping criterion if a EOS value was learned "
"during training. \nCan either be a probability "
Expand All @@ -80,9 +77,9 @@ def add_tracking_options(p):
"probability, no mather its value.")
stop_g.add_argument(
'--discard_last_point', action='store_true',
help="If set, discard the last point (once out of the tracking mask) \n"
"of the streamline. Default: append them. This is the default in \n"
"Dipy too. Note that points obtained after an invalid direction \n"
help="If set, discard the last point (once out of the tracking mask)\n"
"of the streamline. Default: append them. This is the default in\n"
"Dipy too. Note that points obtained after an invalid direction\n"
"(based on the propagator's definition of invalid; ex when \n"
"angle is too sharp of sh_threshold not reached) are never added.")

Expand Down
Loading
Loading