@@ -79,9 +79,15 @@ def adjusted_mclmc_find_L_and_step_size(
79
79
80
80
part1_key , part2_key = jax .random .split (rng_key , 2 )
81
81
82
+ total_num_tuning_integrator_steps = 0
82
83
for i in range (num_windows ):
83
84
window_key = jax .random .fold_in (part1_key , i )
84
- (state , params , eigenvector ) = adjusted_mclmc_make_L_step_size_adaptation (
85
+ (
86
+ state ,
87
+ params ,
88
+ eigenvector ,
89
+ num_tuning_integrator_steps ,
90
+ ) = adjusted_mclmc_make_L_step_size_adaptation (
85
91
kernel = mclmc_kernel ,
86
92
dim = dim ,
87
93
frac_tune1 = frac_tune1 ,
@@ -90,22 +96,38 @@ def adjusted_mclmc_find_L_and_step_size(
90
96
diagonal_preconditioning = diagonal_preconditioning ,
91
97
max = max ,
92
98
tuning_factor = tuning_factor ,
93
- )(state , params , num_steps , window_key )
99
+ )(
100
+ state , params , num_steps , window_key
101
+ )
102
+ total_num_tuning_integrator_steps += num_tuning_integrator_steps
94
103
95
104
if frac_tune3 != 0 :
96
105
for i in range (num_windows ):
97
106
part2_key = jax .random .fold_in (part2_key , i )
98
107
part2_key1 , part2_key2 = jax .random .split (part2_key , 2 )
99
108
100
- state , params = adjusted_mclmc_make_adaptation_L (
109
+ (
110
+ state ,
111
+ params ,
112
+ num_tuning_integrator_steps ,
113
+ ) = adjusted_mclmc_make_adaptation_L (
101
114
mclmc_kernel ,
102
115
frac = frac_tune3 ,
103
116
Lfactor = 0.5 ,
104
117
max = max ,
105
118
eigenvector = eigenvector ,
106
- )(state , params , num_steps , part2_key1 )
119
+ )(
120
+ state , params , num_steps , part2_key1
121
+ )
107
122
108
- (state , params , _ ) = adjusted_mclmc_make_L_step_size_adaptation (
123
+ total_num_tuning_integrator_steps += num_tuning_integrator_steps
124
+
125
+ (
126
+ state ,
127
+ params ,
128
+ _ ,
129
+ num_tuning_integrator_steps ,
130
+ ) = adjusted_mclmc_make_L_step_size_adaptation (
109
131
kernel = mclmc_kernel ,
110
132
dim = dim ,
111
133
frac_tune1 = frac_tune1 ,
@@ -115,9 +137,13 @@ def adjusted_mclmc_find_L_and_step_size(
115
137
diagonal_preconditioning = diagonal_preconditioning ,
116
138
max = max ,
117
139
tuning_factor = tuning_factor ,
118
- )(state , params , num_steps , part2_key2 )
140
+ )(
141
+ state , params , num_steps , part2_key2
142
+ )
119
143
120
- return state , params
144
+ total_num_tuning_integrator_steps += num_tuning_integrator_steps
145
+
146
+ return state , params , total_num_tuning_integrator_steps
121
147
122
148
123
149
def adjusted_mclmc_make_L_step_size_adaptation (
@@ -256,6 +282,7 @@ def L_step_size_adaptation(state, params, num_steps, rng_key):
256
282
update_da = update_da ,
257
283
)
258
284
285
+ num_tuning_integrator_steps = info .num_integration_steps .sum ()
259
286
final_stepsize = final_da (dual_avg_state )
260
287
params = params ._replace (step_size = final_stepsize )
261
288
@@ -299,9 +326,11 @@ def L_step_size_adaptation(state, params, num_steps, rng_key):
299
326
initial_da = initial_da ,
300
327
)
301
328
329
+ num_tuning_integrator_steps += info .num_integration_steps .sum ()
330
+
302
331
params = params ._replace (step_size = final_da (dual_avg_state ))
303
332
304
- return state , params , eigenvector
333
+ return state , params , eigenvector , num_tuning_integrator_steps
305
334
306
335
return L_step_size_adaptation
307
336
@@ -316,16 +345,16 @@ def adaptation_L(state, params, num_steps, key):
316
345
adaptation_L_keys = jax .random .split (key , num_steps )
317
346
318
347
def step (state , key ):
319
- next_state , _ = kernel (
348
+ next_state , info = kernel (
320
349
rng_key = key ,
321
350
state = state ,
322
351
step_size = params .step_size ,
323
352
avg_num_integration_steps = params .L / params .step_size ,
324
353
inverse_mass_matrix = params .inverse_mass_matrix ,
325
354
)
326
- return next_state , next_state .position
355
+ return next_state , ( next_state .position , info )
327
356
328
- state , samples = jax .lax .scan (
357
+ state , ( samples , info ) = jax .lax .scan (
329
358
f = step ,
330
359
init = state ,
331
360
xs = adaptation_L_keys ,
@@ -346,10 +375,14 @@ def step(state, key):
346
375
# number of effective samples per 1 actual sample
347
376
ess = contract (effective_sample_size (flat_samples [None , ...])) / num_steps
348
377
349
- return state , params ._replace (
350
- L = jnp .clip (
351
- Lfactor * params .L / jnp .mean (ess ), max = params .L * Lratio_upperbound
352
- )
378
+ return (
379
+ state ,
380
+ params ._replace (
381
+ L = jnp .clip (
382
+ Lfactor * params .L / jnp .mean (ess ), max = params .L * Lratio_upperbound
383
+ )
384
+ ),
385
+ info .num_integration_steps .sum (),
353
386
)
354
387
355
388
return adaptation_L
0 commit comments