Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions EVA_ViT/Readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,28 @@ conda activate swift

### Train ViT from scratch
```
torchrun --standalone --nnodes=1 --nproc_per_node=1 main_EVA_ViT.py --model evavit_giant_patch16_3D --optim AdamW --lr 1e-4 --epoch 200 --exp_name test --batch_size 16 --accumulation_steps 16 --study_sample GARD_T1 --cat_target sex
torchrun --standalone --nnodes=1 --nproc_per_node=1 main_EVA_ViT.py --model evavit_giant_patch16_3D --optim AdamW --lr 1e-4 --epoch 200 --exp_name test --batch_size 16 --accumulation_steps 16 --study_sample GARD_T1 --cat_target sex --img_size 96 96 96 --patch_size 16
```


### Finetuning Pretrained ViT
### Finetuning Pretrained ViT (w/ freeze encoder)
```
torchrun --standalone --nnodes=1 --nproc_per_node=1 main_EVA_ViT.py --model evavit_giant_patch16_3D --optim AdamW --lr 1e-4 --epoch 200 --exp_name test --batch_size 16 --accumulation_steps 16 --study_sample GARD_T1 --cat_target sex --use_pretrained_weight
torchrun --standalone --nnodes=1 --nproc_per_node=1 main_EVA_ViT.py --model evavit_giant_patch16_3D --optim AdamW --lr 1e-4 --epoch 200 --exp_name test --batch_size 16 --accumulation_steps 16 --study_sample GARD_T1 --cat_target sex --img_size 96 96 96 --patch_size 16 --use_pretrained_weight --freeze_encoder
```

### Finetuning Pretrained ViT with additional projector layers
### Finetuning Pretrained ViT (w/o freeze encoder)
```
torchrun --standalone --nnodes=1 --nproc_per_node=1 main_EVA_ViT.py --model evavit_giant_patch16_3D --optim AdamW --lr 1e-4 --epoch 200 --exp_name test --batch_size 16 --accumulation_steps 16 --study_sample GARD_T1 --cat_target sex --use_pretrained_weight --use_projector
torchrun --standalone --nnodes=1 --nproc_per_node=1 main_EVA_ViT.py --model evavit_giant_patch16_3D --optim AdamW --lr 1e-4 --epoch 200 --exp_name test --batch_size 16 --accumulation_steps 16 --study_sample GARD_T1 --cat_target sex --img_size 96 96 96 --patch_size 16 --use_pretrained_weight
```


### Finetuning Pretrained ViT (w/o freeze encoder) with additional projector layers
```
torchrun --standalone --nnodes=1 --nproc_per_node=1 main_EVA_ViT.py --model evavit_giant_patch16_3D --optim AdamW --lr 1e-4 --epoch 200 --exp_name test --batch_size 16 --accumulation_steps 16 --study_sample GARD_T1 --cat_target sex --img_size 96 96 96 --patch_size 16 --use_pretrained_weight --use_projector
```

### Finetuning Pretrained ViT (w/ freeze encoder) with additional projector layers
```
torchrun --standalone --nnodes=1 --nproc_per_node=1 main_EVA_ViT.py --model evavit_giant_patch16_3D --optim AdamW --lr 1e-4 --epoch 200 --exp_name test --batch_size 16 --accumulation_steps 16 --study_sample GARD_T1 --cat_target sex --img_size 96 96 96 --patch_size 16 --use_pretrained_weight --use_projector --freeze_encoder
```

6 changes: 3 additions & 3 deletions EVA_ViT/dataloaders/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,9 @@ def partition_dataset_finetuning(imageFiles_labels, args):
images.append(image)
labels.append(label)

