From 4e38e6d6a33450ae19c22e4eec32281cf9598f6d Mon Sep 17 00:00:00 2001 From: chris-santiago Date: Wed, 20 Sep 2023 22:36:47 -0400 Subject: [PATCH] refactor to torch.distributions --- autoencoders/models/vae.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/autoencoders/models/vae.py b/autoencoders/models/vae.py index 14ef0d7..ead3dbe 100644 --- a/autoencoders/models/vae.py +++ b/autoencoders/models/vae.py @@ -31,14 +31,14 @@ def __init__( self.log_var = nn.Linear(latent_dim, dist_dim) self.dist_decoder = nn.Linear(dist_dim, latent_dim) - def _encode_dist(self, x): + def encode_dist(self, x): """Alt version to return distribution directly.""" mu = self.mu(x) log_var = self.log_var(x) sigma = torch.exp(log_var * 0.5) return torch.distributions.Normal(mu, sigma) - def encode_dist(self, x): + def _encode_dist(self, x): """Basic version that requires re-parameterization when sampling.""" mu = self.mu(x) log_var = self.log_var(x) @@ -48,21 +48,21 @@ def decode_dist(self, z): x = self.dist_decoder(z) return self.decoder(x) - def _forward(self, x): + def forward(self, x): """Alt version to return reconstruction and encoded distribution.""" x = self.encoder(x) - q_z = self._encode_dist(x) + q_z = self.encode_dist(x) z = q_z.rsample() return self.decode_dist(z), q_z - def forward(self, x): + def _forward(self, x): """ Basic version that completes re-parameterization trick to allow gradient flow to mu and log_var params. """ # Don't fully encode the distribution here so that encoder can be used for downstream tasks x = self.encoder(x) - mu, log_var = self.encode_dist(x) + mu, log_var = self._encode_dist(x) sigma = torch.exp(log_var * 0.5) eps = torch.randn_like(sigma) z = mu + sigma * eps @@ -72,8 +72,8 @@ def training_step(self, batch, idx): original = batch[0] # TODO these are alternative methods for forward operation - reconstructed, q_z = self._forward(original) - # reconstructed, mu, log_var = self(original) + reconstructed, q_z = self(original) + # reconstructed, mu, log_var = self._forward(original) # TODO these are alternative KLD losses based on returns from forward operation kl_loss = torch.distributions.kl_divergence(q_z, self.norm).mean()