Replies: 1 comment 1 reply
-
I did play around with this a bit at one point, but I'm not aware of any worked examples. I'd recommend taking a look at the new foreign function interface. I expect the easiest approach would be to write a C++ shim that uses the FFI to dispatch to Rust, rather than trying to reimplement the FFI in Rust directly. Using variadic arguments and results you could probably write a pretty general purpose interface. Hope this helps! |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi everyone!
I really enjoy JAX, and I also really appreciate programming in Rust, and I would love to be able to use Rust code to implement some performance-critical functions and export them to Python.
Currently, I do it using PyO3 and
rust-numpy
, but I lose all the benefits from XLA (mainly automatic differentiation for me), so I eventually have to first convert NumPy arrays to JAX arrays in Python.I am aware of Extending JAX with custom C++ and CUDA code and Custom operations for GPUs with C++ and CUDA, which are two great tutorials, but I feel there should be something easier when it comes to writing a JAX extension with Rust, especially when tools and libraries like PyO3, maturin,
rust-numpy
, andxla-rs
exist.Have some of you already tried implementing a JAX (or XLA) extension with Rust?
I was considering giving it a try myself, but I am of course interested to see if other people have tried before me.
Beta Was this translation helpful? Give feedback.
All reactions