-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_b0.py
115 lines (89 loc) · 3.52 KB
/
train_b0.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import torch
from torch import nn
from torch.optim import Adam
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader
from models import mlp
from models.mlp import eficientB0
from tqdm import tqdm
import os
from data.dataset import data_loader
from data import dataset
import argparse
import os
import json
import pickle
# import hydra
# from omegaconf import DictConfig, OmegaConf
# from hydra.utils import get_original_cwd, to_absolute_path
import logging
log = logging.getLogger(__name__)
# import wandb
import numpy as np
from tqdm import tqdm
# os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
# os.environ['CUDA_VISIBLE_DEVICES']= '4'
import torch.nn as nn
#python mytest.py --vit_dir /home/thao.nguyen/AI701B/SEViT/models/X-ray/m_best_model.pth --root_dir /home/thao.nguyen/AI701B/SEViT/data/X-ray --output_dir /home/thao.nguyen/AI701B/SEViT/models/X-ray/B0
#thao.nguyen
image_size = (224,224)
batch_size = 32
epochs = 70
lr = 1e-5
parser = argparse.ArgumentParser(description='Training MLPs')
parser.add_argument('--vit_dir', type=str , help='pass the path of downloaded ViT')
parser.add_argument('--root_dir', type=str, help='pass the path of downloaded data')
parser.add_argument('--output_dir', type=str, help='pass the path of output')
args = parser.parse_args()
root_dir=args.root_dir
vit_dir=args.vit_dir
output_dir=args.output_dir
vit = torch.load(vit_dir).cuda()
vit.eval()
for w in vit.parameters():
w.requires_grad = False
# criterion = nn.CrossEntropyLoss()
# optimizer = optim.Adam(model.parameters(), lr = lr,betas=(0.9, 0.99))
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=150, eta_min=0)
data_loader, image_dataset = data_loader(root_dir=root_dir, batch_size= batch_size, image_size=image_size)
for index in range(5):
print(f'***Block {index+1}***')
classifier = eficientB0(num_classes=2,vit=vit, block_num=index).cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(classifier.parameters(), lr = lr,betas=(0.9, 0.99))
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=150, eta_min=0)
#scheduler = StepLR(optimizer=optimizer, step_size=15, gamma=0.1, verbose=True)
classifier.train()
best_accuracy=0
for epoch in range(epochs):
#Training
for images, labels in tqdm(data_loader['train'],desc = 'Epoch: {}/{}'.format(epoch + 1, epochs)):
images = images.cuda()
labels = labels.cuda()
optimizer.zero_grad()
outputs = classifier(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
scheduler.step()
#Validation
classifier.eval()
accuracy = 0
total = 0
with torch.no_grad():
for images, labels in tqdm(data_loader['test']):
images = images.cuda()
labels = labels.cuda()
optimizer.zero_grad()
outputs = classifier(images)
loss = criterion(outputs, labels)
prediction = torch.argmax(outputs , dim= -1)
total += labels.size(0)
accuracy += (prediction == labels).sum().item()
accuracy=accuracy/total
if accuracy > best_accuracy:
best_accuracy=accuracy
path = os.path.join(output_dir, f"m_best_model_b0_block_{index}.pth")
torch.save(classifier, path)
print(f"Best accuracy {index+1} is updated to ... {best_accuracy}")