Replies: 6 comments 2 replies
-
Thanks for the question! The issue is your use of nested for-loops within JIT. These will be unrolled by the JIT compiler, meaning that your program ends up sending a sequence of over 16000 individual instructions to XLA, which accounts for the slow compile times. Rather than nested for-loops, you should try expressing your program's logic in terms of vectorized array computations (similar to how Numpy achieves fast performance) or using jax-specific tools like vmap. I'd show an example based on your function, but it seems to be a bit over-simplified. Note also that if you call the jitted function a second time on similar input (so that compilation time is not included), execution will be very fast because XLA can optimize away these 16000 no-ops. |
Beta Was this translation helpful? Give feedback.
-
I just do not see how to use vmap in my function foo because I don't know Jax very well. Could you please help me in that case ? Thank you. |
Beta Was this translation helpful? Give feedback.
-
There's no way to use vmap in your function def foo(a):
return 0 I assume that |
Beta Was this translation helpful? Give feedback.
-
Is the benchmark example trying to measure dispatch time? Might be useful to figure out what you are trying to measure here? |
Beta Was this translation helpful? Give feedback.
-
Actually, I'd like to improve performance for the following functions : import jax.numpy as jnp
def f1(x):
res=0
for e in x:
res+=(3.5*e)**3;
return res
def f2(x):
res=jnp.zeros(sorties)
for i in range(len(res)):
res.at[i].set((x*i)**3)
return res |
Beta Was this translation helpful? Give feedback.
-
Thank you for answering me. I have another question : I have a Jax Tracer Object like that in a function : (Traced<ConcreteArray([1. 2. 3. 4. 5.])>with<JVPTrace(level=2/0)> |
Beta Was this translation helpful? Give feedback.
-
Hello,
For my internship, I need to compare performances between Jax and Autograd.
But when I try to run that code (for example), I get worse performances with Jax than with Autograd :
The outputs I got are :
I don't know where this issue can come from, I work on Visual Studio Code and run my program on a GPU servor working on Ubuntu. The version of CUDA is the following :
I also tried the code on jupyther notebook and got the same results.
Thanks in advance for your help.
Beta Was this translation helpful? Give feedback.
All reactions