Skip to content

Commit

Permalink
Add slice sampling state
Browse files Browse the repository at this point in the history
  • Loading branch information
lucianopaz committed Oct 7, 2024
1 parent c417476 commit db32421
Showing 1 changed file with 15 additions and 1 deletion.
16 changes: 15 additions & 1 deletion pymc/step_methods/slicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
from pymc.model import modelcontext
from pymc.pytensorf import compile_pymc, join_nonshared_inputs, make_shared_replacements
from pymc.step_methods.arraystep import ArrayStepShared
from pymc.step_methods.compound import Competence
from pymc.step_methods.compound import Competence, StepMethodState
from pymc.step_methods.state import dataclass_state
from pymc.util import get_value_vars_from_user_vars
from pymc.vartypes import continuous_types

Expand All @@ -30,6 +31,17 @@
LOOP_ERR_MSG = "max slicer iters %d exceeded"


dataclass_state


@dataclass_state
class SliceState(StepMethodState):
w: np.ndarray
tune: bool
n_tunes: float
iter_limit: float


class Slice(ArrayStepShared):
"""
Univariate slice sampler step method.
Expand Down Expand Up @@ -61,6 +73,8 @@ class Slice(ArrayStepShared):
"nstep_in": (int, []),
}

_state_class = SliceState

def __init__(
self, vars=None, w=1.0, tune=True, model=None, iter_limit=np.inf, rng=None, **kwargs
):
Expand Down

0 comments on commit db32421

Please sign in to comment.