Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert Nx.LinAlg.lu to optional callback #1388

Open
polvalente opened this issue Dec 3, 2023 · 6 comments
Open

Convert Nx.LinAlg.lu to optional callback #1388

polvalente opened this issue Dec 3, 2023 · 6 comments
Assignees
Labels
area:exla Applies to EXLA area:nx Applies to nx kind:feature New feature or request

Comments

@polvalente
Copy link
Contributor

XLA doesn't provide an implementation, and we use LU decomposition for determinant, which is used in some other LinAlg functions.
Therefore, we should either look into JAX's implementation or just translate the BinaryBackend one to a defn, like we did with SVD.

This will also enable LU on MLIR by default.

@polvalente polvalente added kind:feature New feature or request area:exla Applies to EXLA area:nx Applies to nx labels Dec 3, 2023
@josevalim
Copy link
Collaborator

We can start with the translation, I think that's the simplest, and then explore other routes if necessary!

@polvalente
Copy link
Contributor Author

Agreed!

@polvalente polvalente self-assigned this Dec 3, 2023
@polvalente polvalente changed the title Reimplement Nx.LinAlg.lu as an optional callback Convert Nx.LinAlg.lu to optional callback Dec 6, 2023
@jyc
Copy link

jyc commented Oct 26, 2024

Sorry if this is necroing a thread, but I just ran into this while trying to compute determinants:

** (ArgumentError) XLA does not currently support the LU operation

Do you have any advice off-hand? If not no worries, it doesn't block me because I can try the Torchx backend or in the worst case try porting over the JAX implementation.

@polvalente
Copy link
Contributor Author

Sorry if this is necroing a thread, but I just ran into this while trying to compute determinants:

** (ArgumentError) XLA does not currently support the LU operation

Do you have any advice off-hand? If not no worries, it doesn't block me because I can try the Torchx backend or in the worst case try porting over the JAX implementation.

I currently don't have the bandwidth to do a full-fledged implementation of LU on defn.
You might be able to take inspiration from #1510 as well as the eigh implementation and this special function:

defn vector_dot_slice(u, u_start, v, v_start) do
  {n} = Nx.shape(u)
  u = Nx.select(Nx.iota({n}) >= u_start, u, 0)
  {n} = Nx.shape(v)
  v = Nx.select(Nx.iota({n}) >= v_start, v, 0)
  Nx.dot(u, v)
end

I tried porting LU once, and it ended up requiring something like this due to it needing to do something like dot(a[[.., i..-1//1]], a[[i..-1//1, ..]]) where i is a scalar tensor. In Nx this wouldn't necessarily be possible. But if you realize that slicing and taking the dot product has the same effect as turning the removed entries to 0, then you can use the implementation above or something similar :)

All of that being said, I believe we can add a custom call to Eigen like we have for Nx.LinAlg.qr and Nx.LinAlg.eigh that will at least let LU be available on CPU. Would you be open to send a PR on either?

@jyc
Copy link

jyc commented Oct 28, 2024

Thanks so much for the detailed reply, I really appreciate it! For now, I am trying to see if computing the determinant using QR is sufficient for my application, even though it will be slower. If that is insufficient I will certainly write back and try to do something based on your excellent notes. Thank you again!

@polvalente
Copy link
Contributor Author

polvalente commented Oct 29, 2024

@jyc I ended up using this as an excuse to try the Cursor AI editor out 😅
Please give this branch a try: pv-feat/custom-callback-lu

The branch adds at least the CPU implementation as a paliative solution.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
area:exla Applies to EXLA area:nx Applies to nx kind:feature New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants