diff --git a/src/Distributions/DiscreteLaplace/Model.dfy b/src/Distributions/DiscreteLaplace/Model.dfy index b7cb18ca..4e99ecdb 100644 --- a/src/Distributions/DiscreteLaplace/Model.dfy +++ b/src/Distributions/DiscreteLaplace/Model.dfy @@ -5,9 +5,43 @@ module DiscreteLaplace.Model { import Monad + import Rand import Rationals + import Uniform + import BernoulliExpNeg + import Bernoulli ghost function Sample(scale: Rationals.Rational): Monad.Hurd requires scale.numer >= 1 + ghost function SampleTailRecursive(scale: Rationals.Rational, b: bool := true, y: int := 0): Monad.Hurd + requires scale.numer >= 1 + { + assume {:axiom} false; // assume termination + (s: Rand.Bitstream) => + if !(b && y == 0) then + Monad.Result(if b then -y else y, s) + else + var (u, s) := Monad.Extract(Uniform.Model.Sample(scale.numer)(s)); + var (d, s) := Monad.Extract(BernoulliExpNeg.Model.Sample(Rationals.Rational(u, scale.numer))(s)); + if !d then + SampleTailRecursive(scale, b, y)(s) + else + var (v, s) := Monad.Extract(SampleTailRecursiveHelper(scale)(s)); + var x := u + scale.numer * v; + var y := x / scale.denom; + var (b, s) := Monad.Extract(Bernoulli.Model.Sample(1, 2)(s)); + SampleTailRecursive(scale, b, y)(s) + } + + ghost function SampleTailRecursiveHelper(scale: Rationals.Rational, v: int := 0, a: bool := true): Monad.Hurd { + assume {:axiom} false; // assume termination + (s: Rand.Bitstream) => + if !a then + Monad.Result(v, s) + else + var (a, s) := Monad.Extract(BernoulliExpNeg.Model.Sample(Rationals.Int(1))(s)); + SampleTailRecursiveHelper(scale, if a then v + 1 else v, a)(s) + } + } diff --git a/src/ProbabilisticProgramming/Monad.dfy b/src/ProbabilisticProgramming/Monad.dfy index 5892d4a1..904d2691 100644 --- a/src/ProbabilisticProgramming/Monad.dfy +++ b/src/ProbabilisticProgramming/Monad.dfy @@ -19,6 +19,10 @@ module Monad { // It consists of the computed value and the (unconsumed) rest of the bitstream. datatype Result = Result(value: A, rest: Rand.Bitstream) + function Extract(r: Result): (A, Rand.Bitstream) { + (r.value, r.rest) + } + // Equation (2.38) function Tail(s: Rand.Bitstream): (s': Rand.Bitstream) { (n: nat) => s(n+1)