Skip to content

Commit f65a4b2

Browse files
committed
test CI: old tests with addition of num tuning steps
1 parent 85fe088 commit f65a4b2

File tree

3 files changed

+56
-23
lines changed

3 files changed

+56
-23
lines changed

blackjax/adaptation/adjusted_mclmc_adaptation.py

+48-15
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,15 @@ def adjusted_mclmc_find_L_and_step_size(
7979

8080
part1_key, part2_key = jax.random.split(rng_key, 2)
8181

82+
total_num_tuning_integrator_steps = 0
8283
for i in range(num_windows):
8384
window_key = jax.random.fold_in(part1_key, i)
84-
(state, params, eigenvector) = adjusted_mclmc_make_L_step_size_adaptation(
85+
(
86+
state,
87+
params,
88+
eigenvector,
89+
num_tuning_integrator_steps,
90+
) = adjusted_mclmc_make_L_step_size_adaptation(
8591
kernel=mclmc_kernel,
8692
dim=dim,
8793
frac_tune1=frac_tune1,
@@ -90,22 +96,38 @@ def adjusted_mclmc_find_L_and_step_size(
9096
diagonal_preconditioning=diagonal_preconditioning,
9197
max=max,
9298
tuning_factor=tuning_factor,
93-
)(state, params, num_steps, window_key)
99+
)(
100+
state, params, num_steps, window_key
101+
)
102+
total_num_tuning_integrator_steps += num_tuning_integrator_steps
94103

95104
if frac_tune3 != 0:
96105
for i in range(num_windows):
97106
part2_key = jax.random.fold_in(part2_key, i)
98107
part2_key1, part2_key2 = jax.random.split(part2_key, 2)
99108

100-
state, params = adjusted_mclmc_make_adaptation_L(
109+
(
110+
state,
111+
params,
112+
num_tuning_integrator_steps,
113+
) = adjusted_mclmc_make_adaptation_L(
101114
mclmc_kernel,
102115
frac=frac_tune3,
103116
Lfactor=0.5,
104117
max=max,
105118
eigenvector=eigenvector,
106-
)(state, params, num_steps, part2_key1)
119+
)(
120+
state, params, num_steps, part2_key1
121+
)
107122

108-
(state, params, _) = adjusted_mclmc_make_L_step_size_adaptation(
123+
total_num_tuning_integrator_steps += num_tuning_integrator_steps
124+
125+
(
126+
state,
127+
params,
128+
_,
129+
num_tuning_integrator_steps,
130+
) = adjusted_mclmc_make_L_step_size_adaptation(
109131
kernel=mclmc_kernel,
110132
dim=dim,
111133
frac_tune1=frac_tune1,
@@ -115,9 +137,13 @@ def adjusted_mclmc_find_L_and_step_size(
115137
diagonal_preconditioning=diagonal_preconditioning,
116138
max=max,
117139
tuning_factor=tuning_factor,
118-
)(state, params, num_steps, part2_key2)
140+
)(
141+
state, params, num_steps, part2_key2
142+
)
119143

120-
return state, params
144+
total_num_tuning_integrator_steps += num_tuning_integrator_steps
145+
146+
return state, params, total_num_tuning_integrator_steps
121147

122148

123149
def adjusted_mclmc_make_L_step_size_adaptation(
@@ -256,6 +282,7 @@ def L_step_size_adaptation(state, params, num_steps, rng_key):
256282
update_da=update_da,
257283
)
258284

285+
num_tuning_integrator_steps = info.num_integration_steps.sum()
259286
final_stepsize = final_da(dual_avg_state)
260287
params = params._replace(step_size=final_stepsize)
261288

@@ -299,9 +326,11 @@ def L_step_size_adaptation(state, params, num_steps, rng_key):
299326
initial_da=initial_da,
300327
)
301328

329+
num_tuning_integrator_steps += info.num_integration_steps.sum()
330+
302331
params = params._replace(step_size=final_da(dual_avg_state))
303332

304-
return state, params, eigenvector
333+
return state, params, eigenvector, num_tuning_integrator_steps
305334

306335
return L_step_size_adaptation
307336

@@ -316,16 +345,16 @@ def adaptation_L(state, params, num_steps, key):
316345
adaptation_L_keys = jax.random.split(key, num_steps)
317346

318347
def step(state, key):
319-
next_state, _ = kernel(
348+
next_state, info = kernel(
320349
rng_key=key,
321350
state=state,
322351
step_size=params.step_size,
323352
avg_num_integration_steps=params.L / params.step_size,
324353
inverse_mass_matrix=params.inverse_mass_matrix,
325354
)
326-
return next_state, next_state.position
355+
return next_state, (next_state.position, info)
327356

328-
state, samples = jax.lax.scan(
357+
state, (samples, info) = jax.lax.scan(
329358
f=step,
330359
init=state,
331360
xs=adaptation_L_keys,
@@ -346,10 +375,14 @@ def step(state, key):
346375
# number of effective samples per 1 actual sample
347376
ess = contract(effective_sample_size(flat_samples[None, ...])) / num_steps
348377

349-
return state, params._replace(
350-
L=jnp.clip(
351-
Lfactor * params.L / jnp.mean(ess), max=params.L * Lratio_upperbound
352-
)
378+
return (
379+
state,
380+
params._replace(
381+
L=jnp.clip(
382+
Lfactor * params.L / jnp.mean(ess), max=params.L * Lratio_upperbound
383+
)
384+
),
385+
info.num_integration_steps.sum(),
353386
)
354387

355388
return adaptation_L

blackjax/adaptation/mclmc_adaptation.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def mclmc_find_L_and_step_size(
126126
mclmc_kernel(params.inverse_mass_matrix), frac=frac_tune3, Lfactor=0.4
127127
)(state, params, num_steps, part2_key)
128128

129-
return state, params
129+
return state, params, num_steps * (frac_tune1 + frac_tune2 + frac_tune3)
130130

131131

132132
def make_L_step_size_adaptation(
@@ -274,8 +274,8 @@ def make_adaptation_L(kernel, frac, Lfactor):
274274
"""determine L by the autocorrelations (around 10 effective samples are needed for this to be accurate)"""
275275

276276
def adaptation_L(state, params, num_steps, key):
277-
num_steps = int(num_steps * frac)
278-
adaptation_L_keys = jax.random.split(key, num_steps)
277+
num_steps_3 = int(num_steps * frac)
278+
adaptation_L_keys = jax.random.split(key, num_steps_3)
279279

280280
def step(state, key):
281281
next_state, _ = kernel(
@@ -297,7 +297,7 @@ def step(state, key):
297297
ess = effective_sample_size(flat_samples[None, ...])
298298

299299
return state, params._replace(
300-
L=Lfactor * params.step_size * jnp.mean(num_steps / ess)
300+
L=Lfactor * params.step_size * jnp.mean(num_steps_3 / ess)
301301
)
302302

303303
return adaptation_L

tests/mcmc/test_sampling.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ def run_mclmc(
122122
(
123123
blackjax_state_after_tuning,
124124
blackjax_mclmc_sampler_params,
125+
_,
125126
) = blackjax.mclmc_find_L_and_step_size(
126127
mclmc_kernel=kernel,
127128
num_steps=num_steps,
@@ -183,6 +184,7 @@ def run_adjusted_mclmc(
183184
(
184185
blackjax_state_after_tuning,
185186
blackjax_mclmc_sampler_params,
187+
_,
186188
) = blackjax.adjusted_mclmc_find_L_and_step_size(
187189
mclmc_kernel=kernel,
188190
num_steps=num_steps,
@@ -252,6 +254,7 @@ def run_adjusted_mclmc_static(
252254
(
253255
blackjax_state_after_tuning,
254256
blackjax_mclmc_sampler_params,
257+
_,
255258
) = blackjax.adjusted_mclmc_find_L_and_step_size(
256259
mclmc_kernel=kernel,
257260
num_steps=num_steps,
@@ -509,10 +512,7 @@ def get_inverse_mass_matrix():
509512
inverse_mass_matrix=inverse_mass_matrix,
510513
)
511514

512-
(
513-
_,
514-
blackjax_mclmc_sampler_params,
515-
) = blackjax.mclmc_find_L_and_step_size(
515+
(_, blackjax_mclmc_sampler_params, _) = blackjax.mclmc_find_L_and_step_size(
516516
mclmc_kernel=kernel,
517517
num_steps=num_steps,
518518
state=initial_state,

0 commit comments

Comments
 (0)