Passing Lambda functions into jit
very slow
#10868
-
Passing lambda functions into import jax.numpy as jnp
from jax import random, jit
def integrate_toy(rng, func, dim, num_pts):
a, b = -1., 1.
x = random.uniform(rng, shape=(num_pts, dim), minval=a, maxval=b)
vol = jnp.power(b - a, dim)
return jnp.mean(func(x) * vol)
jit_integrate_toy = jit(integrate_toy, static_argnums=(1,2,3))
n_samples, dim, key = 10, 1, random.PRNGKey(42)
%timeit -o integrate_toy(key, lambda x: jnp.exp(-x.sum()), 2*dim, n_samples)
# <TimeitResult : 123 µs ± 1.34 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)>
key, rng = random.split(key, 2)
%timeit -o jit_integrate_toy(key, lambda x: jnp.exp(-x.sum()), 2*dim, n_samples)
# <TimeitResult : 158 ms ± 2.72 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)>
def myfunc(x): return jnp.exp(-x.sum())
key, rng = random.split(key, 2)
%timeit -o integrate_toy(key, myfunc, 2*dim, n_samples)
# <TimeitResult : 124 µs ± 1.7 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)>
key, rng = random.split(key, 2)
%timeit -o jit_integrate_toy(key, myfunc, 2*dim, n_samples)
# <TimeitResult : 3.38 µs ± 15.5 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)> Looks like using a lambda function as an argument instead of a regular function makes a major negative difference for the |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments
-
Yes, this is known. It's not an issue with anonymous functions per se, it's an issue with the fact that you're newly re-defining the function in each iteration of the key, rng = random.split(key, 2)
lambda_x = lambda x: jnp.exp(-x.sum())
%timeit -o jit_integrate_toy(key, lambda_x, 2*dim, n_samples) We've discussed in the past the possibility of somehow cacheing based on the function's bytecode rather than object ID to make this sort of usage faster, but we haven't found a robust and performant solution. So for now the recommendation is that if you're calling a function multiple times, be sure to define it only once. |
Beta Was this translation helpful? Give feedback.
-
I've been using this custom Partial implementation (which also allows to store hashed arguments inside of the lambda). I've been using this in production code for over 1 year now, and have seen 0 issues, but wouldn't be surprised if some edge cases might break this. class HashablePartial(partial):
"""
A class behaving like functools.partial, but that retains it's hash
if it's created with a lexically equivalent (the same) function and
with the same partially applied arguments and keywords.
It also stores the computed hash for faster hashing.
"""
# TODO remove when dropping support for Python < 3.10
def __new__(cls, func, *args, **keywords):
# In Python 3.10+ if func is itself a functools.partial instance,
# functools.partial.__new__ would merge the arguments of this HashablePartial
# instance with the arguments of the func
# Pre 3.10 this does not happen, so here we emulate this behaviour recursively
# This is necessary since functools.partial objects do not have a __code__
# property which we use for the hash
# For python 3.10+ we still need to take care of merging with another HashablePartial
while isinstance(
func, partial if sys.version_info < (3, 10) else HashablePartial
):
original_func = func
func = original_func.func
args = original_func.args + args
keywords = {**original_func.keywords, **keywords}
return super(HashablePartial, cls).__new__(cls, func, *args, **keywords)
def __init__(self, *args, **kwargs):
self._hash = None
def __eq__(self, other):
return (
type(other) is HashablePartial
and self.func.__code__ == other.func.__code__
and self.args == other.args
and self.keywords == other.keywords
)
def __hash__(self):
if self._hash is None:
self._hash = hash(
(self.func.__code__, self.args, frozenset(self.keywords.items()))
)
return self._hash
def __repr__(self):
return f"<hashable partial {self.func.__name__} with args={self.args} and kwargs={self.keywords}, hash={hash(self)}>" |
Beta Was this translation helpful? Give feedback.
Yes, this is known. It's not an issue with anonymous functions per se, it's an issue with the fact that you're newly re-defining the function in each iteration of the
%timeit
. The JIT cache is based on the function's object ID; when you create a lambda function in this manner, you're creating a new function, and so you incur the JIT compilation cost is paid every time. You'll get better results if you define your lambda function once and then re-use it in the%timeit
expression:We've discussed in the past the possibility of somehow cacheing based on the fun…