Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add JAX implementation fol MatrixIsPositiveDefinite Op #6853

Merged
merged 5 commits into from
Aug 16, 2023

Conversation

juanitorduz
Copy link
Contributor

@juanitorduz juanitorduz commented Aug 11, 2023

pymc/sampling/jax.py Outdated Show resolved Hide resolved
@juanitorduz juanitorduz marked this pull request as draft August 11, 2023 13:10
@codecov
Copy link

codecov bot commented Aug 11, 2023

Codecov Report

Merging #6853 (5add5de) into main (9956991) will decrease coverage by 1.42%.
The diff coverage is 100.00%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #6853      +/-   ##
==========================================
- Coverage   92.03%   90.62%   -1.42%     
==========================================
  Files          96       96              
  Lines       16398    16404       +6     
==========================================
- Hits        15092    14866     -226     
- Misses       1306     1538     +232     
Files Changed Coverage Δ
pymc/distributions/multivariate.py 92.22% <100.00%> (ø)
pymc/sampling/jax.py 98.30% <100.00%> (+0.04%) ⬆️

... and 13 files with indirect coverage changes

@juanitorduz juanitorduz marked this pull request as ready for review August 11, 2023 19:10
pymc/sampling/jax.py Outdated Show resolved Hide resolved
@ricardoV94
Copy link
Member

We have to check that HSGP failing test

@juanitorduz
Copy link
Contributor Author

We have to check that HSGP failing test

The tests pass locally 🤔

@juanitorduz
Copy link
Contributor Author

Some random (unrelated) tests failed 🤷

@ricardoV94
Copy link
Member

I am not sure is random, doesn't it use the Op we modified?

@juanitorduz
Copy link
Contributor Author

Well, before the last commit where I simply changed the test parameterization all the tests passed. Can you maybe rerun the failed jobs? I can double check from my side.

@ricardoV94
Copy link
Member

I don't know if the test is deterministic, so passing once isn't a conclusive thing. I can have a look on Monday

@ricardoV94
Copy link
Member

ricardoV94 commented Aug 16, 2023

There was a bug in the MatrixNormal test, it was supposed to try a different draw in the k2_samp test when it failed (max 10 times), but it was reusing the same draws everytime.

@ricardoV94
Copy link
Member

I fixed the test and cleaned the PR git history so we can merge without squash. Let me know if it looks correct (and if you need chages be carefull with fetching the remote branch)

@ricardoV94 ricardoV94 changed the title Add JAX OP MatrixIsPositiveDefinite Add JAX implementation fol MatrixIsPositiveDefinite Op Aug 16, 2023
@ricardoV94 ricardoV94 changed the title Add JAX implementation fol MatrixIsPositiveDefinite Op Add JAX implementation fol MatrixIsPositiveDefinite Op Aug 16, 2023
@juanitorduz
Copy link
Contributor Author

Great, thank you! Looks great!

@@ -178,7 +178,7 @@ def test_prior(self, model, cov_func, X1, parameterization):
gp = pm.gp.Latent(cov_func=cov_func)
f2 = gp.prior("f2", X=X1)

idata = pm.sample_prior_predictive(samples=1000)
idata = pm.sample_prior_predictive(samples=1000, random_seed=rng)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🙏

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One day we will get this 5 line PR merged. Just you wait :D

@ricardoV94 ricardoV94 merged commit 13a2310 into pymc-devs:main Aug 16, 2023
20 of 21 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

ENH: JAX OP MatrixIsPositiveDefinite
2 participants