Multithreading parallel executions #11565
Unanswered
PatrickHuembeli
asked this question in
General
Replies: 1 comment 8 replies
-
IIUC, JAX's cpu backend is multi-thread by default. |
Beta Was this translation helpful? Give feedback.
8 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi everyone,
We are running a fairly small neural network (NN) with jax and jit, but we would like to repeat the training many times to evaluate its average performance. Jitting the NN and its gradients takes almost the same amount of time as its training. Therefore, we would like to jit the functions and gradients just once and then run the optimization on several CPUs at the same time. What would be the best way to do this?
I have seen there is the pmap function. So what I was wondering, is it possible to pmap the gradient function
jax.jit(jax.pmap(jax.grad(f(weights))))
and then instead of passing the normal weights we pass a batch of weights? Does this still work normally with the optax optimizers? And is there anything we have to be aware of? Or is there a better solution to this problem?Beta Was this translation helpful? Give feedback.
All reactions