diff --git a/EVA_ViT/Readme.md b/EVA_ViT/Readme.md index 248684d..a129914 100644 --- a/EVA_ViT/Readme.md +++ b/EVA_ViT/Readme.md @@ -5,16 +5,28 @@ conda activate swift ### Train ViT from scratch ``` -torchrun --standalone --nnodes=1 --nproc_per_node=1 main_EVA_ViT.py --model evavit_giant_patch16_3D --optim AdamW --lr 1e-4 --epoch 200 --exp_name test --batch_size 16 --accumulation_steps 16 --study_sample GARD_T1 --cat_target sex +torchrun --standalone --nnodes=1 --nproc_per_node=1 main_EVA_ViT.py --model evavit_giant_patch16_3D --optim AdamW --lr 1e-4 --epoch 200 --exp_name test --batch_size 16 --accumulation_steps 16 --study_sample GARD_T1 --cat_target sex --img_size 96 96 96 --patch_size 16 ``` -### Finetuning Pretrained ViT +### Finetuning Pretrained ViT (w/ freeze encoder) ``` -torchrun --standalone --nnodes=1 --nproc_per_node=1 main_EVA_ViT.py --model evavit_giant_patch16_3D --optim AdamW --lr 1e-4 --epoch 200 --exp_name test --batch_size 16 --accumulation_steps 16 --study_sample GARD_T1 --cat_target sex --use_pretrained_weight +torchrun --standalone --nnodes=1 --nproc_per_node=1 main_EVA_ViT.py --model evavit_giant_patch16_3D --optim AdamW --lr 1e-4 --epoch 200 --exp_name test --batch_size 16 --accumulation_steps 16 --study_sample GARD_T1 --cat_target sex --img_size 96 96 96 --patch_size 16 --use_pretrained_weight --freeze_encoder ``` -### Finetuning Pretrained ViT with additional projector layers +### Finetuning Pretrained ViT (w/o freeze encoder) ``` -torchrun --standalone --nnodes=1 --nproc_per_node=1 main_EVA_ViT.py --model evavit_giant_patch16_3D --optim AdamW --lr 1e-4 --epoch 200 --exp_name test --batch_size 16 --accumulation_steps 16 --study_sample GARD_T1 --cat_target sex --use_pretrained_weight --use_projector +torchrun --standalone --nnodes=1 --nproc_per_node=1 main_EVA_ViT.py --model evavit_giant_patch16_3D --optim AdamW --lr 1e-4 --epoch 200 --exp_name test --batch_size 16 --accumulation_steps 16 --study_sample GARD_T1 --cat_target sex --img_size 96 96 96 --patch_size 16 --use_pretrained_weight ``` + + +### Finetuning Pretrained ViT (w/o freeze encoder) with additional projector layers +``` +torchrun --standalone --nnodes=1 --nproc_per_node=1 main_EVA_ViT.py --model evavit_giant_patch16_3D --optim AdamW --lr 1e-4 --epoch 200 --exp_name test --batch_size 16 --accumulation_steps 16 --study_sample GARD_T1 --cat_target sex --img_size 96 96 96 --patch_size 16 --use_pretrained_weight --use_projector +``` + +### Finetuning Pretrained ViT (w/ freeze encoder) with additional projector layers +``` +torchrun --standalone --nnodes=1 --nproc_per_node=1 main_EVA_ViT.py --model evavit_giant_patch16_3D --optim AdamW --lr 1e-4 --epoch 200 --exp_name test --batch_size 16 --accumulation_steps 16 --study_sample GARD_T1 --cat_target sex --img_size 96 96 96 --patch_size 16 --use_pretrained_weight --use_projector --freeze_encoder +``` + diff --git a/EVA_ViT/dataloaders/dataloaders.py b/EVA_ViT/dataloaders/dataloaders.py index c8f664c..11a0b8b 100644 --- a/EVA_ViT/dataloaders/dataloaders.py +++ b/EVA_ViT/dataloaders/dataloaders.py @@ -203,9 +203,9 @@ def partition_dataset_finetuning(imageFiles_labels, args): images.append(image) labels.append(label) - ratio = 0.3 - patch_size = (8, 8, 8) - num_patches = (args.img_size[0] // patch_size[0]) + (args.img_size[1] // patch_size[1]) + (args.img_size[2] // patch_size[2]) + patch_size = args.patch_size + num_patches = (args.img_size[0] // patch_size) + (args.img_size[1] // patch_size) + (args.img_size[2] // patch_size) + #num_patches = (args.img_size[0] // patch_size[0]) + (args.img_size[1] // patch_size[1]) + (args.img_size[2] // patch_size[2]) train_transform = Compose([AddChannel(), Resize(tuple(args.img_size)), diff --git a/EVA_ViT/envs/finetuning_experiments.py b/EVA_ViT/envs/finetuning_experiments.py index 4896664..a22b69d 100644 --- a/EVA_ViT/envs/finetuning_experiments.py +++ b/EVA_ViT/envs/finetuning_experiments.py @@ -185,13 +185,16 @@ def train_experiment(partition, num_classes, save_dir, args): #in_channels,out_d if '3d' in name: param.requires_grad = True else: - for name, param in net.named_parameters(): - if 'blocks' in name: - param.requires_grad = False - elif 'cls_' in name: - param.requires_grad = False + if args.freeze_encoder: + for name, param in net.named_parameters(): + if 'blocks' in name: + param.requires_grad = False + elif 'cls_' in name: + param.requires_grad = False - + # check which modules are learnable + for name, param in net.named_parameters(): + print(f"{name} required_grad is set to {param.requires_grad}") # setting optimizer @@ -255,12 +258,16 @@ def train_experiment(partition, num_classes, save_dir, args): #in_channels,out_d val_losses = [] previous_performance = {} - previous_performance['ACC'] = [0.0] - previous_performance['abs_loss'] = [100000.0] - previous_performance['mse_loss'] = [100000.0] - if num_classes == 2: - + previous_performance = {} + if num_classes >= 2: + previous_performance['ACC'] = [0.0] previous_performance['AUROC'] = [0.0] + elif num_classes == 1: + previous_performance['abs_loss'] = [100000.0] + previous_performance['mse_loss'] = [100000.0] + previous_performance['r_square'] = [-10000.0] + + best_checkpoint_dir = None # training for epoch in tqdm(range(last_epoch, last_epoch + args.epoch)): @@ -278,26 +285,36 @@ def train_experiment(partition, num_classes, save_dir, args): #in_channels,out_d # store result per epoch train_losses.append(train_loss) val_losses.append(val_loss) + print('Epoch {}. Train Loss: {:2.2f}. Validation Loss: {:2.2f}. \n Training Prediction Performance: {}. \n Validation Prediction Performance: {}. \n Current learning rate {}. Took {:2.2f} sec'.format(epoch+1, train_loss, val_loss, train_performance, val_performance, optimizer.param_groups[0]['lr'],te-ts)) if 'ACC' or 'AUROC' in val_performance.keys(): if args.metric == 'ACC': previous_performance['ACC'].append(val_performance['ACC']) if val_performance['ACC'] > max(previous_performance['ACC'][:-1]): - checkpoint_save(net, optimizer, checkpoint_dir, epoch, scheduler, scaler, args, val_performance,mode='finetune') + checkpoint_save(net, optimizer, checkpoint_dir.replace(".pt", f"_epoch{epoch}.pt"), epoch, scheduler, scaler, args, val_performance,mode='finetune') + best_checkpoint_dir = copy.copy(checkpoint_dir.replace(".pt", f"_epoch{epoch}.pt")) elif args.metric == 'AUROC': previous_performance['AUROC'].append(val_performance['AUROC']) if val_performance['AUROC'] > max(previous_performance['AUROC'][:-1]): - checkpoint_save(net, optimizer, checkpoint_dir, epoch, scheduler, scaler, args, val_performance,mode='finetune') + checkpoint_save(net, optimizer, checkpoint_dir.replace(".pt", f"_epoch{epoch}.pt"), epoch, scheduler, scaler, args, val_performance,mode='finetune') + best_checkpoint_dir = copy.copy(checkpoint_dir.replace(".pt", f"_epoch{epoch}.pt")) if 'abs_loss' or 'mse_loss' in val_performance.keys(): if args.metric == 'abs_loss': previous_performance['abs_loss'].append(val_performance['abs_loss']) if val_performance['abs_loss'] < min(previous_performance['abs_loss'][:-1]): - checkpoint_save(net, optimizer, checkpoint_dir, epoch, scheduler, scaler, args, val_performance,mode='finetune') + checkpoint_save(net, optimizer, checkpoint_dir.replace(".pt", f"_epoch{epoch}.pt"), epoch, scheduler, scaler, args, val_performance,mode='finetune') + best_checkpoint_dir = copy.copy(checkpoint_dir.replace(".pt", f"_epoch{epoch}.pt")) elif args.metric == 'mse_loss': previous_performance['mse_loss'].append(val_performance['mse_loss']) if val_performance['mse_loss'] < min(previous_performance['mse_loss'][:-1]): - checkpoint_save(net, optimizer, checkpoint_dir, epoch, scheduler, scaler, args, val_performance,mode='finetune') + checkpoint_save(net, optimizer, checkpoint_dir.replace(".pt", f"_epoch{epoch}.pt"), epoch, scheduler, scaler, args, val_performance,mode='finetune') + best_checkpoint_dir = copy.copy(checkpoint_dir.replace(".pt", f"_epoch{epoch}.pt")) + elif args.metric == 'r_square': + previous_performance['r_square'].append(val_performance['r_square']) + if val_performance['r_square'] > max(previous_performance['r_square'][:-1]): + checkpoint_save(net, optimizer, checkpoint_dir.replace(".pt", f"_epoch{epoch}.pt"), epoch, scheduler, scaler, args, val_performance,mode='finetune') + best_checkpoint_dir = copy.copy(checkpoint_dir.replace(".pt", f"_epoch{epoch}.pt")) torch.cuda.empty_cache() @@ -308,7 +325,8 @@ def train_experiment(partition, num_classes, save_dir, args): #in_channels,out_d result['train_losses'] = train_losses result['validation_losses'] = val_losses - return vars(args), result, checkpoint_dir + #return vars(args), result, checkpoint_dir + return vars(args), result, best_checkpoint_dir ## ==================================== ## diff --git a/EVA_ViT/main_EVA_ViT.py b/EVA_ViT/main_EVA_ViT.py index 01be65a..6f26779 100644 --- a/EVA_ViT/main_EVA_ViT.py +++ b/EVA_ViT/main_EVA_ViT.py @@ -71,6 +71,8 @@ parser.set_defaults(use_pretrained_weight=False) parser.add_argument('--use_projector', action='store_true', help='Using adapter layers after the backbone network') parser.set_defaults(use_projector=False) +parser.add_argument('--freeze_encoder', action='store_true', help='Freeze the encoder including attention blocks and cls tokens') +parser.set_defaults(freeze_encoder=False) ######################### @@ -98,6 +100,7 @@ parser.add_argument("--attention_drop",default=0.5,type=float,required=False,help='dropout rate of encoder attention layer') parser.add_argument("--projection_drop",default=0.5,type=float,required=False,help='dropout rate of encoder projection layer') parser.add_argument("--path_drop",default=0.1,type=float,required=False,help='dropout rate of encoder attention block') +parser.add_argument("--patch_size",default=8,type=int,required=False,help='size of patchifying layer. Isotropic.') #parser.add_argument("--mask_ratio",required=False,default=0.75,type=float,help='the ratio of random masking') #parser.add_argument("--norm_pix_loss",action='store_true',help='Use (per-patch) normalized pixels as targets for computing loss') diff --git a/EVA_ViT/model/model_EvaViT.py b/EVA_ViT/model/model_EvaViT.py index ef499a5..822fa7c 100644 --- a/EVA_ViT/model/model_EvaViT.py +++ b/EVA_ViT/model/model_EvaViT.py @@ -322,12 +322,13 @@ def __init__(self, init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None, use_lora=use_lora) for i in range(depth)]) self.use_projector = use_projector - self.projectors = nn.ModuleList([ - Block( - dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, - init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None, use_lora=False) - for i in range(2)]) + if self.use_projector: + self.projectors = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, + init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None, use_lora=False) + for i in range(2)]) if self.pos_embed_3d is not None: trunc_normal_(self.pos_embed_3d, std=.02) @@ -488,4 +489,4 @@ def create_eva_vit_g(img_size=256,patch_size=16,drop_path_rate=0.4,use_checkpoin ## set recommended archs -evavit_giant_patch16_3D = create_eva_vit_g \ No newline at end of file +evavit_giant_patch16_3D = create_eva_vit_g diff --git a/EVA_ViT/util/loss_functions.py b/EVA_ViT/util/loss_functions.py index 1f54d65..5f6e06a 100644 --- a/EVA_ViT/util/loss_functions.py +++ b/EVA_ViT/util/loss_functions.py @@ -86,17 +86,22 @@ def get_result(self): total, correct = total.sum().detach().cpu(), correct.sum().detach().cpu() result['total'], result['correct'] = total.item(), correct.item() result['ACC'] = 100 * correct.item() / total.item() + true = true.long() + pred = torch.softmax(pred, dim=1) if self.num_classes == 2: - true, pred = torch.cat(true), torch.cat(pred) - true = true.long() result['AUROC'] = roc_auc_score(true.detach().cpu(), pred[:, 1].detach().cpu()) + else: + result['AUROC'] = roc_auc_score(true.detach().cpu(), pred.detach().cpu(), multi_class='ovr', average='macro') else: result['total'] = self.total.sum().item() result['correct'] = self.correct.sum().item() result['ACC'] = 100 * self.correct.sum().item() /self.total.sum().item() + self.true = self.true.long() + self.pred = torch.softmax(self.pred, dim=1) if self.num_classes ==2: - self.true = self.true.long() result['AUROC'] = roc_auc_score(self.true.detach().cpu(), self.pred[:, 1].detach().cpu()) + else: + result['AUROC'] = roc_auc_score(self.true.detach().cpu(), self.pred.detach().cpu(), multi_class='ovr', average='macro') elif self.num_classes == 1: if self.is_DDP: