Skip to content

Commit fa6558f

Browse files
committed
test CI: add static
1 parent 0638785 commit fa6558f

File tree

2 files changed

+245
-0
lines changed

2 files changed

+245
-0
lines changed

blackjax/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from .base import SamplingAlgorithm, VIAlgorithm
1313
from .diagnostics import effective_sample_size as ess
1414
from .diagnostics import potential_scale_reduction as rhat
15+
from .mcmc import adjusted_mclmc as _adjusted_mclmc
1516
from .mcmc import adjusted_mclmc_dynamic as _adjusted_mclmc_dynamic
1617
from .mcmc import barker
1718
from .mcmc import dynamic_hmc as _dynamic_hmc
@@ -114,6 +115,7 @@ def generate_top_level_api_from(module):
114115

115116
mclmc = generate_top_level_api_from(_mclmc)
116117
adjusted_mclmc_dynamic = generate_top_level_api_from(_adjusted_mclmc_dynamic)
118+
adjusted_mclmc = generate_top_level_api_from(_adjusted_mclmc)
117119
elliptical_slice = generate_top_level_api_from(_elliptical_slice)
118120
ghmc = generate_top_level_api_from(_ghmc)
119121
barker_proposal = generate_top_level_api_from(barker)

blackjax/mcmc/adjusted_mclmc.py

