From 4d59ffc99742360c0cba0c115f066b1393481a2a Mon Sep 17 00:00:00 2001 From: Craig Russell Date: Fri, 5 Jan 2024 19:33:57 +0000 Subject: [PATCH] Cleaning up VAE while I'm here --- bioimage_embed/models/bolts/vae.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/bioimage_embed/models/bolts/vae.py b/bioimage_embed/models/bolts/vae.py index 6d816d0e..90d76ae3 100644 --- a/bioimage_embed/models/bolts/vae.py +++ b/bioimage_embed/models/bolts/vae.py @@ -2,6 +2,9 @@ from pythae.models.base.base_utils import ModelOutput from pythae.models.nn import BaseDecoder, BaseEncoder + +from pythae import models +from pythae.models import VQVAEConfig, VAEConfig from pl_bolts.models import autoencoders from pythae.models import VQVAE, VQVAEConfig, VAE, VAEConfig @@ -34,9 +37,7 @@ def forward(self, x): output = ModelOutput() x = self.encoder(x) # x = self.fc1(x) - output["embedding"] = self.embedding(x) - output["log_covariance"] = self.log_var(x) - return output + return ModelOutput(embedding=self.embedding(x), log_covariance=self.log_var(x)) class ResNet50VAEDecoder(BaseDecoder):