Skip to content

Commit 24142ee

Browse files
committed
(some) Compatibility with DynamicPPL 0.39
1 parent db57a1d commit 24142ee

File tree

11 files changed

+73
-109
lines changed

11 files changed

+73
-109
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,4 +92,4 @@ DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb"
9292
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
9393

9494
[sources]
95-
DynamicPPL = {url = "https://github.com/TuringLang/DynamicPPL.jl", rev = "py/ldf"}
95+
DynamicPPL = {url = "https://github.com/TuringLang/DynamicPPL.jl", rev = "py/not-experimental"}

src/mcmc/Inference.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ using DynamicPPL:
66
Metadata,
77
VarInfo,
88
LogDensityFunction,
9-
SimpleVarInfo,
109
AbstractVarInfo,
1110
# TODO(mhauru) all_varnames_grouped_by_symbol isn't exported by DPPL, because it is only
1211
# implemented for NTVarInfo. It is used by mh.jl. Either refactor mh.jl to not use it
@@ -92,9 +91,6 @@ function DynamicPPL.unflatten(vi::DynamicPPL.NTVarInfo, θ::NamedTuple)
9291
set_namedtuple!(deepcopy(vi), θ)
9392
return vi
9493
end
95-
function DynamicPPL.unflatten(vi::SimpleVarInfo, θ::NamedTuple)
96-
return SimpleVarInfo(θ, vi.logp, vi.transformation)
97-
end
9894

