Skip to content

Commit

Permalink
Fix in positive definite matrix approximation
Browse files Browse the repository at this point in the history
  • Loading branch information
papamarkou committed Aug 27, 2020
1 parent 8bd0e07 commit ad8da0e
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 3 deletions.
2 changes: 1 addition & 1 deletion eeyore/chains/chain_lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,6 @@ def summary(
g=g_multi_ess_summary, mc_cov_mat=mc_cov_mat, method=method, adjust=adjust
)
elif key == 'multi_rhat':
summaries[key], _, _ = self.multi_rhat(mc_cov_mat=mc_cov_mat, method=method, adjust=adjust)
summaries[key], _, _, _, _ = self.multi_rhat(mc_cov_mat=mc_cov_mat, method=method, adjust=adjust)

return summaries
1 change: 1 addition & 0 deletions eeyore/linalg/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .is_pos_def import is_pos_def
from .nearest_pd import nearest_pd
42 changes: 42 additions & 0 deletions eeyore/linalg/nearest_pd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Implementation taken from
# https://stackoverflow.com/questions/43238173/python-convert-matrix-to-positive-semi-definite/43244194#43244194

import numpy as np
import torch

from .is_pos_def import is_pos_def

def nearest_pd(A, f=np.spacing):
"""Find the nearest positive-definite matrix to input
A Python/Numpy port of John D'Errico's `nearestSPD` MATLAB code [1], which credits [2]
[1] https://www.mathworks.com/matlabcentral/fileexchange/42885-nearestspd
[2] https://doi.org/10.1016/0024-3795(88)90223-6
"""

B = (A + A.T) / 2
_, s, V = torch.svd(B)

# For a comparison with kanga, see the following:
# https://github.com/pytorch/pytorch/issues/16076#issuecomment-477755364
H = torch.matmul(V, torch.matmul(torch.diag(s), V.T))

A2 = (B + H) / 2

A3 = (A2 + A2.T) / 2

if is_pos_def(A3):
return A3

spacing = f(torch.norm(A).item())
I = torch.eye(A.shape[0])
k = 1
while not is_pos_def(A3):
eigenvals = torch.eig(A3, eigenvectors=False)[0][:, 0]
mineig = eigenvals.min().item()
A3 += I * (-mineig * k**2 + spacing)
k += 1

return A3
18 changes: 16 additions & 2 deletions eeyore/stats/multi_rhat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import torch

from eeyore.linalg import is_pos_def, nearest_pd

from .cov import cov
from .mc_cov import mc_cov

Expand All @@ -17,9 +19,21 @@ def multi_rhat(x, mc_cov_mat=None, method='inse', adjust=False):
w = w + mc_cov_mat[i]
w = w / num_chains

if not is_pos_def(w):
w = nearest_pd(w)
is_w_pd = False
else:
is_w_pd = True

b = cov(x.mean(1), rowvar=False)

rhat = max(torch.symeig(torch.matmul(torch.inverse(w), b))[0]).item()
if not is_pos_def(b):
b = nearest_pd(b)
is_b_pd = False
else:
is_b_pd = True

rhat = torch.eig(torch.matmul(torch.inverse(w), b), eigenvectors=False)[0][:, 0].max().item()
rhat = ((num_iters - 1) / num_iters) + ((num_chains + 1) / num_chains) * rhat

return rhat, w, b
return rhat, w, b, is_w_pd, is_b_pd

0 comments on commit ad8da0e

Please sign in to comment.