forked from ducha-aiki/affnet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpytorch_sift.py
94 lines (92 loc) · 4.23 KB
/
pytorch_sift.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch
import math
import torch.nn.init
import torch.nn as nn
from torch.autograd import Variable
import torch.backends.cudnn as cudnn
import numpy as np
import torch.nn.functional as F
class L2Norm(nn.Module):
def __init__(self):
super(L2Norm,self).__init__()
self.eps = 1e-10
def forward(self, x):
norm = torch.sqrt(torch.abs(torch.sum(x * x, dim = 1)) + self.eps)
x= x / norm.unsqueeze(1).expand_as(x)
return x
def getPoolingKernel(kernel_size = 25):
step = 1. / float(np.floor( kernel_size / 2.));
x_coef = np.arange(step/2., 1. ,step)
xc2 = np.hstack([x_coef,[1], x_coef[::-1]])
kernel = np.outer(xc2.T,xc2)
kernel = np.maximum(0,kernel)
return kernel
def get_bin_weight_kernel_size_and_stride(patch_size, num_spatial_bins):
bin_weight_stride = int(round(2.0 * math.floor(patch_size / 2) / float(num_spatial_bins + 1)))
bin_weight_kernel_size = int(2 * bin_weight_stride - 1);
return bin_weight_kernel_size, bin_weight_stride
class SIFTNet(nn.Module):
def CircularGaussKernel(self,kernlen=21):
halfSize = kernlen / 2;
r2 = float(halfSize*halfSize);
sigma2 = 0.9 * r2;
disq = 0;
kernel = np.zeros((kernlen,kernlen))
for y in range(kernlen):
for x in range(kernlen):
disq = (y - halfSize)*(y - halfSize) + (x - halfSize)*(x - halfSize);
if disq < r2:
kernel[y,x] = math.exp(-disq / sigma2)
else:
kernel[y,x] = 0.
return kernel
def __init__(self, patch_size = 65, num_ang_bins = 8, num_spatial_bins = 4, clipval = 0.2):
super(SIFTNet, self).__init__()
gk = torch.from_numpy(self.CircularGaussKernel(kernlen=patch_size).astype(np.float32))
self.bin_weight_kernel_size, self.bin_weight_stride = get_bin_weight_kernel_size_and_stride(patch_size, num_spatial_bins)
self.gk = Variable(gk)
self.num_ang_bins = num_ang_bins
self.num_spatial_bins = num_spatial_bins
self.clipval = clipval
self.gx = nn.Sequential(nn.Conv2d(1, 1, kernel_size=(1,3), bias = False))
for l in self.gx:
if isinstance(l, nn.Conv2d):
l.weight.data = torch.from_numpy(np.array([[[[-1, 0, 1]]]], dtype=np.float32))
self.gy = nn.Sequential(nn.Conv2d(1, 1, kernel_size=(3,1), bias = False))
for l in self.gy:
if isinstance(l, nn.Conv2d):
l.weight.data = torch.from_numpy(np.array([[[[-1], [0], [1]]]], dtype=np.float32))
self.pk = nn.Sequential(nn.Conv2d(1, 1, kernel_size=(self.bin_weight_kernel_size, self.bin_weight_kernel_size),
stride = (self.bin_weight_stride, self.bin_weight_stride),
bias = False))
for l in self.pk:
if isinstance(l, nn.Conv2d):
nw = getPoolingKernel(kernel_size = self.bin_weight_kernel_size)
new_weights = np.array(nw.reshape((1, 1, self.bin_weight_kernel_size, self.bin_weight_kernel_size)))
l.weight.data = torch.from_numpy(new_weights.astype(np.float32))
def forward(self, x):
gx = self.gx(F.pad(x, (1,1,0, 0), 'replicate'))
gy = self.gy(F.pad(x, (0,0, 1,1), 'replicate'))
mag = torch.sqrt(gx **2 + gy **2 + 1e-10)
ori = torch.atan2(gy,gx + 1e-8)
if x.is_cuda:
self.gk = self.gk.cuda()
else:
self.gk = self.gk.cpu()
mag = mag * self.gk.expand_as(mag)
o_big = (ori +2.0 * math.pi )/ (2.0 * math.pi) * float(self.num_ang_bins)
bo0_big = torch.floor(o_big)
wo1_big = o_big - bo0_big
bo0_big = bo0_big % self.num_ang_bins
bo1_big = (bo0_big + 1) % self.num_ang_bins
wo0_big = (1.0 - wo1_big) * mag
wo1_big = wo1_big * mag
ang_bins = []
for i in range(0, self.num_ang_bins):
ang_bins.append(self.pk((bo0_big == i).float() * wo0_big + (bo1_big == i).float() * wo1_big))
ang_bins = torch.cat(ang_bins,1)
ang_bins = ang_bins.view(ang_bins.size(0), -1)
ang_bins = L2Norm()(ang_bins)
ang_bins = torch.clamp(ang_bins, 0.,float(self.clipval))
ang_bins = L2Norm()(ang_bins)
return ang_bins