Skip to content

Custom JAX type? #6231

Answered by lukepfister
helange23 asked this question in Q&A
Mar 25, 2021 · 1 comments · 1 reply
Discussion options

You must be logged in to vote

I've done something similar and asked for advice here. You'll need to register a number of handlers (see mattj's reply #4269 (comment)).

I also found this example very helpful: https://github.com/google/jax/blob/63c06ef77e84bb5b3582fe23b17d8dfd2f5ecd0c/tests/custom_object_test.py.

Unfortunately, I didn't find a good way to make my custom array class play nice with jax.numpy in the way I'd like, as __array_wrap__ isn't implemented. See also: #4725

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@helange23
Comment options

Answer selected by helange23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants