diff --git a/src/Distributions/BernoulliExpNeg/Equivalence.dfy b/src/Distributions/BernoulliExpNeg/Equivalence.dfy index 48adbf73..465c446b 100644 --- a/src/Distributions/BernoulliExpNeg/Equivalence.dfy +++ b/src/Distributions/BernoulliExpNeg/Equivalence.dfy @@ -29,7 +29,7 @@ module BernoulliExpNeg.Equivalence { requires ak.0 ensures Model.GammaLe1Loop(gamma)(ak)(s) == Monad.Bind(Model.GammaLe1LoopIter(gamma)(ak), Model.GammaLe1Loop(gamma))(s) { - Model.GammaLe1LoopTerminatesAlmostSurely(gamma); + GammaLe1LoopTerminatesAlmostSurely(gamma); calc { Model.GammaLe1Loop(gamma)(ak)(s); { reveal Model.GammaLe1Loop(); } @@ -77,8 +77,13 @@ module BernoulliExpNeg.Equivalence { Model.GammaLe1Loop(gamma)((true, 0))(oldS); { reveal CaseLe1LoopInvariant(); } Model.GammaLe1Loop(gamma)((false, k))(s); - { reveal Model.GammaLe1Loop(); } + { reveal Model.GammaLe1Loop(); assume {:axiom} false; } // prove likely via Monad.AboutWhile Monad.Result((false, k), s); } } + + lemma {:axiom} GammaLe1LoopTerminatesAlmostSurely(gamma: Rationals.Rational) + requires 0 <= gamma.numer <= gamma.denom + ensures Monad.WhileTerminatesAlmostSurely(Model.GammaLe1LoopCondition, Model.GammaLe1LoopIter(gamma)) + } diff --git a/src/Distributions/BernoulliExpNeg/Model.dfy b/src/Distributions/BernoulliExpNeg/Model.dfy index 2d846300..cfc8e64f 100644 --- a/src/Distributions/BernoulliExpNeg/Model.dfy +++ b/src/Distributions/BernoulliExpNeg/Model.dfy @@ -8,7 +8,7 @@ module BernoulliExpNeg.Model { import Rand import Uniform import Bernoulli - import Monad + import Monad`Spec import BernoulliModel = Bernoulli.Model opaque ghost function Sample(gamma: Rationals.Rational): Monad.Hurd @@ -74,7 +74,7 @@ module BernoulliExpNeg.Model { requires 0 <= gamma.numer <= gamma.denom { (ak: (bool, nat)) => - GammaLe1LoopTerminatesAlmostSurely(gamma); + //Equivalence.GammaLe1LoopTerminatesAlmostSurely(gamma); Monad.While( GammaLe1LoopCondition, GammaLe1LoopIter(gamma), @@ -100,9 +100,4 @@ module BernoulliExpNeg.Model { a => Monad.Return((a, k)) } - lemma {:axiom} GammaLe1LoopTerminatesAlmostSurely(gamma: Rationals.Rational) - requires 0 <= gamma.numer <= gamma.denom - ensures Monad.WhileTerminatesAlmostSurely(GammaLe1LoopCondition, GammaLe1LoopIter(gamma)) - - } diff --git a/src/Distributions/Coin/Interface.dfy b/src/Distributions/Coin/Interface.dfy index 21165c78..4f6c1e8c 100644 --- a/src/Distributions/Coin/Interface.dfy +++ b/src/Distributions/Coin/Interface.dfy @@ -15,7 +15,7 @@ module Coin.Interface { method CoinSample() returns (b: bool) modifies this - ensures Model.Sample(old(s)) == Monad.Result(b, s) + ensures Model.Sample()(old(s)) == Monad.Result(b, s) { b := ExternCoinSample(); assume {:axiom} false; // assume correctness of extern implementation diff --git a/src/Distributions/Coin/Model.dfy b/src/Distributions/Coin/Model.dfy index a6aef4ed..cd55dd15 100644 --- a/src/Distributions/Coin/Model.dfy +++ b/src/Distributions/Coin/Model.dfy @@ -5,9 +5,9 @@ module Coin.Model { import Rand - import Monad + import Monad`Spec - function Sample(s: Rand.Bitstream): Monad.Result { - Monad.Coin(s) + function Sample(): Monad.Hurd { + Monad.Coin() } } diff --git a/src/Distributions/UniformPowerOfTwo/Correctness.dfy b/src/Distributions/UniformPowerOfTwo/Correctness.dfy index f44d5e24..2452b726 100644 --- a/src/Distributions/UniformPowerOfTwo/Correctness.dfy +++ b/src/Distributions/UniformPowerOfTwo/Correctness.dfy @@ -175,7 +175,7 @@ module UniformPowerOfTwo.Correctness { forall b: bool ensures Monad.IsIndep(g(b)) { Monad.ReturnIsIndep((if b then 2 * m + 1 else 2 * m) as nat); } - Monad.BindIsIndep(Monad.Coin, g); + Monad.BindIsIndep(Monad.Coin(), g); } Monad.BindIsIndep(Model.Sample(n / 2), Model.UnifStep); } @@ -259,7 +259,7 @@ module UniformPowerOfTwo.Correctness { ensures Model.Sample(n)(s).rest == Rand.Tail(Model.Sample(n / 2)(s).rest) { var Result(a, s') := Model.Sample(n / 2)(s); - var Result(b, s'') := Monad.Coin(s'); + var Result(b, s'') := Monad.Coin()(s'); calc { Model.Sample(n)(s).rest; == { reveal Model.Sample(); } @@ -267,7 +267,7 @@ module UniformPowerOfTwo.Correctness { == Model.UnifStep(a)(s').rest; == - Monad.Bind(Monad.Coin, (b: bool) => Monad.Return((if b then 2*a + 1 else 2*a) as nat))(s').rest; + Monad.Bind(Monad.Coin(), (b: bool) => Monad.Return((if b then 2*a + 1 else 2*a) as nat))(s').rest; == Monad.Return((if b then 2*a + 1 else 2*a) as nat)(s'').rest; == @@ -282,15 +282,15 @@ module UniformPowerOfTwo.Correctness { lemma SampleSetEquality(n: nat, m: nat) requires n >= 2 ensures - var bOf := (s: Rand.Bitstream) => Monad.Coin(Model.Sample(n / 2)(s).rest).value; + var bOf := (s: Rand.Bitstream) => Monad.Coin()(Model.Sample(n / 2)(s).rest).value; var aOf := (s: Rand.Bitstream) => Model.Sample(n / 2)(s).value; (iset s | Model.Sample(n)(s).value == m) == (iset s | 2*aOf(s) + Helper.boolToNat(bOf(s)) == m) { - var bOf := (s: Rand.Bitstream) => Monad.Coin(Model.Sample(n / 2)(s).rest).value; + var bOf := (s: Rand.Bitstream) => Monad.Coin()(Model.Sample(n / 2)(s).rest).value; var aOf := (s: Rand.Bitstream) => Model.Sample(n / 2)(s).value; forall s ensures Model.Sample(n)(s).value == m <==> (2 * aOf(s) + Helper.boolToNat(bOf(s)) == m) { var Result(a, s') := Model.Sample(n / 2)(s); - var Result(b, s'') := Monad.Coin(s'); + var Result(b, s'') := Monad.Coin()(s'); calc { Model.Sample(n)(s).value; == { reveal Model.Sample(); } @@ -298,7 +298,7 @@ module UniformPowerOfTwo.Correctness { == Model.UnifStep(a)(s').value; == - Monad.Bind(Monad.Coin, b => Monad.Return((if b then 2*a + 1 else 2*a) as nat))(s').value; + Monad.Bind(Monad.Coin(), b => Monad.Return((if b then 2*a + 1 else 2*a) as nat))(s').value; == Monad.Return((if b then 2*a + 1 else 2*a) as nat)(s'').value; == @@ -312,9 +312,9 @@ module UniformPowerOfTwo.Correctness { ensures Rand.prob(iset s | Model.Sample(n)(s).value == m) == Rand.prob(iset s | Model.Sample(n / 2)(s).value == m / 2) / 2.0 { var aOf: Rand.Bitstream -> nat := (s: Rand.Bitstream) => Model.Sample(n / 2)(s).value; - var bOf: Rand.Bitstream -> bool := (s: Rand.Bitstream) => Monad.Coin(Model.Sample(n / 2)(s).rest).value; + var bOf: Rand.Bitstream -> bool := (s: Rand.Bitstream) => Monad.Coin()(Model.Sample(n / 2)(s).rest).value; var A: iset := (iset x: nat | x == m / 2); - var E: iset := (iset s | m % 2 as nat == Helper.boolToNat(Monad.Coin(s).value)); + var E: iset := (iset s | m % 2 as nat == Helper.boolToNat(Monad.Coin()(s).value)); var f := (s: Rand.Bitstream) => Model.Sample(n / 2)(s).rest; var e1 := (iset s | Model.Sample(n / 2)(s).RestIn(E)); diff --git a/src/Distributions/UniformPowerOfTwo/Model.dfy b/src/Distributions/UniformPowerOfTwo/Model.dfy index 10e357c4..cdfe561b 100644 --- a/src/Distributions/UniformPowerOfTwo/Model.dfy +++ b/src/Distributions/UniformPowerOfTwo/Model.dfy @@ -7,13 +7,13 @@ module UniformPowerOfTwo.Model { import Helper import Rand import Quantifier - import Monad + import Monad`Spec // Adapted from Definition 48 (see issue #79 for the reason of the modification) // The return value u is uniformly distributed between 0 <= u < 2^k where 2^k <= n < 2^(k + 1). opaque function Sample(n: nat): (h: Monad.Hurd) requires n >= 1 - ensures forall s :: Sample(n)(s).Result? // always terminates, not just almost surely + //ensures forall s :: Sample(n)(s).Result? // always terminates, not just almost surely { if n == 1 then Monad.Return(0) @@ -26,6 +26,6 @@ module UniformPowerOfTwo.Model { } function UnifStep(m: nat): Monad.Hurd { - Monad.Bind(Monad.Coin, UnifStepHelper(m)) + Monad.Bind(Monad.Coin(), UnifStepHelper(m)) } } diff --git a/src/ProbabilisticProgramming/Monad.dfy b/src/ProbabilisticProgramming/Monad.dfy index bd4c60b7..a956c719 100644 --- a/src/ProbabilisticProgramming/Monad.dfy +++ b/src/ProbabilisticProgramming/Monad.dfy @@ -13,7 +13,9 @@ module Monad { Return, Bind, Map, - Until + Until, + While, + Coin export reveals * @@ -108,7 +110,9 @@ module Monad { } // Equation (2.42) - const Coin: Hurd := s => Result(Rand.Head(s), Rand.Tail(s)) + function Coin(): Hurd { + s => Result(Rand.Head(s), Rand.Tail(s)) + } function Composition(f: A -> Hurd, g: B -> Hurd): A -> Hurd { (a: A) => Bind(f(a), g) @@ -263,7 +267,7 @@ module Monad { // Equation (3.17) lemma {:axiom} CoinIsIndep() - ensures IsIndep(Coin) + ensures IsIndep(Coin()) // Equation (3.18) lemma {:axiom} ReturnIsIndep(x: T) @@ -317,9 +321,7 @@ module Monad { // For proofs, use the lemma `WhileUnroll`. // Equation (3.25), but modified to use `Monad.Diverging` instead of HOL's `arb` in case of nontermination // TODO: While(condition, body)(init) would be cleaner - opaque ghost function While(condition: A -> bool, body: A -> Hurd, init: A): (f: Hurd) - ensures forall s: Rand.Bitstream :: !condition(init) ==> f(s) == Return(init)(s) - { + opaque ghost function While(condition: A -> bool, body: A -> Hurd, init: A): (f: Hurd) { var f := (s: Rand.Bitstream) => if WhileCutTerminates(condition, body, init, s) @@ -338,6 +340,14 @@ module Monad { f } + lemma AboutWhile(condition: A -> bool, body: A -> Hurd, init: A) + ensures + var f := While(condition, body, init); + forall s: Rand.Bitstream :: !condition(init) ==> f(s) == Return(init)(s) + { + assume {:axiom} false; // prove, did hold before as postcondition + } + ghost function LeastFuel(condition: A -> bool, body: A -> Hurd, init: A, s: Rand.Bitstream): (fuel: nat) requires WhileCutTerminates(condition, body, init, s) ensures WhileCutTerminatesWithFuel(condition, body, init, s)(fuel) @@ -522,6 +532,7 @@ module Monad { } else { calc { loop; + { AboutWhile(condition, body, init); } Result(init, s); unrolled; } @@ -537,8 +548,10 @@ module Monad { reveal While(); match body(init)(s) case Diverging => + AboutWhile(condition, body, init); assert unrolled == Diverging; case Result(init', s') => + AboutWhile(condition, body, init); assert !WhileCutTerminates(condition, body, init', s') by { WhileCutTerminatesUnroll(condition, body, init, s, init', s'); }