Skip to content

Commit

Permalink
Fix rand function for DegenerateMvNormal distributions
Browse files Browse the repository at this point in the history
  • Loading branch information
Aidan Gleich committed Aug 9, 2021
1 parent 1776c9c commit eb7ebd6
Showing 1 changed file with 21 additions and 1 deletion.
22 changes: 21 additions & 1 deletion src/distributions_ext.jl
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,27 @@ Distributions.rand(d::DegenerateMvNormal; cc::T = 1.0) where T<:AbstractFloat
Generate a draw from `d` with variance optionally scaled by `cc^2`.
"""
function Distributions.rand(d::DegenerateMvNormal; cc::T = 1.0) where T<:AbstractFloat
return d.μ + cc*d.σ*randn(length(d))
# abusing notation slightly, if Y is a degen MV normal r.v. with covariance matrix Σ,
# and Σ = U Λ^2 Vt according to the svd, then given an standard MV normal r.v X with
# the same dimension as Y, Y = μ + UΛX.

# we need to ensure symmetry when computing SVD
U, λ_vals, Vt = svd((d.σ + d.σ')./2)

# set near-zero values to zero
λ_vals[λ_vals .< 10^(-6)] .= 0

# leave x as 0 where λ_vals equals 0 (b/c r.v. is fixed where λ_vals = 0)
λ_vals = abs.(λ_vals)
x = zeros(length(λ_vals))
for i in 1:length(λ_vals)
if λ_vals[i] == 0
x[i] = 0
else
x[i] = randn()
end
end
return d.μ + cc*U*diagm(sqrt.(λ_vals))*x
end

"""
Expand Down

0 comments on commit eb7ebd6

Please sign in to comment.