diff --git a/src/abaco/ABaCo.py b/src/abaco/ABaCo.py index 60c83d8..675c3f5 100644 --- a/src/abaco/ABaCo.py +++ b/src/abaco/ABaCo.py @@ -4418,11 +4418,22 @@ def fit( adv_lr=1e-3, ): # Define optimizer + if isinstance(self.vae.prior, MoCPPrior): + prior_params = self.vae.prior.parameters() + + elif isinstance(self.vae.prior, VMMPrior): + prior_params = [self.vae.prior.u, self.vae.prior.var] + + else: + raise NotImplementedError( + "metaABaCo prior distribution can only be 'MoG' or 'VMM'" + ) + vae_optimizer_1 = torch.optim.Adam( [ {"params": self.vae.encoder.parameters()}, {"params": self.vae.decoder.parameters()}, - {"params": self.vae.prior.parameters()}, + {"params": prior_params}, ], lr=phase_1_vae_lr, )