From 6810c413b2f552cd905a7c7f760af3440950e377 Mon Sep 17 00:00:00 2001 From: Luciano Paz Date: Thu, 19 Sep 2024 09:55:22 +0200 Subject: [PATCH] Add slice sampling state --- pymc/step_methods/slicer.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/pymc/step_methods/slicer.py b/pymc/step_methods/slicer.py index 2396850ade..08bfae99b7 100644 --- a/pymc/step_methods/slicer.py +++ b/pymc/step_methods/slicer.py @@ -21,7 +21,7 @@ from pymc.model import modelcontext from pymc.pytensorf import compile_pymc, join_nonshared_inputs, make_shared_replacements from pymc.step_methods.arraystep import ArrayStepShared -from pymc.step_methods.compound import Competence +from pymc.step_methods.compound import Competence, StepMethodState from pymc.util import get_value_vars_from_user_vars from pymc.vartypes import continuous_types @@ -30,6 +30,13 @@ LOOP_ERR_MSG = "max slicer iters %d exceeded" +class SliceState(StepMethodState): + w: np.ndarray + tune: bool + n_tunes: float + iter_limit: float + + class Slice(ArrayStepShared): """ Univariate slice sampler step method. @@ -57,6 +64,8 @@ class Slice(ArrayStepShared): "nstep_in": (int, []), } + _state_class = SliceState + def __init__( self, vars=None, w=1.0, tune=True, model=None, iter_limit=np.inf, rng=None, **kwargs ):