Skip to content

Commit

Permalink
refactor to torch.distributions
Browse files Browse the repository at this point in the history
  • Loading branch information
chris-santiago committed Sep 21, 2023
1 parent 6dd60e9 commit 4e38e6d
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions autoencoders/models/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@ def __init__(
self.log_var = nn.Linear(latent_dim, dist_dim)
self.dist_decoder = nn.Linear(dist_dim, latent_dim)

def _encode_dist(self, x):
def encode_dist(self, x):
"""Alt version to return distribution directly."""
mu = self.mu(x)
log_var = self.log_var(x)
sigma = torch.exp(log_var * 0.5)
return torch.distributions.Normal(mu, sigma)

def encode_dist(self, x):
def _encode_dist(self, x):
"""Basic version that requires re-parameterization when sampling."""
mu = self.mu(x)
log_var = self.log_var(x)
Expand All @@ -48,21 +48,21 @@ def decode_dist(self, z):
x = self.dist_decoder(z)
return self.decoder(x)

def _forward(self, x):
def forward(self, x):
"""Alt version to return reconstruction and encoded distribution."""
x = self.encoder(x)
q_z = self._encode_dist(x)
q_z = self.encode_dist(x)
z = q_z.rsample()
return self.decode_dist(z), q_z

def forward(self, x):
def _forward(self, x):
"""
Basic version that completes re-parameterization trick to allow gradient flow to
mu and log_var params.
"""
# Don't fully encode the distribution here so that encoder can be used for downstream tasks
x = self.encoder(x)
mu, log_var = self.encode_dist(x)
mu, log_var = self._encode_dist(x)
sigma = torch.exp(log_var * 0.5)
eps = torch.randn_like(sigma)
z = mu + sigma * eps
Expand All @@ -72,8 +72,8 @@ def training_step(self, batch, idx):
original = batch[0]

# TODO these are alternative methods for forward operation
reconstructed, q_z = self._forward(original)
# reconstructed, mu, log_var = self(original)
reconstructed, q_z = self(original)
# reconstructed, mu, log_var = self._forward(original)

# TODO these are alternative KLD losses based on returns from forward operation
kl_loss = torch.distributions.kl_divergence(q_z, self.norm).mean()
Expand Down

0 comments on commit 4e38e6d

Please sign in to comment.