diff --git a/eeyore/chains/chain_lists.py b/eeyore/chains/chain_lists.py index b06e57f..4c3a965 100755 --- a/eeyore/chains/chain_lists.py +++ b/eeyore/chains/chain_lists.py @@ -150,6 +150,6 @@ def summary( g=g_multi_ess_summary, mc_cov_mat=mc_cov_mat, method=method, adjust=adjust ) elif key == 'multi_rhat': - summaries[key], _, _ = self.multi_rhat(mc_cov_mat=mc_cov_mat, method=method, adjust=adjust) + summaries[key], _, _, _, _ = self.multi_rhat(mc_cov_mat=mc_cov_mat, method=method, adjust=adjust) return summaries diff --git a/eeyore/linalg/__init__.py b/eeyore/linalg/__init__.py index 0e074e1..d745be6 100755 --- a/eeyore/linalg/__init__.py +++ b/eeyore/linalg/__init__.py @@ -1 +1,2 @@ from .is_pos_def import is_pos_def +from .nearest_pd import nearest_pd diff --git a/eeyore/linalg/nearest_pd.py b/eeyore/linalg/nearest_pd.py new file mode 100644 index 0000000..74a84aa --- /dev/null +++ b/eeyore/linalg/nearest_pd.py @@ -0,0 +1,42 @@ +# Implementation taken from +# https://stackoverflow.com/questions/43238173/python-convert-matrix-to-positive-semi-definite/43244194#43244194 + +import numpy as np +import torch + +from .is_pos_def import is_pos_def + +def nearest_pd(A, f=np.spacing): + """Find the nearest positive-definite matrix to input + + A Python/Numpy port of John D'Errico's `nearestSPD` MATLAB code [1], which credits [2] + + [1] https://www.mathworks.com/matlabcentral/fileexchange/42885-nearestspd + + [2] https://doi.org/10.1016/0024-3795(88)90223-6 + """ + + B = (A + A.T) / 2 + _, s, V = torch.svd(B) + + # For a comparison with kanga, see the following: + # https://github.com/pytorch/pytorch/issues/16076#issuecomment-477755364 + H = torch.matmul(V, torch.matmul(torch.diag(s), V.T)) + + A2 = (B + H) / 2 + + A3 = (A2 + A2.T) / 2 + + if is_pos_def(A3): + return A3 + + spacing = f(torch.norm(A).item()) + I = torch.eye(A.shape[0]) + k = 1 + while not is_pos_def(A3): + eigenvals = torch.eig(A3, eigenvectors=False)[0][:, 0] + mineig = eigenvals.min().item() + A3 += I * (-mineig * k**2 + spacing) + k += 1 + + return A3 diff --git a/eeyore/stats/multi_rhat.py b/eeyore/stats/multi_rhat.py index b47b702..4de76bc 100755 --- a/eeyore/stats/multi_rhat.py +++ b/eeyore/stats/multi_rhat.py @@ -2,6 +2,8 @@ import torch +from eeyore.linalg import is_pos_def, nearest_pd + from .cov import cov from .mc_cov import mc_cov @@ -17,9 +19,21 @@ def multi_rhat(x, mc_cov_mat=None, method='inse', adjust=False): w = w + mc_cov_mat[i] w = w / num_chains + if not is_pos_def(w): + w = nearest_pd(w) + is_w_pd = False + else: + is_w_pd = True + b = cov(x.mean(1), rowvar=False) - rhat = max(torch.symeig(torch.matmul(torch.inverse(w), b))[0]).item() + if not is_pos_def(b): + b = nearest_pd(b) + is_b_pd = False + else: + is_b_pd = True + + rhat = torch.eig(torch.matmul(torch.inverse(w), b), eigenvectors=False)[0][:, 0].max().item() rhat = ((num_iters - 1) / num_iters) + ((num_chains + 1) / num_chains) * rhat - return rhat, w, b + return rhat, w, b, is_w_pd, is_b_pd