-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathMLCA_changed.py
114 lines (87 loc) · 5.17 KB
/
MLCA_changed.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class MLCA_Changed_Option1(nn.Module):
def __init__(self, in_size, local_size=5, gamma=2, b=1, local_weight=0.5):
super(MLCA_Changed_Option1, self).__init__()
# ECA 计算方法
self.local_size = local_size
self.gamma = gamma
self.b = b
t = int(abs(math.log(in_size, 2) + self.b) / self.gamma) # eca gamma=2
k = t if t % 2 else t + 1
self.conv = nn.Conv1d(1, 1, kernel_size=k, padding=(k - 1) // 2, bias=False)
self.conv_local = nn.Conv1d(1, 1, kernel_size=k, padding=(k - 1) // 2, bias=False)
self.local_weight = nn.Parameter(torch.Tensor([local_weight])) #-----------------------1.1 改为自适应加权融合---------------
self.local_arv_pool = nn.AdaptiveAvgPool2d(local_size)
self.global_arv_pool = nn.AdaptiveAvgPool2d(1)
def forward(self, x):
local_arv = self.local_arv_pool(x)
global_arv = self.global_arv_pool(local_arv)
b, c, m, n = x.shape
b_local, c_local, m_local, n_local = local_arv.shape
# (b,c,local_size,local_size) -> (b,c,local_size*local_size)-> (b,local_size*local_size,c)-> (b,1,local_size*local_size*c)
temp_local = local_arv.view(b, c_local, -1).transpose(-1, -2).reshape(b, 1, -1)
temp_global = global_arv.view(b, c, -1).transpose(-1, -2)
y_local = self.conv_local(temp_local)
y_global = self.conv(temp_global)
# (b,c,local_size,local_size) <- (b,c,local_size*local_size)<-(b,local_size*local_size,c) <- (b,1,local_size*local_size*c)
y_local_transpose = y_local.reshape(b, self.local_size * self.local_size, c).transpose(-1, -2).view(b, c,
self.local_size,
self.local_size)
# y_global_transpose = y_global.view(b, -1).transpose(-1, -2).unsqueeze(-1)
y_global_transpose = y_global.view(b, -1).unsqueeze(-1).unsqueeze(-1) # 代码修正
# print(y_global_transpose.size())
# 反池化
att_local = y_local_transpose.sigmoid()
att_global = F.adaptive_avg_pool2d(y_global_transpose.sigmoid(), [self.local_size, self.local_size])
# print(att_local.size())
# print(att_global.size())
att_all = F.adaptive_avg_pool2d(att_global * (1 - self.local_weight) + (att_local * self.local_weight),
[m, n]) # 1.2 融合方式可以是相加,也可以是相乘,或者别的操作,等等
# print(att_all.size())
x = x * att_all
return x
class MLCA_Changed_Option2(nn.Module):
def __init__(self, in_size, local_size=5, gamma=2, b=1, local_weight=0.5):
super(MLCA_Changed_Option2, self).__init__()
# ECA 计算方法
self.local_size = local_size
self.gamma = gamma
self.b = b
t = int(abs(math.log(in_size, 2) + self.b) / self.gamma) # eca gamma=2
k = t if t % 2 else t + 1
self.conv = nn.Conv1d(1, 1, kernel_size=k, padding=(k - 1) // 2, bias=False)
self.conv_local = nn.Conv1d(1, 1, kernel_size=k, padding=(k - 1) // 2, bias=False)
self.local_weight = local_weight
self.local_arv_pool = nn.AdaptiveAvgPool2d(local_size)
self.global_upsample = nn.Upsample(scale_factor=(in_size // local_size), mode='nearest')
def forward(self, x):
local_arv = self.local_arv_pool(x)
global_arv = self.global_upsample(local_arv)
b, c, m, n = x.shape
b_local, c_local, m_local, n_local = local_arv.shape
# (b,c,local_size,local_size) -> (b,c,local_size*local_size)-> (b,local_size*local_size,c)-> (b,1,local_size*local_size*c)
temp_local = local_arv.view(b, c_local, -1).transpose(-1, -2).reshape(b, 1, -1)
temp_global = global_arv.view(b, c, -1).transpose(-1, -2)
y_local = self.conv_local(temp_local)
y_global = self.conv(temp_global)
# (b,c,local_size,local_size) <- (b,c,local_size*local_size)<-(b,local_size*local_size,c) <- (b,1,local_size*local_size*c)
y_local_transpose = y_local.reshape(b, self.local_size * self.local_size, c).transpose(-1, -2).view(b, c,
self.local_size,
self.local_size)
y_global_transpose = y_global.view(b, -1).unsqueeze(-1).unsqueeze(-1) # 修改为上采样模块
# print(y_global_transpose.size())
# 反池化
att_local = y_local_transpose.sigmoid()
att_global = y_global_transpose.sigmoid()
att_all = F.adaptive_avg_pool2d(att_global * (1 - self.local_weight) + (att_local * self.local_weight),
[m, n])
# print(att_all.size())
x = x * att_all
return x
if __name__=="__main__":
attention=MLCA_Changed_1(in_size=64)
inputs=torch.randn((2,55,16,16))
result=attention(inputs)