Skip to content

Accurate way to time JIT-compiled functions #4432

Answered by jakevdp
leakec asked this question in Q&A
Discussion options

You must be logged in to vote

JAX uses asynchronous dispatch in the backend, so the Python program will continue executing while the JAX computation is running.

For this sort of benchmark, the best approach is to call the block_until_ready() function, which will cause the backend to block until all results are computed:

y = f(A).block_until_ready()

When you call print, it implicitly blocks until the computation is complete in order to display the result.

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@leakec
Comment options

Answer selected by leakec
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants