Skip to content

Commit

Permalink
fix fisher yates proofs
Browse files Browse the repository at this point in the history
  • Loading branch information
stefan-aws committed Mar 4, 2024
1 parent 46d6b86 commit 79e9093
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 31 deletions.
75 changes: 54 additions & 21 deletions src/Util/FisherYates/Correctness.dfy
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,26 @@ module FisherYates.Correctness {
import Independence
import RealArith

/************
Definitions
************/

ghost function CorrectnessConstructEvent<T(!new)>(xs: seq<T>, p: seq<T>, i: nat): iset<Rand.Bitstream>
requires i <= |xs|
requires i <= |p|
{
iset s | Model.Shuffle(xs, i)(s).Result? && Model.Shuffle(xs, i)(s).value[i..] == p[i..]
}

ghost predicate CorrectnessPredicate<T(!new)>(xs: seq<T>, p: seq<T>, i: nat)
requires i <= |xs|
requires i <= |p|
{
var e := CorrectnessConstructEvent(xs, p, i);
e in Rand.eventSpace
&& Rand.prob(e) == 1.0 / (NatArith.FactorialTraditional(|xs|-i) as real)
}

/*******
Lemmas
*******/
Expand Down Expand Up @@ -59,10 +79,7 @@ module FisherYates.Correctness {
requires forall a, b | i <= a < b < |xs| :: xs[a] != xs[b]
requires |xs| == |p|
requires multiset(p[i..]) == multiset(xs[i..])
ensures
var e := iset s | Model.Shuffle(xs, i)(s).Result? && Model.Shuffle(xs, i)(s).value[i..] == p[i..];
e in Rand.eventSpace
&& Rand.prob(e) == 1.0 / (NatArith.FactorialTraditional(|xs|-i) as real)
ensures CorrectnessPredicate(xs, p, i)
{
var e := iset s | Model.Shuffle(xs, i)(s).Result? && Model.Shuffle(xs, i)(s).value[i..] == p[i..];
if |xs[i..]| <= 1 {
Expand All @@ -80,10 +97,7 @@ module FisherYates.Correctness {
requires |xs| == |p|
requires multiset(p[i..]) == multiset(xs[i..])
requires |xs[i..]| <= 1
ensures
var e := iset s | Model.Shuffle(xs, i)(s).Result? && Model.Shuffle(xs, i)(s).value[i..] == p[i..];
e in Rand.eventSpace
&& Rand.prob(e) == 1.0 / (NatArith.FactorialTraditional(|xs|-i) as real)
ensures CorrectnessPredicate(xs, p, i)
{
Model.PermutationsPreserveCardinality(p[i..], xs[i..]);
var e := iset s | Model.Shuffle(xs, i)(s).Result? && Model.Shuffle(xs, i)(s).value[i..] == p[i..];
Expand All @@ -107,7 +121,7 @@ module FisherYates.Correctness {
}
}
}
assert e in Rand.eventSpace && Rand.prob(e) == 1.0 / (NatArith.FactorialTraditional(|xs|-i) as real) by {
assert CorrectnessPredicate(xs, p, i) by {
reveal NatArith.FactorialTraditional();
Rand.ProbIsProbabilityMeasure();
assert Measures.IsProbability(Rand.eventSpace, Rand.prob);
Expand All @@ -130,10 +144,7 @@ module FisherYates.Correctness {
requires |xs| == |p|
requires multiset(p[i..]) == multiset(xs[i..])
requires |xs[i..]| > 1
ensures
var e := iset s | Model.Shuffle(xs, i)(s).Result? && Model.Shuffle(xs, i)(s).value[i..] == p[i..];
e in Rand.eventSpace
&& Rand.prob(e) == 1.0 / (NatArith.FactorialTraditional(|xs|-i) as real)
ensures CorrectnessPredicate(xs, p, i)
{
Model.PermutationsPreserveCardinality(p[i..], xs[i..]);
var e := iset s | Model.Shuffle(xs, i)(s).Result? && Model.Shuffle(xs, i)(s).value[i..] == p[i..];
Expand Down Expand Up @@ -183,8 +194,8 @@ module FisherYates.Correctness {
BitStreamsInA(xs, p, i, j, h, A);
}
var ys := Model.Swap(xs, i, j);
var e' := iset s | Model.Shuffle(ys, i+1)(s).Result? && Model.Shuffle(ys, i+1)(s).value[i+1..] == p[i+1..];
assert InductionHypothesis: e' in Rand.eventSpace && Rand.prob(e') == 1.0 / (NatArith.FactorialTraditional(|xs|-(i+1)) as real) by {
var e' := CorrectnessConstructEvent(ys, p, i+1);
assert InductionHypothesis: CorrectnessPredicate(ys, p, i+1) by {
assert multiset(ys[i+1..]) == multiset(p[i+1..]) by {
InductionHypothesisPrecondition1(xs, ys, p, i, j);
}
Expand Down Expand Up @@ -227,14 +238,31 @@ module FisherYates.Correctness {
assert DecomposeE: e == Monad.BitstreamsWithValueIn(h, A) * Monad.BitstreamsWithRestIn(h, e') by {
DecomposeE(xs, ys, p, i, j, h, A, e, e');
}
assert e in Rand.eventSpace && Rand.prob(e) == 1.0 / (NatArith.FactorialTraditional(|xs|-i) as real) by {
assert CorrectnessPredicate(xs, p, i) by {
reveal DecomposeE;
reveal HIsIndependent;
reveal InductionHypothesis;
reveal BitStreamsInA;
ProbabilityOfE(xs, p, i, j, h, A, e, e');
assert e' in Rand.eventSpace && Rand.prob(e') == 1.0 / (NatArith.FactorialTraditional(|xs|-(i+1)) as real) by {
assert e' in Rand.eventSpace by {
assert CorrectnessPredicate(ys, p, i+1) by { reveal InductionHypothesis; }
assert e' == CorrectnessConstructEvent(ys, p, i+1);
}
calc {
Rand.prob(e');
{ assert CorrectnessPredicate(ys, p, i+1) by { reveal InductionHypothesis; }
assert e' == CorrectnessConstructEvent(ys, p, i+1); }
1.0 / (NatArith.FactorialTraditional(|ys|-(i+1)) as real);
{ assert |xs| == |ys|;
assert |ys|-(i+1) == |xs|-(i+1);
assert NatArith.FactorialTraditional(|ys|-(i+1)) == NatArith.FactorialTraditional(|xs|-(i+1));
assert (NatArith.FactorialTraditional(|ys|-(i+1)) as real) == (NatArith.FactorialTraditional(|xs|-(i+1)) as real); }
1.0 / (NatArith.FactorialTraditional(|xs|-(i+1)) as real);
}
}
ProbabilityOfE(xs, ys, p, i, j, h, A, e, e');
EInEventSpace(xs, p, h, A, e, e');
}

}

lemma BitStreamsInA<T(!new)>(xs: seq<T>, p: seq<T>, i: nat, j: nat, h: Monad.Hurd<int>, A: iset<int>)
Expand Down Expand Up @@ -285,6 +313,7 @@ module FisherYates.Correctness {
if s in e {
var zs := Model.Shuffle(xs, i)(s).value;
assert zs[i..] == p[i..];
assert h(s).Result?;
var k := Uniform.Model.IntervalSample(i, |xs|)(s).value;
Uniform.Model.IntervalSampleBound(i, |xs|, s);
var s' := Uniform.Model.IntervalSample(i, |xs|)(s).rest;
Expand All @@ -296,6 +325,7 @@ module FisherYates.Correctness {
p[i];
zs[i];
zs'[i];
{ assert Model.ShuffleInvariancePredicatePointwise(ys', Model.Shuffle(ys', i+1)(s'), i); }
ys'[i];
xs[k];
}
Expand Down Expand Up @@ -328,7 +358,7 @@ module FisherYates.Correctness {
Model.Shuffle(ys, i+1)(s').value[i..];
{ SliceOfSequences(Model.Shuffle(ys, i+1)(s').value, i); }
[Model.Shuffle(ys, i+1)(s').value[i]] + Model.Shuffle(ys, i+1)(s').value[i+1..];
{ assert Model.Shuffle(ys, i+1)(s').value[i] == ys[i]; }
{ assert Model.ShuffleInvariancePredicatePointwise(ys, Model.Shuffle(ys, i+1)(s'), i); assert Model.Shuffle(ys, i+1)(s').value[i] == ys[i]; }
[ys[i]] + Model.Shuffle(ys, i+1)(s').value[i+1..];
{ assert ys[i] == xs[k]; }
[xs[k]] + Model.Shuffle(ys, i+1)(s').value[i+1..];
Expand Down Expand Up @@ -448,17 +478,20 @@ module FisherYates.Correctness {
reveal DecomposeE;
}

lemma ProbabilityOfE<T(!new)>(xs: seq<T>, p: seq<T>, i: nat, j: nat, h: Monad.Hurd<int>, A: iset<int>, e: iset<Rand.Bitstream>, e': iset<Rand.Bitstream>)
lemma ProbabilityOfE<T(!new)>(xs: seq<T>, ys: seq<T>, p: seq<T>, i: nat, j: nat, h: Monad.Hurd<int>, A: iset<int>, e: iset<Rand.Bitstream>, e': iset<Rand.Bitstream>)
requires i <= |xs|
requires i <= |p|
requires forall a, b | i <= a < b < |xs| :: xs[a] != xs[b]
requires |xs| == |p|
requires multiset(p[i..]) == multiset(xs[i..])
requires i <= j < |xs| && xs[j] == p[i]
requires ys == Model.Swap(xs, i, j)
requires |xs|-i > 1
requires e' in Rand.eventSpace
requires |xs| == |ys|
requires DecomposeE: e == Monad.BitstreamsWithValueIn(h, A) * Monad.BitstreamsWithRestIn(h, e')
requires HIsIndependent: Independence.IsIndepFunction(h)
requires InductionHypothesis: e' in Rand.eventSpace && Rand.prob(e') == 1.0 / (NatArith.FactorialTraditional(|xs|-(i+1)) as real)
requires InductionHypothesis: Rand.prob(e') == 1.0 / (NatArith.FactorialTraditional(|xs|-(i+1)) as real)
requires BitStreamsInA: Monad.BitstreamsWithValueIn(h, A) == (iset s | Uniform.Model.IntervalSample(i, |xs|)(s).Equals(j))
ensures
Rand.prob(e) == 1.0 / (NatArith.FactorialTraditional(|xs|-i) as real)
Expand Down
41 changes: 31 additions & 10 deletions src/Util/FisherYates/Model.dfy
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,41 @@ module FisherYates.Model {
Definitions
************/

ghost predicate ShuffleInvariancePredicatePointwise<T>(xs: seq<T>, r: Monad.Result<seq<T>>, j: int)
requires 0 <= j < |xs|
{
r.Result? ==> |r.value| == |xs| && r.value[j] == xs[j]
}

ghost function Shuffle<T>(xs: seq<T>, i: nat := 0): (h: Monad.Hurd<seq<T>>)
requires i <= |xs|
decreases |xs| - i
ensures forall s :: h(s).Result? ==> multiset(h(s).value) == multiset(xs) && |h(s).value| == |xs|
ensures forall s, j | 0 <= j < i :: h(s).Result? ==> h(s).value[j] == xs[j]
ensures forall s, j | 0 <= j < i :: ShuffleInvariancePredicatePointwise(xs, h(s), j)
{
(s: Rand.Bitstream) =>
if |xs[i..]| > 1 then
var (j, s') :- Uniform.Model.IntervalSample(i, |xs|)(s);
assert i <= j < |xs| by { Uniform.Model.IntervalSampleBound(i, |xs|, s); }
var ys := Swap(xs, i, j);
Shuffle(ys, i + 1)(s')
else
Monad.Return(xs)(s)
(s: Rand.Bitstream) => ShuffleCurried(xs, s, i)
}

ghost function ShuffleCurried<T>(xs: seq<T>, s: Rand.Bitstream, i: nat := 0): (r: Monad.Result<seq<T>>)
requires i <= |xs|
decreases |xs| - i
ensures r.Result? ==> multiset(r.value) == multiset(xs) && |r.value| == |xs|
ensures forall j | 0 <= j < i :: ShuffleInvariancePredicatePointwise(xs, r, j)
{
if |xs| - i > 1 then
var (j, s') :- Uniform.Model.IntervalSample(i, |xs|)(s);
assert i <= j < |xs| by { Uniform.Model.IntervalSampleBound(i, |xs|, s); }
var ys := Swap(xs, i, j);
var r := ShuffleCurried(ys, s', i + 1);
assert forall j | 0 <= j < i :: ShuffleInvariancePredicatePointwise(xs, r, j) by {
forall j | 0 <= j < i
ensures ShuffleInvariancePredicatePointwise(xs, r, j)
{
assert ShuffleInvariancePredicatePointwise(ys, r, j);
}
}
r
else
Monad.Return(xs)(s)
}

function Swap<T>(s: seq<T>, i: nat, j: nat): (t: seq<T>)
Expand Down

0 comments on commit 79e9093

Please sign in to comment.