-
This is the same question as #11240 but I would like to ask it again since the accepted answer to that question is now deprecated. Is there a non-experimental way to run some portion of the jitted function on a CPU and the rest on a GPU? More specifically, I have a decorator to run some function on CPU while the global device is GPU, like this def execute_on_cpu(func):
"""Decorator to set default device to CPU for a function.
Parameters
----------
func : callable
Function to decorate
Returns
-------
wrapper : callable
Decorated function that will run always on CPU even if
there are available GPUs.
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
with jax.default_device(jax.devices("cpu")[0]):
return func(*args, **kwargs)
return wrapper However, if I call the decorated function inside JIT, it throws an error or doesn't run. Basically, I would like to do this, @execute_on_cpu
def fun_cpu(*args):
# ... some stuff that is fast on CPU ...
return something
@jax.jit
def main(*args):
# ... do something
x = fun_cpu(*args0)
# ... do more ....
return something else |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments
-
Have you tried using pure_callback? https://jax.readthedocs.io/en/latest/_autosummary/jax.pure_callback.html |
Beta Was this translation helpful? Give feedback.
-
I guess it is exactly what I want! Thank you! For future reference, an example usage could be, import jax
from jax.lax import fori_loop
@execute_on_cpu
@jax.jit
def loop_for_cpu(x):
def add_for(i,x):
return x.at[i].add(5)
x = fori_loop(0, x.size, add_for, x)
return x
@jax.jit
def loop_for_gpu(x):
def add_for(i,x):
return x.at[i].add(5)
x = fori_loop(0, x.size, add_for, x)
return x
@jax.jit
def loop_for_pure_callback(x):
return jax.pure_callback(loop_for_cpu, x, x)
a = jnp.ones(100000)
# compile the function once
x = loop_for_cpu(a)
y = loop_for_gpu(a)
z = loop_for_pure_callback(a)
assert np.allclose(x, y)
assert np.allclose(x, z)
%timeit _ = loop_for_cpu(a).block_until_ready()
%timeit _ = loop_for_gpu(a).block_until_ready()
%timeit _ = loop_for_pure_callback(a).block_until_ready()
|
Beta Was this translation helpful? Give feedback.
Have you tried using pure_callback? https://jax.readthedocs.io/en/latest/_autosummary/jax.pure_callback.html