Skip to content

Commit

Permalink
add simsiam v2 and sidae v2
Browse files Browse the repository at this point in the history
  • Loading branch information
chris-santiago committed Sep 18, 2023
1 parent 664fe4e commit 8f4bdcc
Show file tree
Hide file tree
Showing 30 changed files with 1,743 additions and 31 deletions.
4 changes: 2 additions & 2 deletions autoencoders/conf/comps.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ models:
latent_dim: 1024
dim: 1024
pred_dim: 512
ckpt_path: SimSiam/train/2023-09-07/13-23-54/checkpoints/epoch=40-step=9635.ckpt
ckpt_path: SimSiam/train/2023-09-17/21-40-54/checkpoints/epoch=39-step=9400.ckpt
sidae:
name: SiDAE
module:
Expand All @@ -103,4 +103,4 @@ models:
latent_dim: 512
dim: 512
pred_dim: 512
ckpt_path: SiDAE/train/2023-09-15/22-33-09/checkpoints/epoch=44-step=10575.ckpt
ckpt_path: SiDAE2/train/2023-09-17/23-30-16/checkpoints/epoch=38-step=18291.ckpt
25 changes: 25 additions & 0 deletions autoencoders/conf/data/sidae2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
batch_size: 128
n_workers: 6 # Multirun launcher requires n_workers=0 otherwise it fails
name: mnist

train:
_target_: torch.utils.data.DataLoader
dataset:
_target_: autoencoders.data.SiDAEDataset2
dataset:
_target_: autoencoders.data.get_mnist_dataset
train: True
batch_size: ${data.batch_size}
shuffle: True
num_workers: ${data.n_workers}

valid:
_target_: torch.utils.data.DataLoader
dataset:
_target_: autoencoders.data.SiDAEDataset2
dataset:
_target_: autoencoders.data.get_mnist_dataset
train: False
batch_size: ${data.batch_size}
shuffle: False
num_workers: ${data.n_workers}
25 changes: 25 additions & 0 deletions autoencoders/conf/model/sidae2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Defining the optimizer as a group default allows CLI override, e.g.
# python train.py "optimizer@model.optimizer=sgd"
# or via config "override scheduler@model.scheduler: cyclic"
# See https://stackoverflow.com/questions/71438040/overwriting-hydra-configuration-groups-from-cli/71439510#71439510
defaults:
- /optimizer@optimizer: adam
- /scheduler@scheduler: plateau

name: SiDAE2

nn:
_target_: autoencoders.models.sidae.SiDAE2
encoder:
_target_: autoencoders.modules.CNNEncoderProjection
channels_in: 1
base_channels: 32
latent_dim: ${model.nn.dim}
decoder:
_target_: autoencoders.modules.CNNDecoder
channels_in: 1
base_channels: 32
latent_dim: ${model.nn.dim}
dim: 512
pred_dim: 512
alpha: .25
18 changes: 14 additions & 4 deletions autoencoders/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from omegaconf import DictConfig

import autoencoders.constants
from autoencoders.modules import WhiteNoise
from autoencoders.modules import RandWhiteNoise

constants = autoencoders.constants.Constants()

Expand Down Expand Up @@ -76,9 +76,9 @@ def __init__(
factor: float = 1.0,
):
super().__init__(dataset, transform, num_ops)
self.noise = WhiteNoise(loc, scale, factor)
self.augment_1 = T.RandomPerspective()
self.augment_2 = T.GaussianBlur(3)
self.noise = RandWhiteNoise(loc, scale, factor=(0.10, 0.75))
self.augment_1 = T.RandomPerspective(p=1.0)
self.augment_2 = T.ElasticTransform(alpha=100.0)

def __getitem__(self, idx):
inputs = self.dataset.data.__getitem__(idx)
Expand All @@ -92,6 +92,16 @@ def __getitem__(self, idx):
return aug_1, aug_2, self.noise(inputs).unsqueeze(0), inputs.unsqueeze(0)


class SiDAEDataset2(SimSiamDataset):
def __getitem__(self, idx):
inputs = self.dataset.data.__getitem__(idx)
aug_1, aug_2 = self.augment_1(inputs.unsqueeze(0)), self.augment_2(inputs.unsqueeze(0))
if self.transform:
aug_1, aug_2 = self.transform(aug_1), self.transform(aug_2)
inputs = self.transform(inputs)
return aug_1, aug_2, inputs.unsqueeze(0)


class AugmentedDataset(torch.utils.data.Dataset):
def __init__(self, dataset: DictConfig, augments: List[DictConfig], transform=scale_mnist):
self.dataset = hydra.utils.call(dataset)
Expand Down
38 changes: 37 additions & 1 deletion autoencoders/models/sidae.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,46 @@ def training_step(self, batch, idx):
z_2, p_2 = self(x_2)
siam_loss = -0.5 * (self.loss_func(p_1, z_2).mean() + self.loss_func(p_2, z_1).mean())

# todo note this deviates from original in that a different noisy augment
# todo is used for recon loss vice x_1, x_2
recon = self.decoder(self.encoder(x_noise))
recon_loss = self.recon_loss_func(recon, x)

total_loss = siam_loss + recon_loss
total_loss = siam_loss + recon_loss # todo add weight param

self.log("loss", total_loss, on_step=True, on_epoch=True, prog_bar=True, logger=False)
metrics = {"train-loss": total_loss}
self.log_dict(metrics, on_step=False, on_epoch=True, prog_bar=False, logger=True)
return total_loss


class SiDAE2(SiDAE):
def __init__(
self,
encoder: nn.Module,
decoder: nn.Module,
dim: int = 2048,
pred_dim: int = 512,
loss_func: nn.Module = nn.CosineSimilarity(),
optim: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler.LRScheduler] = None,
alpha: float = 0.5,
):
super().__init__(encoder, decoder, dim, pred_dim, loss_func, optim, scheduler)
self.alpha = alpha

def training_step(self, batch, idx):
x_1, x_2, x = batch

z_1, p_1 = self(x_1)
z_2, p_2 = self(x_2)
siam_loss = -0.5 * (self.loss_func(p_1, z_2).mean() + self.loss_func(p_2, z_1).mean())

recon_1 = self.decoder(self.encoder(x_1))
recon_2 = self.decoder(self.encoder(x_2))
recon_loss = 0.5 * (self.recon_loss_func(recon_1, x) + self.recon_loss_func(recon_2, x))

total_loss = (siam_loss * (1 - self.alpha)) + (recon_loss * self.alpha)

self.log("loss", total_loss, on_step=True, on_epoch=True, prog_bar=True, logger=False)
metrics = {"train-loss": total_loss}
Expand Down
99 changes: 99 additions & 0 deletions outputs/SiDAE/train/2023-09-15/23-45-29/.hydra/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
data:
batch_size: 256
n_workers: 10
name: mnist
train:
_target_: torch.utils.data.DataLoader
dataset:
_target_: autoencoders.data.SiDAEDataset
dataset:
_target_: autoencoders.data.get_mnist_dataset
train: true
num_ops: 1
loc: 0
scale: 1
factor: 0.05
batch_size: ${data.batch_size}
shuffle: true
num_workers: ${data.n_workers}
valid:
_target_: torch.utils.data.DataLoader
dataset:
_target_: autoencoders.data.SiDAEDataset
dataset:
_target_: autoencoders.data.get_mnist_dataset
train: false
num_ops: 1
loc: 0
scale: 1
factor: 1.0
batch_size: ${data.batch_size}
shuffle: false
num_workers: ${data.n_workers}
model:
optimizer:
_target_: torch.optim.Adam
_partial_: true
lr: 0.001
betas:
- 0.9
- 0.999
weight_decay: 0
scheduler:
_target_: torch.optim.lr_scheduler.ReduceLROnPlateau
_partial_: true
mode: min
factor: 0.1
patience: 10
name: SiDAE
nn:
_target_: autoencoders.models.sidae.SiDAE
encoder:
_target_: autoencoders.modules.CNNEncoderProjection
channels_in: 1
base_channels: 32
latent_dim: ${model.nn.dim}
decoder:
_target_: autoencoders.modules.CNNDecoder
channels_in: 1
base_channels: 32
latent_dim: ${model.nn.dim}
dim: 512
pred_dim: 512
trainer:
_target_: pytorch_lightning.Trainer
max_epochs: 100
accelerator: mps
devices: 1
logger:
_target_: pytorch_lightning.loggers.WandbLogger
project: autoencoders
name: null
id: null
group: null
job_type: null
save_dir: ${hydra:runtime.output_dir}
log_model: true
tags: ${tags}
callbacks:
model_summary:
_target_: pytorch_lightning.callbacks.RichModelSummary
progress_bar:
_target_: pytorch_lightning.callbacks.RichProgressBar
refresh_rate: 5
leave: true
early_stopping:
_target_: pytorch_lightning.callbacks.EarlyStopping
monitor: train-loss
min_delta: 0.001
patience: 10
check_on_train_epoch_end: true
model_checkpoint:
_target_: pytorch_lightning.callbacks.ModelCheckpoint
dirpath: ${hydra:runtime.output_dir}/checkpoints
monitor: train-loss
save_top_k: 1
save_on_train_epoch_end: true
tags:
- ${data.name}
- ${model.name}
Loading

0 comments on commit 8f4bdcc

Please sign in to comment.