linen.vmap gives error "TypeError: list indices must be integers or slices, not BatchTracer" #1148
Answered
by
marcvanzee
marcvanzee
asked this question in
Q&A
-
The input is a tensor of indexes. |
Beta Was this translation helpful? Give feedback.
Answered by
marcvanzee
Mar 17, 2021
Replies: 1 comment
-
Answer by @levskaya: You're trying to index into a python list with a tracer. You should make sure your vmapped input is in a def foo1(x):
return 1*x
def foo2(x):
return 2*x
@jax.vmap
def bar(x, idx):
y = jnp.array([foo1(x), foo2(x)])
return y[idx]
x = jnp.array([0.,1.,2.])
idx = jnp.array([0,1,0])
bar(x, idx) |
Beta Was this translation helpful? Give feedback.
0 replies
Answer selected by
marcvanzee
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Answer by @levskaya:
You're trying to index into a python list with a tracer. You should make sure your vmapped input is in a
jnp.array(...)
: