Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Turing"
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
version = "0.40.2"
version = "0.41.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
33 changes: 32 additions & 1 deletion src/mcmc/prior.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,36 @@ function AbstractMCMC.step(
),
)
_, vi = DynamicPPL.evaluate!!(sampling_model, vi)
return Transition(model, vi, nothing; reevaluate=false), nothing
vi = DynamicPPL.typed_varinfo(vi)
return Transition(model, vi, nothing; reevaluate=false), vi
end

function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
sampler::DynamicPPL.Sampler{<:Prior},
vi::AbstractVarInfo;
kwargs...,
)
# TODO(DPPL0.38/penelopeysm): replace this entire thing with init!!
#
# `vi` is a VarInfo from the previous step so already has all the
# right accumulators and stuff. The only thing we need to change is to make
# sure that the old values are overwritten when we resample.
#
# Note also that the values in the Transition (and hence the chain) are not
# obtained from the VarInfo's metadata itself, but are instead obtained
# from the ValuesAsInModelAccumulator, which is cleared in the evaluate!!
# call. Thus, the actual values in the VarInfo's metadata don't matter:
# we only set the del flag here to make sure that new values are sampled
# (and thus new values enter VAIMAcc), rather than the old ones being
# reused during the evaluation. Yes, SampleFromPrior really sucks.
for vn in keys(vi)
DynamicPPL.set_flag!(vi, vn, "del")
end
sampling_model = DynamicPPL.contextualize(
model, DynamicPPL.SamplingContext(rng, DynamicPPL.SampleFromPrior(), model.context)
)
_, vi = DynamicPPL.evaluate!!(sampling_model, vi)
return Transition(model, vi, nothing; reevaluate=false), vi
end
Loading