ratio = 0.3
patch_size = (8, 8, 8)
num_patches = (args.img_size[0] // patch_size[0]) + (args.img_size[1] // patch_size[1]) + (args.img_size[2] // patch_size[2])
patch_size = args.patch_size
num_patches = (args.img_size[0] // patch_size) + (args.img_size[1] // patch_size) + (args.img_size[2] // patch_size)
#num_patches = (args.img_size[0] // patch_size[0]) + (args.img_size[1] // patch_size[1]) + (args.img_size[2] // patch_size[2])

train_transform = Compose([AddChannel(),
Resize(tuple(args.img_size)),
Expand Down
50 changes: 34 additions & 16 deletions EVA_ViT/envs/finetuning_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,13 +185,16 @@ def train_experiment(partition, num_classes, save_dir, args): #in_channels,out_d
if '3d' in name:
param.requires_grad = True
else:
for name, param in net.named_parameters():
if 'blocks' in name:
param.requires_grad = False
elif 'cls_' in name:
param.requires_grad = False
if args.freeze_encoder:
for name, param in net.named_parameters():
if 'blocks' in name:
param.requires_grad = False
elif 'cls_' in name:
param.requires_grad = False


# check which modules are learnable
for name, param in net.named_parameters():
print(f"{name} required_grad is set to {param.requires_grad}")


# setting optimizer
Expand Down Expand Up @@ -255,12 +258,16 @@ def train_experiment(partition, num_classes, save_dir, args): #in_channels,out_d
val_losses = []

previous_performance = {}
previous_performance['ACC'] = [0.0]
previous_performance['abs_loss'] = [100000.0]
previous_performance['mse_loss'] = [100000.0]
if num_classes == 2:

previous_performance = {}
if num_classes >= 2:
previous_performance['ACC'] = [0.0]
previous_performance['AUROC'] = [0.0]
elif num_classes == 1:
previous_performance['abs_loss'] = [100000.0]
previous_performance['mse_loss'] = [100000.0]
previous_performance['r_square'] = [-10000.0]

best_checkpoint_dir = None

# training
for epoch in tqdm(range(last_epoch, last_epoch + args.epoch)):
Expand All @@ -278,26 +285,36 @@ def train_experiment(partition, num_classes, save_dir, args): #in_channels,out_d
# store result per epoch
train_losses.append(train_loss)
val_losses.append(val_loss)

print('Epoch {}. Train Loss: {:2.2f}. Validation Loss: {:2.2f}. \n Training Prediction Performance: {}. \n Validation Prediction Performance: {}. \n Current learning rate {}. Took {:2.2f} sec'.format(epoch+1, train_loss, val_loss, train_performance, val_performance, optimizer.param_groups[0]['lr'],te-ts))
if 'ACC' or 'AUROC' in val_performance.keys():
if args.metric == 'ACC':
previous_performance['ACC'].append(val_performance['ACC'])
if val_performance['ACC'] > max(previous_performance['ACC'][:-1]):
checkpoint_save(net, optimizer, checkpoint_dir, epoch, scheduler, scaler, args, val_performance,mode='finetune')
checkpoint_save(net, optimizer, checkpoint_dir.replace(".pt", f"_epoch{epoch}.pt"), epoch, scheduler, scaler, args, val_performance,mode='finetune')
best_checkpoint_dir = copy.copy(checkpoint_dir.replace(".pt", f"_epoch{epoch}.pt"))
elif args.metric == 'AUROC':
previous_performance['AUROC'].append(val_performance['AUROC'])
if val_performance['AUROC'] > max(previous_performance['AUROC'][:-1]):
checkpoint_save(net, optimizer, checkpoint_dir, epoch, scheduler, scaler, args, val_performance,mode='finetune')
checkpoint_save(net, optimizer, checkpoint_dir.replace(".pt", f"_epoch{epoch}.pt"), epoch, scheduler, scaler, args, val_performance,mode='finetune')
best_checkpoint_dir = copy.copy(checkpoint_dir.replace(".pt", f"_epoch{epoch}.pt"))

if 'abs_loss' or 'mse_loss' in val_performance.keys():
if args.metric == 'abs_loss':
previous_performance['abs_loss'].append(val_performance['abs_loss'])
if val_performance['abs_loss'] < min(previous_performance['abs_loss'][:-1]):
checkpoint_save(net, optimizer, checkpoint_dir, epoch, scheduler, scaler, args, val_performance,mode='finetune')
checkpoint_save(net, optimizer, checkpoint_dir.replace(".pt", f"_epoch{epoch}.pt"), epoch, scheduler, scaler, args, val_performance,mode='finetune')
best_checkpoint_dir = copy.copy(checkpoint_dir.replace(".pt", f"_epoch{epoch}.pt"))
elif args.metric == 'mse_loss':
previous_performance['mse_loss'].append(val_performance['mse_loss'])
if val_performance['mse_loss'] < min(previous_performance['mse_loss'][:-1]):
checkpoint_save(net, optimizer, checkpoint_dir, epoch, scheduler, scaler, args, val_performance,mode='finetune')
checkpoint_save(net, optimizer, checkpoint_dir.replace(".pt", f"_epoch{epoch}.pt"), epoch, scheduler, scaler, args, val_performance,mode='finetune')
best_checkpoint_dir = copy.copy(checkpoint_dir.replace(".pt", f"_epoch{epoch}.pt"))
elif args.metric == 'r_square':
previous_performance['r_square'].append(val_performance['r_square'])
if val_performance['r_square'] > max(previous_performance['r_square'][:-1]):
checkpoint_save(net, optimizer, checkpoint_dir.replace(".pt", f"_epoch{epoch}.pt"), epoch, scheduler, scaler, args, val_performance,mode='finetune')
best_checkpoint_dir = copy.copy(checkpoint_dir.replace(".pt", f"_epoch{epoch}.pt"))


torch.cuda.empty_cache()
Expand All @@ -308,7 +325,8 @@ def train_experiment(partition, num_classes, save_dir, args): #in_channels,out_d
result['train_losses'] = train_losses
result['validation_losses'] = val_losses

return vars(args), result, checkpoint_dir
#return vars(args), result, checkpoint_dir
return vars(args), result, best_checkpoint_dir


## ==================================== ##
3 changes: 3 additions & 0 deletions EVA_ViT/main_EVA_ViT.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@
parser.set_defaults(use_pretrained_weight=False)
parser.add_argument('--use_projector', action='store_true', help='Using adapter layers after the backbone network')
parser.set_defaults(use_projector=False)
parser.add_argument('--freeze_encoder', action='store_true', help='Freeze the encoder including attention blocks and cls tokens')
parser.set_defaults(freeze_encoder=False)


#########################
Expand Down Expand Up @@ -98,6 +100,7 @@
parser.add_argument("--attention_drop",default=0.5,type=float,required=False,help='dropout rate of encoder attention layer')
parser.add_argument("--projection_drop",default=0.5,type=float,required=False,help='dropout rate of encoder projection layer')
parser.add_argument("--path_drop",default=0.1,type=float,required=False,help='dropout rate of encoder attention block')
parser.add_argument("--patch_size",default=8,type=int,required=False,help='size of patchifying layer. Isotropic.')

#parser.add_argument("--mask_ratio",required=False,default=0.75,type=float,help='the ratio of random masking')
#parser.add_argument("--norm_pix_loss",action='store_true',help='Use (per-patch) normalized pixels as targets for computing loss')
Expand Down
15 changes: 8 additions & 7 deletions EVA_ViT/model/model_EvaViT.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,12 +322,13 @@ def __init__(self,
init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None, use_lora=use_lora)
for i in range(depth)])
self.use_projector = use_projector
self.projectors = nn.ModuleList([
Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None, use_lora=False)
for i in range(2)])
if self.use_projector:
self.projectors = nn.ModuleList([
Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None, use_lora=False)
for i in range(2)])

