-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_coral_depth.py
108 lines (85 loc) · 3.32 KB
/
train_coral_depth.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import sys
import torch
import visdom
import argparse
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch.autograd import Variable
from torch.utils import data
from ptsemseg.models import get_model
from ptsemseg.loader import get_loader, get_data_path
from ptsemseg.loss import cross_entropy2d
from ptsemseg.metrics import scores
from lr_scheduling import *
config = dict(
img_size = 768,
n_epoch = 300,
batch_size = 2,
learning_rate = 5e-5,
feature_scale = 1,
)
def train(args):
# Setup Dataloader
print("Setting up dataloader...")
data_loader = get_loader("coral_depth")
data_path = get_data_path("coral_depth")
loader = data_loader(data_path, img_size=args.img_size)
n_classes = loader.n_classes
n_channels = loader.n_channels
trainloader = data.DataLoader(loader, batch_size=args.batch_size, num_workers=8, shuffle=True)
print("Finished creating dataloader.")
# Setup visdom for visualization
vis = visdom.Visdom()
loss_window = vis.line(X=torch.zeros((1,)).cpu(),
Y=torch.zeros((1)).cpu(),
opts=dict(xlabel='minibatches',
ylabel='Loss',
title='Training Loss',
legend=['Loss']))
# Setup Model
print("Setting up model...")
# model = get_model("coralnet", n_classes, in_channels=n_channels)
model = get_model("segnet", n_classes, in_channels=n_channels)
if torch.cuda.is_available():
model.cuda(0)
test_image, test_segmap = loader[0]
test_image = Variable(test_image.unsqueeze(0).cuda(0))
else:
print("CUDA Error.")
test_image, test_segmap = loader[0]
test_image = Variable(test_image.unsqueeze(0))
print("Setting up optimizer...")
# optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=0.99, weight_decay=5e-4)
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
print("Starting training...")
for epoch in range(args.n_epoch):
for i, (images, labels) in enumerate(trainloader):
if torch.cuda.is_available():
images = Variable(images.cuda(0))
labels = Variable(labels.cuda(0))
else:
images = Variable(images)
labels = Variable(labels)
iter = len(trainloader)*epoch + i
poly_lr_scheduler(optimizer, args.learning_rate, iter)
optimizer.zero_grad()
outputs = model(images)
loss = cross_entropy2d(outputs, labels)
loss.backward()
optimizer.step()
vis.line(
X=torch.ones((1, 1)).cpu() * i,
Y=torch.Tensor([loss.data[0]]).unsqueeze(0).cpu(),
win=loss_window,
update='append')
if (epoch+1) % 20 == 0:
print("Epoch [%d/%d] Loss: %.4f" % (epoch+1, args.n_epoch, loss.data[0]))
torch.save(model, "{}_{}_{}_{}.pkl".format("training/coralnet", "coral_depth", args.feature_scale, epoch))
class Namespace:
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
if __name__ == '__main__':
args = Namespace(**config)
train(args)