Skip to content

Commit

Permalink
tail recursive model
Browse files Browse the repository at this point in the history
  • Loading branch information
stefan-aws committed Oct 23, 2023
1 parent 711679e commit 9236c25
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 0 deletions.
34 changes: 34 additions & 0 deletions src/Distributions/DiscreteLaplace/Model.dfy
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,43 @@

module DiscreteLaplace.Model {
import Monad
import Rand
import Rationals
import Uniform
import BernoulliExpNeg
import Bernoulli

ghost function Sample(scale: Rationals.Rational): Monad.Hurd<int>
requires scale.numer >= 1

ghost function SampleTailRecursive(scale: Rationals.Rational, b: bool := true, y: int := 0): Monad.Hurd<int>
requires scale.numer >= 1
{
assume {:axiom} false; // assume termination
(s: Rand.Bitstream) =>
if !(b && y == 0) then
Monad.Result(if b then -y else y, s)
else
var (u, s) := Monad.Extract(Uniform.Model.Sample(scale.numer)(s));
var (d, s) := Monad.Extract(BernoulliExpNeg.Model.Sample(Rationals.Rational(u, scale.numer))(s));
if !d then
SampleTailRecursive(scale, b, y)(s)
else
var (v, s) := Monad.Extract(SampleTailRecursiveHelper(scale)(s));
var x := u + scale.numer * v;
var y := x / scale.denom;
var (b, s) := Monad.Extract(Bernoulli.Model.Sample(1, 2)(s));
SampleTailRecursive(scale, b, y)(s)
}

ghost function SampleTailRecursiveHelper(scale: Rationals.Rational, v: int := 0, a: bool := true): Monad.Hurd<int> {
assume {:axiom} false; // assume termination
(s: Rand.Bitstream) =>
if !a then
Monad.Result(v, s)
else
var (a, s) := Monad.Extract(BernoulliExpNeg.Model.Sample(Rationals.Int(1))(s));
SampleTailRecursiveHelper(scale, if a then v + 1 else v, a)(s)
}

}
4 changes: 4 additions & 0 deletions src/ProbabilisticProgramming/Monad.dfy
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ module Monad {
// It consists of the computed value and the (unconsumed) rest of the bitstream.
datatype Result<A> = Result(value: A, rest: Rand.Bitstream)

function Extract<A>(r: Result): (A, Rand.Bitstream) {
(r.value, r.rest)
}

// Equation (2.38)
function Tail(s: Rand.Bitstream): (s': Rand.Bitstream) {
(n: nat) => s(n+1)
Expand Down

0 comments on commit 9236c25

Please sign in to comment.