Skip to content

Commit

Permalink
refactor VAE for old/new methods
Browse files Browse the repository at this point in the history
  • Loading branch information
chris-santiago committed Sep 20, 2023
1 parent e113fd0 commit 6dd60e9
Show file tree
Hide file tree
Showing 6 changed files with 315 additions and 14 deletions.
7 changes: 4 additions & 3 deletions autoencoders/conf/model/vae.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
# See https://stackoverflow.com/questions/71438040/overwriting-hydra-configuration-groups-from-cli/71439510#71439510
defaults:
- /optimizer@optimizer: adam
- /scheduler@scheduler: cyclic
- /scheduler@scheduler: plateau

name: VAE

nn:
_target_: autoencoders.models.vae.VAE
base_channels: 32
latent_dim: 8
base_channels: 16
latent_dim: 256
dist_dim: 8
input_channels: 1
loss_func:
_target_: torch.nn.MSELoss
60 changes: 49 additions & 11 deletions autoencoders/models/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,35 +12,73 @@ def __init__(
self,
base_channels: int,
latent_dim: int,
dist_dim: int,
encoder: nn.Module = CNNEncoder,
decoder: nn.Module = CNNDecoder,
input_channels: int = 1,
loss_func: nn.Module = nn.MSELoss(),
optim: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler.LRScheduler] = None,
kl_coef: float = 0.1,
kl_coef: float = 0.001, # ~ batch/train
):
super().__init__(
base_channels, latent_dim, encoder, decoder, input_channels, loss_func, optim, scheduler
)
self.kl_coef = kl_coef
self.norm = torch.distributions.Normal(0, 1)

self.mu = nn.LazyLinear(latent_dim)
self.sigma = nn.LazyLinear(latent_dim)
self.mu = nn.Linear(latent_dim, dist_dim)
self.log_var = nn.Linear(latent_dim, dist_dim)
self.dist_decoder = nn.Linear(dist_dim, latent_dim)

def _encode_dist(self, x):
"""Alt version to return distribution directly."""
mu = self.mu(x)
log_var = self.log_var(x)
sigma = torch.exp(log_var * 0.5)
return torch.distributions.Normal(mu, sigma)

def encode_dist(self, x):
"""Basic version that requires re-parameterization when sampling."""
mu = self.mu(x)
log_var = self.log_var(x)
return mu, log_var

def decode_dist(self, z):
x = self.dist_decoder(z)
return self.decoder(x)

def _forward(self, x):
"""Alt version to return reconstruction and encoded distribution."""
x = self.encoder(x)
q_z = self._encode_dist(x)
z = q_z.rsample()
return self.decode_dist(z), q_z

def forward(self, x):
z = self.encoder(x)
mu = self.mu(z)
sigma = self.sigma(z)
q = torch.distributions.Normal(mu, torch.exp(sigma))
return self.decoder(z), q
"""
Basic version that completes re-parameterization trick to allow gradient flow to
mu and log_var params.
"""
# Don't fully encode the distribution here so that encoder can be used for downstream tasks
x = self.encoder(x)
mu, log_var = self.encode_dist(x)
sigma = torch.exp(log_var * 0.5)
eps = torch.randn_like(sigma)
z = mu + sigma * eps
return self.decode_dist(z), mu, log_var

def training_step(self, batch, idx):
original = batch[1]
reconstructed, q = self(batch[0])
original = batch[0]

# TODO these are alternative methods for forward operation
reconstructed, q_z = self._forward(original)
# reconstructed, mu, log_var = self(original)

# TODO these are alternative KLD losses based on returns from forward operation
kl_loss = torch.distributions.kl_divergence(q_z, self.norm).mean()
# kl_loss = -0.5 * (1 + log_var - mu**2 - torch.exp(log_var)).mean()

kl_loss = torch.distributions.kl_divergence(q, self.norm).mean()
reconstruction_loss = self.loss_func(original, reconstructed)
total_loss = reconstruction_loss + kl_loss * self.kl_coef

