From 1b46913b4bab3a7218f7526526fde3480ecb233b Mon Sep 17 00:00:00 2001 From: Simon Dahan Date: Mon, 12 Feb 2024 16:03:28 +0000 Subject: [PATCH] updates --- .gitignore | 7 + config/SiT/pretraining/mpp.yml | 89 +++++++ config/SiT/training/hparams.yml | 7 +- models/mpp.py | 135 ++++++++++ tools/pretrain.py | 437 ++++++++++++++++++++++++++++++++ tools/train.py | 6 - 6 files changed, 670 insertions(+), 11 deletions(-) create mode 100644 .gitignore create mode 100644 config/SiT/pretraining/mpp.yml create mode 100644 models/mpp.py create mode 100644 tools/pretrain.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..4a36a8d --- /dev/null +++ b/.gitignore @@ -0,0 +1,7 @@ +utils/__pycache__/* +outputs/* +models/MSG3D/__pycache__/* +logs/* +models/__pycache__/* +data/* +*test* \ No newline at end of file diff --git a/config/SiT/pretraining/mpp.yml b/config/SiT/pretraining/mpp.yml new file mode 100644 index 0000000..3a35509 --- /dev/null +++ b/config/SiT/pretraining/mpp.yml @@ -0,0 +1,89 @@ +# @Author: Simon Dahan +# @Last Modified time: 2022-01-12 14:11:23 + +SSL: mpp + +resolution: + ico: 6 ## full mesh resolution + sub_ico: 2 ## patching grid resolution + +data: + data_path: ../data/{}/{} + task: scan_age #scan_age # birth_age + configuration: template #template # native + dataset: dHCP + +logging: + folder_to_save_model: "../logs/SiT/" + +training: + LR: 0.0003 + bs: 256 + bs_val: 1 + epochs: 100 + gpu: 0 + l1loss: False + testing: False + val_epoch: 10 + load_weights_ssl: False + load_weights_imagenet: False + save_ckpt: True + finetuning: True + dataset_ssl: 'dhcp' + +weights: + ssl_mpp: '..' # path to .pt checkpoint + imagenet: 'vit_tiny_patch16_224' #ViT(dim=192, depth=12, heads=3,mlp_dim=768,dim_head=64) + #imagenet: 'vit_small_patch16_224' #ViT(dim=384, depth=12, heads=6,mlp_dim=1536,dim_head=64) + #imagenet: 'vit_base_patch16_224' #ViT(dim=768, depth=12, heads=12,mlp_dim=3072,dim_head=64) + +transformer: + dim: 192 #192, 384, 768 + depth: 12 #12, 12, 12 + heads: 3 #3, 6, 12 + mlp_dim: 768 #768, 1536, 3072 ## 4*dim according to DeiT + pool: 'cls' # or 'mean' + num_classes: 1 + num_channels: 4 + dim_head: 64 #64 + dropout: 0.0 + emb_dropout: 0.0 + model: SiT + +pretraining_mpp: + mask_prob: 0.75 #0.5 + replace_prob: 0.8 #0.8 + swap_prob: 0.02 #0.02 + +optimisation: + optimiser: SGD + +Adam: + weight_decay: 0. + +AdamW: + weight_decay: 0. +SGD: + weight_decay: 0. + momentum: 0.9 + nesterov: False + +StepLR: + stepsize: 1000 + decay: 0.5 + +CosineDecay: + T_max: 5000 + eta_min: 0.0001 + +sub_ico_0: + num_patches: 20 + num_vertices: 2145 + +sub_ico_1: + num_patches: 80 + num_vertices: 561 + +sub_ico_2: + num_patches: 320 + num_vertices: 153 \ No newline at end of file diff --git a/config/SiT/training/hparams.yml b/config/SiT/training/hparams.yml index 8f11633..96c8c77 100644 --- a/config/SiT/training/hparams.yml +++ b/config/SiT/training/hparams.yml @@ -17,7 +17,7 @@ training: LR: 0.00001 bs: 256 bs_val: 1 - epochs: 30 + epochs: 100 gpu: 0 l1loss: False testing: False @@ -78,7 +78,4 @@ sub_ico_1: sub_ico_2: num_patches: 320 - num_vertices: 153 - - - + num_vertices: 153 \ No newline at end of file diff --git a/models/mpp.py b/models/mpp.py new file mode 100644 index 0000000..d273308 --- /dev/null +++ b/models/mpp.py @@ -0,0 +1,135 @@ +# -*- coding: utf-8 -*- +# @Author: Your name +# @Date: 1970-01-01 01:00:00 +# @Last Modified by: Your name +# @Last Modified time: 2022-02-14 17:50:22 +# +# Created on Mon Oct 18 2021 +# +# by Simon Dahan @SD3004 +# +# Copyright (c) 2021 MeTrICS Lab +# + + +import math +from random import random + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from einops import rearrange, repeat + + +def get_mask_from_prob(inputs, prob): + ''' + This function creates a mask on the sequence of tokens, per sample + Based on the probability of masking. + return: a boolean mask of the shape of the inputs. + ''' + batch, seq_len, _, device = *inputs.shape, inputs.device + max_masked = math.ceil(prob * seq_len) + + rand = torch.rand((batch, seq_len), device=device) + _, sampled_indices = rand.topk(max_masked, dim=-1) + + new_mask = torch.zeros((batch, seq_len), device=device) + new_mask.scatter_(1, sampled_indices, 1) + return new_mask.bool() + +def prob_mask_like(inputs, prob): + batch, seq_length, _ = inputs.shape + return torch.zeros((batch, seq_length)).float().uniform_(0, 1) < prob + + +class masked_patch_pretraining(nn.Module): + + def __init__( + self, + transformer, + dim_in, + dim_out, + device, + mask_prob=0.15, + replace_prob=0.5, + swap_prob=0.3, + channels=4, + num_vertices=561,): + + super().__init__() + self.transformer = transformer + + self.dim_out = dim_out + self.dim_in = dim_in + + self.to_original = nn.Linear(dim_in,dim_out) + self.to_original.to(device) + + self.mask_prob = mask_prob + self.replace_prob = replace_prob + self.swap_prob = swap_prob + + # token ids + self.mask_token = nn.Parameter(torch.randn(1, 1, channels * num_vertices)) + + + def forward(self, batch, **kwargs): + + transformer = self.transformer + + # clone original image for loss + batch = rearrange(batch, + 'b c n v -> b n (v c)') + + corrupted_sequence = get_mask_from_prob(batch, self.mask_prob) + + corrupted_batch = batch.clone().detach() + + #randomly swap patches in the sequence + if self.swap_prob > 0: + random_patch_sampling_prob = self.swap_prob / ( + 1 - self.replace_prob) + + random_patch_prob = prob_mask_like(batch, + random_patch_sampling_prob).to(corrupted_sequence.device) + + bool_random_patch_prob = corrupted_sequence * (random_patch_prob == True) + + random_patches = torch.randint(0, + batch.shape[1], + (batch.shape[0], batch.shape[1]), + device=batch.device) + #shuffle entierely masked_batch + randomized_input = corrupted_batch[ + torch.arange(corrupted_batch.shape[0]).unsqueeze(-1), + random_patches] + corrupted_batch[bool_random_patch_prob] = randomized_input[bool_random_patch_prob] + + tokens_to_mask = prob_mask_like(batch, self.replace_prob).to(corrupted_sequence.device) + + bool_mask_replace = (corrupted_sequence * tokens_to_mask) == True + corrupted_batch[bool_mask_replace] = self.mask_token.to(corrupted_sequence.device) + + # linear embedding of patches + corrupted_batch = transformer.to_patch_embedding[-1](corrupted_batch) + emb_masked_sequence = corrupted_batch.clone().detach() + + # add cls token to input sequence + b, n, _ = corrupted_batch.shape + cls_tokens = repeat(transformer.cls_token, '() n d -> b n d', b=b) + corrupted_batch = torch.cat((cls_tokens, corrupted_batch), dim=1) + + # add positional embeddings to input + corrupted_batch += transformer.pos_embedding[:, :(n + 1)] + corrupted_batch = transformer.dropout(corrupted_batch) + + # get generator output and get mpp loss + batch_out = transformer.transformer(corrupted_batch, **kwargs) + batch_out = self.to_original(batch_out[:,1:,:]) + + # compute loss + mpp_loss = F.mse_loss(batch_out[corrupted_sequence], batch[corrupted_sequence]) + + return mpp_loss, batch_out + diff --git a/tools/pretrain.py b/tools/pretrain.py new file mode 100644 index 0000000..0b363bc --- /dev/null +++ b/tools/pretrain.py @@ -0,0 +1,437 @@ +# -*- coding: utf-8 -*- +# @Author: Simon Dahan +# +# Created on Fri Oct 01 2021 +# +# by Simon Dahan @SD3004 +# +# Copyright (c) 2021 MeTrICS Lab +# + + +''' +This file implements the training procedure to train a SiT model. +Models can be either trained: + - from scratch + - from pretrained weights (after self-supervision or ImageNet for instance) +Models can be trained for two tasks: + - age at scan prediction + - birth age prediction + +Pretrained ImageNet models are downloaded from the Timm library. +''' + +import os +import argparse +import yaml +import sys +import timm +from datetime import datetime + + +sys.path.append('../') +sys.path.append('./') +sys.path.append('../../') + +import torch +import torch.nn as nn +import torch.optim as optim +import numpy as np +import pandas as pd + +from torch.optim.lr_scheduler import StepLR +from torch.optim.lr_scheduler import ReduceLROnPlateau + + +from models.sit import SiT +from models.mpp import masked_patch_pretraining + + +from warmup_scheduler import GradualWarmupScheduler + +from utils.utils import load_weights_imagenet + +from torch.utils.tensorboard import SummaryWriter + + +def train(config): + + gpu = config['training']['gpu'] + LR = config['training']['LR'] + use_l1loss = config['training']['l1loss'] + epochs = config['training']['epochs'] + val_epoch = config['training']['val_epoch'] + testing = config['training']['testing'] + bs = config['training']['bs'] + bs_val = config['training']['bs_val'] + configuration = config['data']['configuration'] + task = config['data']['task'] + + ico = config['resolution']['ico'] + sub_ico = config['resolution']['sub_ico'] + + data_path = config['data']['data_path'].format(task,configuration) + + folder_to_save_model = config['logging']['folder_to_save_model'] + + num_patches = config['sub_ico_{}'.format(sub_ico)]['num_patches'] + num_vertices = config['sub_ico_{}'.format(sub_ico)]['num_vertices'] + + device = torch.device("cuda:{}".format(gpu) if torch.cuda.is_available() else "cpu") + + print('') + print('#'*30) + print('##### Config #####') + print('#'*30) + print('') + + print(device) + print(data_path) + + ############################## + ###### DATASET ###### + ############################## + + print('') + print('#'*30) + print('##### Loading data#####') + print('#'*30) + print('') + + print('LOADING DATA: ICO {} - sub-res ICO {}'.format(ico,sub_ico)) + + #loading already processed and patched cortical surfaces. + + train_data = np.load(os.path.join(data_path,'train_data.npy')) + train_label = np.load(os.path.join(data_path,'train_labels.npy')) + + print('training data: {}'.format(train_data.shape)) + + val_data = np.load(os.path.join(data_path,'validation_data.npy')) + val_label = np.load(os.path.join(data_path,'validation_labels.npy')) + + print('validation data: {}'.format(val_data.shape)) + + train_data_dataset = torch.utils.data.TensorDataset(torch.from_numpy(train_data).float(), + torch.from_numpy(train_label).float()) + + train_loader = torch.utils.data.DataLoader(train_data_dataset, + batch_size = bs, + shuffle=True, + num_workers=16) + + val_data_dataset = torch.utils.data.TensorDataset(torch.from_numpy(val_data).float(), + torch.from_numpy(val_label).float()) + + + val_loader = torch.utils.data.DataLoader(val_data_dataset, + batch_size = bs_val, + shuffle=False, + num_workers=16) + + + if testing: + test_data = np.load(os.path.join(data_path,'test_data.npy')) + test_label = np.load(os.path.join(data_path,'test_labels.npy')).reshape(-1) + + print('testing data: {}'.format(test_data.shape)) + print('') + + test_data_dataset = torch.utils.data.TensorDataset(torch.from_numpy(test_data).float(), + torch.from_numpy(test_label).float()) + + test_loader = torch.utils.data.DataLoader(test_data_dataset, + batch_size = bs_val, + shuffle=False, + num_workers=16) + + + ############################## + ###### LOGGING ###### + ############################## + + # creating folders for logging. + try: + os.mkdir(folder_to_save_model) + print('Creating folder: {}'.format(folder_to_save_model)) + except OSError: + print('folder already exist: {}'.format(folder_to_save_model)) + + date = datetime.today().strftime('%Y-%m-%d-%H:%M:%S') + + # folder time + folder_to_save_model = os.path.join(folder_to_save_model,date) + print(folder_to_save_model) + if config['transformer']['dim'] == 192: + folder_to_save_model = folder_to_save_model + '-tiny' + elif config['transformer']['dim'] == 384: + folder_to_save_model = folder_to_save_model + '-small' + elif config['transformer']['dim'] == 768: + folder_to_save_model = folder_to_save_model + '-base' + + if config['training']['load_weights_imagenet']: + folder_to_save_model = folder_to_save_model + '-imgnet' + if config['training']['load_weights_ssl']: + folder_to_save_model = folder_to_save_model + '-ssl' + if config['training']['dataset_ssl']=='hcp': + folder_to_save_model = folder_to_save_model + '-hcp' + elif config['training']['dataset_ssl']=='dhcp-hcp': + folder_to_save_model = folder_to_save_model + '-dhcp-hcp' + elif config['training']['dataset_ssl']=='dhcp': + folder_to_save_model = folder_to_save_model + '-dhcp' + if config['training']['finetuning']: + folder_to_save_model = folder_to_save_model + '-finetune' + else: + folder_to_save_model = folder_to_save_model + '-freeze' + + try: + os.mkdir(folder_to_save_model) + print('Creating folder: {}'.format(folder_to_save_model)) + except OSError: + print('folder already exist: {}'.format(folder_to_save_model)) + + writer = SummaryWriter(log_dir=folder_to_save_model) + + + ############################## + ####### MODEL ####### + ############################## + + print('') + print('#'*30) + print('##### Init model #####') + print('#'*30) + print('') + + if config['transformer']['model'] == 'SiT': + + model = SiT(dim=config['transformer']['dim'], + depth=config['transformer']['depth'], + heads=config['transformer']['heads'], + mlp_dim=config['transformer']['mlp_dim'], + pool=config['transformer']['pool'], + num_patches=num_patches, + num_classes=config['transformer']['num_classes'], + num_channels=config['transformer']['num_channels'], + num_vertices=num_vertices, + dim_head=config['transformer']['dim_head'], + dropout=config['transformer']['dropout'], + emb_dropout=config['transformer']['emb_dropout']) + + if config['training']['load_weights_ssl']: + + print('Loading weights from self-supervision training') + model.load_state_dict(torch.load(config['weights']['ssl_mpp'],map_location=device),strict=False) + + if config['training']['load_weights_imagenet']: + + print('Loading weights from imagenet pretraining') + model_trained = timm.create_model(config['weights']['imagenet'], pretrained=True) + new_state_dict = load_weights_imagenet(model.state_dict(),model_trained.state_dict(),config['transformer']['depth']) + model.load_state_dict(new_state_dict) + + + model.to(device) + + print('Number of parameters encoder: {:,}'.format(sum(p.numel() for p in model.parameters() if p.requires_grad))) + print('') + + ################################################## + ####### SELF-SUPERVISION PIPELINE ####### + ################################################## + + if config['SSL'] == 'mpp': + + print('Pretrain using Masked Patch Prediction') + ssl = masked_patch_pretraining(transformer=model, + dim_in = config['transformer']['dim'], + dim_out= num_vertices*config['transformer']['num_channels'], + device=device, + mask_prob=config['pretraining_mpp']['mask_prob'], + replace_prob=config['pretraining_mpp']['replace_prob'], + swap_prob=config['pretraining_mpp']['swap_prob'], + num_vertices=num_vertices, + channels=config['transformer']['num_channels']) + else: + raise('not implemented yet') + + ssl.to(device) + + print('Number of parameters pretraining pipeline : {:,}'.format(sum(p.numel() for p in ssl.parameters() if p.requires_grad))) + print('') + + ##################################### + ####### OPTIMISATION ####### + ##################################### + + if config['optimisation']['optimiser']=='Adam': + print('using Adam optimiser') + optimizer = optim.Adam(model.parameters(), lr=LR, weight_decay=config['Adam']['weight_decay']) + elif config['optimisation']['optimiser']=='SGD': + print('using SGD optimiser') + optimizer = optim.SGD(model.parameters(), lr=LR, + weight_decay=config['SGD']['weight_decay'], + momentum=config['SGD']['momentum'], + nesterov=config['SGD']['nesterov']) + elif config['optimisation']['optimiser']=='AdamW': + print('using AdamW optimiser') + optimizer = optim.AdamW(model.parameters(), + lr=LR, + weight_decay=config['AdamW']['weight_decay']) + else: + raise('not implemented yet') + + ################################### + ####### SCHEDULING ####### + ################################### + + it_per_epoch = np.ceil(len(train_loader)) + + ################################## + ###### PRE-TRAINING ###### + ################################## + + print('') + print('#'*30) + print('#### Starting pre-training ###') + print('#'*30) + print('') + + best_val_loss = 100000000000 + c_early_stop = 0 + + for epoch in range(epochs): + + ssl.train() + + running_loss = 0 + + for i, data in enumerate(train_loader): + + inputs, _ = data[0].to(device), data[1].to(device) + + optimizer.zero_grad() + + if config['SSL'] == 'mpp': + mpp_loss, _ = ssl(inputs) + + mpp_loss.backward() + optimizer.step() + + running_loss += mpp_loss.item() + + writer.add_scalar('loss/train_it', mpp_loss.item(), epoch*it_per_epoch+1) + + ############################## + ######### LOG IT ########### + ############################## + + if (epoch+1)%5==0: + + print('| Epoch - {} | It - {} | Loss - {:.4f} | LR - {}'.format(epoch+1, epoch*it_per_epoch + i +1, running_loss / (i+1), optimizer.param_groups[0]['lr'])) + + loss_pretrain_epoch = running_loss / (i+1) + + writer.add_scalar('loss/train', loss_pretrain_epoch, epoch+1) + + + ############################## + ###### VALIDATION ###### + ############################## + + if (epoch+1)%val_epoch==0: + + running_val_loss = 0 + ssl.eval() + + with torch.no_grad(): + + for i, data in enumerate(val_loader): + + inputs, _ = data[0].to(device), data[1].to(device) + + if config['SSL'] == 'mpp': + mpp_loss, _ = ssl(inputs) + + running_val_loss += mpp_loss.item() + + loss_pretrain_val_epoch = running_val_loss /(i+1) + + writer.add_scalar('loss/val', loss_pretrain_val_epoch, epoch+1) + + print('| Validation | Epoch - {} | Loss - {} | '.format(epoch+1, loss_pretrain_val_epoch)) + + if loss_pretrain_val_epoch < best_val_loss: + best_val_loss = loss_pretrain_val_epoch + best_epoch = epoch+1 + c_early_stop = 0 + + config['results'] = {} + config['results']['best_epoch'] = best_epoch + config['results']['best_current_loss'] = loss_pretrain_epoch + config['results']['best_current_loss_validation'] = best_val_loss + + with open(os.path.join(folder_to_save_model,'hparams.yml'), 'w') as yaml_file: + yaml.dump(config, yaml_file) + + print('saving_model') + torch.save({ 'epoch':epoch+1, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss':loss_pretrain_epoch, + }, + os.path.join(folder_to_save_model, 'encoder-best.pt')) + torch.save({ 'epoch':epoch+1, + 'model_state_dict': ssl.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss':loss_pretrain_epoch, + }, + os.path.join(folder_to_save_model, 'encoder-decoder-best.pt')) + + print('') + print('Final results: best model obtained at epoch {} - loss {}'.format(best_epoch,best_val_loss)) + + config['logging']['folder_model_saved'] = folder_to_save_model + config['results']['final_loss'] = loss_pretrain_epoch + config['results']['training_finished'] = True + + with open(os.path.join(folder_to_save_model,'hparams.yml'), 'w') as yaml_file: + yaml.dump(config, yaml_file) + + + ##################################### + ###### SAVING FINAL CKPT ###### + ##################################### + + torch.save({'epoch':epoch+1, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss':loss_pretrain_epoch, + }, + os.path.join(folder_to_save_model,'encoder-final.pt')) + + torch.save({'epoch':epoch+1, + 'model_state_dict': ssl.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss':loss_pretrain_epoch, + }, + os.path.join(folder_to_save_model,'encoder-decoder-final.pt')) + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser(description='ViT') + + parser.add_argument( + 'config', + type=str, + default='./config/hparams.yml', + help='path where the data is stored') + + args = parser.parse_args() + + with open(args.config) as f: + config = yaml.safe_load(f) + + # Call training + train(config) diff --git a/tools/train.py b/tools/train.py index a5376cd..b1c820a 100644 --- a/tools/train.py +++ b/tools/train.py @@ -39,14 +39,8 @@ import numpy as np import pandas as pd -from torch.optim.lr_scheduler import StepLR -from torch.optim.lr_scheduler import ReduceLROnPlateau - - from models.sit import SiT -from warmup_scheduler import GradualWarmupScheduler - from utils.utils import load_weights_imagenet from torch.utils.tensorboard import SummaryWriter