-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain.py
123 lines (110 loc) · 4.9 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from tqdm import tqdm
import numpy as np
import torch
# training
def train1Epoch(epoch_index, model, optimizer, loss_fn, training_loader, model_type):
model.train()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
losses = []
if model_type == "VanillaResNet" or model_type == "SegmentedResNet":
for i, (image, bnpp, _) in tqdm(
enumerate(training_loader), total=len(training_loader)
):
image, bnpp = (
image.to(device, dtype=torch.float32, non_blocking=True),
bnpp.to(device, dtype=torch.float32, non_blocking=True),
)
pred = model(image)
loss = loss_fn(torch.squeeze(pred, 1), bnpp)
for param in model.parameters():
param.grad = None
loss.backward()
optimizer.step()
losses.append(loss.item())
elif model_type == "MultiChannelResNet":
for i, (image1, image2, image3, bnpp, data, _) in tqdm(
enumerate(training_loader), total=len(training_loader)
):
image1, image2, image3, bnpp, data = (
image1.to(device, dtype=torch.float32, non_blocking=True),
image2.to(device, dtype=torch.float32, non_blocking=True),
image3.to(device, dtype=torch.float32, non_blocking=True),
bnpp.to(device, dtype=torch.float32, non_blocking=True),
data.to(device, dtype=torch.float32, non_blocking=True),
)
pred = model(image1, image2, image3, data)
loss = loss_fn(torch.squeeze(pred, 1), bnpp)
for param in model.parameters():
param.grad = None
loss.backward()
optimizer.step()
losses.append(loss.item())
else:
for i, (image, bnpp, data, _) in tqdm(
enumerate(training_loader), total=len(training_loader)
):
image, bnpp, data = (
image.to(device, dtype=torch.float32, non_blocking=True),
bnpp.to(device, dtype=torch.float32, non_blocking=True),
data.to(device, dtype=torch.float32, non_blocking=True),
)
pred = model(image, data)
loss = loss_fn(torch.squeeze(pred, 1), bnpp)
for param in model.parameters():
param.grad = None
loss.backward()
optimizer.step()
losses.append(loss.item())
return np.mean(losses)
def test1Epoch(epoch_index, model, loss_fn, valid_loader, model_type):
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
losses = []
with torch.no_grad():
if model_type == "VanillaResNet" or model_type == "SegmentedResNet":
for i, (image, bnpp, _) in tqdm(
enumerate(valid_loader), total=len(valid_loader)
):
image, bnpp = (
image.to(device, dtype=torch.float32, non_blocking=True),
bnpp.to(device, dtype=torch.float32, non_blocking=True),
)
pred = model(image)
loss = loss_fn(torch.squeeze(pred, 1), bnpp).detach()
losses.append(loss.item())
image.detach()
bnpp.detach()
elif model_type == "MultiChannelResNet":
for i, (image1, image2, image3, bnpp, data, _) in tqdm(
enumerate(valid_loader), total=len(valid_loader)
):
image1, image2, image3, bnpp, data = (
image1.to(device, dtype=torch.float32, non_blocking=True),
image2.to(device, dtype=torch.float32, non_blocking=True),
image3.to(device, dtype=torch.float32, non_blocking=True),
bnpp.to(device, dtype=torch.float32, non_blocking=True),
data.to(device, dtype=torch.float32, non_blocking=True),
)
pred = model(image1, image2, image3, data)
loss = loss_fn(torch.squeeze(pred, 1), bnpp).detach()
losses.append(loss.item())
image1.detach()
image2.detach()
image3.detach()
bnpp.detach()
data.detach()
else:
for i, (image, bnpp, data, _) in tqdm(
enumerate(valid_loader), total=len(valid_loader)
):
image, bnpp, data = (
image.to(device, dtype=torch.float32, non_blocking=True),
bnpp.to(device, dtype=torch.float32, non_blocking=True),
data.to(device, dtype=torch.float32, non_blocking=True),
)
pred = model(image, data)
loss = loss_fn(torch.squeeze(pred, 1), bnpp).detach()
losses.append(loss.item())
image.detach()
bnpp.detach()
return np.mean(losses)