Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions edward2/jax/nn/heteroscedastic_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class MCSoftmaxDenseFA(nn.Module):
share_samples_across_batch: bool = False
logits_only: bool = False
return_locs: bool = False
return_unaveraged_logits: bool = False
eps: float = 1e-7
tune_temperature: bool = False
temperature_lower_bound: Optional[float] = None
Expand Down Expand Up @@ -251,9 +252,10 @@ def _compute_mc_samples(self, inputs, scale, num_samples):
# [B, S, dim] -> [B, S, K]
latents = self._compute_loc_param(latents) # pylint: disable=assignment-from-none

samples = jax.nn.softmax(latents / self.get_temperature())
scaled_latents = latents / self.get_temperature()
samples = jax.nn.softmax(scaled_latents)

return jnp.mean(samples, axis=1)
return jnp.mean(samples, axis=1), jax.nn.log_softmax(scaled_latents)

@nn.compact
def __call__(self, inputs, training=True):
Expand All @@ -278,7 +280,8 @@ def __call__(self, inputs, training=True):
else:
total_mc_samples = self.test_mc_samples

probs_mean = self._compute_mc_samples(inputs, scale, total_mc_samples)
probs_mean, unaveraged_logits = self._compute_mc_samples(
inputs, scale, total_mc_samples)

probs_mean = jnp.clip(probs_mean, a_min=self.eps)
log_probs = jnp.log(probs_mean)
Expand All @@ -288,8 +291,12 @@ def __call__(self, inputs, training=True):
logits = self._compute_loc_param(inputs) # pylint: disable=assignment-from-none

if self.logits_only:
if self.return_unaveraged_logits:
return logits, unaveraged_logits
return logits

if self.return_unaveraged_logits:
return logits, log_probs, probs_mean, unaveraged_logits
return logits, log_probs, probs_mean


Expand Down