forked from qimaqi/Unet_family
-
Notifications
You must be signed in to change notification settings - Fork 0
/
layers.py
62 lines (53 loc) · 2.13 KB
/
layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch
import torch.nn as nn
from utils import init_weights
class unetConv2(nn.Module):
def __init__(self, in_size, out_size, is_batchnorm, n=2, ks=3, stride=1, padding=1):
super(unetConv2, self).__init__()
self.n = n
self.ks = ks
self.stride = stride
self.padding = padding
s = stride
p = padding
if is_batchnorm:
for i in range(1, n+1):
conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p),
nn.BatchNorm2d(out_size),
nn.ReLU(inplace=True),)
setattr(self, 'conv%d'%i, conv)
in_size = out_size
else:
for i in range(1, n+1):
conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p),
nn.ReLU(inplace=True),)
setattr(self, 'conv%d'%i, conv)
in_size = out_size
# initialise the blocks
for m in self.children():
init_weights(m, init_type='kaiming')
def forward(self, inputs):
x = inputs
for i in range(1, self.n+1):
conv = getattr(self, 'conv%d'%i)
x = conv(x)
return x
class unetUp(nn.Module):
def __init__(self, in_size, out_size, is_deconv, n_concat=2):
super(unetUp, self).__init__()
self.conv = unetConv2(in_size+(n_concat-2)*out_size, out_size, False)
if is_deconv:
self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2, padding=0)
else:
self.up = nn.Sequential(
nn.UpsamplingBilinear2d(scale_factor=2),
nn.Conv2d(in_size, out_size, 1))
# initialise the blocks
for m in self.children():
if m.__class__.__name__.find('unetConv2') != -1: continue
init_weights(m, init_type='kaiming')
def forward(self, high_feature, *low_feature):
outputs0 = self.up(high_feature)
for feature in low_feature:
outputs0 = torch.cat([outputs0, feature], 1)
return self.conv(outputs0)