Skip to content

Commit 40402de

Browse files
andyElkingpatrick-kidger
authored andcommitted
Replaced abs with relu, and updated some comments
1 parent 00d7af2 commit 40402de

File tree

1 file changed

+11
-15
lines changed

1 file changed

+11
-15
lines changed

diffrax/_brownian/tree.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,7 @@ def _split_interval(
103103

104104

105105
class VirtualBrownianTree(AbstractBrownianPath):
106-
"""Brownian simulation that discretises the interval `[t0, t1]` to tolerance `tol`,
107-
and is piecewise quadratic at that discretisation.
106+
"""Brownian simulation that discretises the interval `[t0, t1]` to tolerance `tol`.
108107
109108
Can be initialised with `levy_area` set to `""`, or `"space-time"`.
110109
If `levy_area="space_time"`, then it also computes space-time Lévy area `H`.
@@ -267,16 +266,12 @@ def _evaluate_leaf(
267266
)
268267

269268
def _cond_fun(_state):
270-
# Slight adaptation on the version of the algorithm given in the
271-
# above-referenced thesis. There the returned value is snapped to one of
272-
# the dyadic grid points, so they just stop once
273-
# jnp.abs(τ - state.s) > self.tol
274-
# Here, because we use quadratic splines to get better samples, we always
275-
# iterate down to the level of the spline.
269+
"""Condition for the binary search for r."""
270+
# If true, continue splitting the interval and descending the tree.
276271
return 2.0 ** (-_state.level) > self.tol
277272

278273
def _body_fun(_state: _State):
279-
"""Single-step of binary search for r."""
274+
"""Single-step of the binary search for r."""
280275

281276
(
282277
_t,
@@ -318,15 +313,16 @@ def _body_fun(_state: _State):
318313
s = final_state.s
319314
su = 2.0**-final_state.level
320315

321-
sr = r - s
322-
ru = su - sr # make sure su = sr + ru regardless of cancellation error
316+
sr = jax.nn.relu(r - s)
317+
# make sure su = sr + ru regardless of cancellation error
318+
ru = jax.nn.relu(su - sr)
323319

324320
w_s, w_u, w_su = final_state.w_s_u_su
325321

326322
# BM only case
327323
if self.levy_area == "":
328324
z = jr.normal(final_state.key, shape, dtype)
329-
w_sr = sr / su * w_su + jnp.sqrt(jnp.abs(sr * ru / su)) * z
325+
w_sr = sr / su * w_su + jnp.sqrt(sr * ru / su) * z
330326
w_r = w_s + w_sr
331327
return LevyVal(dt=r, W=w_r, H=None, bar_H=None, K=None, bar_K=None)
332328

@@ -340,15 +336,15 @@ def _body_fun(_state: _State):
340336
x1 = jr.normal(key1, shape, dtype)
341337
x2 = jr.normal(key2, shape, dtype)
342338

343-
sr_ru_half = jnp.sqrt(jnp.abs(sr * ru))
344-
d = jnp.sqrt(jnp.abs(sr3 + ru3))
339+
sr_ru_half = jnp.sqrt(sr * ru)
340+
d = jnp.sqrt(sr3 + ru3)
345341
d_prime = 1 / (2 * su * d)
346342
a = d_prime * sr3 * sr_ru_half
347343
b = d_prime * ru3 * sr_ru_half
348344

349345
w_sr = sr / su * w_su + 6 * sr * ru / su3 * bhh_su + 2 * (a + b) / su * x1
350346
w_r = w_s + w_sr
351-
c = jnp.sqrt(jnp.abs(3 * sr3 * ru3)) / (6 * d)
347+
c = jnp.sqrt(3 * sr3 * ru3) / (6 * d)
352348
bhh_sr = sr3 / su3 * bhh_su - a * x1 + c * x2
353349
bhh_r = bhh_s + bhh_sr + 0.5 * (r * w_s - s * w_r)
354350

0 commit comments

Comments
 (0)