Skip to content

Commit d017e4b

Browse files
committed
Make pobserve work with VarInfos that have no likelihood accumulator
1 parent 13b0991 commit d017e4b

File tree

2 files changed

+25
-7
lines changed

2 files changed

+25
-7
lines changed

src/pobserve_macro.jl

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,11 @@ function _pobserve(expr::Expr)
2525
end
2626
retvals_and_likelihoods = fetch.(likelihood_tasks)
2727
total_likelihoods = sum(last, retvals_and_likelihoods)
28-
# println("Total likelihoods: ", total_likelihoods)
29-
$(esc(:(__varinfo__))) = $(DynamicPPL.accloglikelihood!!)(
30-
$(esc(:(__varinfo__))), total_likelihoods
31-
)
28+
if $(DynamicPPL.hasacc)($(esc(:(__varinfo__))), Val(:LogLikelihood))
29+
$(esc(:(__varinfo__))) = $(DynamicPPL.accloglikelihood!!)(
30+
$(esc(:(__varinfo__))), total_likelihoods
31+
)
32+
end
3233
map(first, retvals_and_likelihoods)
3334
end
3435
return return_expr
@@ -49,8 +50,13 @@ function process_tilde_statements(expr::Expr)
4950
end
5051
) || error("expected block")
5152
@gensym loglike
52-
beginning_statement =
53-
:($loglike = zero($(DynamicPPL.getloglikelihood)($(esc(:(__varinfo__))))))
53+
beginning_expr = quote
54+
$loglike = if $(DynamicPPL.hasacc)($(esc(:(__varinfo__))), Val(:LogLikelihood))
55+
zero($(DynamicPPL.getloglikelihood)($(esc(:(__varinfo__)))))
56+
else
57+
zero($(DynamicPPL.LogProbType))
58+
end
59+
end
5460
n_statements = length(statements)
5561
transformed_statements::Vector{Vector{Expr}} = map(enumerate(statements)) do (i, stmt)
5662
is_last = i == n_statements
@@ -79,6 +85,6 @@ function process_tilde_statements(expr::Expr)
7985
e
8086
end
8187
end
82-
new_statements = [beginning_statement, reduce(vcat, transformed_statements)...]
88+
new_statements = [beginning_expr.args..., reduce(vcat, transformed_statements)...]
8389
return Expr(:block, new_statements...)
8490
end

test/pobserve_macro.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,18 @@ using DynamicPPL, Distributions, Test
1515
@test isapprox(DynamicPPL.getloglikelihood(vi), expected_loglike)
1616
end
1717

18+
@testset "doesn't error when varinfo has no likelihood acc" begin
19+
@model function f(x)
20+
@pobserve for i in eachindex(x)
21+
x[i] ~ Normal()
22+
end
23+
end
24+
x = randn(3)
25+
vi = VarInfo()
26+
vi = DynamicPPL.setaccs!!(vi, (DynamicPPL.LogPriorAccumulator(),))
27+
@test DynamicPPL.evaluate!!(f(x), vi) isa Any
28+
end
29+
1830
@testset "return values are correct" begin
1931
@testset "single expression at the end" begin
2032
@model function f(x)

0 commit comments

Comments
 (0)