-
Notifications
You must be signed in to change notification settings - Fork 12
Expand file tree
/
Copy pathmodel.py
More file actions
90 lines (80 loc) · 2.58 KB
/
model.py
File metadata and controls
90 lines (80 loc) · 2.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import vgg19
class VGGEncoder(nn.Module):
def __init__(self):
super().__init__()
vgg = vgg19(pretrained=True).features
self.slice1 = vgg[: 2]
self.slice2 = vgg[2: 7]
self.slice3 = vgg[7: 12]
for p in self.parameters():
p.requires_grad = False
def forward(self, images):
h1 = self.slice1(images)
h2 = self.slice2(h1)
h3 = self.slice3(h2)
return h3
#
# class CIR(nn.Module):
# def __init__(self, in_channels, out_channels, kernel_size=3, pad_size=1):
# super().__init__()
# self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=pad_size)
# self.instance_norm = nn.InstanceNorm2d(out_channels)
#
# def forward(self, x):
# h = self.conv(x)
# h = self.instance_norm(h)
# h = F.relu(h)
# return h
#
#
# class Decoder(nn.Module):
# def __init__(self):
# super().__init__()
# self.cir1 = CIR(256, 128, 3, 1)
# self.cir2 = CIR(128, 128, 3, 1)
# self.cir3 = CIR(128, 64, 3, 1)
# self.cir4 = CIR(64, 64, 3, 1)
# self.out_conv = nn.Conv2d(64, 3, 3, padding=1)
#
# def forward(self, features):
# h = self.cir1(features)
# h = F.interpolate(h, scale_factor=2)
# h = self.cir2(h)
# h = self.cir3(h)
# h = F.interpolate(h, scale_factor=2)
# h = self.cir4(h)
# h = self.out_conv(h)
# return h
class RC(nn.Module):
"""A wrapper of ReflectionPad2d and Conv2d"""
def __init__(self, in_channels, out_channels, kernel_size=3, pad_size=1, activated=True):
super().__init__()
self.pad = nn.ReflectionPad2d((pad_size, pad_size, pad_size, pad_size))
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size)
self.activated = activated
def forward(self, x):
h = self.pad(x)
h = self.conv(h)
if self.activated:
return F.relu(h)
else:
return h
class Decoder(nn.Module):
def __init__(self):
super().__init__()
self.rc1 = RC(256, 128, 3, 1)
self.rc2 = RC(128, 128, 3, 1)
self.rc3 = RC(128, 64, 3, 1)
self.rc4 = RC(64, 64, 3, 1)
self.rc5 = RC(64, 3, 3, 1, False)
def forward(self, features):
h = self.rc1(features)
h = F.interpolate(h, scale_factor=2)
h = self.rc2(h)
h = self.rc3(h)
h = F.interpolate(h, scale_factor=2)
h = self.rc4(h)
h = self.rc5(h)
return h