Skip to content

Passing Lambda functions into jit very slow #10868

Answered by jakevdp
Justin-Tan asked this question in General
Discussion options

You must be logged in to vote

Yes, this is known. It's not an issue with anonymous functions per se, it's an issue with the fact that you're newly re-defining the function in each iteration of the %timeit. The JIT cache is based on the function's object ID; when you create a lambda function in this manner, you're creating a new function, and so you incur the JIT compilation cost is paid every time. You'll get better results if you define your lambda function once and then re-use it in the %timeit expression:

key, rng = random.split(key, 2)
lambda_x = lambda x: jnp.exp(-x.sum())
%timeit -o jit_integrate_toy(key, lambda_x, 2*dim, n_samples)

We've discussed in the past the possibility of somehow cacheing based on the fun…

Replies: 2 comments

Comment options

You must be logged in to vote
0 replies
Answer selected by Justin-Tan
Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
3 participants