Skip to content

Commit

Permalink
adding masked patch pretraining code
Browse files Browse the repository at this point in the history
  • Loading branch information
SD3004 committed Feb 12, 2024
1 parent 7f65a73 commit b599fbc
Show file tree
Hide file tree
Showing 3 changed files with 661 additions and 0 deletions.
89 changes: 89 additions & 0 deletions config/SiT/pretraining/mpp.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# @Author: Simon Dahan
# @Last Modified time: 2022-01-12 14:11:23

SSL: mpp

resolution:
ico: 6 ## full mesh resolution
sub_ico: 2 ## patching grid resolution

data:
data_path: ../data/{}/{}
task: scan_age #scan_age # birth_age
configuration: template #template # native
dataset: dHCP

logging:
folder_to_save_model: "../logs/SiT/"

training:
LR: 0.0003
bs: 32
bs_val: 1
epochs: 20
gpu: 0
l1loss: False
testing: False
val_epoch: 5
load_weights_ssl: False
load_weights_imagenet: False
save_ckpt: True
finetuning: True
dataset_ssl: 'dhcp'

weights:
ssl_mpp: '..' # path to .pt checkpoint
imagenet: 'vit_tiny_patch16_224' #ViT(dim=192, depth=12, heads=3,mlp_dim=768,dim_head=64)
#imagenet: 'vit_small_patch16_224' #ViT(dim=384, depth=12, heads=6,mlp_dim=1536,dim_head=64)
#imagenet: 'vit_base_patch16_224' #ViT(dim=768, depth=12, heads=12,mlp_dim=3072,dim_head=64)

transformer:
dim: 192 #192, 384, 768
depth: 12 #12, 12, 12
heads: 3 #3, 6, 12
mlp_dim: 768 #768, 1536, 3072 ## 4*dim according to DeiT
pool: 'cls' # or 'mean'
num_classes: 1
num_channels: 4
dim_head: 64 #64
dropout: 0.0
emb_dropout: 0.0
model: SiT

pretraining_mpp:
mask_prob: 0.75 #0.5
replace_prob: 0.8 #0.8
swap_prob: 0.02 #0.02

optimisation:
optimiser: SGD

Adam:
weight_decay: 0.

AdamW:
weight_decay: 0.
SGD:
weight_decay: 0.
momentum: 0.9
nesterov: False

StepLR:
stepsize: 1000
decay: 0.5

CosineDecay:
T_max: 5000
eta_min: 0.0001

sub_ico_0:
num_patches: 20
num_vertices: 2145

sub_ico_1:
num_patches: 80
num_vertices: 561

sub_ico_2:
num_patches: 320
num_vertices: 153
135 changes: 135 additions & 0 deletions models/mpp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# -*- coding: utf-8 -*-
# @Author: Your name
# @Date: 1970-01-01 01:00:00
# @Last Modified by: Your name
# @Last Modified time: 2022-02-14 17:50:22
#
# Created on Mon Oct 18 2021
#
# by Simon Dahan @SD3004
#
# Copyright (c) 2021 MeTrICS Lab
#


import math
from random import random

import torch
import torch.nn as nn
import torch.nn.functional as F

from einops import rearrange, repeat


def get_mask_from_prob(inputs, prob):
'''
This function creates a mask on the sequence of tokens, per sample
Based on the probability of masking.
return: a boolean mask of the shape of the inputs.
'''
batch, seq_len, _, device = *inputs.shape, inputs.device
max_masked = math.ceil(prob * seq_len)

rand = torch.rand((batch, seq_len), device=device)
_, sampled_indices = rand.topk(max_masked, dim=-1)

new_mask = torch.zeros((batch, seq_len), device=device)
new_mask.scatter_(1, sampled_indices, 1)
return new_mask.bool()

def prob_mask_like(inputs, prob):
batch, seq_length, _ = inputs.shape
return torch.zeros((batch, seq_length)).float().uniform_(0, 1) < prob


class masked_patch_pretraining(nn.Module):

def __init__(
self,
transformer,
dim_in,
dim_out,
device,
mask_prob=0.15,
replace_prob=0.5,
swap_prob=0.3,
channels=4,
num_vertices=561,):

super().__init__()
self.transformer = transformer

self.dim_out = dim_out
self.dim_in = dim_in

self.to_original = nn.Linear(dim_in,dim_out)
self.to_original.to(device)

self.mask_prob = mask_prob
self.replace_prob = replace_prob
self.swap_prob = swap_prob

# token ids
self.mask_token = nn.Parameter(torch.randn(1, 1, channels * num_vertices))


def forward(self, batch, **kwargs):

transformer = self.transformer

# clone original image for loss
batch = rearrange(batch,
'b c n v -> b n (v c)')

corrupted_sequence = get_mask_from_prob(batch, self.mask_prob)

corrupted_batch = batch.clone().detach()

#randomly swap patches in the sequence
if self.swap_prob > 0:
random_patch_sampling_prob = self.swap_prob / (
1 - self.replace_prob)

random_patch_prob = prob_mask_like(batch,
random_patch_sampling_prob).to(corrupted_sequence.device)

bool_random_patch_prob = corrupted_sequence * (random_patch_prob == True)

random_patches = torch.randint(0,
batch.shape[1],
(batch.shape[0], batch.shape[1]),
device=batch.device)
#shuffle entierely masked_batch
randomized_input = corrupted_batch[
torch.arange(corrupted_batch.shape[0]).unsqueeze(-1),
random_patches]
corrupted_batch[bool_random_patch_prob] = randomized_input[bool_random_patch_prob]

tokens_to_mask = prob_mask_like(batch, self.replace_prob).to(corrupted_sequence.device)

bool_mask_replace = (corrupted_sequence * tokens_to_mask) == True
corrupted_batch[bool_mask_replace] = self.mask_token.to(corrupted_sequence.device)

# linear embedding of patches
corrupted_batch = transformer.to_patch_embedding[-1](corrupted_batch)
emb_masked_sequence = corrupted_batch.clone().detach()

# add cls token to input sequence
b, n, _ = corrupted_batch.shape
cls_tokens = repeat(transformer.cls_token, '() n d -> b n d', b=b)
corrupted_batch = torch.cat((cls_tokens, corrupted_batch), dim=1)

# add positional embeddings to input
corrupted_batch += transformer.pos_embedding[:, :(n + 1)]
corrupted_batch = transformer.dropout(corrupted_batch)

# get generator output and get mpp loss
batch_out = transformer.transformer(corrupted_batch, **kwargs)
batch_out = self.to_original(batch_out[:,1:,:])

# compute loss
mpp_loss = F.mse_loss(batch_out[corrupted_sequence], batch[corrupted_sequence])

return mpp_loss, batch_out

Loading

0 comments on commit b599fbc

Please sign in to comment.