-
Notifications
You must be signed in to change notification settings - Fork 36
Remove ThreadSafeVarInfo
#1023
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: breaking
Are you sure you want to change the base?
Remove ThreadSafeVarInfo
#1023
Conversation
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## breaking #1023 +/- ##
============================================
- Coverage 81.20% 80.91% -0.29%
============================================
Files 39 39
Lines 3910 3810 -100
============================================
- Hits 3175 3083 -92
+ Misses 735 727 -8 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Benchmark Report for Commit 700ed07Computer Information
Benchmark Results
|
Pull Request Test Coverage Report for Build 17024844925Details
💛 - Coveralls |
f797af7
to
79bedaf
Compare
DynamicPPL.jl documentation for PR #1023 is available at: |
As a bonus, this PR completely fixes all Enzyme issues arising from DPPL 0.37. #947 |
79bedaf
to
15d662c
Compare
15d662c
to
d017e4b
Compare
f4d4fbf
to
1f7fef3
Compare
Unfortuntately I don't know how to deal with conditioned/fixed variables without a huge amount of faff and macro code duplication 😮💨 |
This requires a bit more discussion before we make a commitment -- not entirely sure we should introduce a new macro. |
yeah, I remember the discussion we had a few meetings ago |
Summary
This PR removes
ThreadSafeVarInfo
.In its place, a
@pobserve
macro is added to enable multithreaded tilde-observe statements, according to the plan outlined in #924 (comment). Broadly speaking, the followingis converted into (modulo variable names)
No actual varinfo manipulation happens inside the
Threads.@spawn
: instead, the log-likelihood contributions are calculated in each thread, then summed after the individual threads have finished their tasks. Because of this, there is no need to maintain one log-likelihood accumulator per thread, and consequently no need forThreadSafeVarInfo
.Closes #429.
Closes #924.
Closes #947.
Why?
Code simplification in DynamicPPL, and reducing the number of
AbstractVarInfo
subtypes, is obviously a big argument.But in fact, that's not my main motivation. I'm mostly motivated to do this because TSVI in general is IMO not good code: it works, but in many ways it's a hack.
Threads.@threads for i in x ... end
, and then internally we useThreads.threadid()
to index into a vector of accumulators. This is now regarded as "incorrect parallel code that contains the possibility of race conditions which can give wrong results". See https://julialang.org/blog/2023/07/PSA-dont-use-threadid/ and https://discourse.julialang.org/t/behavior-of-threads-threads-for-loop/76042.Threads.nthreads() * 2
which is a hacky heuristic. The correct solution would beThreads.maxthreadid()
, but Mooncake couldn't differentiate through that.threadid
,nthreads
and evenmaxthreadid
[is] perilous. Any code that relies on a specificthreadid
staying constant, or on a constant number of threads during execution, is bound to be incorrect.".if Threads.nthreads() > 1
, which cannot be determined at compile time. This means that:evaluate!!
must be together type stable.evaluate!!
. That's just silly IMO.cacheForReverse
EnzymeAD/Enzyme.jl#2518Does this actually work?
This PR has no tests yet, but I ran this locally and the log-likelihood gets accumulated correctly:
I can also confirm that the parallelisation is correctly occurring with this model:
If you run this with 1 thread it takes 2 seconds, and if you run it with 2 threads it takes 1 second.
It also works correctly with
MCMCThreads()
(with some minor adjustments to Turing.jl for compatibility with this branch). NOTE: Sampling with@pobserve
is now fully reproducible, whereasThreads.@threads
was not reproducible even when seeded.What now?
There are a handful of limitations to this PR. These are the ones I can think of right now:
It will crash if the VarInfo used for evaluation does not have a likelihood accumulator.DynamicPPL.acclogprior!!()
..~
(or maybe it does, I haven't tested, but my guess is that it will bug out)x
is not a model argument or conditioned upon, this will yield wrong results for the typicalx = Vector{Float64}(undef, 2); @pobserve for i in eachindex(x); x[i] ~ dist; end
as it will naively accumulatelogpdf(dist, x[i])
even though this should be an assumption rather than observationThere is no way to extract other computations from the threads.Threads.@spawn
, so PG will throw an error with@pobserve
.@pobserve
is a bit too unambitious. If one day we make it work with assume, then it will have to be renamed, i.e. a breaking change.I believe that all of these are either unimportant or can be worked around with some additional macro leg-work:
Not important, nobody is running around evaluating their models with no likelihood accumulator. Not even Turing does this. Also easy enough for us to guard against by wrapping the entire thing in an if/else.acclogprior!!
outside the threaded bit.This can be fixed easily by changing the macro to return a tuple of(retval, loglike)
rather than justloglike
.Threads.@threads
.So for now this should mostly be considered a proof of principle rather than a complete PR.
Finally, note that this PR already removes > 550 lines of code but this is not a full picture of the simplification afforded. For example, I did not remove the
split
,combine
, andconvert_eltype
methods on accumulators, which I believe can either be removed or simplified once TSVI is removed.