-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathRFMDSConv.py
131 lines (106 loc) · 5.13 KB
/
RFMDSConv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# -------------------------------------------------------------------#
# Author : 张杰
# Date : 2024-3-9 21:39
# LastEditTime : 2024-3-9 21:39
# -------------------------------------------------------------------#
import torch
from torch import nn
from einops import rearrange
class h_sigmoid(nn.Module):
def __init__(self, inplace=True):
super(h_sigmoid, self).__init__()
self.relu = nn.ReLU6(inplace=inplace)
def forward(self, x):
return self.relu(x + 3) / 6
class h_swish(nn.Module):
def __init__(self, inplace=True):
super(h_swish, self).__init__()
self.sigmoid = h_sigmoid(inplace=inplace)
def forward(self, x):
return x * self.sigmoid(x)
class MDS(nn.Module):
def __init__(self, in_channels, channelAttention_reduce=4):
super().__init__()
self.C = in_channels
self.O = in_channels
self.dconv5_5 = nn.Conv2d(in_channels, in_channels, kernel_size=5, padding=2, groups=in_channels)
self.dconv1_7 = nn.Conv2d(in_channels, in_channels, kernel_size=(1, 7), padding=(0, 3), groups=in_channels)
self.dconv7_1 = nn.Conv2d(in_channels, in_channels, kernel_size=(7, 1), padding=(3, 0), groups=in_channels)
self.dconv1_11 = nn.Conv2d(in_channels, in_channels, kernel_size=(1, 11), padding=(0, 5), groups=in_channels)
self.dconv11_1 = nn.Conv2d(in_channels, in_channels, kernel_size=(11, 1), padding=(5, 0), groups=in_channels)
self.dconv1_21 = nn.Conv2d(in_channels, in_channels, kernel_size=(1, 21), padding=(0, 10), groups=in_channels)
self.dconv21_1 = nn.Conv2d(in_channels, in_channels, kernel_size=(21, 1), padding=(10, 0), groups=in_channels)
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=(1, 1), padding=0)
def forward(self, inputs):
x_init = self.dconv5_5(inputs)
x_1 = self.dconv1_7(x_init)
x_1 = self.dconv7_1(x_1)
x_2 = self.dconv1_11(x_init)
x_2 = self.dconv11_1(x_2)
x_3 = self.dconv1_21(x_init)
x_3 = self.dconv21_1(x_3)
x = x_1 + x_2 + x_3 + x_init
spatial_att = self.conv(x)
out = spatial_att * inputs
out = self.conv(out) # 通道混合由空间注意模块的尾部使用1 × 1卷积进行
return out
class SE(nn.Module):
def __init__(self, in_channel, ratio=16):
super(SE, self).__init__()
self.gap = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Sequential(
nn.Linear(in_channel, ratio, bias=False), # 从 c -> c/r
nn.ReLU(),
nn.Linear(ratio, in_channel, bias=False), # 从 c/r -> c
nn.Sigmoid()
)
def forward(self, x):
b, c = x.shape[0:2]
y = self.gap(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return y
class RFMDSConv(nn.Module):
"""multi-dynamic-spatial"""
def __init__(self, in_channel, out_channel, kernel_size=3, stride=1, dilation=1):
super().__init__()
if kernel_size % 2 == 0:
assert ("the kernel_size must be odd.")
self.kernel_size = kernel_size
# 获得感受野空间特征 b c*kernel**2,h*w
self.generate = nn.Sequential(
nn.Conv2d(in_channel, in_channel * (kernel_size ** 2), kernel_size, padding=kernel_size // 2,
stride=stride, groups=in_channel, bias=False),
nn.BatchNorm2d(in_channel * (kernel_size ** 2)),
nn.ReLU()
)
# 注意力值
self.get_weight = nn.Sequential(nn.Conv2d(2, 1, kernel_size=3, padding=1, bias=False), nn.Sigmoid())
self.se = SE(in_channel)
self.sa = MDS(in_channel)
self.conv = nn.Sequential(nn.Conv2d(in_channel, out_channel, kernel_size, stride=kernel_size),
nn.BatchNorm2d(out_channel), nn.ReLU())
def forward(self, x):
b, c = x.shape[0:2]
channel_attention = self.se(x) # 通道注意力
generate_feature = self.generate(x) # 获得感受野空间
h, w = generate_feature.shape[2:]
# 获得感受野空间特征 b c*kernel**2,h*w
generate_feature = generate_feature.view(b, c, self.kernel_size ** 2, h, w)
generate_feature = rearrange(generate_feature, 'b c (n1 n2) h w -> b c (h n1) (w n2)', n1=self.kernel_size,
n2=self.kernel_size)
# 感受野空间特征*通道注意力
# print(generate_feature.shape)
# print(channel_attention.shape)
unfold_feature = generate_feature * channel_attention
# 计算空间注意力,在空间上进行最大和平均池化
receptive_field_attention = self.sa(unfold_feature)
# print(receptive_field_attention.shape)
# *空间注意力
conv_data = unfold_feature * receptive_field_attention
# 卷积
return self.conv(conv_data)
if __name__ == "__main__":
block = RFMDSConv(in_channel=64, out_channel=128)
input = torch.rand(32, 64, 9, 9)
output = block(input)
print(output.size())