Is it possible to do two shard maps at the same time? #23015
Unanswered
jeffgortmaker
asked this question in
Q&A
Replies: 1 comment 7 replies
-
xref #21004, which is a similar question. Maybe @yashk2810 has some ideas here? |
Beta Was this translation helpful? Give feedback.
7 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I have data indexed by two dimensions, and I'm trying to shard computation along both of them at the same time. My understanding is that the current recommendation is to use shard maps, but I'm struggling with how to do use them along two dimensions.
Here's a minimal working example of what I'm trying to do. My goal in function
h
is to map a second functiong
over the first dimensiona
of a 2x2 array. Withing
, I attempt to map a third functionf
over the array's second dimensionb
. Each map appliesshard_map
tolax.map
:With the current JAX version of 0.4.31, this raises:
If I set
check_rep=False
in eachshard_map
, it instead raises:If I replace
h
withjax.lax.map(g, x) + 3
, I get a 2x2 array of sixes as expected. But computation is only sharded over the second axis.My guess is there's something terribly wrong with how I'm trying to do this. Any advice would be wonderful!
Beta Was this translation helpful? Give feedback.
All reactions