Skip to content

Automatically apply JAX vmap given base dimensions of inputs

Notifications You must be signed in to change notification settings

mathisgerdes/jax-autovmap

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

10 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Tests PyPI

Automatically broadcast via JAX vmap

Automatically broadcast a function that takes inputs of specific ranks to accept any batch dimensions using jax.vmap. See the documentation of JAX for more information about vmap.

This module defines a decorator function which takes as input the base ranks of the arguments of a function. The transformed function takes any broadcastable combination of batched inputs and automatically applies jax.vmap as required. One could equivalently apply vmap by hand, however the in-axes have to be chosen differently for different inputs. The decorator takes care of this automatically; if the underlying ranks are known, batch dimensions can be inferred.

Example

Consider the following function which takes numeric arguments with fixed and known ranks as input:

import jax.numpy as jnp

def foo(s, v, m):  # s - scalar, v - vector, m - matrix
    return v @ m @ v + s * v.size

If we have inputs of appropriate rank, the function can be applied without problem:

s = jnp.array(2.0)    # scalar
v = jnp.ones(3)       # vector
m = jnp.ones((3, 3))  # matrix
print(foo(s, v, m))   # prints 15.0

Assume now, however, that we have 5 matrices and 5 vectors for which we want to apply the above function as a batch. Many numpy functions can take inputs with leading batch dimensions, but here we have an issue because v @ m @ v requires m to be a matching matrix.

s = jnp.array(2.0)
v = jnp.ones((5, 3))
m = jnp.ones((5, 3, 3))
print(foo(s, v, m))  # throws TypeError

There are multiple possible ways we can solve this

  • We could try to write the function more carefully, so it can take both single and batch inputs. Or we could always require the first dimension to be a batch index (like often done in convolutional neural networks).
  • Given known inputs, we can transform our function: jax.jit(foo, (None, 0, 0)). However, the axes change depending on the inputs. If we want to expose functions that accept batched inputs to the user, we need to have some clear naming scheme (to indicate which arguments are batched).

Sometimes, the best solution is one of the above. This module provides another more flexible solution. If the ranks the function wants are known, we can derive which arguments have (leading) batch dimensions. Based on that, we can apply jax.vmap appropriately. That is exactly what the auto_vmap wrapper does. Thanks to jax.jit, after the transformed function is JIT-compiled, there is no price to pay for this extra flexibility since it only depends on the statically-known input shapes.

We can define the more flexible function as follows:

from jax_autovmap import auto_vmap

@auto_vmap(s=0, v=1, m=2)
def foo(s, v, m):
    return v @ m @ v + s * v.size

print(foo(s, v, m))  # prints [15. 15. 15. 15. 15.]

The ranks can be specified by keyword argument as above, or positionally (in this case @auto_vmap(0, 1, 2)). This does not have to be applied to all input arguments. They can either be omitted or, if ranks are given positionally, specified as None.

If the arguments are pytrees (python structures of arrays) and the rank is a single integer, all constituents (leaves) are assumed to have that rank. Alternatively, the rank can be a matching pytree, just like the in_axes in jax.vmap:

@auto_vmap({'s'=0, 'v'=1, 'm'=2})
def foo(inputs):
    return inputs['v'] @ inputs['m'] @ inputs['v']

print(foo(dict(s=s, v=v, m=m)))

About

Automatically apply JAX vmap given base dimensions of inputs

Topics

Resources

Stars

Watchers

Forks

Languages