From d65dcc67f27f98236f6acd1ac38995a290ef3686 Mon Sep 17 00:00:00 2001 From: Theodore Papamarkou Date: Mon, 14 Dec 2020 22:00:45 +0200 Subject: [PATCH] Fix in reset of proposal kernels --- eeyore/__init__.py | 2 +- eeyore/samplers/mala.py | 4 ++++ eeyore/samplers/metropolis_hastings.py | 4 ++++ 3 files changed, 9 insertions(+), 1 deletion(-) diff --git a/eeyore/__init__.py b/eeyore/__init__.py index 4bfdfc6..8ecd910 100644 --- a/eeyore/__init__.py +++ b/eeyore/__init__.py @@ -1 +1 @@ -__version__ = '0.0.14' +__version__ = '0.0.15' diff --git a/eeyore/samplers/mala.py b/eeyore/samplers/mala.py index a85b2eb..81bd1d7 100644 --- a/eeyore/samplers/mala.py +++ b/eeyore/samplers/mala.py @@ -30,6 +30,10 @@ def set_current(self, theta, data=None): self.current['target_val'], self.current['grad_val'] = \ self.model.upto_grad_log_target(self.current['sample'].clone().detach(), x, y) + def reset(self, theta, data=None, reset_counter=True, reset_chain=True): + super().reset(theta, data=data, reset_counter=reset_counter, reset_chain=reset_chain) + self.set_kernel(self.current) + def kernel_mean(self, state): return state['sample'] + 0.5 * self.step * state['grad_val'] diff --git a/eeyore/samplers/metropolis_hastings.py b/eeyore/samplers/metropolis_hastings.py index faa6bcd..a8e0865 100644 --- a/eeyore/samplers/metropolis_hastings.py +++ b/eeyore/samplers/metropolis_hastings.py @@ -41,6 +41,10 @@ def set_current(self, theta, data=None): x, y = super().set_current(theta, data=data) self.current['target_val'] = self.model.log_target(self.current['sample'].clone().detach(), x, y) + def reset(self, theta, data=None, reset_counter=True, reset_chain=True): + super().reset(theta, data=data, reset_counter=reset_counter, reset_chain=reset_chain) + self.set_kernel(self.current) + def set_kernel(self, state, scale=None, scale_tril=None): self.kernel.set_density_params(state['sample'].clone().detach())