+243
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
2+
# Copyright 2020- The Blackjax Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Public API for the Metropolis Hastings Microcanonical Hamiltonian Monte Carlo (MHMCHMC) Kernel. This is closely related to the Microcanonical Langevin Monte Carlo (MCLMC) Kernel, which is an unadjusted method. This kernel adds a Metropolis-Hastings correction to the MCLMC kernel. It also only refreshes the momentum variable after each MH step, rather than during the integration of the trajectory. Hence "Hamiltonian" and not "Langevin".
16+
17+
NOTE: For best performance, we recommend using adjusted_mclmc_dynamic instead of this module, which is primarily intended for use in parallelized versions of the algorithm.
18+
19+
"""
20+
from typing import Callable, Union
21+
22+
import jax
23+
import jax.numpy as jnp
24+
25+
import blackjax.mcmc.integrators as integrators
26+
from blackjax.base import SamplingAlgorithm
27+
from blackjax.mcmc.hmc import HMCInfo, HMCState
28+
from blackjax.mcmc.proposal import static_binomial_sampling
29+
from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey
30+
from blackjax.util import generate_unit_vector
31+
32+
__all__ = ["init", "build_kernel", "as_top_level_api"]
33+
34+
35+
def init(position: ArrayLikeTree, logdensity_fn: Callable):
36+
logdensity, logdensity_grad = jax.value_and_grad(logdensity_fn)(position)
37+
return HMCState(position, logdensity, logdensity_grad)
38+
39+
40+
def build_kernel(
41+
logdensity_fn: Callable,
42+
integrator: Callable = integrators.isokinetic_mclachlan,
43+
divergence_threshold: float = 1000,
44+
inverse_mass_matrix=1.0,
45+
):
46+
"""Build an MHMCHMC kernel where the number of integration steps is chosen randomly.
47+
48+
Parameters
49+
----------
50+
integrator
51+
The integrator to use to integrate the Hamiltonian dynamics.
52+
divergence_threshold
53+
Value of the difference in energy above which we consider that the transition is divergent.
54+
next_random_arg_fn
55+
Function that generates the next `random_generator_arg` from its previous value.
56+
integration_steps_fn
57+
Function that generates the next pseudo or quasi-random number of integration steps in the
58+
sequence, given the current `random_generator_arg`. Needs to return an `int`.
59+
60+
Returns
61+
-------
62+
A kernel that takes a rng_key and a Pytree that contains the current state
63+
of the chain and that returns a new state of the chain along with
64+
information about the transition.
65+
"""
66+
67+
def kernel(
68+
rng_key: PRNGKey,
69+
state: HMCState,
70+
step_size: float,
71+
num_integration_steps: int,
72+
L_proposal_factor: float = jnp.inf,
73+
) -> tuple[HMCState, HMCInfo]:
74+
"""Generate a new sample with the MHMCHMC kernel."""
75+
76+
key_momentum, key_integrator = jax.random.split(rng_key, 2)
77+
momentum = generate_unit_vector(key_momentum, state.position)
78+
proposal, info, _ = adjusted_mclmc_proposal(
79+
integrator=integrators.with_isokinetic_maruyama(
80+
integrator(
81+
logdensity_fn=logdensity_fn, inverse_mass_matrix=inverse_mass_matrix
82+
)
83+
),
84+
step_size=step_size,
85+
L_proposal_factor=L_proposal_factor * (num_integration_steps * step_size),
86+
num_integration_steps=num_integration_steps,
87+
divergence_threshold=divergence_threshold,
88+
)(
89+
key_integrator,
90+
integrators.IntegratorState(
91+
state.position, momentum, state.logdensity, state.logdensity_grad
92+
),
93+
)
94+
95+
return (
96+
HMCState(
97+
proposal.position,
98+
proposal.logdensity,
99+
proposal.logdensity_grad,
100+
),
101+
info,
102+
)
103+
104+
return kernel
105+
106+
107+
def as_top_level_api(
108+
logdensity_fn: Callable,
109+
step_size: float,
110+
L_proposal_factor: float = jnp.inf,
111+
inverse_mass_matrix=1.0,
112+
*,
113+
divergence_threshold: int = 1000,
114+
integrator: Callable = integrators.isokinetic_mclachlan,
115+
num_integration_steps,
116+
) -> SamplingAlgorithm:
117+
"""Implements the (basic) user interface for the MHMCHMC kernel.
118+
119+
Parameters
120+
----------
121+
logdensity_fn
122+
The log-density function we wish to draw samples from.
123+
step_size
124+
The value to use for the step size in the symplectic integrator.
125+
divergence_threshold
126+
The absolute value of the difference in energy between two states above
127+
which we say that the transition is divergent. The default value is
128+
commonly found in other libraries, and yet is arbitrary.
129+
integrator
130+
(algorithm parameter) The symplectic integrator to use to integrate the trajectory.
131+
next_random_arg_fn
132+
Function that generates the next `random_generator_arg` from its previous value.
133+
integration_steps_fn
134+
Function that generates the next pseudo or quasi-random number of integration steps in the
135+
sequence, given the current `random_generator_arg`.
136+
137+
138+
Returns
139+
-------
140+
A ``SamplingAlgorithm``.
141+
"""
142+
143+
kernel = build_kernel(
144+
logdensity_fn=logdensity_fn,
145+
integrator=integrator,
146+
inverse_mass_matrix=inverse_mass_matrix,
147+
divergence_threshold=divergence_threshold,
148+
)
149+
150+
def init_fn(position: ArrayLikeTree, rng_key=None):
151+
del rng_key
152+
return init(position, logdensity_fn)
153+
154+
def update_fn(rng_key: PRNGKey, state):
155+
return kernel(
156+
rng_key=rng_key,
157+
state=state,
158+
step_size=step_size,
159+
num_integration_steps=num_integration_steps,
160+
L_proposal_factor=L_proposal_factor,
161+
)
162+
163+
return SamplingAlgorithm(init_fn, update_fn) # type: ignore[arg-type]
164+
165+
166+
def adjusted_mclmc_proposal(
167+
integrator: Callable,
168+
step_size: Union[float, ArrayLikeTree],
169+
L_proposal_factor: float,
170+
num_integration_steps: int = 1,
171+
divergence_threshold: float = 1000,
172+
*,
173+
sample_proposal: Callable = static_binomial_sampling,
174+
) -> Callable:
175+
"""Vanilla MHMCHMC algorithm.
176+
177+
The algorithm integrates the trajectory applying a integrator
178+
`num_integration_steps` times in one direction to get a proposal and uses a
179+
Metropolis-Hastings acceptance step to either reject or accept this
180+
proposal. This is what people usually refer to when they talk about "the
181+
HMC algorithm".
182+
183+
Parameters
184+
----------
185+
integrator
186+
integrator used to build the trajectory step by step.
187+
kinetic_energy
188+
Function that computes the kinetic energy.
189+
step_size
190+
Size of the integration step.
191+
num_integration_steps
192+
Number of times we run the integrator to build the trajectory
193+
divergence_threshold
194+
Threshold above which we say that there is a divergence.
195+
196+
Returns
197+
-------
198+
A kernel that generates a new chain state and information about the transition.
199+
200+
"""
201+
202+
def step(i, vars):
203+
state, kinetic_energy, rng_key = vars
204+
rng_key, next_rng_key = jax.random.split(rng_key)
205+
next_state, next_kinetic_energy = integrator(
206+
state, step_size, L_proposal_factor, rng_key
207+
)
208+
209+
return next_state, kinetic_energy + next_kinetic_energy, next_rng_key
210+
211+
def build_trajectory(state, num_integration_steps, rng_key):
212+
return jax.lax.fori_loop(
213+
0 * num_integration_steps, num_integration_steps, step, (state, 0, rng_key)
214+
)
215+
216+
def generate(
217+
rng_key, state: integrators.IntegratorState
218+
) -> tuple[integrators.IntegratorState, HMCInfo, ArrayTree]:
219+
"""Generate a new chain state."""
220+
end_state, kinetic_energy, rng_key = build_trajectory(
221+
state, num_integration_steps, rng_key
222+
)
223+
224+
new_energy = -end_state.logdensity
225+
delta_energy = -state.logdensity + end_state.logdensity - kinetic_energy
226+
delta_energy = jnp.where(jnp.isnan(delta_energy), -jnp.inf, delta_energy)
227+
is_diverging = -delta_energy > divergence_threshold
228+
sampled_state, info = sample_proposal(rng_key, delta_energy, state, end_state)
229+
do_accept, p_accept, other_proposal_info = info
230+
231+
info = HMCInfo(
232+
state.momentum,
233+
p_accept,
234+
do_accept,
235+
is_diverging,
236+
new_energy,
237+
end_state,
238+
num_integration_steps,
239+
)
240+
241+
return sampled_state, info, other_proposal_info
242+
243+
return generate

0 commit comments

Comments
 (0)