Massive speedup for Turing model with DDM by enabling AutoReverseDiff(;compile=true) #130
bfalandays
started this conversation in
Ideas
Replies: 1 comment
-
|
@bfalandays, thank you for sharing your performance tip. That is a very large speed up indeed! Have you tried the Mooncake AD? From what I gather, that might be the go-to reverse mode AD option in the future. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hey SequentialSamplingModels team, I wanted to share about some major performance gains I saw from making modifications to the
DDMfunctions to enable the use ofAutoReverseDiff(;compile=true)as the automatic-differentiation backend. I have a pretty complex hierarchical DDM built with Turing.jl, and I'm using NUTS as the sampler. When using the default AD backend (AutoForwardDiff) with 250 iterations, 50 adaptation steps, and target_accept = .65, the model took over 8 hours to run. After altering the code to getAutoReverseDiff(;compile=true)working, the same settings ran in ~12 minutes!In the current version of SequentialSamplingModels.DDM, using compiled tape for the AD backend isn't an option due to branching control flow in two places:
(1) the
pdffunction has a step of reflecting the parameter valuesif choice == 1,(2) the
_pdfsub-function has branches for computing the number of terms for the small-time vs large-time expansion, and also for deciding which to use.The first branch (reflecting parameters) is pretty simple to eliminate with some arithmetic. Note that I'm coding choice=1 as the upper boundary, choice=0 as lower boundary. So in
pdfI computeν_new = (1 - 2*choice)*νto reflectν, andz_new = (1-choice) * z + choice * (1-z)to reflectz.The second was a bit trickier, but the approach I used was to always evaluate both expansions using the smaller value of K, then use
ifelseto return the value I want--unlike theifbranch,ifelsealways evaluates both sides, so it works fine with compiled tape. I also leveraged the NaNMath.jl to deal with domain errors that arise when forcing every statement to be evaluated.Here are the key bits of my code:
Beta Was this translation helpful? Give feedback.
All reactions