Skip to content

linen.vmap gives error "TypeError: list indices must be integers or slices, not BatchTracer" #1148

Answered by marcvanzee
marcvanzee asked this question in Q&A
Discussion options

You must be logged in to vote

Answer by @levskaya:

You're trying to index into a python list with a tracer. You should make sure your vmapped input is in a jnp.array(...):

def foo1(x):
  return 1*x

def foo2(x):
  return 2*x

@jax.vmap
def bar(x, idx):
  y = jnp.array([foo1(x), foo2(x)])
  return y[idx]

x = jnp.array([0.,1.,2.])
idx = jnp.array([0,1,0])
bar(x, idx)

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by marcvanzee
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
1 participant