if self.pos_embed_3d is not None:
trunc_normal_(self.pos_embed_3d, std=.02)
Expand Down Expand Up @@ -488,4 +489,4 @@ def create_eva_vit_g(img_size=256,patch_size=16,drop_path_rate=0.4,use_checkpoin


## set recommended archs
evavit_giant_patch16_3D = create_eva_vit_g
evavit_giant_patch16_3D = create_eva_vit_g
11 changes: 8 additions & 3 deletions EVA_ViT/util/loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,17 +86,22 @@ def get_result(self):
total, correct = total.sum().detach().cpu(), correct.sum().detach().cpu()
result['total'], result['correct'] = total.item(), correct.item()
result['ACC'] = 100 * correct.item() / total.item()
true = true.long()
pred = torch.softmax(pred, dim=1)
if self.num_classes == 2:
true, pred = torch.cat(true), torch.cat(pred)
true = true.long()
result['AUROC'] = roc_auc_score(true.detach().cpu(), pred[:, 1].detach().cpu())
else:
result['AUROC'] = roc_auc_score(true.detach().cpu(), pred.detach().cpu(), multi_class='ovr', average='macro')
else:
result['total'] = self.total.sum().item()
result['correct'] = self.correct.sum().item()
result['ACC'] = 100 * self.correct.sum().item() /self.total.sum().item()
self.true = self.true.long()
self.pred = torch.softmax(self.pred, dim=1)
if self.num_classes ==2:
self.true = self.true.long()
result['AUROC'] = roc_auc_score(self.true.detach().cpu(), self.pred[:, 1].detach().cpu())
else:
result['AUROC'] = roc_auc_score(self.true.detach().cpu(), self.pred.detach().cpu(), multi_class='ovr', average='macro')

elif self.num_classes == 1:
if self.is_DDP:
Expand Down