9995
"""
10096
mh_accept(logp_current::Real, logp_proposal::Real, log_proposal_ratio::Real)

src/mcmc/emcee.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,13 @@ function AbstractMCMC.step(
8686
densitymodel = AMH.DensityModel(Base.Fix1(LogDensityProblems.logdensity, state.ldf))
8787

8888
# Compute the next states.
89-
t, states = AbstractMCMC.step(rng, densitymodel, spl.ensemble, state.states)
89+
_, states = AbstractMCMC.step(rng, densitymodel, spl.ensemble, state.states)
9090

9191
# Compute the next transition and state.
9292
transition = map(states) do _state
93-
return DynamicPPL.ParamsWithStats(_state.params, state.ldf, t)
93+
return DynamicPPL.ParamsWithStats(
94+
_state.params, state.ldf, AbstractMCMC.getstats(_state)
95+
)
9496
end
9597
newstate = EmceeState(state.ldf, states)
9698

src/mcmc/ess.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ function Turing.Inference.initialstep(
3131
EllipticalSliceSampling.isgaussian(typeof(dist)) ||
3232
error("ESS only supports Gaussian prior distributions")
3333
end
34-
return Transition(model, vi, nothing), vi
34+
return DynamicPPL.ParamsWithStats(vi, model), vi
3535
end
3636

3737
function AbstractMCMC.step(
@@ -56,7 +56,7 @@ function AbstractMCMC.step(
5656
vi = DynamicPPL.unflatten(vi, sample)
5757
vi = DynamicPPL.setloglikelihood!!(vi, state.loglikelihood)
5858

59-
return Transition(model, vi, nothing), vi
59+
return DynamicPPL.ParamsWithStats(vi, model), vi
6060
end
6161

6262
# Prior distribution of considered random variable

src/mcmc/external_sampler.jl

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -122,12 +122,12 @@ function externalsampler(
122122
end
123123

124124
# TODO(penelopeysm): Can't we clean this up somehow?
125-
struct TuringState{S,V1,M,V}
125+
struct TuringState{S,V,L<:DynamicPPL.LogDensityFunction}
126126
state::S
127-
# Note that this varinfo must have the correct parameters set; but logp
128-
# does not matter as it will be re-evaluated
129-
varinfo::V1
130-
ldf::DynamicPPL.LogDensityFunction{M,V}
127+
# Note that this varinfo is used only for structure. Its parameters and other info do
128+
# not need to be accurate
129+
varinfo::V
130+
ldf::L
131131
end
132132

133133
# get_varinfo should return something from which the correct parameters can be
@@ -185,11 +185,10 @@ function AbstractMCMC.step(
185185
end
186186

187187
new_parameters = AbstractMCMC.getparams(f.model, state_inner)
188-
new_vi = DynamicPPL.unflatten(f.varinfo, new_parameters)
189188
new_stats = AbstractMCMC.getstats(state_inner)
190189
return (
191-
DynamicPPL.ParamsWithStats(new_vi, f.model, new_stats),
192-
TuringState(state_inner, new_vi, f),
190+
DynamicPPL.ParamsWithStats(new_parameters, f, new_stats),
191+
TuringState(state_inner, varinfo, f),
193192
)
194193
end
195194

@@ -209,10 +208,9 @@ function AbstractMCMC.step(
209208
)
210209

211210
new_parameters = AbstractMCMC.getparams(f.model, state_inner)
212-
new_vi = DynamicPPL.unflatten(f.varinfo, new_parameters)
213211
new_stats = AbstractMCMC.getstats(state_inner)
214212
return (
215-
DynamicPPL.ParamsWithStats(new_vi, f.model, new_stats),
216-
TuringState(state_inner, new_vi, f),
213+
DynamicPPL.ParamsWithStats(new_parameters, f, new_stats),
214+
TuringState(state_inner, state.varinfo, f),
217215
)
218216
end

src/mcmc/gibbs.jl

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -488,18 +488,12 @@ function setparams_varinfo!!(
488488
end
489489

490490
function setparams_varinfo!!(
491-
model::DynamicPPL.Model,
492-
sampler::ExternalSampler,
493-
state::TuringState,
494-
params::AbstractVarInfo,
491+
::DynamicPPL.Model, ::ExternalSampler, state::TuringState, params::AbstractVarInfo
495492
)
496-
logdensity = DynamicPPL.LogDensityFunction(
497-
model, DynamicPPL.getlogjoint_internal, state.ldf.varinfo; adtype=sampler.adtype
498-
)
499493
new_inner_state = AbstractMCMC.setparams!!(
500-
AbstractMCMC.LogDensityModel(logdensity), state.state, params[:]
494+
AbstractMCMC.LogDensityModel(state.ldf), state.state, params[:]
501495
)
502-
return TuringState(new_inner_state, params, logdensity)
496+
return TuringState(new_inner_state, params, state.ldf)
503497
end
504498

505499
function setparams_varinfo!!(
@@ -513,11 +507,11 @@ function setparams_varinfo!!(
513507
z = state.z
514508
resize!(z.θ, length(θ_new))
515509
z.θ .= θ_new
516-
return HMCState(params, state.i, state.kernel, hamiltonian, z, state.adaptor)
510+
return HMCState(params, state.i, state.kernel, hamiltonian, z, state.adaptor, state.ldf)
517511
end
518512

519513
function setparams_varinfo!!(
520-
model::DynamicPPL.Model, sampler::PG, state::PGState, params::AbstractVarInfo
514+
::DynamicPPL.Model, ::PG, state::PGState, params::AbstractVarInfo
521515
)
522516
return PGState(params, state.rng)
523517
end

src/mcmc/mh.jl

Lines changed: 48 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -179,33 +179,33 @@ get_varinfo(s::MHState) = s.varinfo
179179
#####################
180180

181181
"""
182-
set_namedtuple!(vi::VarInfo, nt::NamedTuple)
182+
OldLogDensityFunction
183183
184-
Places the values of a `NamedTuple` into the relevant places of a `VarInfo`.
184+
This is a clone of pre-0.39 DynamicPPL.LogDensityFunction. It is needed for MH because MH
185+
doesn't actually obey the LogDensityProblems.jl interface: it evaluates
186+
'LogDensityFunctions' with a NamedTuple(!!)
187+
188+
This means that we can't _really_ use DynamicPPL's LogDensityFunction, since that only
189+
promises to obey the interface of being called with a vector.
190+
191+
In particular, because `set_namedtuple!` acts on a VarInfo, we need to store the VarInfo
192+
inside this struct (which DynamicPPL's LogDensityFunction no longer does).
193+
194+
This SHOULD really be refactored to remove this requirement.
185195
"""
186-
function set_namedtuple!(vi::DynamicPPL.VarInfoOrThreadSafeVarInfo, nt::NamedTuple)
187-
for (n, vals) in pairs(nt)
188-
vns = vi.metadata[n].vns
189-
if vals isa AbstractVector
190-
vals = unvectorize(vals)
191-
end
192-
if length(vns) == 1
193-
# Only one variable, assign the values to it
194-
DynamicPPL.setindex!(vi, vals, vns[1])
195-
else
196-
# Spread the values across the variables
197-
length(vns) == length(vals) || error("Unequal number of variables and values")
198-
for (vn, val) in zip(vns, vals)
199-
DynamicPPL.setindex!(vi, val, vn)
200-
end
201-
end
202-
end
196+
struct OldLogDensityFunction{M<:DynamicPPL.Model,V<:DynamicPPL.AbstractVarInfo}
197+
model::M
198+
varinfo::V
199+
end
200+
function (f::OldLogDensityFunction)(x::AbstractVector)
201+
vi = DynamicPPL.unflatten(f.varinfo, x)
202+
_, vi = DynamicPPL.evaluate!!(f.model, vi)
203+
return DynamicPPL.getlogjoint_internal(vi)
203204
end
204-
205205
# NOTE(penelopeysm): MH does not conform to the usual LogDensityProblems
206206
# interface in that it gets evaluated with a NamedTuple. Hence we need this
207207
# method just to deal with MH.
208-
function LogDensityProblems.logdensity(f::LogDensityFunction, x::NamedTuple)
208+
function (f::OldLogDensityFunction)(x::NamedTuple)
209209
vi = deepcopy(f.varinfo)
210210
# Note that the NamedTuple `x` does NOT conform to the structure required for
211211
# `InitFromParams`. In particular, for models that look like this:
@@ -223,8 +223,31 @@ function LogDensityProblems.logdensity(f::LogDensityFunction, x::NamedTuple)
223223
set_namedtuple!(vi, x)
224224
# Update log probability.
225225
_, vi_new = DynamicPPL.evaluate!!(f.model, vi)
226-
lj = f.getlogdensity(vi_new)
227-
return lj
226+
return DynamicPPL.getlogjoint_internal(vi_new)
227+
end
228+
229+
"""
230+
set_namedtuple!(vi::VarInfo, nt::NamedTuple)
231+
232+
Places the values of a `NamedTuple` into the relevant places of a `VarInfo`.
233+
"""
234+
function set_namedtuple!(vi::DynamicPPL.VarInfoOrThreadSafeVarInfo, nt::NamedTuple)
235+
for (n, vals) in pairs(nt)
236+
vns = vi.metadata[n].vns
237+
if vals isa AbstractVector
238+
vals = unvectorize(vals)
239+
end
240+
if length(vns) == 1
241+
# Only one variable, assign the values to it
242+
DynamicPPL.setindex!(vi, vals, vns[1])
243+
else
244+
# Spread the values across the variables
245+
length(vns) == length(vals) || error("Unequal number of variables and values")
246+
for (vn, val) in zip(vns, vals)
247+
DynamicPPL.setindex!(vi, val, vn)
248+
end
249+
end
250+
end
228251
end
229252

230253
# unpack a vector if possible
@@ -335,12 +358,7 @@ function propose!!(rng::AbstractRNG, prev_state::MHState, model::Model, spl::MH,
335358

336359
# Make a new transition.
337360
model = DynamicPPL.setleafcontext(model, MHContext(rng))
338-
densitymodel = AMH.DensityModel(
339-
Base.Fix1(
340-
LogDensityProblems.logdensity,
341-
DynamicPPL.LogDensityFunction(model, DynamicPPL.getlogjoint_internal, vi),
342-
),
343-
)
361+
densitymodel = AMH.DensityModel(OldLogDensityFunction(model, vi))
344362
trans, _ = AbstractMCMC.step(rng, densitymodel, mh_sampler, prev_trans)
345363
# trans.params isa NamedTuple
346364
set_namedtuple!(vi, trans.params)
@@ -370,12 +388,7 @@ function propose!!(
370388

371389
# Make a new transition.
372390
model = DynamicPPL.setleafcontext(model, MHContext(rng))
373-
densitymodel = AMH.DensityModel(
374-
Base.Fix1(
375-
LogDensityProblems.logdensity,
376-
DynamicPPL.LogDensityFunction(model, DynamicPPL.getlogjoint_internal, vi),
377-
),
378-
)
391+
densitymodel = AMH.DensityModel(OldLogDensityFunction(model, vi))
379392
trans, _ = AbstractMCMC.step(rng, densitymodel, mh_sampler, prev_trans)
380393
# trans.params isa AbstractVector
381394
vi = DynamicPPL.unflatten(vi, trans.params)

src/mcmc/prior.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,6 @@ function AbstractMCMC.step(
1717
DynamicPPL.LogPriorAccumulator(),
1818
DynamicPPL.LogLikelihoodAccumulator(),
1919
))
20-
_, vi = DynamicPPL.fast_evaluate!!(rng, model, InitFromPrior(), accs)
20+
_, vi = DynamicPPL.fast_evaluate!!(rng, model, DynamicPPL.InitFromPrior(), accs)
2121
return DynamicPPL.ParamsWithStats(vi), nothing
2222
end

test/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,4 +79,4 @@ TimerOutputs = "0.5"
7979
julia = "1.10"
8080

8181
[sources]
82-
DynamicPPL = {url = "https://github.com/TuringLang/DynamicPPL.jl", rev = "py/ldf"}
82+
DynamicPPL = {url = "https://github.com/TuringLang/DynamicPPL.jl", rev = "py/not-experimental"}

test/mcmc/Inference.jl

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -174,19 +174,6 @@ using Turing
174174
@test mean(chains, :m) 0 atol = 0.1
175175
end
176176

177-
@testset "Vector chain_type" begin
178-
chains = sample(
179-
StableRNG(seed), gdemo_d(), Prior(), N; chain_type=Vector{NamedTuple}
180-
)
181-
@test chains isa Vector{<:NamedTuple}
182-
@test length(chains) == N
183-
@test all(haskey(x, :lp) for x in chains)
184-
@test all(haskey(x, :logprior) for x in chains)
185-
@test all(haskey(x, :loglikelihood) for x in chains)
186-
@test mean(x[:s][1] for x in chains) 3 atol = 0.11
187-
@test mean(x[:m][1] for x in chains) 0 atol = 0.1
188-
end
189-
190177
@testset "accumulators are set correctly" begin
191178
# Prior() does not reevaluate the model when constructing a
192179
# `DynamicPPL.ParamsWithStats`, so we had better make sure that it does capture
@@ -638,32 +625,6 @@ using Turing
638625
)
639626
end
640627

641-
@testset "getparams" begin
642-
@model function e(x=1.0)
643-
return x ~ Normal()
644-
end
645-
evi = DynamicPPL.VarInfo(e())
646-
@test isempty(Turing.Inference.getparams(e(), evi))
647-
648-
@model function f()
649-
return x ~ Normal()
650-
end
651-
fvi = DynamicPPL.VarInfo(f())
652-
fparams = Turing.Inference.getparams(f(), fvi)
653-
@test fparams[@varname(x)] == fvi[@varname(x)]
654-
@test length(fparams) == 1
655-
656-
@model function g()
657-
x ~ Normal()
658-
return y ~ Poisson()
659-
end
660-
gvi = DynamicPPL.VarInfo(g())
661-
gparams = Turing.Inference.getparams(g(), gvi)
662-
@test gparams[@varname(x)] == gvi[@varname(x)]
663-
@test gparams[@varname(y)] == gvi[@varname(y)]
664-
@test length(gparams) == 2
665-
end
666-
667628
@testset "empty model" begin
668629
@model function e(x=1.0)
669630
return x ~ Normal()

0 commit comments

Comments
 (0)