Custom JAX type? #6231
-
This might be a stupid question but I was wondering how difficult it would be to write a custom JAX type. I am trying to make my code a little bit more readable by creating something that behaves like a matrix but is slightly different: I have a data structure that is almost like a matrix but allows for non-integer indexing but is backed by a regular matrix. The easiest example would be a data structure that linearly interpolates between values when getitem is called with a float, i.e. something like: class myType(jax.interpreters.xla.DeviceArray):
def __getitem__(self, idx):
return (super(self.__class__, self).__getitem__(jnp.floor(idx)) + super(self.__class__, self).__getitem__(jnp.ceil(idx)))/2. What I was trying to do was to override the Except for the I probably should not be doing this, right? How bad would it be if I did it anyways and how difficult would it to actually achieve this? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
I've done something similar and asked for advice here. You'll need to register a number of handlers (see mattj's reply #4269 (comment)). I also found this example very helpful: https://github.com/google/jax/blob/63c06ef77e84bb5b3582fe23b17d8dfd2f5ecd0c/tests/custom_object_test.py. Unfortunately, I didn't find a good way to make my custom array class play nice with |
Beta Was this translation helpful? Give feedback.
I've done something similar and asked for advice here. You'll need to register a number of handlers (see mattj's reply #4269 (comment)).
I also found this example very helpful: https://github.com/google/jax/blob/63c06ef77e84bb5b3582fe23b17d8dfd2f5ecd0c/tests/custom_object_test.py.
Unfortunately, I didn't find a good way to make my custom array class play nice with
jax.numpy
in the way I'd like, as__array_wrap__
isn't implemented. See also: #4725