Is it possible to accelerate autograd by formula simplification? #11915
-
Hi everyone, I really enjoy the convenience brought by autograd mechanism of jax. However, I encounter some performance problems. Assuming we want to take the laplacian of a simple 1d gaussian function f: x = jnp.array(1.)
def f(x):
return jnp.exp(-x**2) Using autograd, it's simple to fulfill the task: jax.jit(jax.grad(jax.grad(f))(x) However, it takes quite a long time on my local laptop(only cpu), nearly 20 ms to complete the calculation. def laps_manual(x):
y=x**2
return (-2 + 4*y)*jnp.exp(-y) And laps_manual runs extremely faster than autograd, only takes 20 us. I think the main reason leading to the performance difference is that autograd doesn't do enough formula simplifications. So I want to ask if there exists some ways to simplify the formula in autograd and then we can run it much faster? Maybe I'm requiring too much, but I think it's a usual problem we will meet in scientific calculations. It will be grate if we can enjoy the convenience of autograd and the efficiency of some manual works. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
When I benchmark these functions using the recommended approaches there, I find that the automatic and manual versions of the function both have comparable runtimes of ~4µs on a Colab CPU: import jax
import jax.numpy as jnp
x = jnp.array(1.)
def f(x):
return jnp.exp(-x**2)
fprime = jax.jit(jax.grad(jax.grad(f)))
x = jnp.array(1.)
%timeit fprime.lower(x).compile()
# 38.5 ms ± 629 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
_ = fprime(x).block_until_ready()
%timeit fprime(x).block_until_ready()
# 3.99 µs ± 104 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
@jax.jit
def laps_manual(x):
y=x**2
return (-2 + 4*y)*jnp.exp(-y)
%timeit laps_manual.lower(x).compile()
# 24.5 ms ± 1.32 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
_ = laps_manual(x).block_until_ready()
%timeit laps_manual(x).block_until_ready()
# 4.11 µs ± 711 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) |
Beta Was this translation helpful? Give feedback.
jax.jit
already does this kind of acceleration via formula simplification automatically. I think your timings are probably being thrown-off by JAX's asynchronous dispatch; for information about how to accurately assess microbenchmarks of JAX code, see https://jax.readthedocs.io/en/latest/faq.html#benchmarking-jax-codeWhen I benchmark these functions using the recommended approaches there, I find that the automatic and manual versions of the function both have comparable runtimes of ~4µs on a Colab CPU: