Skip to content

Commit

Permalink
Allowing args into grad_f for ULD
Browse files Browse the repository at this point in the history
  • Loading branch information
ricor07 authored and patrick-kidger committed Jan 13, 2025
1 parent 134a40a commit cc0d4bc
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 4 deletions.
8 changes: 4 additions & 4 deletions diffrax/_term.py
Original file line number Diff line number Diff line change
Expand Up @@ -925,13 +925,13 @@ class UnderdampedLangevinDriftTerm(AbstractTerm):

gamma: PyTree[ArrayLike]
u: PyTree[ArrayLike]
grad_f: Callable[[UnderdampedLangevinX], UnderdampedLangevinX]
grad_f: Callable[[UnderdampedLangevinX, Args], UnderdampedLangevinX]

def __init__(
self,
gamma: PyTree[ArrayLike],
u: PyTree[ArrayLike],
grad_f: Callable[[UnderdampedLangevinX], UnderdampedLangevinX],
grad_f: Callable[[UnderdampedLangevinX, Args], UnderdampedLangevinX],
):
r"""
**Arguments:**
Expand All @@ -942,7 +942,7 @@ def __init__(
a scalar or a PyTree of the same shape as the position vector $x$.
- `grad_f`: A callable representing the gradient of the potential function $f$.
This callable should take a PyTree of the same shape as $x$ and
return a PyTree of the same shape.
an optional `args` argument, returning a PyTree of the same shape.
"""
self.gamma = gamma
self.u = u
Expand All @@ -963,7 +963,7 @@ def fun(_gamma, _u, _v, _f_x):

vf_x = v
try:
f_x = self.grad_f(x)
f_x = self.grad_f(x, args) # Pass args to grad_f
vf_v = jtu.tree_map(fun, gamma, u, v, f_x)
except ValueError:
raise RuntimeError(
Expand Down
36 changes: 36 additions & 0 deletions test/test_term.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,39 @@ def test_weaklydiagonal_deprecate():
_ = diffrax.WeaklyDiagonalControlTerm(
lambda t, y, args: 0.0, lambda t0, t1: jnp.array(t1 - t0)
)


def test_underdamped_langevin_drift_term_args():
"""
Test that the UnderdampedLangevinDriftTerm handles `args` in grad_f correctly.
"""

# Mock gradient function that uses args
def mock_grad_f(x, args):
return jtu.tree_map(lambda xi, ai: xi + ai, x, args)

# Mock data
gamma = jnp.array([0.1, 0.2, 0.3])
u = jnp.array([0.4, 0.5, 0.6])
x = jnp.array([1.0, 2.0, 3.0])
v = jnp.array([0.1, 0.2, 0.3])
args = jnp.array([0.7, 0.8, 0.9])
y = (x, v)

# Create instance of the drift term
term = diffrax.UnderdampedLangevinDriftTerm(gamma=gamma, u=u, grad_f=mock_grad_f)

# Compute the vector field
vf_y = term.vf(0.0, y, args)

# Extract results
vf_x, vf_v = vf_y

# Expected results
expected_vf_x = v # By definition, vf_x = v
f_x = x + args # Output of mock_grad_f
expected_vf_v = -gamma * v - u * f_x # Drift term calculation

# Assertions
assert jnp.allclose(vf_x, expected_vf_x), "vf_x does not match expected results"
assert jnp.allclose(vf_v, expected_vf_v), "vf_v does not match expected results"

0 comments on commit cc0d4bc

Please sign in to comment.