From bd541e44443ff638cdf27c408ade6bad4edd97fd Mon Sep 17 00:00:00 2001 From: edircoli Date: Wed, 13 Aug 2025 16:23:55 +0200 Subject: [PATCH] fixed metaABaCo.fit() with VMM prior --- src/abaco/ABaCo.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/src/abaco/ABaCo.py b/src/abaco/ABaCo.py index 60c83d8..675c3f5 100644 --- a/src/abaco/ABaCo.py +++ b/src/abaco/ABaCo.py @@ -4418,11 +4418,22 @@ def fit( adv_lr=1e-3, ): # Define optimizer + if isinstance(self.vae.prior, MoCPPrior): + prior_params = self.vae.prior.parameters() + + elif isinstance(self.vae.prior, VMMPrior): + prior_params = [self.vae.prior.u, self.vae.prior.var] + + else: + raise NotImplementedError( + "metaABaCo prior distribution can only be 'MoG' or 'VMM'" + ) + vae_optimizer_1 = torch.optim.Adam( [ {"params": self.vae.encoder.parameters()}, {"params": self.vae.decoder.parameters()}, - {"params": self.vae.prior.parameters()}, + {"params": prior_params}, ], lr=phase_1_vae_lr, )