From 79e90933ad627333bea6201f58deb52d4916c3ac Mon Sep 17 00:00:00 2001 From: stefan-aws Date: Mon, 4 Mar 2024 15:44:58 +0000 Subject: [PATCH] fix fisher yates proofs --- src/Util/FisherYates/Correctness.dfy | 75 ++++++++++++++++++++-------- src/Util/FisherYates/Model.dfy | 41 +++++++++++---- 2 files changed, 85 insertions(+), 31 deletions(-) diff --git a/src/Util/FisherYates/Correctness.dfy b/src/Util/FisherYates/Correctness.dfy index 74fc81fd..809bb8e6 100644 --- a/src/Util/FisherYates/Correctness.dfy +++ b/src/Util/FisherYates/Correctness.dfy @@ -13,6 +13,26 @@ module FisherYates.Correctness { import Independence import RealArith + /************ + Definitions + ************/ + + ghost function CorrectnessConstructEvent(xs: seq, p: seq, i: nat): iset + requires i <= |xs| + requires i <= |p| + { + iset s | Model.Shuffle(xs, i)(s).Result? && Model.Shuffle(xs, i)(s).value[i..] == p[i..] + } + + ghost predicate CorrectnessPredicate(xs: seq, p: seq, i: nat) + requires i <= |xs| + requires i <= |p| + { + var e := CorrectnessConstructEvent(xs, p, i); + e in Rand.eventSpace + && Rand.prob(e) == 1.0 / (NatArith.FactorialTraditional(|xs|-i) as real) + } + /******* Lemmas *******/ @@ -59,10 +79,7 @@ module FisherYates.Correctness { requires forall a, b | i <= a < b < |xs| :: xs[a] != xs[b] requires |xs| == |p| requires multiset(p[i..]) == multiset(xs[i..]) - ensures - var e := iset s | Model.Shuffle(xs, i)(s).Result? && Model.Shuffle(xs, i)(s).value[i..] == p[i..]; - e in Rand.eventSpace - && Rand.prob(e) == 1.0 / (NatArith.FactorialTraditional(|xs|-i) as real) + ensures CorrectnessPredicate(xs, p, i) { var e := iset s | Model.Shuffle(xs, i)(s).Result? && Model.Shuffle(xs, i)(s).value[i..] == p[i..]; if |xs[i..]| <= 1 { @@ -80,10 +97,7 @@ module FisherYates.Correctness { requires |xs| == |p| requires multiset(p[i..]) == multiset(xs[i..]) requires |xs[i..]| <= 1 - ensures - var e := iset s | Model.Shuffle(xs, i)(s).Result? && Model.Shuffle(xs, i)(s).value[i..] == p[i..]; - e in Rand.eventSpace - && Rand.prob(e) == 1.0 / (NatArith.FactorialTraditional(|xs|-i) as real) + ensures CorrectnessPredicate(xs, p, i) { Model.PermutationsPreserveCardinality(p[i..], xs[i..]); var e := iset s | Model.Shuffle(xs, i)(s).Result? && Model.Shuffle(xs, i)(s).value[i..] == p[i..]; @@ -107,7 +121,7 @@ module FisherYates.Correctness { } } } - assert e in Rand.eventSpace && Rand.prob(e) == 1.0 / (NatArith.FactorialTraditional(|xs|-i) as real) by { + assert CorrectnessPredicate(xs, p, i) by { reveal NatArith.FactorialTraditional(); Rand.ProbIsProbabilityMeasure(); assert Measures.IsProbability(Rand.eventSpace, Rand.prob); @@ -130,10 +144,7 @@ module FisherYates.Correctness { requires |xs| == |p| requires multiset(p[i..]) == multiset(xs[i..]) requires |xs[i..]| > 1 - ensures - var e := iset s | Model.Shuffle(xs, i)(s).Result? && Model.Shuffle(xs, i)(s).value[i..] == p[i..]; - e in Rand.eventSpace - && Rand.prob(e) == 1.0 / (NatArith.FactorialTraditional(|xs|-i) as real) + ensures CorrectnessPredicate(xs, p, i) { Model.PermutationsPreserveCardinality(p[i..], xs[i..]); var e := iset s | Model.Shuffle(xs, i)(s).Result? && Model.Shuffle(xs, i)(s).value[i..] == p[i..]; @@ -183,8 +194,8 @@ module FisherYates.Correctness { BitStreamsInA(xs, p, i, j, h, A); } var ys := Model.Swap(xs, i, j); - var e' := iset s | Model.Shuffle(ys, i+1)(s).Result? && Model.Shuffle(ys, i+1)(s).value[i+1..] == p[i+1..]; - assert InductionHypothesis: e' in Rand.eventSpace && Rand.prob(e') == 1.0 / (NatArith.FactorialTraditional(|xs|-(i+1)) as real) by { + var e' := CorrectnessConstructEvent(ys, p, i+1); + assert InductionHypothesis: CorrectnessPredicate(ys, p, i+1) by { assert multiset(ys[i+1..]) == multiset(p[i+1..]) by { InductionHypothesisPrecondition1(xs, ys, p, i, j); } @@ -227,14 +238,31 @@ module FisherYates.Correctness { assert DecomposeE: e == Monad.BitstreamsWithValueIn(h, A) * Monad.BitstreamsWithRestIn(h, e') by { DecomposeE(xs, ys, p, i, j, h, A, e, e'); } - assert e in Rand.eventSpace && Rand.prob(e) == 1.0 / (NatArith.FactorialTraditional(|xs|-i) as real) by { + assert CorrectnessPredicate(xs, p, i) by { reveal DecomposeE; reveal HIsIndependent; - reveal InductionHypothesis; reveal BitStreamsInA; - ProbabilityOfE(xs, p, i, j, h, A, e, e'); + assert e' in Rand.eventSpace && Rand.prob(e') == 1.0 / (NatArith.FactorialTraditional(|xs|-(i+1)) as real) by { + assert e' in Rand.eventSpace by { + assert CorrectnessPredicate(ys, p, i+1) by { reveal InductionHypothesis; } + assert e' == CorrectnessConstructEvent(ys, p, i+1); + } + calc { + Rand.prob(e'); + { assert CorrectnessPredicate(ys, p, i+1) by { reveal InductionHypothesis; } + assert e' == CorrectnessConstructEvent(ys, p, i+1); } + 1.0 / (NatArith.FactorialTraditional(|ys|-(i+1)) as real); + { assert |xs| == |ys|; + assert |ys|-(i+1) == |xs|-(i+1); + assert NatArith.FactorialTraditional(|ys|-(i+1)) == NatArith.FactorialTraditional(|xs|-(i+1)); + assert (NatArith.FactorialTraditional(|ys|-(i+1)) as real) == (NatArith.FactorialTraditional(|xs|-(i+1)) as real); } + 1.0 / (NatArith.FactorialTraditional(|xs|-(i+1)) as real); + } + } + ProbabilityOfE(xs, ys, p, i, j, h, A, e, e'); EInEventSpace(xs, p, h, A, e, e'); } + } lemma BitStreamsInA(xs: seq, p: seq, i: nat, j: nat, h: Monad.Hurd, A: iset) @@ -285,6 +313,7 @@ module FisherYates.Correctness { if s in e { var zs := Model.Shuffle(xs, i)(s).value; assert zs[i..] == p[i..]; + assert h(s).Result?; var k := Uniform.Model.IntervalSample(i, |xs|)(s).value; Uniform.Model.IntervalSampleBound(i, |xs|, s); var s' := Uniform.Model.IntervalSample(i, |xs|)(s).rest; @@ -296,6 +325,7 @@ module FisherYates.Correctness { p[i]; zs[i]; zs'[i]; + { assert Model.ShuffleInvariancePredicatePointwise(ys', Model.Shuffle(ys', i+1)(s'), i); } ys'[i]; xs[k]; } @@ -328,7 +358,7 @@ module FisherYates.Correctness { Model.Shuffle(ys, i+1)(s').value[i..]; { SliceOfSequences(Model.Shuffle(ys, i+1)(s').value, i); } [Model.Shuffle(ys, i+1)(s').value[i]] + Model.Shuffle(ys, i+1)(s').value[i+1..]; - { assert Model.Shuffle(ys, i+1)(s').value[i] == ys[i]; } + { assert Model.ShuffleInvariancePredicatePointwise(ys, Model.Shuffle(ys, i+1)(s'), i); assert Model.Shuffle(ys, i+1)(s').value[i] == ys[i]; } [ys[i]] + Model.Shuffle(ys, i+1)(s').value[i+1..]; { assert ys[i] == xs[k]; } [xs[k]] + Model.Shuffle(ys, i+1)(s').value[i+1..]; @@ -448,17 +478,20 @@ module FisherYates.Correctness { reveal DecomposeE; } - lemma ProbabilityOfE(xs: seq, p: seq, i: nat, j: nat, h: Monad.Hurd, A: iset, e: iset, e': iset) + lemma ProbabilityOfE(xs: seq, ys: seq, p: seq, i: nat, j: nat, h: Monad.Hurd, A: iset, e: iset, e': iset) requires i <= |xs| requires i <= |p| requires forall a, b | i <= a < b < |xs| :: xs[a] != xs[b] requires |xs| == |p| requires multiset(p[i..]) == multiset(xs[i..]) requires i <= j < |xs| && xs[j] == p[i] + requires ys == Model.Swap(xs, i, j) requires |xs|-i > 1 + requires e' in Rand.eventSpace + requires |xs| == |ys| requires DecomposeE: e == Monad.BitstreamsWithValueIn(h, A) * Monad.BitstreamsWithRestIn(h, e') requires HIsIndependent: Independence.IsIndepFunction(h) - requires InductionHypothesis: e' in Rand.eventSpace && Rand.prob(e') == 1.0 / (NatArith.FactorialTraditional(|xs|-(i+1)) as real) + requires InductionHypothesis: Rand.prob(e') == 1.0 / (NatArith.FactorialTraditional(|xs|-(i+1)) as real) requires BitStreamsInA: Monad.BitstreamsWithValueIn(h, A) == (iset s | Uniform.Model.IntervalSample(i, |xs|)(s).Equals(j)) ensures Rand.prob(e) == 1.0 / (NatArith.FactorialTraditional(|xs|-i) as real) diff --git a/src/Util/FisherYates/Model.dfy b/src/Util/FisherYates/Model.dfy index 4dda99a1..ff5fbe95 100644 --- a/src/Util/FisherYates/Model.dfy +++ b/src/Util/FisherYates/Model.dfy @@ -12,20 +12,41 @@ module FisherYates.Model { Definitions ************/ + ghost predicate ShuffleInvariancePredicatePointwise(xs: seq, r: Monad.Result>, j: int) + requires 0 <= j < |xs| + { + r.Result? ==> |r.value| == |xs| && r.value[j] == xs[j] + } + ghost function Shuffle(xs: seq, i: nat := 0): (h: Monad.Hurd>) requires i <= |xs| - decreases |xs| - i ensures forall s :: h(s).Result? ==> multiset(h(s).value) == multiset(xs) && |h(s).value| == |xs| - ensures forall s, j | 0 <= j < i :: h(s).Result? ==> h(s).value[j] == xs[j] + ensures forall s, j | 0 <= j < i :: ShuffleInvariancePredicatePointwise(xs, h(s), j) { - (s: Rand.Bitstream) => - if |xs[i..]| > 1 then - var (j, s') :- Uniform.Model.IntervalSample(i, |xs|)(s); - assert i <= j < |xs| by { Uniform.Model.IntervalSampleBound(i, |xs|, s); } - var ys := Swap(xs, i, j); - Shuffle(ys, i + 1)(s') - else - Monad.Return(xs)(s) + (s: Rand.Bitstream) => ShuffleCurried(xs, s, i) + } + + ghost function ShuffleCurried(xs: seq, s: Rand.Bitstream, i: nat := 0): (r: Monad.Result>) + requires i <= |xs| + decreases |xs| - i + ensures r.Result? ==> multiset(r.value) == multiset(xs) && |r.value| == |xs| + ensures forall j | 0 <= j < i :: ShuffleInvariancePredicatePointwise(xs, r, j) + { + if |xs| - i > 1 then + var (j, s') :- Uniform.Model.IntervalSample(i, |xs|)(s); + assert i <= j < |xs| by { Uniform.Model.IntervalSampleBound(i, |xs|, s); } + var ys := Swap(xs, i, j); + var r := ShuffleCurried(ys, s', i + 1); + assert forall j | 0 <= j < i :: ShuffleInvariancePredicatePointwise(xs, r, j) by { + forall j | 0 <= j < i + ensures ShuffleInvariancePredicatePointwise(xs, r, j) + { + assert ShuffleInvariancePredicatePointwise(ys, r, j); + } + } + r + else + Monad.Return(xs)(s) } function Swap(s: seq, i: nat, j: nat): (t: seq)