-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathbbvi.jl
99 lines (82 loc) · 2.53 KB
/
bbvi.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
using Distributions
using Flux
using Plots
using LinearAlgebra
D = 2 # dimensions of approximate posterior
num_samples = 100
"""
A 2D non-gaussian log-density.
"""
function log_density(params)
mu, log_sigma = params
d1 = Normal(0, 1.35)
d2 = Normal(0, exp(log_sigma))
d1_density = logpdf(d1, log_sigma)
d2_density = logpdf(d2, mu)
return d1_density + d2_density
end
"""
Entropy of the Gaussian distribution.
"""
function gaussian_entropy(log_std)
H = 0.5 * D * (1.0 + log(2 * pi)) + sum(log_std)
return H
end
"""
Variational approximation to the non-gaussian density
"""
function variational_objective(parameters; D=2)
mu, log_std = parameters
samples = rand(Normal(), num_samples, D) .* exp.(log_std) .+ mu
log_px = mapslices(log_density, samples; dims=2) # eval log(target) for all samples of params (i.e. cols)
elbo = gaussian_entropy(log_std) + mean(log_px)
return -elbo
end
mu = Flux.Tracker.param(reshape([-1, -1], 1, :))
sigma = Flux.Tracker.param(reshape([-5, -5], 1, :))
parameters = Flux.Tracker.Params([mu, sigma])
elbo_gradient = Flux.Tracker.gradient(() -> variational_objective(parameters), parameters)
steps = 200
elbo = Array{Float32}(undef, steps)
elbo[1] = variational_objective([mu.data, sigma.data])
### Plotting
x = -2:0.1:2
y = -4:0.1:2
X = repeat(reshape(x, 1, :), length(y), 1)
Y = repeat(y, 1, length(x))
Z = Array{Float64}(undef, size(X))
for i in 1:size(X)[1]
for j in 1:size(X)[2]
Z[i, j] = exp(log_density([X[i, j], Y[i, j]]))
end
end
q = MultivariateNormal(mu[1,:].data, Diagonal(exp.(2*sigma[1,:].data)))
Z_q = Array{Float64}(undef, size(X))
for i in 1:size(X)[1]
for j in 1:size(X)[2]
Z_q[i, j] = pdf(q, [X[i, j], Y[i, j]])
end
end
Z_q_trajectory = Array{Float32}(undef, steps, size(X)[1], size(X)[2])
opt = ADAM(0.1)
for step in 1:steps
println(step)
elbo_gradient = Flux.Tracker.gradient(() -> variational_objective(parameters), parameters)
for p in (mu, sigma)
Δ = Flux.Optimise.update!(opt, p, Flux.data(elbo_gradient[p]))
Flux.Tracker.update!(p, -Δ)
elbo[step] = variational_objective([mu.data, sigma.data])
q = MultivariateNormal(mu[1,:].data, Diagonal(exp.(2*sigma[1,:].data)))
for i in 1:size(X)[1]
for j in 1:size(X)[2]
Z_q_trajectory[step, i, j] = pdf(q, [X[i, j], Y[i, j]])
end
end
end
end
# Plotting
anim = @animate for i=1:steps
plot(contour(x, y, Z))
contour!(x, y, Z_q_trajectory[i, :, :])
end
gif(anim, "mygif.gif", fps = 1)