Skip to content

Commit

Permalink
Add slice sampling state
Browse files Browse the repository at this point in the history
  • Loading branch information
lucianopaz committed Sep 25, 2024
1 parent e0df681 commit 6810c41
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion pymc/step_methods/slicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from pymc.model import modelcontext
from pymc.pytensorf import compile_pymc, join_nonshared_inputs, make_shared_replacements
from pymc.step_methods.arraystep import ArrayStepShared
from pymc.step_methods.compound import Competence
from pymc.step_methods.compound import Competence, StepMethodState
from pymc.util import get_value_vars_from_user_vars
from pymc.vartypes import continuous_types

Expand All @@ -30,6 +30,13 @@
LOOP_ERR_MSG = "max slicer iters %d exceeded"


class SliceState(StepMethodState):
w: np.ndarray
tune: bool
n_tunes: float
iter_limit: float


class Slice(ArrayStepShared):
"""
Univariate slice sampler step method.
Expand Down Expand Up @@ -57,6 +64,8 @@ class Slice(ArrayStepShared):
"nstep_in": (int, []),
}

_state_class = SliceState

def __init__(
self, vars=None, w=1.0, tune=True, model=None, iter_limit=np.inf, rng=None, **kwargs
):
Expand Down

0 comments on commit 6810c41

Please sign in to comment.