You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi all, I have a problem that is effectively trying to construct a sparse matrix and I'd like to see if there is a more flexible solution that is JIT-compatible. I'm aware of the typical requirements for having shapes known at compile time, but was curious if there was a way to handle dynamic shapes, but with known maximums more effectively than I am.
The general problem is something along the lines of the following:
# all four have shape (p,)mins: ArrayLikemaxs: ArrayLikevec: ArrayLikeindex_map: ArrayLikefinal=zeros((p, p))
forjinrange(p):
min_j, max_j=mins[j], maxs[j]
intermediate=f(vec[j], min_j, max_j) # shape (max_j - min_j + 1,)map_j=index_map[min_j: max_j+1]
final=final.at[j, map_j].add(intermediate)
Naively implementing the above with lax.fori_loop won't compile due to the dynamic shapes within the loop. As such, one workaround is to keep shape fixed to the original (p,) and zero-out the "non-used" parts for that iteration:
However in my case p >> maximum(maxs - mins), and this causes a dramatic slowdown. Rather than using the original shape and zero-ing out the values outside the iterations current bound, with a bit of work, we can extend each iterations bound to the maximum bound size, such that the shape of each iterations needs are fixed. This isn't memory optimal BUT, the maximum bound typically isn't too much larger than the smallest bound.
Unfortunately, this wont work, due to the dynamic lookup/shape, despite being "fixed" to (bound,).
Are there any other means to wrangle this into something that is JIT compatible? Would defining another internal function with static_argname for the bound work?
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Hi all, I have a problem that is effectively trying to construct a sparse matrix and I'd like to see if there is a more flexible solution that is JIT-compatible. I'm aware of the typical requirements for having shapes known at compile time, but was curious if there was a way to handle dynamic shapes, but with known maximums more effectively than I am.
The general problem is something along the lines of the following:
Naively implementing the above with
lax.fori_loop
won't compile due to the dynamic shapes within the loop. As such, one workaround is to keep shape fixed to the original(p,)
and zero-out the "non-used" parts for that iteration:However in my case
p >> maximum(maxs - mins)
, and this causes a dramatic slowdown. Rather than using the original shape and zero-ing out the values outside the iterations current bound, with a bit of work, we can extend each iterations bound to the maximum bound size, such that the shape of each iterations needs are fixed. This isn't memory optimal BUT, the maximum bound typically isn't too much larger than the smallest bound.Unfortunately, this wont work, due to the dynamic lookup/shape, despite being "fixed" to
(bound,)
.Are there any other means to wrangle this into something that is JIT compatible? Would defining another internal function with static_argname for the bound work?
Beta Was this translation helpful? Give feedback.
All reactions