From 74721a364ca90f5587acc9d0dbdd8d32c8ad7535 Mon Sep 17 00:00:00 2001 From: stefan-aws Date: Mon, 23 Oct 2023 15:03:50 +0100 Subject: [PATCH] divide by zero etc --- .../DiscreteGaussian/Implementation.dfy | 1 + .../DiscreteLaplace/Correctness.dfy | 18 ++++++++++----- .../DiscreteLaplace/Implementation.dfy | 22 +++++++++++++++++-- src/Math/Rationals.dfy | 10 ++++++--- 4 files changed, 41 insertions(+), 10 deletions(-) diff --git a/src/Distributions/DiscreteGaussian/Implementation.dfy b/src/Distributions/DiscreteGaussian/Implementation.dfy index cdb9ef7d..c85c0248 100644 --- a/src/Distributions/DiscreteGaussian/Implementation.dfy +++ b/src/Distributions/DiscreteGaussian/Implementation.dfy @@ -18,6 +18,7 @@ module DiscreteGaussian.Implementation { { var sigmaSquared := Rationals.Mul(sigma, sigma); var t := Rationals.Floor(sigma) + 1; + assert t >= 1; while true decreases * { diff --git a/src/Distributions/DiscreteLaplace/Correctness.dfy b/src/Distributions/DiscreteLaplace/Correctness.dfy index b6173821..eefa6005 100644 --- a/src/Distributions/DiscreteLaplace/Correctness.dfy +++ b/src/Distributions/DiscreteLaplace/Correctness.dfy @@ -11,17 +11,25 @@ module DiscreteLaplace.Correctness { import Monad import RealArith - - ghost function Numerator1(scale: Rationals.Rational): real { + ghost function Numerator1(scale: Rationals.Rational): real + requires scale.numer >= 1 + { Exponential.Exp((1.0 / Rationals.ToReal(scale)) - 1.0) } - ghost function Numerator2(scale: Rationals.Rational, x: int): real { + ghost function Numerator2(scale: Rationals.Rational, x: int): real + requires scale.numer >= 1 + { Exponential.Exp(-((RealArith.Abs(x as real)) / Rationals.ToReal(scale))) } - ghost function Denominator(scale: Rationals.Rational): real { - Exponential.Exp((1.0 / Rationals.ToReal(scale)) + 1.0) + ghost function Denominator(scale: Rationals.Rational): (r: real) + requires scale.numer >= 1 + ensures r > 0.0 + { + var x := (1.0 / Rationals.ToReal(scale)) + 1.0; + Exponential.Positive(x); + Exponential.Exp(x) } lemma {:axiom} Correctness(scale: Rationals.Rational, x: int) diff --git a/src/Distributions/DiscreteLaplace/Implementation.dfy b/src/Distributions/DiscreteLaplace/Implementation.dfy index 3d75ac68..8cb33683 100644 --- a/src/Distributions/DiscreteLaplace/Implementation.dfy +++ b/src/Distributions/DiscreteLaplace/Implementation.dfy @@ -8,6 +8,9 @@ module DiscreteLaplace.Implementation { import Interface import Monad import Model + import Uniform + import BernoulliExpNeg + import Bernoulli trait {:termination false} Trait extends Interface.Trait { @@ -16,21 +19,33 @@ module DiscreteLaplace.Implementation { modifies this requires scale.numer >= 1 decreases * - ensures Monad.Result(z, s) == Model.Sample(scale)(old(s)) + ensures Monad.Result(z, s) == Model.SampleTailRecursive(scale)(old(s)) { var b := true; var y := 0; + ghost var b' := b; + ghost var y' := y; while b && y == 0 + invariant Model.SampleTailRecursive(scale, b', y')(old(s)) == Model.SampleTailRecursive(scale, b, y)(s) decreases * { + b' := b; + y' := y; + label L1: var u := UniformSample(scale.numer); + assert (u, s) == Monad.Extract(Uniform.Model.Sample(scale.numer)(old@L1(s))); + label L2: var d := BernoulliExpNegSample(Rationals.Rational(u, scale.numer)); + assert (d, s) == Monad.Extract(BernoulliExpNeg.Model.Sample(Rationals.Rational(u, scale.numer))(old@L2(s))); + assert Model.SampleTailRecursive(scale, b', y')(old@L1(s)) == Model.SampleTailRecursive(scale, b, y)(s); if !d { continue; } + label L3: var v := 0; var a := true; while a + invariant Model.SampleTailRecursiveHelper(scale)(old@L3(s)) == Model.SampleTailRecursiveHelper(scale, v, a)(s) decreases * { a := BernoulliExpNegSample(Rationals.Int(1)); @@ -38,12 +53,15 @@ module DiscreteLaplace.Implementation { v := v + 1; } } + assert (v, s) == Monad.Extract(Model.SampleTailRecursiveHelper(scale)(old@L3(s))); var x := u + scale.numer * v; y := x / scale.denom; + label L4: b := BernoulliSample(Rationals.Rational(1, 2)); + assert (b, s) == Monad.Extract(Bernoulli.Model.Sample(1, 2)(old@L4(s))); + assume {:axiom} false; // add equivalence proof later } z := if b then -y else y; - assume {:axiom} false; // add equivalence proof later } } diff --git a/src/Math/Rationals.dfy b/src/Math/Rationals.dfy index a1ef2669..f1c65590 100644 --- a/src/Math/Rationals.dfy +++ b/src/Math/Rationals.dfy @@ -8,7 +8,7 @@ module Rationals { Definitions ************/ - type PosInt = n: int | n >=1 witness 1 + type PosInt = n: int | n >= 1 witness 1 datatype Rational = Rational(numer: int, denom: PosInt) @@ -20,7 +20,9 @@ module Rationals { lhs.numer * rhs.denom <= rhs.numer * lhs.denom } - function Int(n: int): Rational { + function Int(n: int): (r: Rational) + ensures r.numer == n + { Rational(n, 1) } @@ -53,7 +55,9 @@ module Rationals { Mul(lhs, Inv(rhs)) } - function Floor(r: Rational): int { + function Floor(r: Rational): (f: int) + ensures r.numer >= 1 ==> f >= 0 + { r.numer / r.denom }