@@ -25,10 +25,11 @@ function _pobserve(expr::Expr)
25
25
end
26
26
retvals_and_likelihoods = fetch .(likelihood_tasks)
27
27
total_likelihoods = sum (last, retvals_and_likelihoods)
28
- # println("Total likelihoods: ", total_likelihoods)
29
- $ (esc (:(__varinfo__))) = $ (DynamicPPL. accloglikelihood!!)(
30
- $ (esc (:(__varinfo__))), total_likelihoods
31
- )
28
+ if $ (DynamicPPL. hasacc)($ (esc (:(__varinfo__))), Val (:LogLikelihood ))
29
+ $ (esc (:(__varinfo__))) = $ (DynamicPPL. accloglikelihood!!)(
30
+ $ (esc (:(__varinfo__))), total_likelihoods
31
+ )
32
+ end
32
33
map (first, retvals_and_likelihoods)
33
34
end
34
35
return return_expr
@@ -49,8 +50,13 @@ function process_tilde_statements(expr::Expr)
49
50
end
50
51
) || error (" expected block" )
51
52
@gensym loglike
52
- beginning_statement =
53
- :($ loglike = zero ($ (DynamicPPL. getloglikelihood)($ (esc (:(__varinfo__))))))
53
+ beginning_expr = quote
54
+ $ loglike = if $ (DynamicPPL. hasacc)($ (esc (:(__varinfo__))), Val (:LogLikelihood ))
55
+ zero ($ (DynamicPPL. getloglikelihood)($ (esc (:(__varinfo__)))))
56
+ else
57
+ zero ($ (DynamicPPL. LogProbType))
58
+ end
59
+ end
54
60
n_statements = length (statements)
55
61
transformed_statements:: Vector{Vector{Expr}} = map (enumerate (statements)) do (i, stmt)
56
62
is_last = i == n_statements
@@ -79,6 +85,6 @@ function process_tilde_statements(expr::Expr)
79
85
e
80
86
end
81
87
end
82
- new_statements = [beginning_statement , reduce (vcat, transformed_statements)... ]
88
+ new_statements = [beginning_expr . args ... , reduce (vcat, transformed_statements)... ]
83
89
return Expr (:block , new_statements... )
84
90
end
0 commit comments