Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow Minibatch of derived RVs and deprecate generators as data #7480

Merged
merged 4 commits into from
Sep 7, 2024

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Aug 27, 2024

Description

This PR fixes issues related to Minibatch indexing reported in https://discourse.pymc.io/t/warning-using-minibatch-and-censored-together-rng-variable-has-shared-clients/14943 and extends the MinibatchRV functionality for derived RVs.

Minibatch value variables are uniquely tricky because they are random graphs, that can share RNG with other variables in the forward / logp graph. As such we need to make sure they are not mutated for the default updates to work. We tried some tricks in the past but as revealed in the discourse issue that was not enough. This PR solves the problem by encapsulating the random graph in an OpFromGraph so that the inner graph will not be touched by PyMC logprob derivation routines. It will still be inlined in the final compiled functions to avoid overhead.

I also decided to deprecate Generators as data, which showed up in some of the refactoring. The GeneratorOp is not a true Op, which should not have any side-effects. It is also not compatible with non default backends like Numba and JAX that we are moving towards to. If needed, the logic should be handled by the sampler by consuming the generator and setting the values before subsequent function calls.

@ricardoV94 ricardoV94 added enhancements major Include in major changes release notes section labels Aug 27, 2024
mA, mB = pm.Minibatch(A, B, batch_size=10)

[draw_mA, draw_mB] = pm.draw([mA, mB])
assert draw_mA.shape == (10,)
np.testing.assert_allclose(draw_mA, draw_mB)

# Check invalid dims
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was already checked in the test above

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

@ricardoV94 ricardoV94 force-pushed the minibatch_censored branch 2 times, most recently from 479c4a4 to 290a643 Compare August 29, 2024 14:41
Copy link

codecov bot commented Aug 29, 2024

Codecov Report

Attention: Patch coverage is 97.91667% with 1 line in your changes missing coverage. Please review.

Project coverage is 92.15%. Comparing base (c92a9a9) to head (49542b5).
Report is 12 commits behind head on main.

Files with missing lines Patch % Lines
pymc/data.py 96.00% 1 Missing ⚠️
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #7480      +/-   ##
==========================================
- Coverage   92.16%   92.15%   -0.02%     
==========================================
  Files         103      103              
  Lines       17214    17224      +10     
==========================================
+ Hits        15866    15873       +7     
- Misses       1348     1351       +3     
Files with missing lines Coverage Δ
pymc/logprob/basic.py 94.36% <100.00%> (ø)
pymc/logprob/rewriting.py 89.75% <100.00%> (ø)
pymc/model/core.py 91.75% <100.00%> (-0.03%) ⬇️
pymc/pytensorf.py 90.62% <100.00%> (+0.11%) ⬆️
pymc/variational/minibatch_rv.py 100.00% <100.00%> (ø)
pymc/variational/opvi.py 87.42% <100.00%> (ø)
pymc/data.py 89.09% <96.00%> (-0.36%) ⬇️

... and 2 files with indirect coverage changes

mb_tensors = [tensor[mb_indices] for tensor in tensors]

# Wrap graph in OFG so it's easily identifiable and not rewritten accidentally
*mb_tensors, _ = MinibatchOp([*tensors, rng], [*mb_tensors, rng_update])(*tensors, rng)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice trick, did not know that

@@ -666,6 +672,9 @@ class GeneratorOp(Op):
__props__ = ("generator",)

def __init__(self, gen, default=None):
warnings.warn(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

mA, mB = pm.Minibatch(A, B, batch_size=10)

[draw_mA, draw_mB] = pm.draw([mA, mB])
assert draw_mA.shape == (10,)
np.testing.assert_allclose(draw_mA, draw_mB)

# Check invalid dims
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

@@ -162,22 +159,7 @@ def fit_kwargs(inference, use_minibatch):


def test_fit_oo(inference, fit_kwargs, simple_model_data):
# Minibatch data can't be extracted into the `observed_data` group in the final InferenceData
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no more issues there?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope I allow to extract the data in the idata, it extracts the whole data

minibatch_idx = minibatch_index(0, 10, size=(9,))
AD_mt = AD[minibatch_idx]
TD_mt = TD[minibatch_idx]
AD_mt, TD_mt = Minibatch(AD, TD, batch_size=9)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thisis much cleaner

@ferrine
Copy link
Member

ferrine commented Sep 7, 2024

I've created an issue to continue this work later and improve scalability of minibatches #7496

@ricardoV94 ricardoV94 merged commit 2856062 into pymc-devs:main Sep 7, 2024
22 checks passed
@ricardoV94 ricardoV94 deleted the minibatch_censored branch September 7, 2024 17:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancements major Include in major changes release notes section
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants