-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #27 from zksecurity/circuit-prototype
Circuit monad
- Loading branch information
Showing
8 changed files
with
1,289 additions
and
1 deletion.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
import Mathlib.Algebra.Field.Basic | ||
import Mathlib.Data.ZMod.Basic | ||
import Clean.Utils.Primes | ||
import Clean.Utils.Vector | ||
|
||
variable {F: Type} | ||
|
||
structure Variable (F : Type) where | ||
index: ℕ | ||
witness: Unit → F | ||
|
||
instance : Repr (Variable F) where | ||
reprPrec v _ := "x" ++ repr v.index | ||
|
||
inductive Expression (F : Type) where | ||
| var : Variable F -> Expression F | ||
| const : F -> Expression F | ||
| add : Expression F -> Expression F -> Expression F | ||
| mul : Expression F -> Expression F -> Expression F | ||
|
||
namespace Expression | ||
variable [Field F] | ||
|
||
@[simp] | ||
def eval : Expression F → F | ||
| var v => v.witness () | ||
| const c => c | ||
| add x y => eval x + eval y | ||
| mul x y => eval x * eval y | ||
|
||
/-- | ||
Evaluate expression given an external `environment` that determines the assignment | ||
of all variables. | ||
This is needed when we want to make statements about a circuit in the adversarial | ||
situation where the prover can assign anything to variables. | ||
-/ | ||
@[simp] | ||
def eval_env (env: ℕ → F) : Expression F → F | ||
| var v => env v.index | ||
| const c => c | ||
| add x y => eval_env env x + eval_env env y | ||
| mul x y => eval_env env x * eval_env env y | ||
|
||
def toString [Repr F] : Expression F → String | ||
| var v => "x" ++ reprStr v.index | ||
| const c => reprStr c | ||
| add x y => "(" ++ toString x ++ " + " ++ toString y ++ ")" | ||
| mul x y => "(" ++ toString x ++ " * " ++ toString y ++ ")" | ||
|
||
instance [Repr F] : Repr (Expression F) where | ||
reprPrec e _ := toString e | ||
|
||
-- combine expressions elegantly | ||
instance : Zero (Expression F) where | ||
zero := const 0 | ||
|
||
instance : One (Expression F) where | ||
one := const 1 | ||
|
||
instance : Add (Expression F) where | ||
add := add | ||
|
||
instance : Neg (Expression F) where | ||
neg e := mul (const (-1)) e | ||
|
||
instance : Sub (Expression F) where | ||
sub e₁ e₂ := add e₁ (-e₂) | ||
|
||
instance : Mul (Expression F) where | ||
mul := mul | ||
|
||
instance : Coe F (Expression F) where | ||
coe f := const f | ||
|
||
instance : Coe (Variable F) (Expression F) where | ||
coe x := var x | ||
|
||
instance : Coe (Expression F) F where | ||
coe x := x.eval | ||
|
||
instance : HMul F (Expression F) (Expression F) where | ||
hMul := fun f e => mul f e | ||
end Expression |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
import Mathlib.Algebra.Field.Basic | ||
import Mathlib.Data.ZMod.Basic | ||
import Clean.Utils.Primes | ||
import Clean.Utils.Vector | ||
import Clean.Circuit.Expression | ||
|
||
variable {F: Type} [Field F] | ||
|
||
structure TypePair where | ||
var: Type | ||
value: Type | ||
|
||
-- class of types that are composed of variables, | ||
-- and can be evaluated into something that is composed of field elements | ||
class ProvableType (F: Type) (α: TypePair) where | ||
size : ℕ | ||
to_vars : α.var → Vector (Expression F) size | ||
from_vars : Vector (Expression F) size → α.var | ||
to_values : α.value → Vector F size | ||
from_values : Vector F size → α.value | ||
|
||
-- or is it better as a structure? | ||
structure ProvableType' (F : Type) where | ||
var: Type | ||
value: Type | ||
size : ℕ | ||
to_vars : var → Vector (Expression F) size | ||
from_vars : Vector (Expression F) size → var | ||
to_values : value → Vector F size | ||
from_values : Vector F size → value | ||
|
||
-- or like this? | ||
def Provable' (F: Type) := { α : TypePair // ∃ p : Type, p = ProvableType F α } | ||
|
||
namespace Provable | ||
variable {α β γ: TypePair} [ProvableType F α] [ProvableType F β] [ProvableType F γ] | ||
|
||
@[simp] | ||
def eval (F: Type) [Field F] [ProvableType F α] (x: α.var) : α.value := | ||
let n := ProvableType.size F α | ||
let vars : Vector (Expression F) n := ProvableType.to_vars x | ||
let values := vars.map (fun v => v.eval) | ||
ProvableType.from_values values | ||
|
||
@[simp] | ||
def eval_env (env: ℕ → F) (x: α.var) : α.value := | ||
let n := ProvableType.size F α | ||
let vars : Vector (Expression F) n := ProvableType.to_vars x | ||
let values := vars.map (fun v => v.eval_env env) | ||
ProvableType.from_values values | ||
|
||
def const (F: Type) [ProvableType F α] (x: α.value) : α.var := | ||
let n := ProvableType.size F α | ||
let values : Vector F n := ProvableType.to_values x | ||
ProvableType.from_vars (values.map (fun v => Expression.const v)) | ||
|
||
@[reducible] | ||
def field (F : Type) : TypePair := ⟨ Expression F, F ⟩ | ||
|
||
instance : ProvableType F (field F) where | ||
size := 1 | ||
to_vars x := vec [x] | ||
from_vars v := v.get ⟨ 0, by norm_num ⟩ | ||
to_values x := vec [x] | ||
from_values v := v.get ⟨ 0, by norm_num ⟩ | ||
|
||
@[reducible] | ||
def pair (α β : TypePair) : TypePair := ⟨ α.var × β.var, α.value × β.value ⟩ | ||
|
||
@[reducible] | ||
def field2 (F : Type) : TypePair := pair (field F) (field F) | ||
|
||
instance : ProvableType F (field2 F) where | ||
size := 2 | ||
to_vars pair := vec [pair.1, pair.2] | ||
from_vars v := (v.get ⟨ 0, by norm_num ⟩, v.get ⟨ 1, by norm_num ⟩) | ||
to_values pair :=vec [pair.1, pair.2] | ||
from_values v := (v.get ⟨ 0, by norm_num ⟩, v.get ⟨ 1, by norm_num ⟩) | ||
|
||
variable {n: ℕ} | ||
def vec (α: TypePair) (n: ℕ) : TypePair := ⟨ Vector α.var n, Vector α.value n ⟩ | ||
|
||
@[reducible] | ||
def fields (F: Type) (n: ℕ) : TypePair := vec (field F) n | ||
|
||
instance : ProvableType F (fields F n) where | ||
size := n | ||
to_vars x := x | ||
from_vars v := v | ||
to_values x := x | ||
from_values v := v | ||
end Provable |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
import Clean.GenericConstraint | ||
import Clean.Expression | ||
import Clean.Gadgets.Boolean | ||
import Clean.Gadgets.ByteLookup | ||
import Clean.Utils.Field | ||
import Mathlib.Algebra.Field.Basic | ||
import Mathlib.Data.ZMod.Basic | ||
|
||
/- | ||
8-bit addition constraint gadget: the output `out` must be the sum of the | ||
inputs `x` and `y` modulo 256, and the carry `carry` must be the quotient | ||
of the sum of `x` and `y` divided by 256. | ||
-/ | ||
namespace Addition8 | ||
open Expression | ||
variable {p : ℕ} [p_is_prime: Fact p.Prime] [p_large_enough: Fact (p > 512)] | ||
instance : CommRing (F p) := ZMod.commRing p | ||
|
||
-- could use nats as expressions but seems to slightly complicate proofs | ||
instance {N: ℕ+} {M n : ℕ} : OfNat (Expression N M (F p)) n where | ||
ofNat := const (OfNat.ofNat n : (Fin p)) | ||
|
||
variable (N : ℕ+) (M : ℕ) | ||
|
||
def assumptions (x y z : Expression N M (F p)) := [ | ||
ByteLookup.lookup N M x, | ||
ByteLookup.lookup N M y, | ||
ByteLookup.lookup N M z, | ||
] | ||
|
||
def circuit (N : ℕ+) (M : ℕ) (x y out carry : Expression N M (F p)) : ConstraintGadget p N M := | ||
⟨ | ||
[ | ||
x + y - out - carry * (const 256) | ||
], | ||
assumptions N M x y out, | ||
[ | ||
Boolean.circuit N M carry | ||
] | ||
⟩ | ||
|
||
|
||
def spec (N : ℕ+) (M : ℕ) (x y z: Expression N M (F p)) : TraceOfLength N M (F p) -> Prop := | ||
fun trace => | ||
have x := trace.eval x | ||
have y := trace.eval y | ||
have z := trace.eval z | ||
z.val = (x.val + y.val) % 256 | ||
|
||
theorem equiv (N : ℕ+) (M : ℕ) (x y out: Expression N M (F p)) : | ||
(∀ X, | ||
(forallList (assumptions N M x y out) (fun lookup => lookup.prop X)) | ||
-> ( | ||
(∃ carry, constraints_hold (circuit N M x y out carry) X) | ||
↔ | ||
spec N M x y out X | ||
) | ||
) := by | ||
|
||
intro X | ||
simp [constraints_hold, forallList, ByteLookup.lookup] | ||
simp [TraceOfLength.eval, spec] | ||
intro hx_byte | ||
intro hy_byte | ||
intro hout_byte | ||
set x := X.eval x | ||
set y := X.eval y | ||
set out := X.eval out | ||
|
||
-- preliminaries | ||
have no_wrap_xy : (x + y).val = x.val + y.val := by | ||
rw [ZMod.val_add_of_lt] | ||
linarith [hx_byte, hy_byte, p_large_enough.elim] | ||
|
||
have val_self : (256 : ZMod p).val = 256 := ZMod.val_natCast_of_lt (by linarith [p_large_enough.elim]) | ||
|
||
have no_wrap_out : (out + 256).val = out.val + 256 := by | ||
rw [ZMod.val_add_of_lt, val_self] | ||
linarith [hout_byte, p_large_enough.elim] | ||
|
||
constructor | ||
-- soundness | ||
· rintro ⟨ carry, h ⟩ | ||
set carry := X.eval carry | ||
rcases (And.right h) with zero_carry | one_carry | ||
-- carry = 0 | ||
· rw [zero_carry] at h | ||
simp [←sub_eq_add_neg] at h | ||
rw [←Nat.mod_eq_of_lt hout_byte] | ||
have : out = x + y := calc | ||
_ = 0 + out := by ring | ||
_ = x + y - out + out := by rw [h] | ||
_ = x + y := by ring | ||
rw [this, no_wrap_xy] | ||
-- carry = 1 | ||
· have one_carry': carry = 1 := calc | ||
_ = carry - 0 := by ring | ||
_ = carry - (carry + -1) := by rw [one_carry] | ||
_ = 1 := by ring | ||
rw [one_carry'] at h | ||
simp [ZMod.val_add] at h | ||
rw [← Nat.mod_eq_of_lt hout_byte] | ||
rw [← no_wrap_xy] | ||
have : x + y = out + 256 := calc | ||
_ = x + y + -out + -256 + out + 256 := by ring | ||
_ = 0 + out + 256 := by rw [h] | ||
_ = _ := by ring | ||
rw [this, no_wrap_out] | ||
rw [Nat.add_mod, Nat.mod_self, Nat.add_zero, Nat.mod_mod] | ||
|
||
-- completeness | ||
· intro h | ||
have carry? := Nat.lt_or_ge (x.val + y.val) 256 | ||
rcases carry? with sum_lt_256 | sum_ge_256 | ||
|
||
-- first case: x + y <= 256, carry = 0 | ||
· use 0 | ||
simp [TraceOfLength.eval] | ||
rw [(Nat.mod_eq_iff_lt (by linarith)).mpr sum_lt_256, ← no_wrap_xy] at h | ||
rw [←sub_eq_add_neg, sub_eq_zero] | ||
apply_fun ZMod.val | ||
· symm; exact h | ||
· apply ZMod.val_injective | ||
|
||
-- second case: x + y > 256, carry = 1 | ||
· use 1 | ||
simp [TraceOfLength.eval] | ||
have one_lt : 1 < p := by linarith [p_large_enough.elim] | ||
rw [Nat.mod_eq_of_lt one_lt] | ||
have one_val : ((1 : ℕ) : ZMod p).val = 1 := ZMod.val_natCast_of_lt one_lt | ||
simp [one_val] | ||
suffices g : x + y = out + 256 from calc x + y + -out + -256 | ||
_ = out + 256 + -out + -256 := by rw [g] | ||
_ = 0 := by ring | ||
|
||
have sum_le_512 := Nat.add_lt_add hx_byte hy_byte | ||
simp at sum_le_512 | ||
have div_one : (x.val + y.val) / 256 = 1 := by | ||
apply Nat.div_eq_of_lt_le | ||
· simp; exact sum_ge_256 | ||
· simp; exact sum_le_512 | ||
have modulo_definition_div := Nat.mod_add_div (x.val + y.val) 256 | ||
rw [← h, div_one] at modulo_definition_div | ||
simp at modulo_definition_div | ||
symm | ||
apply_fun ZMod.val | ||
· rw [no_wrap_xy, no_wrap_out] | ||
exact modulo_definition_div | ||
· apply ZMod.val_injective | ||
|
||
end Addition8 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
import Mathlib.Data.ZMod.Basic | ||
|
||
theorem prime_1009 : Nat.Prime 1009 := by | ||
-- isn't there a more efficient way to prove primalitity? | ||
set_option maxRecDepth 900 in decide |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import Mathlib.Data.Fintype.Basic | ||
|
||
variable {α β : Type} {n : ℕ} | ||
|
||
def Vector (α : Type) (n: ℕ) := { l: List α // l.length = n } | ||
|
||
@[reducible] | ||
def vec (l: List α) : Vector α l.length := ⟨ l, rfl ⟩ | ||
|
||
namespace Vector | ||
theorem length_matches (v: Vector α n) : v.1.length = n := v.2 | ||
|
||
@[simp] | ||
def map (f: α → β) : Vector α n → Vector β n | ||
| ⟨ l, h ⟩ => ⟨ l.map f, by rw [List.length_map, h] ⟩ | ||
|
||
@[simp] | ||
def zip : Vector α n → Vector β n → Vector (α × β) n | ||
| ⟨ [], ha ⟩, ⟨ [], _ ⟩ => ⟨ [], ha ⟩ | ||
| ⟨ a::as, ha ⟩, ⟨ b::bs, hb ⟩ => ⟨ (a, b) :: List.zip as bs, by sorry ⟩ | ||
|
||
def get (v: Vector α n) (i: Fin n) : α := | ||
let i' : Fin v.1.length := Fin.cast (length_matches v).symm i | ||
v.val.get i' | ||
|
||
-- map over monad | ||
@[simp] | ||
def mapM { M : Type → Type } [Monad M] (v : Vector (M α) n) : M (Vector α n) := | ||
-- there `List.mapM` which we can use, but there doesn't seem to be an equivalent of `List.length_map` for monads | ||
do | ||
let l' ← List.mapM id v.val | ||
return ⟨ l', by sorry ⟩ | ||
|
||
-- other direction | ||
@[simp] | ||
def unmapM { M : Type → Type } [Monad M] (v : M (Vector α n)) : Vector (M α) n := | ||
sorry | ||
|
||
end Vector |