Skip to content

Commit

Permalink
Make samplers and While loops opaque (#100)
Browse files Browse the repository at this point in the history
Reason: For while loops, the definition is not very helpful, it's better
to rely on lemmas around it for proofs. When calling samplers, we should
also not rely on their internals, but on their correctness properties.
It also reduces resource usage.

As a concession, this PR also adds a postcondition for `While` to avoid
having to call a lemma to handle the case of a while loop immediately
exiting.

By submitting this pull request, I confirm that my contribution is made
under the terms of the [MIT
license](https://github.com/dafny-lang/dafny/blob/master/LICENSE.txt).
  • Loading branch information
fzaiser authored Oct 23, 2023
1 parent 214f283 commit ec4e2a1
Show file tree
Hide file tree
Showing 11 changed files with 40 additions and 17 deletions.
12 changes: 6 additions & 6 deletions audit.log
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ src/Distributions/BernoulliExpNeg/Implementation.dfy(53,6): BernoulliExpNegSampl
src/Distributions/BernoulliExpNeg/Model.dfy(30,4): GammaReductionLoop: Definition has `assume {:axiom}` statement in body.
src/Distributions/BernoulliExpNeg/Model.dfy(60,4): GammaLe1Loop: Definition has `assume {:axiom}` statement in body.
src/Distributions/Coin/Interface.dfy(21,6): CoinSample: Definition has `assume {:axiom}` statement in body.
src/Distributions/Uniform/Implementation.dfy(45,6): UniformSample: Definition has `assume {:axiom}` statement in body.
src/Distributions/Uniform/Implementation.dfy(46,6): UniformSample: Definition has `assume {:axiom}` statement in body.
src/Math/Analysis/Reals.dfy(35,17): LeastUpperBoundProperty: Declaration has explicit `{:axiom}` attribute.
src/Math/Exponential.dfy(11,17): EvalOne: Declaration has explicit `{:axiom}` attribute.
src/Math/Exponential.dfy(2,26): Exp: Declaration has explicit `{:axiom}` attribute.
Expand All @@ -21,11 +21,11 @@ src/ProbabilisticProgramming/Independence.dfy(46,17): IsIndepImpliesIsIndepFunct
src/ProbabilisticProgramming/Independence.dfy(51,17): CoinIsIndep: Declaration has explicit `{:axiom}` attribute.
src/ProbabilisticProgramming/Independence.dfy(55,17): ReturnIsIndep: Declaration has explicit `{:axiom}` attribute.
src/ProbabilisticProgramming/Independence.dfy(59,17): BindIsIndep: Declaration has explicit `{:axiom}` attribute.
src/ProbabilisticProgramming/Loops.dfy(229,8): WhileUnroll: Definition has `assume {:axiom}` statement in body.
src/ProbabilisticProgramming/Loops.dfy(257,17): EnsureWhileTerminates: Declaration has explicit `{:axiom}` attribute.
src/ProbabilisticProgramming/Loops.dfy(263,17): UntilProbabilityFraction: Declaration has explicit `{:axiom}` attribute.
src/ProbabilisticProgramming/Loops.dfy(294,4): EnsureUntilTerminatesAndForAll: Definition has `assume {:axiom}` statement in body.
src/ProbabilisticProgramming/Loops.dfy(317,4): WhileIsIndep: Definition has `assume {:axiom}` statement in body.
src/ProbabilisticProgramming/Loops.dfy(237,8): WhileUnroll: Definition has `assume {:axiom}` statement in body.
src/ProbabilisticProgramming/Loops.dfy(265,17): EnsureWhileTerminates: Declaration has explicit `{:axiom}` attribute.
src/ProbabilisticProgramming/Loops.dfy(271,17): UntilProbabilityFraction: Declaration has explicit `{:axiom}` attribute.
src/ProbabilisticProgramming/Loops.dfy(302,4): EnsureUntilTerminatesAndForAll: Definition has `assume {:axiom}` statement in body.
src/ProbabilisticProgramming/Loops.dfy(325,4): WhileIsIndep: Definition has `assume {:axiom}` statement in body.
src/ProbabilisticProgramming/Monad.dfy(140,17): CoinHasProbOneHalf: Declaration has explicit `{:axiom}` attribute.
src/ProbabilisticProgramming/Monad.dfy(147,17): MeasureHeadDrop: Declaration has explicit `{:axiom}` attribute.
src/ProbabilisticProgramming/Monad.dfy(153,17): TailIsMeasurePreserving: Declaration has explicit `{:axiom}` attribute.
Expand Down
1 change: 1 addition & 0 deletions src/Distributions/Bernoulli/Correctness.dfy
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ module Bernoulli.Correctness {
}

Independence.BindIsIndep(f, g);
reveal Model.Sample();
}


Expand Down
1 change: 1 addition & 0 deletions src/Distributions/Bernoulli/Implementation.dfy
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ module Bernoulli.Implementation {
requires 0 <= p.numer <= p.denom
ensures Model.Sample(p.numer, p.denom)(old(s)) == Monad.Result(c, s)
{
reveal Model.Sample();
var k := UniformSample(p.denom);
c := k < p.numer;
}
Expand Down
2 changes: 1 addition & 1 deletion src/Distributions/Bernoulli/Model.dfy
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ module Bernoulli.Model {
import Monad

// Footnote 5, p. 82
ghost function Sample(numer: nat, denom: nat): (f: Monad.Hurd<bool>)
opaque ghost function Sample(numer: nat, denom: nat): (f: Monad.Hurd<bool>)
requires denom != 0
requires numer <= denom
ensures forall s :: f(s).value == (Uniform.Model.Sample(denom)(s).value < numer)
Expand Down
6 changes: 4 additions & 2 deletions src/Distributions/BernoulliExpNeg/Model.dfy
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ module BernoulliExpNeg.Model {
import Loops
import BernoulliModel = Bernoulli.Model

ghost function Sample(gamma: Rationals.Rational): Monad.Hurd<bool>
opaque ghost function Sample(gamma: Rationals.Rational): Monad.Hurd<bool>
requires gamma.denom != 0
requires gamma.numer >= 0
{
Expand Down Expand Up @@ -70,7 +70,9 @@ module BernoulliExpNeg.Model {
{
Monad.Bind(
BernoulliModel.Sample(gamma.numer, (ak.1 + 1) * gamma.denom),
(a': bool) => Monad.Return((a', ak.1 + 1))
(a': bool) =>
var res: Monad.Hurd<(bool, nat)> := Monad.Return((a', ak.1 + 1));
res
)
}
}
2 changes: 2 additions & 0 deletions src/Distributions/Uniform/Correctness.dfy
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ module Uniform.Correctness {

assert Eq: eventResultEqualsI == SampleEquals(n, i) by {
forall s ensures s in eventResultEqualsI <==> s in SampleEquals(n, i) {
reveal Model.Sample();
assert s in eventResultEqualsI <==> s in SampleEquals(n, i);
}
}
Expand Down Expand Up @@ -166,5 +167,6 @@ module Uniform.Correctness {
Model.SampleTerminates(n);
}
Loops.UntilIsIndep(Model.Proposal(n), Model.Accept(n));
reveal Model.Sample();
}
}
1 change: 1 addition & 0 deletions src/Distributions/Uniform/Implementation.dfy
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ module Uniform.Implementation {
prevS := s;
u := UniformPowerOfTwoSample(2 * n);
}
reveal Model.Sample();
}
}

Expand Down
3 changes: 2 additions & 1 deletion src/Distributions/Uniform/Model.dfy
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ module Uniform.Model {
import UniformPowerOfTwo

// Definition 49
ghost function Sample(n: nat): Monad.Hurd<nat>
opaque ghost function Sample(n: nat): Monad.Hurd<nat>
requires n > 0
{
SampleTerminates(n);
Expand Down Expand Up @@ -88,6 +88,7 @@ module Uniform.Model {
ensures Sample(n)(s) == Monad.Bind(Proposal(n), (x: nat) => if Accept(n)(x) then Monad.Return(x) else Sample(n))(s)
{
SampleTerminates(n);
reveal Sample();
Loops.UntilUnroll(Proposal(n), Accept(n), s);
}
}
7 changes: 5 additions & 2 deletions src/Distributions/UniformPowerOfTwo/Correctness.dfy
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ module UniformPowerOfTwo.Correctness {
forall m: nat ensures UnifIsCorrect(n, k, m) {
assert n >= 1 by { Helper.PowerGreater0(2, k); }
if k == 0 {
reveal Model.Sample();
if m == 0 {
assert (iset s | Model.Sample(1)(s).value == m) == (iset s);
} else {
Expand Down Expand Up @@ -162,6 +163,7 @@ module UniformPowerOfTwo.Correctness {
ensures Independence.IsIndep(Model.Sample(n))
{
var fn := Model.Sample(n);
reveal Model.Sample();
if n == 1 {
Independence.ReturnIsIndep(0 as nat);
} else {
Expand Down Expand Up @@ -193,6 +195,7 @@ module UniformPowerOfTwo.Correctness {
if n == 1 {
forall e | e in Rand.eventSpace ensures Rand.prob(Measures.PreImage(f, e)) == Rand.prob(e) {
forall s: Rand.Bitstream ensures f(s) == s {
reveal Model.Sample();
assert f(s) == s;
}
Measures.PreImageIdentity(f, e);
Expand Down Expand Up @@ -251,7 +254,7 @@ module UniformPowerOfTwo.Correctness {
var Result(b, s'') := Monad.Coin(s');
calc {
Model.Sample(n)(s).rest;
==
== { reveal Model.Sample(); }
Monad.Bind(Model.Sample(n / 2), Model.UnifStep)(s).rest;
==
Model.UnifStep(a)(s').rest;
Expand Down Expand Up @@ -282,7 +285,7 @@ module UniformPowerOfTwo.Correctness {
var Result(b, s'') := Monad.Coin(s');
calc {
Model.Sample(n)(s).value;
==
== { reveal Model.Sample(); }
Monad.Bind(Model.Sample(n / 2), Model.UnifStep)(s).value;
==
Model.UnifStep(a)(s').value;
Expand Down
7 changes: 5 additions & 2 deletions src/Distributions/UniformPowerOfTwo/Model.dfy
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ module UniformPowerOfTwo.Model {

// Adapted from Definition 48 (see issue #79 for the reason of the modification)
// The return value u is uniformly distributed between 0 <= u < 2^k where 2^k <= n < 2^(k + 1).
function Sample(n: nat): (h: Monad.Hurd<nat>)
opaque function Sample(n: nat): (h: Monad.Hurd<nat>)
requires n >= 1
{
if n == 1 then
Expand All @@ -47,6 +47,7 @@ module UniformPowerOfTwo.Model {
ensures SampleTailRecursive(n)(s) == Sample(n)(s)
{
if n == 1 {
reveal Sample();
assert SampleTailRecursive(n)(s) == Sample(n)(s);
} else {
var k := Helper.Log2Floor(n);
Expand Down Expand Up @@ -102,9 +103,11 @@ module UniformPowerOfTwo.Model {
assert 1 <= n;
calc {
Sample(m)(s);
{ reveal Sample(); }
Monad.Bind(Sample(m / 2), UnifStep)(s);
{ SampleEqualIfSameLog2Floor(m / 2, n / 2, k - 1, s); }
Monad.Bind(Sample(n / 2), UnifStep)(s);
{ reveal Sample(); }
Sample(n)(s);
}
}
Expand Down Expand Up @@ -134,7 +137,7 @@ module UniformPowerOfTwo.Model {
{ assert Helper.Power(2, l) / 2 == Helper.Power(2, l - 1); reveal L1GreaterZero; }
(var Result(u', s') := Monad.Bind(Sample(Helper.Power(2, m)), UnifStep)(s);
SampleTailRecursive(Helper.Power(2, l - 1), u')(s'));
{ assert Helper.Power(2, m + 1) / 2 == Helper.Power(2, m); }
{ assert Helper.Power(2, m + 1) / 2 == Helper.Power(2, m); reveal Sample(); }
(var Result(u', s') := Sample(Helper.Power(2, m + 1))(s);
SampleTailRecursive(Helper.Power(2, l - 1), u')(s'));
Monad.Bind(Sample(Helper.Power(2, m + 1)), (u: nat) => SampleTailRecursive(Helper.Power(2, l - 1), u))(s);
Expand Down
15 changes: 12 additions & 3 deletions src/ProbabilisticProgramming/Loops.dfy
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,12 @@ module Loops {
forall init :: Quantifier.AlmostSurely(p(init))
}

// Definition of while loops.
// This definition is opaque because the details are not very useful.
// For proofs, use the lemma `WhileUnroll`.
// Equation (3.25)
ghost function While<A>(condition: A -> bool, body: A -> Monad.Hurd<A>, init: A): (f: Monad.Hurd<A>)
opaque ghost function While<A>(condition: A -> bool, body: A -> Monad.Hurd<A>, init: A): (f: Monad.Hurd<A>)
ensures forall s: Rand.Bitstream :: !condition(init) ==> f(s) == Monad.Return(init)(s)
{
(s: Rand.Bitstream) =>
if WhileCutTerminates(condition, body, init, s)
Expand Down Expand Up @@ -77,6 +81,8 @@ module Loops {
WhileTerminatesAlmostSurely(reject, body)
}

// Definition of until loops (rejection sampling).
// For proofs, use the lemma `UntilUnroll`.
// Definition 44
ghost function Until<A>(proposal: Monad.Hurd<A>, accept: A -> bool): (f: Monad.Hurd<A>)
requires UntilTerminatesAlmostSurely(proposal, accept)
Expand Down Expand Up @@ -179,9 +185,11 @@ module Loops {
assert loop == unrolled by {
calc {
loop;
{ reveal While(); }
WhileCut(condition, body, init, fuel)(s);
{ WhileCutUnroll(condition, body, init, s, init', s', fuel'); }
WhileCut(condition, body, init', fuel')(s');
{ reveal While(); }
unrolled;
}
}
Expand Down Expand Up @@ -324,9 +332,10 @@ module Loops {
{
var reject := (a: A) => !accept(a);
var body := (a: A) => proposal;
forall init: A {
var f := (init: A) => While(reject, body, init);
forall init: A ensures Independence.IsIndep(f(init)) {
WhileIsIndep(reject, body, init);
}
Independence.BindIsIndep(proposal, (init: A) => While(reject, body, init));
Independence.BindIsIndep(proposal, f);
}
}

0 comments on commit ec4e2a1

Please sign in to comment.