diff --git a/bioimage_embed/models/bolts/vae.py b/bioimage_embed/models/bolts/vae.py index 6d816d0e..90d76ae3 100644 --- a/bioimage_embed/models/bolts/vae.py +++ b/bioimage_embed/models/bolts/vae.py @@ -2,6 +2,9 @@ from pythae.models.base.base_utils import ModelOutput from pythae.models.nn import BaseDecoder, BaseEncoder + +from pythae import models +from pythae.models import VQVAEConfig, VAEConfig from pl_bolts.models import autoencoders from pythae.models import VQVAE, VQVAEConfig, VAE, VAEConfig @@ -34,9 +37,7 @@ def forward(self, x): output = ModelOutput() x = self.encoder(x) # x = self.fc1(x) - output["embedding"] = self.embedding(x) - output["log_covariance"] = self.log_var(x) - return output + return ModelOutput(embedding=self.embedding(x), log_covariance=self.log_var(x)) class ResNet50VAEDecoder(BaseDecoder):