Skip to content

How to code "if" statement when using "vmap" #4951

Answered by jeremiecoullon
shixinxing asked this question in Q&A
Discussion options

You must be logged in to vote

Check out the control flow section of "The Sharp Bits" in the documentation; it discusses this!
Issue #196 also explains this (in @mattjj's main reply).

If you want to use if statements in Jax, use lax.cond or lax.switch.
Hope that helps!

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@mattjj
Comment options

Answer selected by shixinxing
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants