-
Hi all, I'm wondering if import jax
x = jax.device_put(1, jax.devices()[0])
y = jax.device_put(1, jax.local_devices(backend='cpu')[0])
x + y # ValueError: primitive arguments must be colocated on the same device, got gpu:0, cpu:0 However, a likewise case in pmap confused me: (there are 4 gpu devices) import jax
from jax import pmap, partial
import jax.numpy as np
cpu_device = jax.local_devices(backend='cpu')[0]
@partial(pmap, axis_name="i", in_axes=(0, None))
def mymul(v, W):
v_cpu = jax.device_put(v, cpu_device)
return W.dot(v_cpu)
v = np.ones((4, 3))
W = np.ones((4, 3))
mymul(v, W) There is no error reported! So I doubt if |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Well, I'll explain it myself. The translation rule of |
Beta Was this translation helpful? Give feedback.
Well, I'll explain it myself. The translation rule of
device_put
is actuallylambda c, x, device=None : x
injax/interpreters/xla.py#L1342
, so oncejit
orpmap
(which also appliesjit
) is applied to a function, thedevice_put
inside does not work.