-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbasic_blocks.py
More file actions
126 lines (96 loc) · 2.82 KB
/
basic_blocks.py
File metadata and controls
126 lines (96 loc) · 2.82 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import torch.nn as nn
def conv_block(act_fn, in_chan, out_chan, bn, act_after_bn, bn_momentum):
""" Creates the basic FusionNet convolution
block
Args:
act_fn (nn.Module): activation function
in_chan (int): input channel depth
out_chan (int): output channel depth
Returns:
(nn.Sequential()) Basic convolution block
"""
layers = [
nn.Conv2d(
in_chan,
out_chan,
kernel_size=3,
stride=1,
padding=1,
padding_mode='reflect'
),
]
if act_fn:
layers.append(act_fn)
if bn:
layers.append(nn.BatchNorm2d(out_chan, momentum=bn_momentum))
if act_after_bn:
layers.insert(1, layers.pop(-1))
block = nn.Sequential(
*layers
)
return block
def res_block(act_fn, chan, bn, act_after_bn, bn_momentum):
""" Creates the combined FusionNet triple
convolution block
Args:
act_fn (nn.Module): activation function
chan (int): channel depth
Returns:
(nn.Sequential()) Combined triple
convolution block
"""
block = nn.Sequential(
conv_block(act_fn, chan, chan, bn, act_after_bn, bn_momentum),
conv_block(act_fn, chan, chan, bn, act_after_bn, bn_momentum),
conv_block(act_fn, chan, chan, bn, act_after_bn, bn_momentum),
)
return block
def maxpool():
""" Creates the basic FusionNet max pooling
block
Returns:
(nn.Module) Basic max pooling block
"""
block = nn.MaxPool2d(
kernel_size=2,
stride=2,
padding=0,
)
return block
def spatial_dropout(spat_drop_p):
""" Creates the basic FusionNet spatial dropout
block
Args:
spat_drop_p (float): spatial dropout chance
Returns:
(nn.Module) Basic spatial dropout block
"""
block = nn.Dropout2d(p=spat_drop_p)
return block
def conv_trans_block(act_fn, chan, act_a_trans, bn_a_trans, act_after_bn, bn_momentum):
""" Creates the basic FusionNet upsampling block
Args:
chan (int): channel depth
Returns:
(nn.Sequential()) Basic upsampling block
"""
layers = [
nn.ConvTranspose2d(
chan,
chan,
kernel_size=3,
stride=2,
padding=1,
output_padding=1
),
]
if act_a_trans:
layers.append(act_fn)
if bn_a_trans:
layers.append(nn.BatchNorm2d(chan, momentum=bn_momentum))
if act_after_bn:
layers.insert(1, layers.pop(-1))
block = nn.Sequential(
*layers
)
return block