Skip to content

Commit

Permalink
update jit results
Browse files Browse the repository at this point in the history
  • Loading branch information
lockwo committed Feb 9, 2025
1 parent 7865a16 commit 20e700d
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
7 changes: 3 additions & 4 deletions benchmarks/stateful_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ def __call__(
):
return self.evaluate(t0, t1, left, use_levy), brownian_state

@eqx.filter_jit
def evaluate(
self,
t0,
Expand Down Expand Up @@ -270,9 +269,9 @@ def step(y, dW):
Results on A100 GPU:
VBT: 2.275057
Old UBP: 0.092015
New UBP: 0.125904
New UBP + Precompute: 0.108587
Old UBP: 0.112461
New UBP: 0.126370
New UBP + Precompute: 0.111837
Pure Jax: 0.261937
For small ndt (e.g. 100) the pure jax is faster, but the diffrax overhead
Expand Down
1 change: 1 addition & 0 deletions diffrax/_brownian/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def init(
key = self.key
return key, noise, counter

@eqx.filter_jit
def __call__(
self,
t0: RealScalarLike,
Expand Down

0 comments on commit 20e700d

Please sign in to comment.