Expand Down
87 changes: 87 additions & 0 deletions outputs/VAE/train/2023-09-20/17-26-29/.hydra/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
data:
batch_size: 256
n_workers: 10
name: mnist
train:
_target_: torch.utils.data.DataLoader
dataset:
_target_: autoencoders.data.AutoEncoderDataset
dataset:
_target_: autoencoders.data.get_mnist_dataset
train: true
batch_size: ${data.batch_size}
shuffle: true
num_workers: ${data.n_workers}
valid:
_target_: torch.utils.data.DataLoader
dataset:
_target_: autoencoders.data.AutoEncoderDataset
dataset:
_target_: autoencoders.data.get_mnist_dataset
train: false
batch_size: ${data.batch_size}
shuffle: false
num_workers: ${data.n_workers}
model:
optimizer:
_target_: torch.optim.Adam
_partial_: true
lr: 0.001
betas:
- 0.9
- 0.999
weight_decay: 0
scheduler:
_target_: torch.optim.lr_scheduler.ReduceLROnPlateau
_partial_: true
mode: min
factor: 0.1
patience: 10
name: VAE
nn:
_target_: autoencoders.models.vae.VAE
base_channels: 16
latent_dim: 256
dist_dim: 8
input_channels: 1
loss_func:
_target_: torch.nn.MSELoss
trainer:
_target_: pytorch_lightning.Trainer
max_epochs: 200
accelerator: mps
devices: 1
logger:
_target_: pytorch_lightning.loggers.WandbLogger
project: autoencoders
name: null
id: null
group: null
job_type: null
save_dir: ${hydra:runtime.output_dir}
log_model: true
tags: ${tags}
callbacks:
model_summary:
_target_: pytorch_lightning.callbacks.RichModelSummary
progress_bar:
_target_: pytorch_lightning.callbacks.RichProgressBar
refresh_rate: 5
leave: true
early_stopping:
_target_: pytorch_lightning.callbacks.EarlyStopping
monitor: train-loss
min_delta: 0.001
patience: 10
check_on_train_epoch_end: true
model_checkpoint:
_target_: pytorch_lightning.callbacks.ModelCheckpoint
dirpath: ${hydra:runtime.output_dir}/checkpoints
monitor: train-loss
save_top_k: 1
save_on_train_epoch_end: true
log_images:
_target_: autoencoders.callbacks.LogReconstructedImagesCallback
tags:
- ${data.name}
- ${model.name}
173 changes: 173 additions & 0 deletions outputs/VAE/train/2023-09-20/17-26-29/.hydra/hydra.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
hydra:
run:
dir: outputs/${model.name}/${hydra.job.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
sweep:
dir: outputs/${model.name}/${hydra.job.name}/multirun
subdir: ${hydra.job.override_dirname}/${now:%Y-%m-%d}/${now:%H-%M-%S}
launcher:
_target_: hydra_plugins.hydra_joblib_launcher.joblib_launcher.JoblibLauncher
n_jobs: -1
backend: null
prefer: processes
require: null
verbose: 0
timeout: null
pre_dispatch: 2*n_jobs
batch_size: auto
temp_folder: null
max_nbytes: null
mmap_mode: r
sweeper:
_target_: hydra._internal.core_plugins.basic_sweeper.BasicSweeper
max_batch_size: null
params: null
help:
app_name: ${hydra.job.name}
header: '${hydra.help.app_name} is powered by Hydra.
'
footer: 'Powered by Hydra (https://hydra.cc)
Use --hydra-help to view Hydra specific help
'
template: '${hydra.help.header}
== Configuration groups ==
Compose your configuration from those groups (group=option)
$APP_CONFIG_GROUPS
== Config ==
Override anything in the config (foo.bar=value)
$CONFIG
${hydra.help.footer}
'
hydra_help:
template: 'Hydra (${hydra.runtime.version})
See https://hydra.cc for more info.
== Flags ==
$FLAGS_HELP
== Configuration groups ==
Compose your configuration from those groups (For example, append hydra/job_logging=disabled
to command line)
$HYDRA_CONFIG_GROUPS
Use ''--cfg hydra'' to Show the Hydra config.
'
hydra_help: ???
hydra_logging:
version: 1
formatters:
simple:
format: '[%(asctime)s][HYDRA] %(message)s'
handlers:
console:
class: logging.StreamHandler
formatter: simple
stream: ext://sys.stdout
root:
level: INFO
handlers:
- console
loggers:
logging_example:
level: DEBUG
disable_existing_loggers: false
job_logging:
version: 1
formatters:
simple:
format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s'
handlers:
console:
class: logging.StreamHandler
formatter: simple
stream: ext://sys.stdout
file:
class: logging.FileHandler
formatter: simple
filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log
root:
level: INFO
handlers:
- console
- file
disable_existing_loggers: false
env: {}
mode: RUN
searchpath: []
callbacks: {}
output_subdir: .hydra
overrides:
hydra:
- hydra.mode=RUN
task:
- model=vae
job:
name: train
chdir: null
override_dirname: model=vae
id: ???
num: ???
config_name: config
env_set: {}
env_copy: []
config:
override_dirname:
kv_sep: '='
item_sep: ','
exclude_keys: []
runtime:
version: 1.3.2
version_base: '1.3'
cwd: /Users/chrissantiago/Dropbox/GitHub/autoencoders
config_sources:
- path: hydra.conf
schema: pkg
provider: hydra
- path: /Users/chrissantiago/Dropbox/GitHub/autoencoders/autoencoders/conf
schema: file
provider: main
- path: ''
schema: structured
provider: schema
output_dir: /Users/chrissantiago/Dropbox/GitHub/autoencoders/outputs/VAE/train/2023-09-20/17-26-29
choices:
experiment: null
callbacks: encoder
trainer: default
model: vae
scheduler@model.scheduler: plateau
optimizer@model.optimizer: adam
data: mnist
hydra/env: default
hydra/callbacks: null
hydra/job_logging: default
hydra/hydra_logging: default
hydra/hydra_help: default
hydra/help: default
hydra/sweeper: basic
hydra/launcher: joblib
hydra/output: default
verbose: false
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
- model=vae
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
/Users/chrissantiago/Dropbox/GitHub/autoencoders/outputs/VAE/train/2023-09-20/17-26-29/checkpoints/epoch=33-step=7990.ckpt: 0.016784192994236946

0 comments on commit 6dd60e9

Please sign in to comment.