@@ -103,8 +103,7 @@ def _split_interval(
103
103
104
104
105
105
class VirtualBrownianTree (AbstractBrownianPath ):
106
- """Brownian simulation that discretises the interval `[t0, t1]` to tolerance `tol`,
107
- and is piecewise quadratic at that discretisation.
106
+ """Brownian simulation that discretises the interval `[t0, t1]` to tolerance `tol`.
108
107
109
108
Can be initialised with `levy_area` set to `""`, or `"space-time"`.
110
109
If `levy_area="space_time"`, then it also computes space-time Lévy area `H`.
@@ -267,16 +266,12 @@ def _evaluate_leaf(
267
266
)
268
267
269
268
def _cond_fun (_state ):
270
- # Slight adaptation on the version of the algorithm given in the
271
- # above-referenced thesis. There the returned value is snapped to one of
272
- # the dyadic grid points, so they just stop once
273
- # jnp.abs(τ - state.s) > self.tol
274
- # Here, because we use quadratic splines to get better samples, we always
275
- # iterate down to the level of the spline.
269
+ """Condition for the binary search for r."""
270
+ # If true, continue splitting the interval and descending the tree.
276
271
return 2.0 ** (- _state .level ) > self .tol
277
272
278
273
def _body_fun (_state : _State ):
279
- """Single-step of binary search for r."""
274
+ """Single-step of the binary search for r."""
280
275
281
276
(
282
277
_t ,
@@ -318,15 +313,16 @@ def _body_fun(_state: _State):
318
313
s = final_state .s
319
314
su = 2.0 ** - final_state .level
320
315
321
- sr = r - s
322
- ru = su - sr # make sure su = sr + ru regardless of cancellation error
316
+ sr = jax .nn .relu (r - s )
317
+ # make sure su = sr + ru regardless of cancellation error
318
+ ru = jax .nn .relu (su - sr )
323
319
324
320
w_s , w_u , w_su = final_state .w_s_u_su
325
321
326
322
# BM only case
327
323
if self .levy_area == "" :
328
324
z = jr .normal (final_state .key , shape , dtype )
329
- w_sr = sr / su * w_su + jnp .sqrt (jnp . abs ( sr * ru / su ) ) * z
325
+ w_sr = sr / su * w_su + jnp .sqrt (sr * ru / su ) * z
330
326
w_r = w_s + w_sr
331
327
return LevyVal (dt = r , W = w_r , H = None , bar_H = None , K = None , bar_K = None )
332
328
@@ -340,15 +336,15 @@ def _body_fun(_state: _State):
340
336
x1 = jr .normal (key1 , shape , dtype )
341
337
x2 = jr .normal (key2 , shape , dtype )
342
338
343
- sr_ru_half = jnp .sqrt (jnp . abs ( sr * ru ) )
344
- d = jnp .sqrt (jnp . abs ( sr3 + ru3 ) )
339
+ sr_ru_half = jnp .sqrt (sr * ru )
340
+ d = jnp .sqrt (sr3 + ru3 )
345
341
d_prime = 1 / (2 * su * d )
346
342
a = d_prime * sr3 * sr_ru_half
347
343
b = d_prime * ru3 * sr_ru_half
348
344
349
345
w_sr = sr / su * w_su + 6 * sr * ru / su3 * bhh_su + 2 * (a + b ) / su * x1
350
346
w_r = w_s + w_sr
351
- c = jnp .sqrt (jnp . abs ( 3 * sr3 * ru3 ) ) / (6 * d )
347
+ c = jnp .sqrt (3 * sr3 * ru3 ) / (6 * d )
352
348
bhh_sr = sr3 / su3 * bhh_su - a * x1 + c * x2
353
349
bhh_r = bhh_s + bhh_sr + 0.5 * (r * w_s - s * w_r )
354
350
0 commit comments