Recommended way to use functions producing tuples with pytrees #5881
-
What is the recommended way to apply a function that produces a tuple of outputs to a pytree? For example an Adam update might take three pytrees having parameters and first/second moment estimates and produce updated values for those three quantities. If you uses Edit: Below is a minimal example of what I'd like to do. Given a function that updates two pytrees in a way dependent on each other, the best way I could figure out how to implement it is to break the update function into two parts. This requires more code, and I think may also be worse memory-wise since you have references to all four pytrees sitting around at some point (unless this just goes away when you JIT everything. Is there a way to accomplish this without needing to modify the update function? import jax
import jax.numpy as jnp
a_pytree = (jnp.zeros(3), [jnp.ones(4), {'foo': jnp.zeros(5)}])
b_pytree = jax.tree_map(jnp.ones_like, a_pytree)
def update_two_values(a, b):
return a + b, a * b
# Would like to do this:
# new_a, new_b = jax.tree_multimap(update_two_values, a_pytree, b_pytree)
# Instead have to break up func into two calls
def update_a(a, b):
return a + b
def update_b(a, b):
return a * b
new_a = jax.tree_multimap(update_a, a_pytree, b_pytree)
new_b = jax.tree_multimap(update_b, a_pytree, b_pytree)
a_pytree = new_a
b_pytree = new_b |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 5 replies
-
Thanks for the question! Could you edit the question to add a short snippet of code demonstrating what you have in mind? |
Beta Was this translation helpful? Give feedback.
Thanks for the question! Could you edit the question to add a short snippet of code demonstrating what you have in mind?