Skip to content

Commit

Permalink
more
Browse files Browse the repository at this point in the history
  • Loading branch information
stefan-aws committed Oct 31, 2023
1 parent b69e6a6 commit 8cbe62a
Show file tree
Hide file tree
Showing 7 changed files with 44 additions and 31 deletions.
9 changes: 7 additions & 2 deletions src/Distributions/BernoulliExpNeg/Equivalence.dfy
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ module BernoulliExpNeg.Equivalence {
requires ak.0
ensures Model.GammaLe1Loop(gamma)(ak)(s) == Monad.Bind(Model.GammaLe1LoopIter(gamma)(ak), Model.GammaLe1Loop(gamma))(s)
{
Model.GammaLe1LoopTerminatesAlmostSurely(gamma);
GammaLe1LoopTerminatesAlmostSurely(gamma);
calc {
Model.GammaLe1Loop(gamma)(ak)(s);
{ reveal Model.GammaLe1Loop(); }
Expand Down Expand Up @@ -77,8 +77,13 @@ module BernoulliExpNeg.Equivalence {
Model.GammaLe1Loop(gamma)((true, 0))(oldS);
{ reveal CaseLe1LoopInvariant(); }
Model.GammaLe1Loop(gamma)((false, k))(s);
{ reveal Model.GammaLe1Loop(); }
{ reveal Model.GammaLe1Loop(); assume {:axiom} false; } // prove likely via Monad.AboutWhile
Monad.Result((false, k), s);
}
}

lemma {:axiom} GammaLe1LoopTerminatesAlmostSurely(gamma: Rationals.Rational)
requires 0 <= gamma.numer <= gamma.denom
ensures Monad.WhileTerminatesAlmostSurely(Model.GammaLe1LoopCondition, Model.GammaLe1LoopIter(gamma))

}
9 changes: 2 additions & 7 deletions src/Distributions/BernoulliExpNeg/Model.dfy
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ module BernoulliExpNeg.Model {
import Rand
import Uniform
import Bernoulli
import Monad
import Monad`Spec
import BernoulliModel = Bernoulli.Model

opaque ghost function Sample(gamma: Rationals.Rational): Monad.Hurd<bool>
Expand Down Expand Up @@ -74,7 +74,7 @@ module BernoulliExpNeg.Model {
requires 0 <= gamma.numer <= gamma.denom
{
(ak: (bool, nat)) =>
GammaLe1LoopTerminatesAlmostSurely(gamma);
//Equivalence.GammaLe1LoopTerminatesAlmostSurely(gamma);
Monad.While(
GammaLe1LoopCondition,
GammaLe1LoopIter(gamma),
Expand All @@ -100,9 +100,4 @@ module BernoulliExpNeg.Model {
a => Monad.Return((a, k))
}

lemma {:axiom} GammaLe1LoopTerminatesAlmostSurely(gamma: Rationals.Rational)
requires 0 <= gamma.numer <= gamma.denom
ensures Monad.WhileTerminatesAlmostSurely(GammaLe1LoopCondition, GammaLe1LoopIter(gamma))


}
2 changes: 1 addition & 1 deletion src/Distributions/Coin/Interface.dfy
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ module Coin.Interface {

method CoinSample() returns (b: bool)
modifies this
ensures Model.Sample(old(s)) == Monad.Result(b, s)
ensures Model.Sample()(old(s)) == Monad.Result(b, s)
{
b := ExternCoinSample();
assume {:axiom} false; // assume correctness of extern implementation
Expand Down
6 changes: 3 additions & 3 deletions src/Distributions/Coin/Model.dfy
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

module Coin.Model {
import Rand
import Monad
import Monad`Spec

function Sample(s: Rand.Bitstream): Monad.Result<bool> {
Monad.Coin(s)
function Sample(): Monad.Hurd<bool> {
Monad.Coin()
}
}
18 changes: 9 additions & 9 deletions src/Distributions/UniformPowerOfTwo/Correctness.dfy
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ module UniformPowerOfTwo.Correctness {
forall b: bool ensures Monad.IsIndep(g(b)) {
Monad.ReturnIsIndep((if b then 2 * m + 1 else 2 * m) as nat);
}
Monad.BindIsIndep(Monad.Coin, g);
Monad.BindIsIndep(Monad.Coin(), g);
}
Monad.BindIsIndep(Model.Sample(n / 2), Model.UnifStep);
}
Expand Down Expand Up @@ -259,15 +259,15 @@ module UniformPowerOfTwo.Correctness {
ensures Model.Sample(n)(s).rest == Rand.Tail(Model.Sample(n / 2)(s).rest)
{
var Result(a, s') := Model.Sample(n / 2)(s);
var Result(b, s'') := Monad.Coin(s');
var Result(b, s'') := Monad.Coin()(s');
calc {
Model.Sample(n)(s).rest;
== { reveal Model.Sample(); }
Monad.Bind(Model.Sample(n / 2), Model.UnifStep)(s).rest;
==
Model.UnifStep(a)(s').rest;
==
Monad.Bind(Monad.Coin, (b: bool) => Monad.Return((if b then 2*a + 1 else 2*a) as nat))(s').rest;
Monad.Bind(Monad.Coin(), (b: bool) => Monad.Return((if b then 2*a + 1 else 2*a) as nat))(s').rest;
==
Monad.Return((if b then 2*a + 1 else 2*a) as nat)(s'').rest;
==
Expand All @@ -282,23 +282,23 @@ module UniformPowerOfTwo.Correctness {
lemma SampleSetEquality(n: nat, m: nat)
requires n >= 2
ensures
var bOf := (s: Rand.Bitstream) => Monad.Coin(Model.Sample(n / 2)(s).rest).value;
var bOf := (s: Rand.Bitstream) => Monad.Coin()(Model.Sample(n / 2)(s).rest).value;
var aOf := (s: Rand.Bitstream) => Model.Sample(n / 2)(s).value;
(iset s | Model.Sample(n)(s).value == m) == (iset s | 2*aOf(s) + Helper.boolToNat(bOf(s)) == m)
{
var bOf := (s: Rand.Bitstream) => Monad.Coin(Model.Sample(n / 2)(s).rest).value;
var bOf := (s: Rand.Bitstream) => Monad.Coin()(Model.Sample(n / 2)(s).rest).value;
var aOf := (s: Rand.Bitstream) => Model.Sample(n / 2)(s).value;
forall s ensures Model.Sample(n)(s).value == m <==> (2 * aOf(s) + Helper.boolToNat(bOf(s)) == m) {
var Result(a, s') := Model.Sample(n / 2)(s);
var Result(b, s'') := Monad.Coin(s');
var Result(b, s'') := Monad.Coin()(s');
calc {
Model.Sample(n)(s).value;
== { reveal Model.Sample(); }
Monad.Bind(Model.Sample(n / 2), Model.UnifStep)(s).value;
==
Model.UnifStep(a)(s').value;
==
Monad.Bind(Monad.Coin, b => Monad.Return((if b then 2*a + 1 else 2*a) as nat))(s').value;
Monad.Bind(Monad.Coin(), b => Monad.Return((if b then 2*a + 1 else 2*a) as nat))(s').value;
==
Monad.Return((if b then 2*a + 1 else 2*a) as nat)(s'').value;
==
Expand All @@ -312,9 +312,9 @@ module UniformPowerOfTwo.Correctness {
ensures Rand.prob(iset s | Model.Sample(n)(s).value == m) == Rand.prob(iset s | Model.Sample(n / 2)(s).value == m / 2) / 2.0
{
var aOf: Rand.Bitstream -> nat := (s: Rand.Bitstream) => Model.Sample(n / 2)(s).value;
var bOf: Rand.Bitstream -> bool := (s: Rand.Bitstream) => Monad.Coin(Model.Sample(n / 2)(s).rest).value;
var bOf: Rand.Bitstream -> bool := (s: Rand.Bitstream) => Monad.Coin()(Model.Sample(n / 2)(s).rest).value;
var A: iset<nat> := (iset x: nat | x == m / 2);
var E: iset<Rand.Bitstream> := (iset s | m % 2 as nat == Helper.boolToNat(Monad.Coin(s).value));
var E: iset<Rand.Bitstream> := (iset s | m % 2 as nat == Helper.boolToNat(Monad.Coin()(s).value));
var f := (s: Rand.Bitstream) => Model.Sample(n / 2)(s).rest;

var e1 := (iset s | Model.Sample(n / 2)(s).RestIn(E));
Expand Down
6 changes: 3 additions & 3 deletions src/Distributions/UniformPowerOfTwo/Model.dfy
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@ module UniformPowerOfTwo.Model {
import Helper
import Rand
import Quantifier
import Monad
import Monad`Spec

// Adapted from Definition 48 (see issue #79 for the reason of the modification)
// The return value u is uniformly distributed between 0 <= u < 2^k where 2^k <= n < 2^(k + 1).
opaque function Sample(n: nat): (h: Monad.Hurd<nat>)
requires n >= 1
ensures forall s :: Sample(n)(s).Result? // always terminates, not just almost surely
//ensures forall s :: Sample(n)(s).Result? // always terminates, not just almost surely
{
if n == 1 then
Monad.Return(0)
Expand All @@ -26,6 +26,6 @@ module UniformPowerOfTwo.Model {
}

function UnifStep(m: nat): Monad.Hurd<nat> {
Monad.Bind(Monad.Coin, UnifStepHelper(m))
Monad.Bind(Monad.Coin(), UnifStepHelper(m))
}
}
25 changes: 19 additions & 6 deletions src/ProbabilisticProgramming/Monad.dfy
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ module Monad {
Return,
Bind,
Map,
Until
Until,
While,
Coin

export reveals *

Expand Down Expand Up @@ -108,7 +110,9 @@ module Monad {
}

// Equation (2.42)
const Coin: Hurd<bool> := s => Result(Rand.Head(s), Rand.Tail(s))
function Coin(): Hurd<bool> {
s => Result(Rand.Head(s), Rand.Tail(s))
}

function Composition<A,B,C>(f: A -> Hurd<B>, g: B -> Hurd<C>): A -> Hurd<C> {
(a: A) => Bind(f(a), g)
Expand Down Expand Up @@ -263,7 +267,7 @@ module Monad {

// Equation (3.17)
lemma {:axiom} CoinIsIndep()
ensures IsIndep(Coin)
ensures IsIndep(Coin())

// Equation (3.18)
lemma {:axiom} ReturnIsIndep<T>(x: T)
Expand Down Expand Up @@ -317,9 +321,7 @@ module Monad {
// For proofs, use the lemma `WhileUnroll`.
// Equation (3.25), but modified to use `Monad.Diverging` instead of HOL's `arb` in case of nontermination
// TODO: While(condition, body)(init) would be cleaner
opaque ghost function While<A>(condition: A -> bool, body: A -> Hurd<A>, init: A): (f: Hurd<A>)
ensures forall s: Rand.Bitstream :: !condition(init) ==> f(s) == Return(init)(s)
{
opaque ghost function While<A>(condition: A -> bool, body: A -> Hurd<A>, init: A): (f: Hurd<A>) {
var f :=
(s: Rand.Bitstream) =>
if WhileCutTerminates(condition, body, init, s)
Expand All @@ -338,6 +340,14 @@ module Monad {
f
}

lemma AboutWhile<A>(condition: A -> bool, body: A -> Hurd<A>, init: A)
ensures
var f := While(condition, body, init);
forall s: Rand.Bitstream :: !condition(init) ==> f(s) == Return(init)(s)
{
assume {:axiom} false; // prove, did hold before as postcondition
}

ghost function LeastFuel<A>(condition: A -> bool, body: A -> Hurd<A>, init: A, s: Rand.Bitstream): (fuel: nat)
requires WhileCutTerminates(condition, body, init, s)
ensures WhileCutTerminatesWithFuel(condition, body, init, s)(fuel)
Expand Down Expand Up @@ -522,6 +532,7 @@ module Monad {
} else {
calc {
loop;
{ AboutWhile(condition, body, init); }
Result(init, s);
unrolled;
}
Expand All @@ -537,8 +548,10 @@ module Monad {
reveal While();
match body(init)(s)
case Diverging =>
AboutWhile(condition, body, init);
assert unrolled == Diverging;
case Result(init', s') =>
AboutWhile(condition, body, init);
assert !WhileCutTerminates(condition, body, init', s') by {
WhileCutTerminatesUnroll(condition, body, init, s, init', s');
}
Expand Down

0 comments on commit 8cbe62a

Please sign in to comment.