Skip to content

Commit 61d1c91

Browse files
Randlpatrick-kidger
authored andcommitted
Complex fixes in SDEs
1 parent d6d09dc commit 61d1c91

File tree

4 files changed

+6
-6
lines changed

4 files changed

+6
-6
lines changed

diffrax/_solver/milstein.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def step(
211211
leaves_ΔwΔw = []
212212
for i1, l1 in enumerate(leaves_Δw):
213213
for i2, l2 in enumerate(leaves_Δw):
214-
leaf = jnp.tensordot(l1[..., None], l2[None, ...], axes=1)
214+
leaf = jnp.tensordot(jnp.conj(l1[..., None]), l2[None, ...], axes=1)
215215
if i1 == i2:
216216
eye = jnp.eye(l1.size).reshape(l1.shape + l1.shape)
217217
with jax.numpy_dtype_promotion("standard"):
@@ -305,7 +305,7 @@ def _to_treemap(_Δw, _g0):
305305
def __dot(_v0, _ΔwΔw):
306306
# _v0 has structure (leaf(y0), leaf(Δw), leaf(Δw))
307307
# _ΔwΔw has structure (leaf(Δw), leaf(Δw))
308-
_out = jnp.tensordot(_v0, _ΔwΔw, axes=jnp.ndim(_ΔwΔw))
308+
_out = jnp.tensordot(jnp.conj(_v0), _ΔwΔw, axes=jnp.ndim(_ΔwΔw))
309309
# _out has structure (leaf(y0),)
310310
return _out
311311

diffrax/_solver/srk.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,7 @@ def _comp_g(_t):
407407

408408
g0_g1 = _comp_g(jnp.array([t0, t1], dtype=complex_to_real_dtype(dtype)))
409409
g0 = jtu.tree_map(lambda g_leaf: g_leaf[0], g0_g1)
410-
# g_delta = 0.5 * g1 - g0
410+
# g_delta = 0.5 * (g1 - g0)
411411
g_delta = jtu.tree_map(lambda g_leaf: 0.5 * (g_leaf[1] - g_leaf[0]), g0_g1)
412412
w_kgs = diffusion.prod(g0, w)
413413
a_w = jnp.asarray(self.tableau.coeffs_w.a, dtype=dtype)
@@ -456,7 +456,7 @@ def sum_prev_stages(_stage_out_buff, _a_j):
456456
)
457457
# Sum up the previous stages weighted by the coefficients in the tableau
458458
return jtu.tree_map(
459-
lambda lf: jnp.tensordot(_a_j, lf, axes=1), _stage_out_view
459+
lambda lf: jnp.tensordot(jnp.conj(_a_j), lf, axes=1), _stage_out_view
460460
)
461461

462462
def insert_jth_stage(results, k_j, j):

diffrax/_term.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ def _callable_to_path(
257257
# control: Shaped[Array, "*control"]
258258
# return: Shaped[Array, "*state"]
259259
def _prod(vf, control):
260-
return jnp.tensordot(vf, control, axes=jnp.ndim(control))
260+
return jnp.tensordot(jnp.conj(vf), control, axes=jnp.ndim(control))
261261

262262

263263
# This class exists for backward compatibility with `WeaklyDiagonalControlTerm`. If we

test/helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def path_l2_dist(
100100
# and the length of saveat). Also sum all the PyTree leaves.
101101
def sum_square_diff(y1, y2):
102102
with jax.numpy_dtype_promotion("standard"):
103-
square_diff = jnp.square(y1 - y2)
103+
square_diff = jnp.square(jnp.abs(y1 - y2))
104104
# sum all but the first two axes
105105
axes = range(2, square_diff.ndim)
106106
out = jnp.sum(square_diff, axis=axes)

0 commit comments

Comments
 (0)