You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm trying to use JAX to sum over a large array (approximately 1 trillion elements long, which I can't fully materialize) and I'm not entirely sure how to utilize jax.sharding.shard_map and jax.vmap to efficiently compute this sum as I'm quite new to using JAX.
My ideal solution would be to iteratively; create a subset of this very long array, compute all terms in the subset via jax.vmap and
distribute all subsets over multiple GPUs (or at least have 8 batches being computed over 8 GPUs at the same time).
I've attached a minimal reproducible example to better explain what I'm planning to do below, and any help would be greatly appreciated,
import jax
from jax import numpy as jnp
jax.config.update('jax_enable_x64',True)
from functools import partial
from tqdm import tqdm
N = 20 # number of samples is 2^N (ideally N=40, but N=20 is quite quick to test)
nchunks = 100_000 # 'chunk' the exp. large array into 'managable' subsets
indices = jnp.linspace(0, 2**N, nchunks + 1, dtype=jnp.int64) # Create a very long array (at the delimiter of each subset)
bounds = jnp.stack(arrays=(indices[:-1], indices[1:]), axis=-1) # defines start/stop of each chunk of `indices`
answer = jnp.zeros(shape=(), dtype=jnp.float64) # temp variable to save intermediate results
@partial(jax.jit, static_argnums=(1))
def single_term(index: jnp.array, N: int) -> jnp.array:
return 0.5 * index**2 + N # NOTE: simple example func
@partial(jax.jit, static_argnums=(1))
def sum_over_subset(indices: jnp.array, N: int) -> jnp.array:
outputs = jax.vmap(single_term, in_axes=(0,None))(indices, N) # vectorize over subset (and sum)
return jnp.sum(outputs)
for (min_idx, max_idx) in tqdm(desc='A very long loop: ',iterable=bounds):
indices = jnp.arange(start=min_idx,stop=max_idx,dtype=jnp.int64)[:,None]
_tmp = sum_over_subset(indices, N) # sums over subset, returns 3 scalars
answer = answer + _tmp #NOTE: Add to 'global' sum
print('Answer: ',answer) # returns: 1.9215330924432746e+17
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Hi All,
I'm trying to use JAX to sum over a large array (approximately 1 trillion elements long, which I can't fully materialize) and I'm not entirely sure how to utilize
jax.sharding.shard_map
andjax.vmap
to efficiently compute this sum as I'm quite new to using JAX.My ideal solution would be to iteratively; create a subset of this very long array, compute all terms in the subset via
jax.vmap
anddistribute all subsets over multiple GPUs (or at least have 8 batches being computed over 8 GPUs at the same time).
I did read through https://jax.readthedocs.io/en/latest/sharded-computation.html , but I believe that requires fully materializing the input data if I'm not mistaken and as my data comes
jax.numpy.arange
it leads to a jax.errors.ConcretizationTypeError when I tried tojax.jit
my function.I've attached a minimal reproducible example to better explain what I'm planning to do below, and any help would be greatly appreciated,
Beta Was this translation helpful? Give feedback.
All reactions