-
Hi, I'm new to using JAX so I'm not sure if what I'm trying to do is at all possible. Essentially, I'm trying to recurse over a variable until it becomes one of the values in the supplied array (in the code below, the supplied array is jnp.array([1,2]))). The error I'm getting is Exception has occurred: TypeError: JAX DeviceArray, like numpy.ndarray, is not hashable. I think the issue is due to trying to use a DeviceArray as a condition in a while_loop, but I've tried walking through the call stack and can't seem to locate the issue exactly. A minimal example is as below: from jax import random,lax
import jax.numpy as jnp
def body_fun(val,x,s):
tmp = 2 * (val + 1)
return jnp.where(x <= s, tmp - 1, tmp)
def fun(x, s):
val = 0
val = lax.while_loop(jnp.isin(val,jnp.array([1,2])), body_fun(val,x,s), val)
return val
fun(1,4) Is there any way possible to achieve this? Thanks in advance! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
The first and second arguments to the def fun(x, s):
val = 0
val = lax.while_loop(lambda val: jnp.isin(val,jnp.array([1,2])), lambda val: body_fun(val, x, s), val)
return val |
Beta Was this translation helpful? Give feedback.
The first and second arguments to the
lax.while_loop
should be callable functions rather than functions that you've already evaluated. If you write it like this, your function executes correctly: