-
Notifications
You must be signed in to change notification settings - Fork 0
/
realnvp.py
129 lines (108 loc) · 5.38 KB
/
realnvp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import torch
import torch.nn as nn
from flows import (
AffineCoupling, LogitTransform, Dequantize,
SqueezeFlow, SplitFlow, CompositeFlow,
)
from network import ResNet
class RealNVP(nn.Module):
"""RealNVP is an implementation of the realNVP model described here:
https://arxiv.org/pdf/1605.08803.pdf
The is an implementation of the model architecture used for training on the
CIFAR-10 dataset. Although the details of the architecture are not given in
the paper, they can be found here:
https://github.com/tensorflow/models/tree/1345ec9bb110ec1173f7558d3d700a5a42ae2b2f/research/real_nvp
"""
def __init__(self, in_shape, n_colors):
"""Init RealNVP flow model.
Args:
in_shape: tuple(int)
The shape of the input tensors. Images should be reshaped
channels first, i.e. in_shape = (C, H ,W).
n_colors: int
Number of colors per channel of the image. Used for
de-quantizing and quantizing the data.
"""
super().__init__()
# Affine flows will use a ResNet to compute the scale and translation.
# Combine both computations by doubling the output channels.
# Define the masking patterns for images.
make_net = lambda C: ResNet(in_channels=C, out_channels=2*C)
checkerboard = lambda H, W: ((torch.arange(H).reshape(-1,1)+torch.arange(W)) % 2).to(torch.bool).reshape(1, H, W)
channelwise = lambda C: (torch.arange(C) % 2).to(torch.bool).reshape(C, 1, 1)
# Define the base distribution to be used by the flow model.
prior = torch.distributions.Normal(0., 1.)
# Create the composite flow by stacking together affine transformations
# with checkerboard and channelwise maskings. Transformations are repeated
# with inverted maskings in order for all inputs to be altered. Following
# such an alternating pattern, the set of units which remain identical
# in one transformation layer are always modified in the next.
C, H, W = in_shape
self.flow = CompositeFlow(
# Deal with discrete input data.
Dequantize(n_colors),
LogitTransform(alpha=0.1),
AffineCoupling(make_net(C), checkerboard(H, W)),
AffineCoupling(make_net(C), ~checkerboard(H, W)),
AffineCoupling(make_net(C), checkerboard(H, W)),
SqueezeFlow(),
# Converting from space dims to channels this way and then performing
# a "channelwise" transform looks at first as if we are performing
# row-by-row coupling. There is a difference however, in the way
# the network perceives the input, namely the way the convolving
# kernel combines the channels.
AffineCoupling(make_net(4*C), channelwise(4*C)),
AffineCoupling(make_net(4*C), ~channelwise(4*C)),
AffineCoupling(make_net(4*C), channelwise(4*C)),
# At this point the original implementation performs unsqueeze
# and then a so-called "factor_out" which is basically the same
# as squeeze but arranges the spatial dimensions differently.
# https://github.com/tensorflow/models/blob/36101ab4095065a4196ff4f6437e94f0d91df4e9/research/real_nvp/real_nvp_multiscale_dataset.py#L734
# See here for a discussion: https://github.com/phlippe/uvadlc_notebooks/issues/78
SplitFlow(prior),
AffineCoupling(make_net(2*C), checkerboard(H//2, W//2)),
AffineCoupling(make_net(2*C), ~checkerboard(H//2, W//2)),
AffineCoupling(make_net(2*C), checkerboard(H//2, W//2)),
AffineCoupling(make_net(2*C), ~checkerboard(H//2, W//2)),
)
# Store the final output shape. When sampling, we need to start
# from this shape in order to build the initial image.
self.out_shape = (2*C, H//2, W//2)
self.in_shape = in_shape
self.n_colors = n_colors
self.prior = prior
def log_prob(self, x):
"""Compute the log probabilities for each pixel.
Args:
x: torch.Tensor
Tensor of shape (B, C, H, W). Note that the input must be the
raw pixel values of the image.
Returns:
log_prob: torch.Tensor
Tensor of shape (B,) giving the log probabilities for each of
the input from the batch.
"""
x = x.to(self.device).contiguous().float()
z, log_det = self.flow(x)
log_pz = self.prior.log_prob(z)
return log_pz.sum(dim=(1, 2, 3)) + log_det
@torch.no_grad()
def sample(self, n=1):
"""Generate samples using the model.
Args:
n: int, optional
Number of samples to be generated. Default: 1.
Returns:
samples: torch.Tensor
Int tensor of shape (n, C, H, W), giving the sampled images.
"""
# Generate samples from the base distribution and invert the flow.
z = self.prior.sample(sample_shape=(n,)+self.out_shape)
z = z.to(self.device).contiguous().float()
imgs, _ = self.flow(z, invert=True)
return imgs.int().cpu()
@property
def device(self):
"""str: Determine on which device is the model placed upon, CPU or GPU."""
return next(self.parameters()).device
#