Skip to content

jittable pytree flattening/unflattening #13473

Answered by PhilipVinc
Goodbrake asked this question in Ideas
Discussion options

You must be logged in to vote

I have exactly the same use-case as you have in a package I develop.
I ended up using a slightly different version of the ravel function, see it here.

You can quickly check to see if it works for you by simply using nk.jax.tree_ravel instead of the jax-provided function.
If it works for you, you can extract it from there (note: if you are not using complex dtypes, then you can safely replace nkjax.vjp with jax.vjp)

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@Goodbrake
Comment options

Answer selected by Goodbrake
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Ideas
Labels
None yet
2 participants