How to implement an JIT compatible groupby? #4849
-
I'd like to implement a generic group-by then aggregate in JAX which has similar semantics to the function below: def groupby_agg(keys, vals, agg_fn):
return jnp.stack([agg_fn(vals[keys==i]) for i in jnp.unique(keys)])
keys = jnp.array([0, 1, 0, 2, 1])
vals = jnp.array([1.1, 1.2, 2.1, 2.2, 3.1])
groupby_agg(keys, vals, agg_fn=jnp.mean)
# DeviceArray([1.5999999, 2.15 , 2.2 ], dtype=float32) JAX doesn't support dynamic sizes or advanced boolean indexing under JIT, so I thought something like this might work: def groupby_agg2(keys, vals, unique_keys, agg_fn):
def helper(i):
arr = jnp.where(keys == i, vals, jnp.inf).sort()
return agg_fn(jax.lax.dynamic_slice(arr, [0], [arr.searchsorted(jnp.inf)]))
return jax.vmap(helper)(unique_keys)
groupby_agg2(keys, vals, jnp.unique(keys), agg_fn=jnp.mean) But def groupby_mean(keys, vals, unique_keys):
def helper(i):
mask = keys == i
return jnp.where(mask, vals, 0).sum() / mask.sum()
return jax.vmap(helper)(unique_keys)
groupby_mean(keys, vals, jnp.unique(keys))
# DeviceArray([1.5999999, 2.15 , 2.2 ], dtype=float32) However, it's harder for more complicated aggregation functions since finding a "filler" value for Are there any dynamic masking functions available in JAX? If not, is there any other way to accomplish this? Thanks in advance! |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 3 replies
-
For a summing aggregator you can use |
Beta Was this translation helpful? Give feedback.
-
I just noticed that some search results redirected to this issue. Here is another implementation for a simple @jax.jit
def ctc_greedy_decoder(logits, blank=0):
logits = jnp.argmax(logits, axis=-1)
changes = jnp.concatenate([jnp.array([True]), logits[1:] != logits[:-1]])
decoded = jnp.where(changes & (logits != blank), logits, -1)
return decoded |
Beta Was this translation helpful? Give feedback.
For a summing aggregator you can use
bin_count
and theweights
argument. You do need to pass what the maximum number of unique keys is, and to encode the keys as ints. Does that do it?