Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] Inconsistent behaviour of batch broadcast when computing log_marginal of non-Gaussian likelihoods. #2556

Open
shixinxing opened this issue Aug 4, 2024 · 0 comments

Comments

@shixinxing
Copy link

🚀 Feature Request

When I call log_marginal() and expected_log_prob(), the Gaussian and non-Gaussian likelihoods show inconsistent shapes. The Gaussian likelihood seems to allow for extra batch dimensions of observations, while the non-Gaussian likelihood does not.

Problem

In the following code snippet, the observation $y$ has shape [4,3,8], and the latent variable has a Gaussian distribution $q(f)$ with shape [3,8]. While the log-marginal and the expected log-likelihood of GaussianLikelihood have broadcastable results on the first batch dimension, the LaplaceLikelihood triggers not broadcastable errors.

y = torch.rand(4,3,8)
mean, L = torch.randn(3, 8), torch.rand(3, 8, 8)
q_f = MultivariateNormal(mean=mean, covariance_matrix=L@L.mT)

# log_marginal for Gaussian likelihood
lk = GaussianLikelihood(batch_shape=[3])  # batched noise

# expected_log_prob
print(lk.expected_log_prob(target=y, input=q_f).shape)  # [4,3,8]
# log_marginal
print(lk.log_marginal(y, q_f).shape) # [4,3,8]

# log_marginal for other likelihoods
laplace_lk = LaplaceLikelihood(batch_shape=[3])

# expected_log_prob
print(laplace_lk.expected_log_prob(observations=y, function_dist=q_f).shape)  # Error: not broadcastable
# log_marginal
print(laplace_lk.log_marginal(y, q_f).shape)  # Error: not broadcastable
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant