Skip to content

Commit

Permalink
Fisher-Yates: Equivalence proof (#142)
Browse files Browse the repository at this point in the history
By submitting this pull request, I confirm that my contribution is made
under the terms of the [MIT
license](https://github.com/dafny-lang/dafny/blob/master/LICENSE.txt).
  • Loading branch information
stefan-aws authored Feb 5, 2024
1 parent 1f590ff commit 9b15eba
Show file tree
Hide file tree
Showing 19 changed files with 72 additions and 24 deletions.
2 changes: 1 addition & 1 deletion src/Distributions/Bernoulli/Implementation.dfy
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ module Bernoulli.Implementation {
trait {:termination false} Trait extends Interface.Trait {

method BernoulliSample(p: Rationals.Rational) returns (c: bool)
modifies this
modifies `s
decreases *
requires 0 <= p.numer <= p.denom
ensures Model.Sample(p.numer, p.denom)(old(s)) == Monad.Result(c, s)
Expand Down
2 changes: 1 addition & 1 deletion src/Distributions/Bernoulli/Interface.dfy
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ module Bernoulli.Interface {
trait {:termination false} Trait extends Uniform.Interface.Trait {

method BernoulliSample(p: Rationals.Rational) returns (c: bool)
modifies this
modifies `s
decreases *
requires 0 <= p.numer <= p.denom
ensures Model.Sample(p.numer, p.denom)(old(s)) == Monad.Result(c, s)
Expand Down
4 changes: 2 additions & 2 deletions src/Distributions/BernoulliExpNeg/Implementation.dfy
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ module BernoulliExpNeg.Implementation {

// Based on Algorithm 1 in https://arxiv.org/pdf/2004.00010.pdf; unverified
method BernoulliExpNegSample(gamma: Rationals.Rational) returns (c: bool)
modifies this
modifies `s
requires gamma.numer >= 0
decreases *
ensures Monad.Result(c, s) == Model.Sample(gamma)(old(s))
Expand All @@ -42,7 +42,7 @@ module BernoulliExpNeg.Implementation {
}

method BernoulliExpNegSampleCaseLe1(gamma: Rationals.Rational) returns (c: bool)
modifies this
modifies `s
requires 0 <= gamma.numer <= gamma.denom
decreases *
ensures Monad.Result(c, s) == Model.SampleLe1(gamma)(old(s))
Expand Down
2 changes: 1 addition & 1 deletion src/Distributions/BernoulliExpNeg/Interface.dfy
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ module BernoulliExpNeg.Interface {
trait {:termination false} Trait extends Bernoulli.Interface.Trait {

method BernoulliExpNegSample(gamma: Rationals.Rational) returns (c: bool)
modifies this
modifies `s
decreases *
requires gamma.numer >= 0
ensures Monad.Result(c, s) == Model.Sample(gamma)(old(s))
Expand Down
2 changes: 1 addition & 1 deletion src/Distributions/Coin/Implementation.dfy
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ module Coin.Implementation {
trait {:termination false} Trait extends Interface.Trait {

method CoinSample() returns (b: bool)
modifies this
modifies `s
ensures Model.Sample(old(s)) == Monad.Result(b, s)
{
var x := UniformPowerOfTwoSample(2);
Expand Down
2 changes: 1 addition & 1 deletion src/Distributions/Coin/Interface.dfy
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ module Coin.Interface {
trait {:termination false} Trait extends UniformPowerOfTwo.Interface.Trait {

method CoinSample() returns (b: bool)
modifies this
modifies `s
ensures Model.Sample(old(s)) == Monad.Result(b, s)

}
Expand Down
2 changes: 1 addition & 1 deletion src/Distributions/DiscreteGaussian/Implementation.dfy
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ module DiscreteGaussian.Implementation {
// Based on Algorithm 3 in https://arxiv.org/pdf/2004.00010.pdf; unverified
// Note that we take sigma as a parameter, not sigma^2, to avoid square roots.
method DiscreteGaussianSample(sigma: Rationals.Rational) returns (y: int)
modifies this
modifies `s
requires sigma.numer >= 1
decreases *
{
Expand Down
2 changes: 1 addition & 1 deletion src/Distributions/DiscreteGaussian/Interface.dfy
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ module DiscreteGaussian.Interface {
trait {:termination false} Trait extends DiscreteLaplace.Interface.Trait, BernoulliExpNeg.Interface.Trait {
// Takes the sigma (not sigma^2!) as a fraction
method DiscreteGaussianSample(sigma: Rationals.Rational) returns (y: int)
modifies this
modifies `s
requires sigma.numer >= 1
decreases *
}
Expand Down
6 changes: 3 additions & 3 deletions src/Distributions/DiscreteLaplace/Implementation.dfy
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ module DiscreteLaplace.Implementation {

// Based on Algorithm 2 in https://arxiv.org/pdf/2004.00010.pdf; unverified
method DiscreteLaplaceSample(scale: Rationals.Rational) returns (z: int)
modifies this
modifies `s
requires scale.numer >= 1
decreases *
ensures Model.Sample(scale)(old(s)) == Monad.Result(z, s)
Expand All @@ -30,7 +30,7 @@ module DiscreteLaplace.Implementation {
}

method {:rlimit 100000} DiscreteLaplaceSampleLoop(scale: Rationals.Rational) returns (bY: (bool, int))
modifies this
modifies `s
requires scale.numer >= 1
decreases *
ensures Model.SampleLoop(scale)(old(s)) == Monad.Result(bY, s)
Expand Down Expand Up @@ -61,7 +61,7 @@ module DiscreteLaplace.Implementation {
}

method DisceteLaplaceSampleInnerLoop() returns (v: int)
modifies this
modifies `s
decreases *
ensures Model.SampleInnerLoopFull()(old(s)) == Monad.Result(v, s)
{
Expand Down
2 changes: 1 addition & 1 deletion src/Distributions/DiscreteLaplace/Interface.dfy
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ module DiscreteLaplace.Interface {

// Based on Algorithm 2 in https://arxiv.org/pdf/2004.00010.pdf; unverified
method DiscreteLaplaceSample(scale: Rationals.Rational) returns (z: int)
modifies this
modifies `s
requires scale.numer >= 1
decreases *

Expand Down
2 changes: 1 addition & 1 deletion src/Distributions/Uniform/Implementation.dfy
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ module Uniform.Implementation {

trait {:termination false} Trait extends Interface.Trait {
method UniformSample(n: nat) returns (u: nat)
modifies this
modifies `s
decreases *
requires n > 0
ensures u < n
Expand Down
4 changes: 2 additions & 2 deletions src/Distributions/Uniform/Interface.dfy
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@ module Uniform.Interface {
trait {:termination false} Trait extends UniformPowerOfTwo.Interface.Trait {

method UniformSample(n: nat) returns (u: nat)
modifies this
modifies `s
decreases *
requires n > 0
ensures u < n
ensures Model.Sample(n)(old(s)) == Monad.Result(u, s)

method UniformIntervalSample(a: int, b: int) returns (u: int)
modifies this
modifies `s
decreases *
requires a < b
ensures a <= u < b
Expand Down
2 changes: 1 addition & 1 deletion src/Distributions/UniformPowerOfTwo/Implementation.dfy
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ module UniformPowerOfTwo.Implementation {
trait {:termination false} Trait extends Interface.Trait {
method UniformPowerOfTwoSample(n: nat) returns (u: nat)
requires n >= 1
modifies this
modifies `s
ensures Model.Sample(n)(old(s)) == Monad.Result(u, s)

}
Expand Down
2 changes: 1 addition & 1 deletion src/Distributions/UniformPowerOfTwo/Interface.dfy
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ module UniformPowerOfTwo.Interface {
// The return value u is uniformly distributed between 0 <= u < 2^k where 2^k <= n < 2^(k + 1).
method UniformPowerOfTwoSample(n: nat) returns (u: nat)
requires n >= 1
modifies this
modifies `s
ensures Model.Sample(n)(old(s)) == Monad.Result(u, s)

}
Expand Down
19 changes: 19 additions & 0 deletions src/Util/FisherYates/Equivalence.dfy
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
/*******************************************************************************
* Copyright by the contributors to the Dafny Project
* SPDX-License-Identifier: MIT
*******************************************************************************/

module FisherYates.Equivalence {
import Model
import Rand

ghost predicate LoopInvariant<T>(prevI: nat, i: nat, a: array<T>, prevASeq: seq<T>, oldASeq: seq<T>, oldS: Rand.Bitstream, prevS: Rand.Bitstream, s: Rand.Bitstream)
reads a
{
&& prevI <= |prevASeq|
&& i <= a.Length - 1
&& Model.Shuffle(oldASeq)(oldS) == Model.Shuffle(prevASeq, prevI)(prevS)
&& Model.Shuffle(prevASeq, prevI)(prevS) == Model.Shuffle(a[..], i)(s)
}

}
27 changes: 24 additions & 3 deletions src/Util/FisherYates/Implementation.dfy
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,41 @@

module FisherYates.Implementation {
import Interface
import Monad
import Model
import Uniform
import Equivalence

trait {:termination false} Trait extends Interface.Trait {

method Shuffle<T>(a: array<T>)
decreases *
modifies this, a
modifies `s, a
ensures Model.Shuffle(old(a[..]))(old(s)) == Monad.Result(a[..], s)
{
ghost var prevI, prevASeq, prevS := 0, a[..], s; // ghost
if a.Length > 1 {
for i := 0 to a.Length - 1 {
for i := 0 to a.Length - 1
invariant Equivalence.LoopInvariant(prevI, i, a, prevASeq, old(a[..]), old(s), prevS, s)
{
prevI, prevASeq, prevS := i, a[..], s; // ghost
var j := UniformIntervalSample(i, a.Length);
a[i], a[j] := a[j], a[i];
assert prevASeq == a[..]; // ghost
Swap(a, i, j);
}
}
}

method Swap<T>(a: array<T>, i: nat, j: nat)
modifies a
requires i <= j
requires 0 <= i < a.Length
requires 0 <= j < a.Length
ensures Model.Swap(old(a[..]), i, j) == a[..]
ensures old(s) == s
{
a[i], a[j] := a[j], a[i];
}

}
}
5 changes: 4 additions & 1 deletion src/Util/FisherYates/Interface.dfy
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,15 @@

module FisherYates.Interface {
import Uniform
import Model
import Monad

trait {:termination false} Trait extends Uniform.Interface.Trait {

method Shuffle<T>(a: array<T>)
decreases *
modifies this, a
modifies `s, a
ensures Model.Shuffle(old(a[..]))(old(s)) == Monad.Result(a[..], s)

}
}
4 changes: 2 additions & 2 deletions src/Util/FisherYates/Model.dfy
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ module FisherYates.Model {
(s: Rand.Bitstream) =>
if |xs[i..]| > 1 then
var (j, s) :- Uniform.Model.IntervalSample(i, |xs|)(s);
var xs := Swap(xs, i, j);
Shuffle(xs, i + 1)(s)
var ys := Swap(xs, i, j);
Shuffle(ys, i + 1)(s)
else
Monad.Return(xs)(s)
}
Expand Down
5 changes: 5 additions & 0 deletions src/interop/java/Full/Random.java
Original file line number Diff line number Diff line change
Expand Up @@ -101,4 +101,9 @@ public void Shuffle(long[] arr) {
public void Shuffle(short[] arr) {
FisherYates.Implementation._Companion_Trait.Shuffle(TypeDescriptor.SHORT, this, arr);
}

public <__T> void Swap(dafny.TypeDescriptor<__T> _td___T, java.lang.Object a, java.math.BigInteger i, java.math.BigInteger j) {
FisherYates.Implementation._Companion_Trait.Swap(_td___T, this, a, i, j);
}

}

0 comments on commit 9b15eba

Please sign in to comment.