Skip to content

Commit

Permalink
Add sampler state methods for metropolis steps
Browse files Browse the repository at this point in the history
  • Loading branch information
lucianopaz committed Sep 18, 2024
1 parent 9a6f775 commit c6c9686
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 2 deletions.
71 changes: 70 additions & 1 deletion pymc/step_methods/metropolis.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Callable
from dataclasses import dataclass

import numpy as np
import numpy.random as nr
Expand Down Expand Up @@ -40,7 +41,7 @@
StatsType,
metrop_select,
)
from pymc.step_methods.compound import Competence
from pymc.step_methods.compound import Competence, StepMethodState

__all__ = [
"Metropolis",
Expand Down Expand Up @@ -111,6 +112,16 @@ def __call__(self, num_draws=None, rng: np.random.Generator | None = None):
return np.dot(self.chol, b)


@dataclass
class MetropolisState(StepMethodState):
tune: bool
steps_until_tune: float
tune_interval: float
accepted_sum: np.ndarray
accept_rate_iter: np.ndarray
accepted_iter: np.ndarray


class Metropolis(ArrayStepShared):
"""Metropolis-Hastings sampling step"""

Expand All @@ -124,6 +135,8 @@ class Metropolis(ArrayStepShared):
"scaling": (np.float64, []),
}

_step_method_state_class = MetropolisState

def __init__(
self,
vars=None,
Expand Down Expand Up @@ -342,6 +355,15 @@ def tune(scale, acc_rate):
)


@dataclass
class BinaryMetropolisState(StepMethodState):
tune: bool
accepted: int
scaling: float
tune_interval: int
steps_until_tune: int


class BinaryMetropolis(ArrayStep):
"""Metropolis-Hastings optimized for binary variables
Expand All @@ -368,6 +390,8 @@ class BinaryMetropolis(ArrayStep):
"p_jump": (np.float64, []),
}

_step_method_state_class = BinaryMetropolisState

def __init__(self, vars, scaling=1.0, tune=True, tune_interval=100, model=None):
model = pm.modelcontext(model)

Expand Down Expand Up @@ -438,6 +462,14 @@ def competence(var):
return Competence.INCOMPATIBLE


@dataclass
class BinaryGibbsMetropolisState(StepMethodState):
tune: bool
transit_p: int
shuffle_dims: bool
order: list


class BinaryGibbsMetropolis(ArrayStep):
"""A Metropolis-within-Gibbs step method optimized for binary variables
Expand All @@ -462,6 +494,8 @@ class BinaryGibbsMetropolis(ArrayStep):
"tune": (bool, []),
}

_step_method_state_class = BinaryGibbsMetropolisState

def __init__(self, vars, order="random", transit_p=0.8, model=None):
model = pm.modelcontext(model)

Expand Down Expand Up @@ -545,6 +579,13 @@ def competence(var):
return Competence.INCOMPATIBLE


@dataclass
class CategoricalGibbsMetropolisState(StepMethodState):
shuffle_dims: bool
dimcats: list[tuple]
tune: bool


class CategoricalGibbsMetropolis(ArrayStep):
"""A Metropolis-within-Gibbs step method optimized for categorical variables.
Expand All @@ -561,6 +602,8 @@ class CategoricalGibbsMetropolis(ArrayStep):
"tune": (bool, []),
}

_step_method_state_class = CategoricalGibbsMetropolisState

def __init__(self, vars, proposal="uniform", order="random", model=None):
model = pm.modelcontext(model)

Expand Down Expand Up @@ -714,6 +757,16 @@ def competence(var):
return Competence.INCOMPATIBLE


@dataclass
class DEMetropolisState(StepMethodState):
scaling: np.ndarray
lamb: float
tune: str | None
tune_interval: int
steps_until_tune: int
accepted: int


class DEMetropolis(PopulationArrayStepShared):
"""
Differential Evolution Metropolis sampling step.
Expand Down Expand Up @@ -760,6 +813,8 @@ class DEMetropolis(PopulationArrayStepShared):
"lambda": (np.float64, []),
}

_step_method_state_class = DEMetropolisState

def __init__(
self,
vars=None,
Expand Down Expand Up @@ -854,6 +909,18 @@ def competence(var, has_grad):
return Competence.COMPATIBLE


@dataclass
class DEMetropolisZState(StepMethodState):
scaling: np.ndarray
lamb: float
tune: bool
tune_target: str | None
tune_interval: int
steps_until_tune: int
accepted: int
_history: list


class DEMetropolisZ(ArrayStepShared):
"""
Adaptive Differential Evolution Metropolis sampling step that uses the past to inform jumps.
Expand Down Expand Up @@ -903,6 +970,8 @@ class DEMetropolisZ(ArrayStepShared):
"lambda": (np.float64, []),
}

_step_method_state_class = DEMetropolisZState

def __init__(
self,
vars=None,
Expand Down
11 changes: 11 additions & 0 deletions tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,3 +186,14 @@ def simple_normal(bounded_prior=False):
pm.Normal("X_obs", mu=mu_i, sigma=sigma, observed=x0)

return model.initial_point(), model, None


def simple_binary():
p1 = 0.5
p2 = 0.5

with pm.Model() as model:
pm.Bernoulli("d1", p=p1)
pm.Bernoulli("d2", p=p2)

return model.initial_point(), model, (p1, p2)
44 changes: 43 additions & 1 deletion tests/step_methods/test_metropolis.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

import warnings

from copy import deepcopy

import arviz as az
import numpy as np
import numpy.testing as npt
Expand All @@ -24,6 +26,7 @@

from pymc.step_methods.metropolis import (
BinaryGibbsMetropolis,
BinaryMetropolis,
CategoricalGibbsMetropolis,
DEMetropolis,
DEMetropolisZ,
Expand All @@ -34,7 +37,13 @@
from pymc.testing import fast_unstable_sampling_mode
from tests import sampler_fixtures as sf
from tests.helpers import RVsAssignmentStepsTester, StepMethodTester
from tests.models import mv_simple, mv_simple_discrete, simple_categorical
from tests.models import (
mv_simple,
mv_simple_discrete,
simple_binary,
simple_categorical,
simple_model,
)


class TestMetropolisUniform(sf.MetropolisFixture, sf.UniformFixture):
Expand Down Expand Up @@ -364,3 +373,36 @@ def test_discrete_steps(self, step, step_kwargs):
)
def test_continuous_steps(self, step, step_kwargs):
self.continuous_steps(step, step_kwargs)


@pytest.mark.parametrize(
["step_method", "model_fn"],
[
[Metropolis, simple_model],
[BinaryMetropolis, simple_binary],
[BinaryGibbsMetropolis, simple_binary],
[CategoricalGibbsMetropolis, simple_categorical],
[DEMetropolis, simple_model],
[DEMetropolisZ, simple_model],
],
)
def test_sampling_state(step_method, model_fn):
with pytensor.config.change_flags(mode=fast_unstable_sampling_mode):
initial_point, model, _ = model_fn()
with model:
sampler = step_method(model.value_vars)
if hasattr(sampler, "link_population"):
sampler.link_population([initial_point] * 100, 0)
sampler_orig = deepcopy(sampler)

sampler.step(initial_point)
sampler.tune = False

state = sampler.sampling_state
state_orig = sampler_orig.sampling_state

sampler_orig.sampling_state = state

assert state != state_orig
assert state == sampler_orig.sampling_state
assert state is not sampler_orig.sampling_state

0 comments on commit c6c9686

Please sign in to comment.