forked from pyaf/DenseNet-MURA-PyTorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
55 lines (44 loc) · 1.86 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import time
import copy
import pandas as pd
import torch
from torch.autograd import Variable
from densenet import densenet169
from utils import plot_training, n_p, get_count
from train import train_model, get_metrics
from pipeline import get_study_level_data, get_dataloaders
# #### load study level dict data
study_data = get_study_level_data(study_type='XR_WRIST')
# #### Create dataloaders pipeline
data_cat = ['train', 'valid'] # data categories
dataloaders = get_dataloaders(study_data, batch_size=1)
dataset_sizes = {x: len(study_data[x]) for x in data_cat}
# #### Build model
# tai = total abnormal images, tni = total normal images
tai = {x: get_count(study_data[x], 'positive') for x in data_cat}
tni = {x: get_count(study_data[x], 'negative') for x in data_cat}
Wt1 = {x: n_p(tni[x] / (tni[x] + tai[x])) for x in data_cat}
Wt0 = {x: n_p(tai[x] / (tni[x] + tai[x])) for x in data_cat}
print('tai:', tai)
print('tni:', tni, '\n')
print('Wt0 train:', Wt0['train'])
print('Wt0 valid:', Wt0['valid'])
print('Wt1 train:', Wt1['train'])
print('Wt1 valid:', Wt1['valid'])
class Loss(torch.nn.modules.Module):
def __init__(self, Wt1, Wt0):
super(Loss, self).__init__()
self.Wt1 = Wt1
self.Wt0 = Wt0
def forward(self, inputs, targets, phase):
loss = - (self.Wt1[phase] * targets * inputs.log() + self.Wt0[phase] * (1 - targets) * (1 - inputs).log())
return loss
model = densenet169(pretrained=True)
model = model.cuda()
criterion = Loss(Wt1, Wt0)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=1, verbose=True)
# #### Train model
model = train_model(model, criterion, optimizer, dataloaders, scheduler, dataset_sizes, num_epochs=5)
torch.save(model.state_dict(), 'models/model.pth')
get_metrics(model, criterion, dataloaders, dataset_sizes)