diff --git a/trax/layers/core.py b/trax/layers/core.py index 07c62937c..9ecb7ec95 100644 --- a/trax/layers/core.py +++ b/trax/layers/core.py @@ -808,8 +808,7 @@ def log_gaussian_pdf(x, mu, sigma): # pylint: disable=invalid-name """ a = mu.shape[-1] * jnp.log(2 * jnp.pi) _, b = jnp.linalg.slogdet(sigma) - y = jnp.linalg.solve(sigma, x - mu) - y = jnp.expand_dims(y, axis=-1) + y = jnp.linalg.solve(sigma, (x - mu)[..., None]) xm = jnp.expand_dims(x - mu, axis=-2) c = jnp.matmul(xm, y) c = jnp.squeeze(jnp.squeeze(c, axis=-1), axis=-1)