Skip to content

Issue using DeviceArray in while_loop #6199

Answered by jakevdp
jodie-c asked this question in Q&A
Discussion options

You must be logged in to vote

The first and second arguments to the lax.while_loop should be callable functions rather than functions that you've already evaluated. If you write it like this, your function executes correctly:

def fun(x, s):
    val = 0
    val = lax.while_loop(lambda val: jnp.isin(val,jnp.array([1,2])), lambda val: body_fun(val, x, s), val)
    return val

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@jodie-c
Comment options

Answer selected by jodie-c
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants