Skip to content

Commit

Permalink
Merge pull request #27 from zksecurity/circuit-prototype
Browse files Browse the repository at this point in the history
Circuit monad
  • Loading branch information
mitschabaude authored Dec 16, 2024
2 parents cf5600d + cee30fb commit 18f12d1
Show file tree
Hide file tree
Showing 8 changed files with 1,289 additions and 1 deletion.
914 changes: 914 additions & 0 deletions Clean/Circuit/Circuit.lean

Large diffs are not rendered by default.

84 changes: 84 additions & 0 deletions Clean/Circuit/Expression.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import Mathlib.Algebra.Field.Basic
import Mathlib.Data.ZMod.Basic
import Clean.Utils.Primes
import Clean.Utils.Vector

variable {F: Type}

structure Variable (F : Type) where
index: ℕ
witness: Unit → F

instance : Repr (Variable F) where
reprPrec v _ := "x" ++ repr v.index

inductive Expression (F : Type) where
| var : Variable F -> Expression F
| const : F -> Expression F
| add : Expression F -> Expression F -> Expression F
| mul : Expression F -> Expression F -> Expression F

namespace Expression
variable [Field F]

@[simp]
def eval : Expression F → F
| var v => v.witness ()
| const c => c
| add x y => eval x + eval y
| mul x y => eval x * eval y

/--
Evaluate expression given an external `environment` that determines the assignment
of all variables.
This is needed when we want to make statements about a circuit in the adversarial
situation where the prover can assign anything to variables.
-/
@[simp]
def eval_env (env: ℕ → F) : Expression F → F
| var v => env v.index
| const c => c
| add x y => eval_env env x + eval_env env y
| mul x y => eval_env env x * eval_env env y

def toString [Repr F] : Expression F → String
| var v => "x" ++ reprStr v.index
| const c => reprStr c
| add x y => "(" ++ toString x ++ " + " ++ toString y ++ ")"
| mul x y => "(" ++ toString x ++ " * " ++ toString y ++ ")"

instance [Repr F] : Repr (Expression F) where
reprPrec e _ := toString e

-- combine expressions elegantly
instance : Zero (Expression F) where
zero := const 0

instance : One (Expression F) where
one := const 1

instance : Add (Expression F) where
add := add

instance : Neg (Expression F) where
neg e := mul (const (-1)) e

instance : Sub (Expression F) where
sub e₁ e₂ := add e₁ (-e₂)

instance : Mul (Expression F) where
mul := mul

instance : Coe F (Expression F) where
coe f := const f

instance : Coe (Variable F) (Expression F) where
coe x := var x

instance : Coe (Expression F) F where
coe x := x.eval

instance : HMul F (Expression F) (Expression F) where
hMul := fun f e => mul f e
end Expression
92 changes: 92 additions & 0 deletions Clean/Circuit/Provable.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import Mathlib.Algebra.Field.Basic
import Mathlib.Data.ZMod.Basic
import Clean.Utils.Primes
import Clean.Utils.Vector
import Clean.Circuit.Expression

variable {F: Type} [Field F]

structure TypePair where
var: Type
value: Type

-- class of types that are composed of variables,
-- and can be evaluated into something that is composed of field elements
class ProvableType (F: Type) (α: TypePair) where
size : ℕ
to_vars : α.var → Vector (Expression F) size
from_vars : Vector (Expression F) size → α.var
to_values : α.value → Vector F size
from_values : Vector F size → α.value

-- or is it better as a structure?
structure ProvableType' (F : Type) where
var: Type
value: Type
size : ℕ
to_vars : var → Vector (Expression F) size
from_vars : Vector (Expression F) size → var
to_values : value → Vector F size
from_values : Vector F size → value

-- or like this?
def Provable' (F: Type) := { α : TypePair // ∃ p : Type, p = ProvableType F α }

namespace Provable
variable {α β γ: TypePair} [ProvableType F α] [ProvableType F β] [ProvableType F γ]

@[simp]
def eval (F: Type) [Field F] [ProvableType F α] (x: α.var) : α.value :=
let n := ProvableType.size F α
let vars : Vector (Expression F) n := ProvableType.to_vars x
let values := vars.map (fun v => v.eval)
ProvableType.from_values values

@[simp]
def eval_env (env: ℕ → F) (x: α.var) : α.value :=
let n := ProvableType.size F α
let vars : Vector (Expression F) n := ProvableType.to_vars x
let values := vars.map (fun v => v.eval_env env)
ProvableType.from_values values

def const (F: Type) [ProvableType F α] (x: α.value) : α.var :=
let n := ProvableType.size F α
let values : Vector F n := ProvableType.to_values x
ProvableType.from_vars (values.map (fun v => Expression.const v))

@[reducible]
def field (F : Type) : TypePair := ⟨ Expression F, F ⟩

instance : ProvableType F (field F) where
size := 1
to_vars x := vec [x]
from_vars v := v.get ⟨ 0, by norm_num ⟩
to_values x := vec [x]
from_values v := v.get ⟨ 0, by norm_num ⟩

@[reducible]
def pair (α β : TypePair) : TypePair := ⟨ α.var × β.var, α.value × β.value ⟩

@[reducible]
def field2 (F : Type) : TypePair := pair (field F) (field F)

instance : ProvableType F (field2 F) where
size := 2
to_vars pair := vec [pair.1, pair.2]
from_vars v := (v.get ⟨ 0, by norm_num ⟩, v.get ⟨ 1, by norm_num ⟩)
to_values pair :=vec [pair.1, pair.2]
from_values v := (v.get ⟨ 0, by norm_num ⟩, v.get ⟨ 1, by norm_num ⟩)

variable {n: ℕ}
def vec (α: TypePair) (n: ℕ) : TypePair := ⟨ Vector α.var n, Vector α.value n ⟩

@[reducible]
def fields (F: Type) (n: ℕ) : TypePair := vec (field F) n

instance : ProvableType F (fields F n) where
size := n
to_vars x := x
from_vars v := v
to_values x := x
from_values v := v
end Provable
2 changes: 1 addition & 1 deletion Clean/Expression.lean
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def TraceOfLength (N : ℕ+) (M : ℕ) (F : Type) : Type := { env : Trace N F /

def Trace.getLe {N: ℕ+} {F : Type} : (env : Trace N F) -> (row : Fin env.len) -> (j : Fin N) -> F
| _ +> currRow, ⟨0, _⟩, columnIndex => currRow columnIndex
| rest +> _, ⟨Nat.succ i, h⟩, j => getLe rest ⟨i, Nat.le_of_succ_le_succ h⟩ j
| rest +> _, ⟨i + 1, h⟩, j => getLe rest ⟨i, Nat.le_of_succ_le_succ h⟩ j

def TraceOfLength.get {N: ℕ+} {M : ℕ} {F : Type} : (env : TraceOfLength N M F) -> (i : Fin M) -> (j : Fin N) -> F
| ⟨env, h⟩, i, j => env.getLe (by rw [←h] at i; exact i) j
Expand Down
151 changes: 151 additions & 0 deletions Clean/Gadgets/Addition8New.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import Clean.GenericConstraint
import Clean.Expression
import Clean.Gadgets.Boolean
import Clean.Gadgets.ByteLookup
import Clean.Utils.Field
import Mathlib.Algebra.Field.Basic
import Mathlib.Data.ZMod.Basic

/-
8-bit addition constraint gadget: the output `out` must be the sum of the
inputs `x` and `y` modulo 256, and the carry `carry` must be the quotient
of the sum of `x` and `y` divided by 256.
-/
namespace Addition8
open Expression
variable {p : ℕ} [p_is_prime: Fact p.Prime] [p_large_enough: Fact (p > 512)]
instance : CommRing (F p) := ZMod.commRing p

-- could use nats as expressions but seems to slightly complicate proofs
instance {N: ℕ+} {M n : ℕ} : OfNat (Expression N M (F p)) n where
ofNat := const (OfNat.ofNat n : (Fin p))

variable (N : ℕ+) (M : ℕ)

def assumptions (x y z : Expression N M (F p)) := [
ByteLookup.lookup N M x,
ByteLookup.lookup N M y,
ByteLookup.lookup N M z,
]

def circuit (N : ℕ+) (M : ℕ) (x y out carry : Expression N M (F p)) : ConstraintGadget p N M :=
[
x + y - out - carry * (const 256)
],
assumptions N M x y out,
[
Boolean.circuit N M carry
]


def spec (N : ℕ+) (M : ℕ) (x y z: Expression N M (F p)) : TraceOfLength N M (F p) -> Prop :=
fun trace =>
have x := trace.eval x
have y := trace.eval y
have z := trace.eval z
z.val = (x.val + y.val) % 256

theorem equiv (N : ℕ+) (M : ℕ) (x y out: Expression N M (F p)) :
(∀ X,
(forallList (assumptions N M x y out) (fun lookup => lookup.prop X))
-> (
(∃ carry, constraints_hold (circuit N M x y out carry) X)
spec N M x y out X
)
) := by

intro X
simp [constraints_hold, forallList, ByteLookup.lookup]
simp [TraceOfLength.eval, spec]
intro hx_byte
intro hy_byte
intro hout_byte
set x := X.eval x
set y := X.eval y
set out := X.eval out

-- preliminaries
have no_wrap_xy : (x + y).val = x.val + y.val := by
rw [ZMod.val_add_of_lt]
linarith [hx_byte, hy_byte, p_large_enough.elim]

have val_self : (256 : ZMod p).val = 256 := ZMod.val_natCast_of_lt (by linarith [p_large_enough.elim])

have no_wrap_out : (out + 256).val = out.val + 256 := by
rw [ZMod.val_add_of_lt, val_self]
linarith [hout_byte, p_large_enough.elim]

constructor
-- soundness
· rintro ⟨ carry, h ⟩
set carry := X.eval carry
rcases (And.right h) with zero_carry | one_carry
-- carry = 0
· rw [zero_carry] at h
simp [←sub_eq_add_neg] at h
rw [←Nat.mod_eq_of_lt hout_byte]
have : out = x + y := calc
_ = 0 + out := by ring
_ = x + y - out + out := by rw [h]
_ = x + y := by ring
rw [this, no_wrap_xy]
-- carry = 1
· have one_carry': carry = 1 := calc
_ = carry - 0 := by ring
_ = carry - (carry + -1) := by rw [one_carry]
_ = 1 := by ring
rw [one_carry'] at h
simp [ZMod.val_add] at h
rw [← Nat.mod_eq_of_lt hout_byte]
rw [← no_wrap_xy]
have : x + y = out + 256 := calc
_ = x + y + -out + -256 + out + 256 := by ring
_ = 0 + out + 256 := by rw [h]
_ = _ := by ring
rw [this, no_wrap_out]
rw [Nat.add_mod, Nat.mod_self, Nat.add_zero, Nat.mod_mod]

-- completeness
· intro h
have carry? := Nat.lt_or_ge (x.val + y.val) 256
rcases carry? with sum_lt_256 | sum_ge_256

-- first case: x + y <= 256, carry = 0
· use 0
simp [TraceOfLength.eval]
rw [(Nat.mod_eq_iff_lt (by linarith)).mpr sum_lt_256, ← no_wrap_xy] at h
rw [←sub_eq_add_neg, sub_eq_zero]
apply_fun ZMod.val
· symm; exact h
· apply ZMod.val_injective

-- second case: x + y > 256, carry = 1
· use 1
simp [TraceOfLength.eval]
have one_lt : 1 < p := by linarith [p_large_enough.elim]
rw [Nat.mod_eq_of_lt one_lt]
have one_val : ((1 : ℕ) : ZMod p).val = 1 := ZMod.val_natCast_of_lt one_lt
simp [one_val]
suffices g : x + y = out + 256 from calc x + y + -out + -256
_ = out + 256 + -out + -256 := by rw [g]
_ = 0 := by ring

have sum_le_512 := Nat.add_lt_add hx_byte hy_byte
simp at sum_le_512
have div_one : (x.val + y.val) / 256 = 1 := by
apply Nat.div_eq_of_lt_le
· simp; exact sum_ge_256
· simp; exact sum_le_512
have modulo_definition_div := Nat.mod_add_div (x.val + y.val) 256
rw [← h, div_one] at modulo_definition_div
simp at modulo_definition_div
symm
apply_fun ZMod.val
· rw [no_wrap_xy, no_wrap_out]
exact modulo_definition_div
· apply ZMod.val_injective

end Addition8
3 changes: 3 additions & 0 deletions Clean/GenericConstraint.lean
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ def forallList {α : Type} (v : List α) (p : α -> Prop) : Prop :=
| [] => true
| (x::xs) => p x ∧ forallList xs p

def constraints_hold {p M: ℕ} {N: ℕ+} [Fact p.Prime] (circuit : ConstraintGadget p N M) (trace : TraceOfLength N M (F p)) :=
(forallList (fullConstraintSet circuit) (fun constraint => trace.eval constraint = 0))

/-
A Constraint is a typeclass that packages the definition of the circuit together with its higher
level specification.
Expand Down
5 changes: 5 additions & 0 deletions Clean/Utils/Primes.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import Mathlib.Data.ZMod.Basic

theorem prime_1009 : Nat.Prime 1009 := by
-- isn't there a more efficient way to prove primalitity?
set_option maxRecDepth 900 in decide
39 changes: 39 additions & 0 deletions Clean/Utils/Vector.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import Mathlib.Data.Fintype.Basic

variable {α β : Type} {n : ℕ}

def Vector (α : Type) (n: ℕ) := { l: List α // l.length = n }

@[reducible]
def vec (l: List α) : Vector α l.length := ⟨ l, rfl ⟩

namespace Vector
theorem length_matches (v: Vector α n) : v.1.length = n := v.2

@[simp]
def map (f: α → β) : Vector α n → Vector β n
| ⟨ l, h ⟩ => ⟨ l.map f, by rw [List.length_map, h] ⟩

@[simp]
def zip : Vector α n → Vector β n → Vector (α × β) n
| ⟨ [], ha ⟩, ⟨ [], _ ⟩ => ⟨ [], ha ⟩
| ⟨ a::as, ha ⟩, ⟨ b::bs, hb ⟩ => ⟨ (a, b) :: List.zip as bs, by sorry

def get (v: Vector α n) (i: Fin n) : α :=
let i' : Fin v.1.length := Fin.cast (length_matches v).symm i
v.val.get i'

-- map over monad
@[simp]
def mapM { M : TypeType } [Monad M] (v : Vector (M α) n) : M (Vector α n) :=
-- there `List.mapM` which we can use, but there doesn't seem to be an equivalent of `List.length_map` for monads
do
let l' ← List.mapM id v.val
return ⟨ l', by sorry

-- other direction
@[simp]
def unmapM { M : TypeType } [Monad M] (v : M (Vector α n)) : Vector (M α) n :=
sorry

end Vector

0 comments on commit 18f12d1

Please sign in to comment.