forked from Thvnvtos/Lung_Segmentation
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
89 lines (78 loc) · 3.55 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import numpy as np
import pickle, json
import torch
from torch.utils import data
import torch.optim as optim
import torch.nn as nn
from torch.nn import functional as F
import model
from data import *
torch.backends.cudnn.benchmark = True
with open("config.json") as f:
config = json.load(f)
device = torch.device("cuda:0")
with open(config["path"]["labelled_list"], "rb") as f:
list_scans = pickle.load(f)
st_scans = [s.split('/')[1] for s in list_scans]
if config["mode"] == "3d":
train_scans = st_scans[:config["train3d"]["train_size"]]
val_scans = st_scans[config["train3d"]["train_size"]:]
train_data = dataset.Dataset(train_scans, config["path"]["scans"], config["path"]["masks"], mode="3d", scan_size = config["train3d"]["scan_size"], n_classes = config["train3d"]["n_classes"])
val_data = dataset.Dataset(val_scans, config["path"]["scans"], config["path"]["masks"], mode = "3d", scan_size = config["train3d"]["scan_size"])
unet = model.UNet(1,config["train3d"]["n_classes"], config["train3d"]["start_filters"], bilinear = False).to(device)
criterion = utils.dice_loss
optimizer = optim.Adam(unet.parameters(), lr = config["train3d"]["lr"])
batch_size = config["train3d"]["batch_size"]
epochs = config["train3d"]["epochs"]
val_steps = config["train3d"]["validation_steps"]
val_size = config["train3d"]["validation_size"]
else:
st_scans = st_scans[:config["train2d"]["train_size"]]
dataset = dataset.Dataset(st_scans, config["path"]["scans"], config["path"]["masks"], mode = "2d")
unet = model.UNet(1,1, config["train2d"]["start_filters"], bilinear = True).to(device)
criterion = utils.dice_loss
optimizer = optim.Adam(unet.parameters(), lr = config["train2d"]["lr"])
batch_size = config["train2d"]["batch_size"]
slices_per_batch = config["train2d"]["slices_per_batch"]
neg = config["train2d"]["neg_examples_per_batch"]
epochs = config["train2d"]["epochs"]
best_val_loss = 1e16
for epoch in range(epochs):
epoch_loss = 0
for i in range(0, len(train_data), batch_size):
batch_loss = 0
batch = np.array([train_data.__getitem__(j)[0] for j in range(i, i+batch_size)]).astype(np.float16)
labels = np.array([train_data.__getitem__(j)[1] for j in range(i, i+batch_size)]).astype(np.float16)
batch = torch.Tensor(batch).to(device)
labels = torch.Tensor(labels).to(device)
batch.requires_grad = True
labels.requires_grad = True
optimizer.zero_grad()
logits = unet(batch)
loss = criterion(logits, labels)
loss.backward()
optimizer.step()
print("Epoch {} ==> Batch {} mean loss : {}".format(epoch+1, (i+1)%(val_steps), loss.item()/batch_size))
epoch_loss += loss.item()/batch_size
del batch
del labels
torch.cuda.empty_cache()
if (i+1)%val_steps == 0:
print("===================> Calculating validation loss ... ")
ids = np.random.randint(0, len(val_data), val_size)
val_loss = 0
for scan_id in ids:
batch = np.array([val_data.__getitem__(j)[0] for j in range(scan_id, scan_id+batch_size)]).astype(np.float16)
labels = np.array([val_data.__getitem__(j)[1] for j in range(scan_id, scan_id+batch_size)]).astype(np.float16)
batch = torch.Tensor(batch).to(device)
labels = torch.Tensor(labels).to(device)
logits = unet(batch)
loss = criterion(logits, labels)
val_loss += loss.item()
val_loss /= val_size
print("\n # Validation Loss : ", val_loss)
if val_loss < best_val_loss:
print("\nSaving Better Model... ")
torch.save(unet.state_dict(), "./model")
best_val_loss = val_loss
print("\n")