Skip to content

Commit

Permalink
Revert Result Type, Proof of ResultsIndependent, Reduce Brittleness i…
Browse files Browse the repository at this point in the history
…n FisherYates (#161)

By submitting this pull request, I confirm that my contribution is made
under the terms of the [MIT
license](https://github.com/dafny-lang/dafny/blob/master/LICENSE.txt).
  • Loading branch information
stefan-aws authored Mar 6, 2024
1 parent 9464263 commit 2697703
Show file tree
Hide file tree
Showing 7 changed files with 104 additions and 89 deletions.
6 changes: 3 additions & 3 deletions audit.log
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
src/Distributions/Uniform/Correctness.dfy(31,17): UniformFullCorrectness: Declaration has explicit `{:axiom}` attribute.
src/Distributions/Uniform/Correctness.dfy(36,17): SampleCoin: Declaration has explicit `{:axiom}` attribute.
src/Distributions/Uniform/Model.dfy(19,33): Sample: Declaration has explicit `{:axiom}` attribute.
src/Distributions/Uniform/Model.dfy(46,17): IntervalSampleIsMeasurePreserving: Declaration has explicit `{:axiom}` attribute.
src/ProbabilisticProgramming/Independence.dfy(30,27): IsIndep: Declaration has explicit `{:axiom}` attribute.
src/ProbabilisticProgramming/Independence.dfy(55,6): ResultsIndependent: Definition has `assume {:axiom}` statement in body.
src/ProbabilisticProgramming/Independence.dfy(60,17): IsIndepImpliesIsIndepFunction: Declaration has explicit `{:axiom}` attribute.
src/ProbabilisticProgramming/Independence.dfy(64,17): MapIsIndep: Declaration has explicit `{:axiom}` attribute.
src/ProbabilisticProgramming/Independence.dfy(70,17): IsIndepImpliesIsIndepFunction: Declaration has explicit `{:axiom}` attribute.
src/ProbabilisticProgramming/Independence.dfy(74,17): MapIsIndep: Declaration has explicit `{:axiom}` attribute.
src/ProbabilisticProgramming/RandomSource.dfy(50,17): ProbIsProbabilityMeasure: Declaration has explicit `{:axiom}` attribute.
8 changes: 5 additions & 3 deletions src/Distributions/Uniform/Model.dfy
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ module Uniform.Model {
opaque ghost function {:axiom} Sample(n: nat): (h: Monad.Hurd<nat>)
requires n > 0
ensures Independence.IsIndep(h)
ensures forall s :: h(s).Result? ==> 0 <= h(s).value < n
ensures forall s :: 0 <= h(s).value < n

ghost function IntervalSample(a: int, b: int): (f: Monad.Hurd<int>)
requires a < b
Expand All @@ -33,16 +33,18 @@ module Uniform.Model {

lemma SampleBound(n: nat, s: Rand.Bitstream)
requires n > 0
requires Sample(n)(s).Result?
ensures 0 <= Sample(n)(s).value < n
{}

lemma IntervalSampleBound(a: int, b: int, s: Rand.Bitstream)
requires a < b
requires IntervalSample(a, b)(s).Result?
ensures a <= IntervalSample(a, b)(s).value < b
{
SampleBound(b-a, s);
}

lemma {:axiom} IntervalSampleIsMeasurePreserving(a: int, b: int)
requires a < b
ensures Measures.IsMeasurePreserving(Rand.eventSpace, Rand.prob, Rand.eventSpace, Rand.prob, s => IntervalSample(a, b)(s).rest)

}
12 changes: 11 additions & 1 deletion src/ProbabilisticProgramming/Independence.dfy
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ module Independence {
)
requires hIndep: IsIndepFunction(h)
requires bMeasurable: bSeeds in Rand.eventSpace
requires hIsMeasurePreserving: Measures.IsMeasurePreserving(Rand.eventSpace, Rand.prob, Rand.eventSpace, Rand.prob, s => h(s).rest)
ensures Rand.prob(Monad.BitstreamsWithValueIn(h, aSet) * Monad.BitstreamsWithRestIn(h, bSeeds)) == Rand.prob(Monad.BitstreamsWithValueIn(h, aSet)) * Rand.prob(bSeeds)
{
var aSeeds := Monad.BitstreamsWithValueIn(h, aSet);
Expand All @@ -52,7 +53,16 @@ module Independence {
assert Measures.AreIndepEvents(Rand.eventSpace, Rand.prob, aSeeds, restBSeeds);
}
assert Rand.prob(restBSeeds) == Rand.prob(bSeeds) by {
assume {:axiom} false; // TODO
calc {
Rand.prob(restBSeeds);
Rand.prob(Monad.BitstreamsWithRestIn(h, bSeeds));
{ assert Monad.BitstreamsWithRestIn(h, bSeeds) == iset s | h(s).rest in bSeeds; }
Rand.prob(iset s | h(s).rest in bSeeds);
{ assert (iset s | h(s).rest in bSeeds) == Measures.PreImage(s => h(s).rest, bSeeds); }
Rand.prob(Measures.PreImage(s => h(s).rest, bSeeds));
{ reveal bMeasurable; reveal hIsMeasurePreserving; }
Rand.prob(bSeeds);
}
}
}

Expand Down
59 changes: 11 additions & 48 deletions src/ProbabilisticProgramming/Monad.dfy
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,14 @@ module Monad {
// The result of a probabilistic computation on a bitstream.
// It either consists of the computed value and the (unconsumed) rest of the bitstream or indicates nontermination.
// It differs from Hurd's definition in that the result can be nontermination, which Hurd does not model explicitly.
datatype Result<A> =
| Result(value: A, rest: Rand.Bitstream)
| Diverging
datatype Result<A> = Result(value: A, rest: Rand.Bitstream)
{
function Map<B>(f: A -> B): Result<B> {
match this
case Diverging => Diverging
case Result(value, rest) => Result(f(value), rest)
Result(f(value), rest)
}

function Bind<B>(f: A -> Hurd<B>): Result<B> {
match this
case Diverging => Diverging
case Result(value, rest) => f(value)(rest)
f(value)(rest)
}

ghost predicate In(s: iset<A>) {
Expand All @@ -43,44 +37,30 @@ module Monad {
}

predicate Satisfies(property: A -> bool) {
match this
case Diverging => false
case Result(value, _) => property(value)
property(value)
}

ghost predicate RestIn(s: iset<Rand.Bitstream>) {
RestSatisfies(r => r in s)
}

predicate RestSatisfies(property: Rand.Bitstream -> bool) {
match this
case Diverging => false
case Result(_, rest) => property(rest)
property(rest)
}

predicate IsFailure() {
Diverging?
}

function PropagateFailure<B>(): Result<B>
requires Diverging?
{
Diverging
}

function Extract(): (A, Rand.Bitstream)
requires Result?
function Extract(): (x: (A, Rand.Bitstream))
ensures this == Result(x.0, x.1)
{
(this.value, this.rest)
}
}

ghost function Values<A>(results: iset<Result<A>>): iset<A> {
iset r <- results | r.Result? :: r.value
iset r <- results :: r.value
}

ghost function Rests<A>(results: iset<Result<A>>): iset<Rand.Bitstream> {
iset r <- results | r.Result? :: r.rest
iset r <- results :: r.rest
}

ghost function ResultEventSpace<A(!new)>(eventSpace: iset<iset<A>>): iset<iset<Result<A>>> {
Expand All @@ -92,11 +72,11 @@ module Monad {
ghost const natResultEventSpace: iset<iset<Result<nat>>> := ResultEventSpace(Measures.natEventSpace)

ghost function ResultsWithValueIn<A(!new)>(values: iset<A>): iset<Result<A>> {
iset result: Result<A> | result.Result? && result.value in values
iset result: Result<A> | result.value in values
}

ghost function ResultsWithRestIn<A(!new)>(rests: iset<Rand.Bitstream>): iset<Result<A>> {
iset result: Result<A> | result.Result? && result.rest in rests
iset result: Result<A> | result.rest in rests
}

ghost function BitstreamsWithValueIn<A(!new)>(h: Hurd<A>, aSet: iset<A>): iset<Rand.Bitstream> {
Expand All @@ -112,30 +92,13 @@ module Monad {
(s: Rand.Bitstream) => f(s).Bind(g)
}

function BindAlternative<A,B>(f: Hurd<A>, g: A -> Hurd<B>): (h: Hurd<B>)
ensures forall s :: h(s) == Bind(f, g)(s)
{
(s: Rand.Bitstream) =>
var (a, s') :- f(s);
g(a)(s')
}

// Equation (2.42)
const Coin: Hurd<bool> := s => Result(Rand.Head(s), Rand.Tail(s))

function Composition<A,B,C>(f: A -> Hurd<B>, g: B -> Hurd<C>): A -> Hurd<C> {
(a: A) => Bind(f(a), g)
}

function CompositionAlternative<A(!new),B,C>(f: A -> Hurd<B>, g: B -> Hurd<C>): (h: A -> Hurd<C>)
ensures forall a, s :: h(a)(s) == Composition(f, g)(a)(s)
{
(a: A) =>
(s: Rand.Bitstream) =>
var (b, s') :- f(a)(s);
g(b)(s')
}

// Equation (3.3)
function Return<A>(a: A): Hurd<A> {
(s: Rand.Bitstream) => Result(a, s)
Expand Down
97 changes: 67 additions & 30 deletions src/Util/FisherYates/Correctness.dfy
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ module FisherYates.Correctness {
requires i <= |xs|
requires i <= |p|
{
iset s | Model.Shuffle(xs, i)(s).Result? && Model.Shuffle(xs, i)(s).value[i..] == p[i..]
iset s | Model.Shuffle(xs, i)(s).value[i..] == p[i..]
}

ghost predicate CorrectnessPredicate<T(!new)>(xs: seq<T>, p: seq<T>, i: nat)
Expand Down Expand Up @@ -64,7 +64,7 @@ module FisherYates.Correctness {
{
var e := iset s | Model.Shuffle(xs)(s).Equals(p);
var i := 0;
var e' := iset s | Model.Shuffle(xs, i)(s).Result? && Model.Shuffle(xs, i)(s).value[i..] == p[i..];
var e' := iset s | Model.Shuffle(xs, i)(s).value[i..] == p[i..];
assert e == e';
assert |xs| == |p| by {
Model.PermutationsPreserveCardinality(xs, p);
Expand All @@ -81,7 +81,7 @@ module FisherYates.Correctness {
requires multiset(p[i..]) == multiset(xs[i..])
ensures CorrectnessPredicate(xs, p, i)
{
var e := iset s | Model.Shuffle(xs, i)(s).Result? && Model.Shuffle(xs, i)(s).value[i..] == p[i..];
var e := iset s | Model.Shuffle(xs, i)(s).value[i..] == p[i..];
if |xs[i..]| <= 1 {
CorrectnessFisherYatesUniqueElementsGeneralLeq1(xs, p, i);
} else {
Expand All @@ -100,16 +100,16 @@ module FisherYates.Correctness {
ensures CorrectnessPredicate(xs, p, i)
{
Model.PermutationsPreserveCardinality(p[i..], xs[i..]);
var e := iset s | Model.Shuffle(xs, i)(s).Result? && Model.Shuffle(xs, i)(s).value[i..] == p[i..];
var e := iset s | Model.Shuffle(xs, i)(s).value[i..] == p[i..];
assert e == Measures.SampleSpace() by {
forall s
ensures s in e
{
calc {
s in e;
Model.Shuffle(xs, i)(s).Result? && Model.Shuffle(xs, i)(s).value[i..] == p[i..];
Model.Shuffle(xs, i)(s).value[i..] == p[i..];
{ assert Model.Shuffle(xs, i)(s) == Monad.Return(xs)(s); }
Monad.Return(xs)(s).Result? && Monad.Return(xs)(s).value[i..] == p[i..];
Monad.Return(xs)(s).value[i..] == p[i..];
{ assert Monad.Return(xs)(s).value == xs; }
xs[i..] == p[i..];
if |xs[i..]| == 0 then [] == p[i..] else [xs[i]] == p[i..];
Expand Down Expand Up @@ -147,9 +147,12 @@ module FisherYates.Correctness {
ensures CorrectnessPredicate(xs, p, i)
{
Model.PermutationsPreserveCardinality(p[i..], xs[i..]);
var e := iset s | Model.Shuffle(xs, i)(s).Result? && Model.Shuffle(xs, i)(s).value[i..] == p[i..];
var e := iset s | Model.Shuffle(xs, i)(s).value[i..] == p[i..];
assert |xs| > i + 1;
var h := Uniform.Model.IntervalSample(i, |xs|);
assert hIsMeasurePreserving: Measures.IsMeasurePreserving(Rand.eventSpace, Rand.prob, Rand.eventSpace, Rand.prob, s => h(s).rest) by {
Uniform.Model.IntervalSampleIsMeasurePreserving(i, |xs|);
}
assert HIsIndependent: Independence.IsIndepFunction(h) by {
Uniform.Correctness.IntervalSampleIsIndep(i, |xs|);
Independence.IsIndepImpliesIsIndepFunction(h);
Expand Down Expand Up @@ -238,33 +241,68 @@ module FisherYates.Correctness {
assert DecomposeE: e == Monad.BitstreamsWithValueIn(h, A) * Monad.BitstreamsWithRestIn(h, e') by {
DecomposeE(xs, ys, p, i, j, h, A, e, e');
}

assert CorrectnessPredicate(xs, p, i) by {
reveal DecomposeE;
reveal HIsIndependent;
reveal BitStreamsInA;
assert e' in Rand.eventSpace && Rand.prob(e') == 1.0 / (NatArith.FactorialTraditional(|xs|-(i+1)) as real) by {
assert e' in Rand.eventSpace by {
assert CorrectnessPredicate(ys, p, i+1) by { reveal InductionHypothesis; }
assert e' == CorrectnessConstructEvent(ys, p, i+1);
}
calc {
Rand.prob(e');
{ assert CorrectnessPredicate(ys, p, i+1) by { reveal InductionHypothesis; }
assert e' == CorrectnessConstructEvent(ys, p, i+1); }
1.0 / (NatArith.FactorialTraditional(|ys|-(i+1)) as real);
{ assert |xs| == |ys|;
assert |ys|-(i+1) == |xs|-(i+1);
assert NatArith.FactorialTraditional(|ys|-(i+1)) == NatArith.FactorialTraditional(|xs|-(i+1));
assert (NatArith.FactorialTraditional(|ys|-(i+1)) as real) == (NatArith.FactorialTraditional(|xs|-(i+1)) as real); }
1.0 / (NatArith.FactorialTraditional(|xs|-(i+1)) as real);
}
reveal InductionHypothesis;
reveal hIsMeasurePreserving;
CorrectnessFisherYatesUniqueElementsGeneralGreater1Helper(xs, ys, p, i, j, h, A, e, e');
}

}

lemma CorrectnessFisherYatesUniqueElementsGeneralGreater1Helper<T(!new)>(xs: seq<T>, ys: seq<T>, p: seq<T>, i: nat, j: nat, h: Monad.Hurd<int>, A: iset<int>, e: iset<Rand.Bitstream>, e': iset<Rand.Bitstream>)
decreases |xs| - i
requires i <= |xs|
requires i <= |p|
requires forall a, b | i <= a < b < |xs| :: xs[a] != xs[b]
requires |xs| == |p|
requires multiset(p[i..]) == multiset(xs[i..])
requires |xs[i..]| > 1
requires i <= j < |xs| && xs[j] == p[i]
requires |xs| == |ys|
requires ys == Model.Swap(xs, i, j)
requires e == CorrectnessConstructEvent(xs, p, i)
requires e' == CorrectnessConstructEvent(ys, p, i+1)
requires DecomposeE: e == Monad.BitstreamsWithValueIn(h, A) * Monad.BitstreamsWithRestIn(h, e')
requires HIsIndependent: Independence.IsIndepFunction(h)
requires BitStreamsInA: Monad.BitstreamsWithValueIn(h, A) == (iset s | Uniform.Model.IntervalSample(i, |xs|)(s).Equals(j))
requires InductionHypothesis: CorrectnessPredicate(ys, p, i+1)
requires hIsMeasurePreserving: Measures.IsMeasurePreserving(Rand.eventSpace, Rand.prob, Rand.eventSpace, Rand.prob, s => h(s).rest)
ensures CorrectnessPredicate(xs, p, i)
{
reveal DecomposeE;
reveal HIsIndependent;
reveal BitStreamsInA;
assert e' in Rand.eventSpace && Rand.prob(e') == 1.0 / (NatArith.FactorialTraditional(|xs|-(i+1)) as real) by {
assert e' in Rand.eventSpace by {
assert CorrectnessPredicate(ys, p, i+1) by { reveal InductionHypothesis; }
assert e' == CorrectnessConstructEvent(ys, p, i+1);
}
ProbabilityOfE(xs, ys, p, i, j, h, A, e, e');
calc {
Rand.prob(e');
{ assert CorrectnessPredicate(ys, p, i+1) by { reveal InductionHypothesis; }
assert e' == CorrectnessConstructEvent(ys, p, i+1); }
1.0 / (NatArith.FactorialTraditional(|ys|-(i+1)) as real);
{ assert |xs| == |ys|;
assert |ys|-(i+1) == |xs|-(i+1);
assert NatArith.FactorialTraditional(|ys|-(i+1)) == NatArith.FactorialTraditional(|xs|-(i+1));
assert (NatArith.FactorialTraditional(|ys|-(i+1)) as real) == (NatArith.FactorialTraditional(|xs|-(i+1)) as real); }
1.0 / (NatArith.FactorialTraditional(|xs|-(i+1)) as real);
}
}
reveal hIsMeasurePreserving;
assert e in Rand.eventSpace by {
EInEventSpace(xs, p, h, A, e, e');
}

assert Rand.prob(e) == 1.0 / (NatArith.FactorialTraditional(|xs|-i) as real) by {
ProbabilityOfE(xs, ys, p, i, j, h, A, e, e');
}
}


lemma BitStreamsInA<T(!new)>(xs: seq<T>, p: seq<T>, i: nat, j: nat, h: Monad.Hurd<int>, A: iset<int>)
requires i <= |xs|
requires i <= |p|
Expand Down Expand Up @@ -302,8 +340,8 @@ module FisherYates.Correctness {
requires A == iset{j}
requires h == Uniform.Model.IntervalSample(i, |xs|)
requires ys == Model.Swap(xs, i, j)
requires e == iset s | Model.Shuffle(xs, i)(s).Result? && Model.Shuffle(xs, i)(s).value[i..] == p[i..]
requires e' == iset s | Model.Shuffle(ys, i+1)(s).Result? && Model.Shuffle(ys, i+1)(s).value[i+1..] == p[i+1..]
requires e == iset s | Model.Shuffle(xs, i)(s).value[i..] == p[i..]
requires e' == iset s | Model.Shuffle(ys, i+1)(s).value[i+1..] == p[i+1..]
ensures e == Monad.BitstreamsWithValueIn(h, A) * Monad.BitstreamsWithRestIn(h, e')
{
assert forall s :: s in e <==> s in Monad.BitstreamsWithValueIn(h, A) * Monad.BitstreamsWithRestIn(h, e') by {
Expand All @@ -313,7 +351,6 @@ module FisherYates.Correctness {
if s in e {
var zs := Model.Shuffle(xs, i)(s).value;
assert zs[i..] == p[i..];
assert h(s).Result?;
var k := Uniform.Model.IntervalSample(i, |xs|)(s).value;
Uniform.Model.IntervalSampleBound(i, |xs|, s);
var s' := Uniform.Model.IntervalSample(i, |xs|)(s).rest;
Expand All @@ -332,7 +369,6 @@ module FisherYates.Correctness {
assert k in A;
}
assert s in Monad.BitstreamsWithRestIn(h, e') by {
assert Model.Shuffle(ys, i+1)(s').Result?;
assert Model.Shuffle(ys, i+1)(s').value[i+1..] == p[i+1..];
}
assert s in Monad.BitstreamsWithValueIn(h, A) * Monad.BitstreamsWithRestIn(h, e');
Expand Down Expand Up @@ -491,6 +527,7 @@ module FisherYates.Correctness {
requires |xs| == |ys|
requires DecomposeE: e == Monad.BitstreamsWithValueIn(h, A) * Monad.BitstreamsWithRestIn(h, e')
requires HIsIndependent: Independence.IsIndepFunction(h)
requires hIsMeasurePreserving: Measures.IsMeasurePreserving(Rand.eventSpace, Rand.prob, Rand.eventSpace, Rand.prob, s => h(s).rest)
requires InductionHypothesis: Rand.prob(e') == 1.0 / (NatArith.FactorialTraditional(|xs|-(i+1)) as real)
requires BitStreamsInA: Monad.BitstreamsWithValueIn(h, A) == (iset s | Uniform.Model.IntervalSample(i, |xs|)(s).Equals(j))
ensures
Expand All @@ -500,7 +537,7 @@ module FisherYates.Correctness {
Rand.prob(e);
{ reveal DecomposeE; }
Rand.prob(Monad.BitstreamsWithValueIn(h, A) * Monad.BitstreamsWithRestIn(h, e'));
{ reveal HIsIndependent; reveal InductionHypothesis; Independence.ResultsIndependent(h, A, e'); }
{ reveal HIsIndependent; reveal InductionHypothesis; reveal hIsMeasurePreserving; Independence.ResultsIndependent(h, A, e'); }
Rand.prob(Monad.BitstreamsWithValueIn(h, A)) * Rand.prob(e');
{ assert Rand.prob(Monad.BitstreamsWithValueIn(h, A)) == Rand.prob(iset s | Uniform.Model.IntervalSample(i, |xs|)(s).Equals(j)) by { reveal BitStreamsInA; } }
Rand.prob(iset s | Uniform.Model.IntervalSample(i, |xs|)(s).Equals(j)) * Rand.prob(e');
Expand Down
3 changes: 3 additions & 0 deletions src/Util/FisherYates/Implementation.dfy
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@ module FisherYates.Implementation {
assert prevASeq == a[..]; // ghost
Swap(a, i, j);
}
} else {
assert prevASeq == a[..]; // ghost
}

}

method Swap<T>(a: array<T>, i: nat, j: nat)
Expand Down
Loading

0 comments on commit 2697703

Please sign in to comment.