-
Hey guys I love both flax and Jax I find it to be kinda a fresh air in deep learning env and want to congrats you all for a job well done. I have one major question which I can't find in both flax and jax docs How to compile dynamically sized inputs faster? As I can see in JAX documentation and from my experiments XLA is compiling function on array shape level and in my case, i can't just give it more abstract tensor as output array size is also dependent on input size. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
Yes you're right, this is a fundamental design property of XLA, which assumes completely known shapes (and uses that information for compile-time optimizations). The typical pattern we use is dynamic bucketing based where sequences within certain length ranges are batched together. For example here: https://github.com/google/flax/blob/master/examples/wmt/input_pipeline.py#L204 (in this case we're taking advantage of We'd love to find long term flexible solutions, but they also need to be fast, so it's not clear one can implement this in pure Python and support very large models. |
Beta Was this translation helpful? Give feedback.
Yes you're right, this is a fundamental design property of XLA, which assumes completely known shapes (and uses that information for compile-time optimizations).
The typical pattern we use is dynamic bucketing based where sequences within certain length ranges are batched together. For example here: https://github.com/google/flax/blob/master/examples/wmt/input_pipeline.py#L204 (in this case we're taking advantage of
tf.data.experimental.bucket_by_sequence_length
, which is not ideal.We'd love to find long term flexible solutions, but they also need to be fast, so it's not clear one can implement this in pure Python and support very large models.