-
Notifications
You must be signed in to change notification settings - Fork 1
/
helpers.py
76 lines (64 loc) · 2.57 KB
/
helpers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch
from torch import nn
from collections import defaultdict
def draw_gaussian_diag_samples(mu, logsigma):
eps = torch.empty_like(mu).normal_(0., 1.)
return torch.exp(logsigma) * eps + mu
def gaussian_analytical_kl(mu1, mu2, logsigma1, logsigma2):
return -0.5 + logsigma2 - logsigma1 + 0.5 * (logsigma1.exp() ** 2 + (mu1 - mu2) ** 2) / (logsigma2.exp() ** 2)
def get_conv(in_dim, out_dim, kernel_size, stride, padding, zero_bias=True, zero_weights=False, groups=1, scaled=False):
c = nn.Conv2d(in_dim, out_dim, kernel_size, stride, padding, groups=groups)
if zero_bias:
c.bias.data *= 0.0
if zero_weights:
c.weight.data *= 0.0
return c
def get_3x3(in_dim, out_dim, zero_bias=True, zero_weights=False, groups=1, scaled=False):
return get_conv(in_dim, out_dim, 3, 1, 1, zero_bias, zero_weights, groups=groups, scaled=scaled)
def get_1x1(in_dim, out_dim, zero_bias=True, zero_weights=False, groups=1, scaled=False):
return get_conv(in_dim, out_dim, 1, 1, 0, zero_bias, zero_weights, groups=groups, scaled=scaled)
def const_max(t, constant):
other = torch.ones_like(t) * constant
return torch.max(t, other)
def const_min(t, constant):
other = torch.ones_like(t) * constant
return torch.min(t, other)
def log_prob_from_logits(x):
""" numerically stable log_softmax implementation that prevents overflow """
axis = len(x.shape) - 1
m = x.max(dim=axis, keepdim=True)[0]
return x - m - torch.log(torch.exp(x - m).sum(dim=axis, keepdim=True))
def pad_channels(t, width):
d1, d2, d3, d4 = t.shape
empty = torch.zeros(d1, width, d3, d4, device=t.device)
empty[:, :d2, :, :] = t
return empty
def get_width_settings(width, s):
mapping = defaultdict(lambda: width)
if s:
s = s.split(',')
for ss in s:
k, v = ss.split(':')
mapping[int(k)] = int(v)
return mapping
def parse_layer_string(s):
layers = []
for ss in s.split(','):
if 'x' in ss:
res, num = ss.split('x')
count = int(num)
layers += [(int(res), None) for _ in range(count)]
elif 'm' in ss:
res, mixin = [int(a) for a in ss.split('m')]
layers.append((res, mixin))
elif 'd' in ss:
res, down_rate = [int(a) for a in ss.split('d')]
layers.append((res, down_rate))
else:
res = int(ss)
layers.append((res, None))
return layers
def linear_warmup(warmup_iters):
def f(iteration):
return 1.0 if iteration > warmup_iters else iteration / warmup_iters
return f