-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
670 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
utils/__pycache__/* | ||
outputs/* | ||
models/MSG3D/__pycache__/* | ||
logs/* | ||
models/__pycache__/* | ||
data/* | ||
*test* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
# @Author: Simon Dahan | ||
# @Last Modified time: 2022-01-12 14:11:23 | ||
|
||
SSL: mpp | ||
|
||
resolution: | ||
ico: 6 ## full mesh resolution | ||
sub_ico: 2 ## patching grid resolution | ||
|
||
data: | ||
data_path: ../data/{}/{} | ||
task: scan_age #scan_age # birth_age | ||
configuration: template #template # native | ||
dataset: dHCP | ||
|
||
logging: | ||
folder_to_save_model: "../logs/SiT/" | ||
|
||
training: | ||
LR: 0.0003 | ||
bs: 256 | ||
bs_val: 1 | ||
epochs: 100 | ||
gpu: 0 | ||
l1loss: False | ||
testing: False | ||
val_epoch: 10 | ||
load_weights_ssl: False | ||
load_weights_imagenet: False | ||
save_ckpt: True | ||
finetuning: True | ||
dataset_ssl: 'dhcp' | ||
|
||
weights: | ||
ssl_mpp: '..' # path to .pt checkpoint | ||
imagenet: 'vit_tiny_patch16_224' #ViT(dim=192, depth=12, heads=3,mlp_dim=768,dim_head=64) | ||
#imagenet: 'vit_small_patch16_224' #ViT(dim=384, depth=12, heads=6,mlp_dim=1536,dim_head=64) | ||
#imagenet: 'vit_base_patch16_224' #ViT(dim=768, depth=12, heads=12,mlp_dim=3072,dim_head=64) | ||
|
||
transformer: | ||
dim: 192 #192, 384, 768 | ||
depth: 12 #12, 12, 12 | ||
heads: 3 #3, 6, 12 | ||
mlp_dim: 768 #768, 1536, 3072 ## 4*dim according to DeiT | ||
pool: 'cls' # or 'mean' | ||
num_classes: 1 | ||
num_channels: 4 | ||
dim_head: 64 #64 | ||
dropout: 0.0 | ||
emb_dropout: 0.0 | ||
model: SiT | ||
|
||
pretraining_mpp: | ||
mask_prob: 0.75 #0.5 | ||
replace_prob: 0.8 #0.8 | ||
swap_prob: 0.02 #0.02 | ||
|
||
optimisation: | ||
optimiser: SGD | ||
|
||
Adam: | ||
weight_decay: 0. | ||
|
||
AdamW: | ||
weight_decay: 0. | ||
SGD: | ||
weight_decay: 0. | ||
momentum: 0.9 | ||
nesterov: False | ||
|
||
StepLR: | ||
stepsize: 1000 | ||
decay: 0.5 | ||
|
||
CosineDecay: | ||
T_max: 5000 | ||
eta_min: 0.0001 | ||
|
||
sub_ico_0: | ||
num_patches: 20 | ||
num_vertices: 2145 | ||
|
||
sub_ico_1: | ||
num_patches: 80 | ||
num_vertices: 561 | ||
|
||
sub_ico_2: | ||
num_patches: 320 | ||
num_vertices: 153 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
# -*- coding: utf-8 -*- | ||
# @Author: Your name | ||
# @Date: 1970-01-01 01:00:00 | ||
# @Last Modified by: Your name | ||
# @Last Modified time: 2022-02-14 17:50:22 | ||
# | ||
# Created on Mon Oct 18 2021 | ||
# | ||
# by Simon Dahan @SD3004 | ||
# | ||
# Copyright (c) 2021 MeTrICS Lab | ||
# | ||
|
||
|
||
import math | ||
from random import random | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
from einops import rearrange, repeat | ||
|
||
|
||
def get_mask_from_prob(inputs, prob): | ||
''' | ||
This function creates a mask on the sequence of tokens, per sample | ||
Based on the probability of masking. | ||
return: a boolean mask of the shape of the inputs. | ||
''' | ||
batch, seq_len, _, device = *inputs.shape, inputs.device | ||
max_masked = math.ceil(prob * seq_len) | ||
|
||
rand = torch.rand((batch, seq_len), device=device) | ||
_, sampled_indices = rand.topk(max_masked, dim=-1) | ||
|
||
new_mask = torch.zeros((batch, seq_len), device=device) | ||
new_mask.scatter_(1, sampled_indices, 1) | ||
return new_mask.bool() | ||
|
||
def prob_mask_like(inputs, prob): | ||
batch, seq_length, _ = inputs.shape | ||
return torch.zeros((batch, seq_length)).float().uniform_(0, 1) < prob | ||
|
||
|
||
class masked_patch_pretraining(nn.Module): | ||
|
||
def __init__( | ||
self, | ||
transformer, | ||
dim_in, | ||
dim_out, | ||
device, | ||
mask_prob=0.15, | ||
replace_prob=0.5, | ||
swap_prob=0.3, | ||
channels=4, | ||
num_vertices=561,): | ||
|
||
super().__init__() | ||
self.transformer = transformer | ||
|
||
self.dim_out = dim_out | ||
self.dim_in = dim_in | ||
|
||
self.to_original = nn.Linear(dim_in,dim_out) | ||
self.to_original.to(device) | ||
|
||
self.mask_prob = mask_prob | ||
self.replace_prob = replace_prob | ||
self.swap_prob = swap_prob | ||
|
||
# token ids | ||
self.mask_token = nn.Parameter(torch.randn(1, 1, channels * num_vertices)) | ||
|
||
|
||
def forward(self, batch, **kwargs): | ||
|
||
transformer = self.transformer | ||
|
||
# clone original image for loss | ||
batch = rearrange(batch, | ||
'b c n v -> b n (v c)') | ||
|
||
corrupted_sequence = get_mask_from_prob(batch, self.mask_prob) | ||
|
||
corrupted_batch = batch.clone().detach() | ||
|
||
#randomly swap patches in the sequence | ||
if self.swap_prob > 0: | ||
random_patch_sampling_prob = self.swap_prob / ( | ||
1 - self.replace_prob) | ||
|
||
random_patch_prob = prob_mask_like(batch, | ||
random_patch_sampling_prob).to(corrupted_sequence.device) | ||
|
||
bool_random_patch_prob = corrupted_sequence * (random_patch_prob == True) | ||
|
||
random_patches = torch.randint(0, | ||
batch.shape[1], | ||
(batch.shape[0], batch.shape[1]), | ||
device=batch.device) | ||
#shuffle entierely masked_batch | ||
randomized_input = corrupted_batch[ | ||
torch.arange(corrupted_batch.shape[0]).unsqueeze(-1), | ||
random_patches] | ||
corrupted_batch[bool_random_patch_prob] = randomized_input[bool_random_patch_prob] | ||
|
||
tokens_to_mask = prob_mask_like(batch, self.replace_prob).to(corrupted_sequence.device) | ||
|
||
bool_mask_replace = (corrupted_sequence * tokens_to_mask) == True | ||
corrupted_batch[bool_mask_replace] = self.mask_token.to(corrupted_sequence.device) | ||
|
||
# linear embedding of patches | ||
corrupted_batch = transformer.to_patch_embedding[-1](corrupted_batch) | ||
emb_masked_sequence = corrupted_batch.clone().detach() | ||
|
||
# add cls token to input sequence | ||
b, n, _ = corrupted_batch.shape | ||
cls_tokens = repeat(transformer.cls_token, '() n d -> b n d', b=b) | ||
corrupted_batch = torch.cat((cls_tokens, corrupted_batch), dim=1) | ||
|
||
# add positional embeddings to input | ||
corrupted_batch += transformer.pos_embedding[:, :(n + 1)] | ||
corrupted_batch = transformer.dropout(corrupted_batch) | ||
|
||
# get generator output and get mpp loss | ||
batch_out = transformer.transformer(corrupted_batch, **kwargs) | ||
batch_out = self.to_original(batch_out[:,1:,:]) | ||
|
||
# compute loss | ||
mpp_loss = F.mse_loss(batch_out[corrupted_sequence], batch[corrupted_sequence]) | ||
|
||
return mpp_loss, batch_out | ||
|
Oops, something went wrong.