-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsam.py
54 lines (43 loc) · 1.71 KB
/
sam.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch
import torch.nn as nn
'''
Spatial Adaptation Modules (SAM)
Paper: `MASA-SR: Matching Acceleration and Spatial Adaptation for Reference-Based Image Super-Resolution`
'''
class SAM(nn.Module):
def __init__(self, nf, use_residual=True, learnable=True):
super(SAM, self).__init__()
self.learnable = learnable
self.norm_layer = nn.InstanceNorm2d(nf, affine=False)
if self.learnable:
self.conv_shared = nn.Sequential(nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True),
nn.ReLU(inplace=True))
self.conv_gamma = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.conv_beta = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.use_residual = use_residual
# initialization
self.conv_gamma.weight.data.zero_()
self.conv_beta.weight.data.zero_()
self.conv_gamma.bias.data.zero_()
self.conv_beta.bias.data.zero_()
def forward(self, lr, ref):
ref_normed = self.norm_layer(ref)
if self.learnable:
style = self.conv_shared(torch.cat([lr, ref], dim=1))
gamma = self.conv_gamma(style)
beta = self.conv_beta(style)
b, c, h, w = lr.size()
lr = lr.view(b, c, h * w)
lr_mean = torch.mean(lr, dim=-1, keepdim=True).unsqueeze(3)
lr_std = torch.std(lr, dim=-1, keepdim=True).unsqueeze(3)
if self.learnable:
if self.use_residual:
gamma = gamma + lr_std
beta = beta + lr_mean
else:
gamma = 1 + gamma
else:
gamma = lr_std
beta = lr_mean
out = ref_normed * gamma + beta
return out