-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpspnet.py
81 lines (64 loc) · 2.46 KB
/
pspnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F
def initialize_weights(*models):
for model in models:
for module in model.modules():
if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
nn.init.kaiming_normal(module.weight)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.BatchNorm2d):
module.weight.data.fill_(1)
module.bias.data.zero_()
class PyramidPool(nn.Module):
def __init__(self, in_features, out_features, pool_size):
super(PyramidPool, self).__init__()
self.features = nn.Sequential(
nn.AdaptiveAvgPool2d(pool_size),
nn.Conv2d(in_features, out_features, 1, bias=False),
nn.BatchNorm2d(out_features),
nn.ReLU(inplace=True)
)
def forward(self, x):
size = x.size()
output = F.upsample(self.features(x), size[2:], mode='bilinear')
return output
class PSPNet(nn.Module):
def __init__(self, pretrained=True, num_classes=2):
super(PSPNet, self).__init__()
print("initializing model")
# init_net=deeplab_resnet.Res_Deeplab()
# state=torch.load("models/MS_DeepLab_resnet_trained_VOC.pth")
# init_net.load_state_dict(state)
self.resnet = torchvision.models.resnet50(pretrained=pretrained)
self.layer5a = PyramidPool(2048, 512, 1)
self.layer5b = PyramidPool(2048, 512, 2)
self.layer5c = PyramidPool(2048, 512, 3)
self.layer5d = PyramidPool(2048, 512, 6)
self.final = nn.Sequential(
nn.Conv2d(4096, 512, 3, padding=1, bias=False),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.Conv2d(512, num_classes, 1),
)
initialize_weights(self.layer5a, self.layer5b, self.layer5c, self.layer5d, self.final)
def forward(self, x):
size = x.size()
x = self.resnet.conv1(x)
x = self.resnet.bn1(x)
x = self.resnet.relu(x)
x = self.resnet.maxpool(x)
x = self.resnet.layer1(x)
x = self.resnet.layer2(x)
x = self.resnet.layer3(x)
x = self.resnet.layer4(x)
x = self.final(torch.cat([
x,
self.layer5a(x),
self.layer5b(x),
self.layer5c(x),
self.layer5d(x),
], 1))
return F.upsample_bilinear(x, size[2:])