Skip to content

How to vmap? - additional and different dimensional arrays #5389

Answered by harsh306
harsh306 asked this question in Q&A
Discussion options

You must be logged in to vote

I think I misunderstood the use of VMAP. Few things I learned:

  • Output of the function to apply grad should output a scalar value. Not a vector, even of shape (1, ).
  • You can then pass all your parameters as a list to the objective function of your choice, (yes of arbitrary shapes to compute grad).
  • Even you can provide multiple lists of parameters or jax arrays. (For example following works)
  • Advantage of sending multiple lists is you can choose which args to compute gradients on.
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
import jax
import functools
p = jnp.array([1.0, 2.0, 3.0])
b = jnp.array([1.0])
params = [p, b]

def func(p, b, inputs, outputs):
   …

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by harsh306
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
1 participant