-
Notifications
You must be signed in to change notification settings - Fork 2
/
DaSiamRPN.py
56 lines (47 loc) · 2.28 KB
/
DaSiamRPN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import torch.nn as nn
import torch.nn.functional as F
class DaSiamRPN(nn.Module):
def __init__(self, size=1, feature_out=256, anchor=5):
configs = [3, 96, 256, 384, 384, 256]
configs = list(map(lambda x: 3 if x==3 else x*size, configs))
feat_in = configs[-1]
super(DaSiamRPN, self).__init__()
self.featureExtract = nn.Sequential(
nn.Conv2d(configs[0], configs[1] , kernel_size=11, stride=2),
nn.BatchNorm2d(configs[1]),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.ReLU(inplace=True),
nn.Conv2d(configs[1], configs[2], kernel_size=5),
nn.BatchNorm2d(configs[2]),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.ReLU(inplace=True),
nn.Conv2d(configs[2], configs[3], kernel_size=3),
nn.BatchNorm2d(configs[3]),
nn.ReLU(inplace=True),
nn.Conv2d(configs[3], configs[4], kernel_size=3),
nn.BatchNorm2d(configs[4]),
nn.ReLU(inplace=True),
nn.Conv2d(configs[4], configs[5], kernel_size=3),
nn.BatchNorm2d(configs[5]),
)
self.anchor = anchor
self.feature_out = feature_out
self.conv_r1 = nn.Conv2d(feat_in, feature_out*4*anchor, 3)
self.conv_r2 = nn.Conv2d(feat_in, feature_out, 3)
self.conv_cls1 = nn.Conv2d(feat_in, feature_out*2*anchor, 3)
self.conv_cls2 = nn.Conv2d(feat_in, feature_out, 3)
self.regress_adjust = nn.Conv2d(4*anchor, 4*anchor, 1)
self.r1_kernel = []
self.cls1_kernel = []
self.cfg = {'lr': 0.30, 'window_influence': 0.40, 'penalty_k': 0.22, 'instance_size': 271, 'adaptive': False}
def forward(self, x):
x_f = self.featureExtract(x)
return self.regress_adjust(F.conv2d(self.conv_r2(x_f), self.r1_kernel)), \
F.conv2d(self.conv_cls2(x_f), self.cls1_kernel)
def temple(self, z):
z_f = self.featureExtract(z)
r1_kernel_raw = self.conv_r1(z_f)
cls1_kernel_raw = self.conv_cls1(z_f)
kernel_size = r1_kernel_raw.data.size()[-1]
self.r1_kernel = r1_kernel_raw.view(self.anchor*4, self.feature_out, kernel_size, kernel_size)
self.cls1_kernel = cls1_kernel_raw.view(self.anchor*2, self.feature_out, kernel_size, kernel_size)