Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add32 proofs in new formulation #43

Merged
merged 18 commits into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Clean/Circuit/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def constraints_hold_default : List (PreOperation F) → Prop
table.contains (entry.map (fun e => e.eval)) ∧ constraints_hold_default ops
| _ => constraints_hold_default ops

@[simp]
def witness_length : List (PreOperation F) → ℕ
| [] => 0
| (Witness _) :: ops => witness_length ops + 1
Expand Down
1 change: 0 additions & 1 deletion Clean/GadgetsNew/Add8/Addition8Full.lean
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ def circuit : FormalCircuit (F p) (Inputs p) (field (F p)) where
let ⟨ asx, asy, as_carry_in ⟩ := as
have as': Add8FullCarry.circuit.assumptions { x, y, carry_in } := ⟨asx, asy, as_carry_in⟩
specialize h_holds (by assumption)
dsimp [ProvableType.from_values] at h_holds

guard_hyp h_holds : Add8FullCarry.circuit.spec
{ x, y, carry_in }
Expand Down
4 changes: 3 additions & 1 deletion Clean/GadgetsNew/Add8/Addition8FullCarry.lean
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def Inputs (p : ℕ) : TypePair := ⟨
InputStruct (F p)

@[simp]
instance : ProvableType (F p) (Inputs p) where
size := 3
to_vars s := vec [s.x, s.y, s.carry_in]
Expand All @@ -40,6 +41,7 @@ def Outputs (p : ℕ) : TypePair := ⟨
OutputStruct (F p)

@[simp]
instance : ProvableType (F p) (Outputs p) where
size := 2
to_vars s := vec [s.z, s.carry_out]
Expand Down Expand Up @@ -169,7 +171,7 @@ def circuit : FormalCircuit (F p) (Inputs p) (Outputs p) where
linarith)

have ⟨as_x, as_y, as_carry_in⟩ := as
have carry_in_bound := FieldUtils.boolean_le_2 carry_in as_carry_in
have carry_in_bound := FieldUtils.boolean_lt_2 as_carry_in

have completeness2 : goal_bool := by
apply Add8Theorems.completeness_bool
Expand Down
2 changes: 1 addition & 1 deletion Clean/GadgetsNew/Add8/Theorems.lean
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ theorem soundness (x y out carry_in carry_out: F p):
(out.val = (x.val + y.val + carry_in.val) % 256
∧ carry_out.val = (x.val + y.val + carry_in.val) / 256):= by
intros hx hy hout carry_in_bool carry_out_bool h
have carry_in_bound := FieldUtils.boolean_le_2 carry_in carry_in_bool
have carry_in_bound := FieldUtils.boolean_lt_2 carry_in_bool

rcases carry_out_bool with zero_carry | one_carry
-- case with zero carry
Expand Down
195 changes: 164 additions & 31 deletions Clean/GadgetsNew/Addition32Full.lean
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import Clean.Types.U32

namespace Addition32Full
variable {p : ℕ} [Fact (p ≠ 0)] [Fact p.Prime]
variable [p_large_enough: Fact (p > 512)]
variable [p_large_enough: Fact (p > 2*2^32)]

open Provable (field field2 fields)

Expand All @@ -17,6 +17,7 @@ def Inputs (p : ℕ) : TypePair := ⟨
InputStruct (F p)

@[simp]
instance : ProvableType (F p) (Inputs p) where
size := 9 -- 4 + 4 + 1
to_vars s := vec [s.x.x0, s.x.x1, s.x.x2, s.x.x3, s.y.x0, s.y.x1, s.y.x2, s.y.x3, s.carry_in]
Expand Down Expand Up @@ -49,33 +50,15 @@ instance : ProvableType (F p) (Outputs p) where
let ⟨ [z0, z1, z2, z3, carry_out], _ ⟩ := v
⟨ ⟨ z0, z1, z2, z3 ⟩, carry_out ⟩

open Add8FullCarry (add8_full_carry)

def add32_full (input : (Inputs p).var) : Circuit (F p) (Outputs p).var := do
let ⟨x, y, carry_in⟩ := input

let {
z := z0,
carry_out := c0
} ← subcircuit Add8FullCarry.circuit ⟨ x.x0, y.x0, carry_in ⟩

let {
z := z1,
carry_out := c1
} ← subcircuit Add8FullCarry.circuit ⟨ x.x1, y.x1, c0 ⟩

let {
z := z2,
carry_out := c2
} ← subcircuit Add8FullCarry.circuit ⟨ x.x2, y.x2, c1 ⟩

let {
z := z3,
carry_out := c3
} ← subcircuit Add8FullCarry.circuit ⟨ x.x3, y.x3, c2 ⟩

return {
z := U32.mk z0 z1 z2 z3,
carry_out := c3
}
let { z := z0, carry_out := c0 } ← add8_full_carry ⟨ x.x0, y.x0, carry_in ⟩
let { z := z1, carry_out := c1 } ← add8_full_carry ⟨ x.x1, y.x1, c0 ⟩
let { z := z2, carry_out := c2 } ← add8_full_carry ⟨ x.x2, y.x2, c1 ⟩
let { z := z3, carry_out := c3 } ← add8_full_carry ⟨ x.x3, y.x3, c2 ⟩
return { z := U32.mk z0 z1 z2 z3, carry_out := c3 }

def assumptions (input : (Inputs p).value) :=
let ⟨x, y, carry_in⟩ := input
Expand All @@ -84,17 +67,167 @@ def assumptions (input : (Inputs p).value) :=
def spec (input : (Inputs p).value) (out: (Outputs p).value) :=
let ⟨x, y, carry_in⟩ := input
let ⟨z, carry_out⟩ := out
z.value = (x.value + y.value + carry_in.val) % 2^32 ∧
z.is_normalized ∧
carry_out.val = (x.value + y.value + carry_in.val) / 2^32
z.value = (x.value + y.value + carry_in.val) % 2^32
∧ carry_out.val = (x.value + y.value + carry_in.val) / 2^32
∧ z.is_normalized ∧ (carry_out = 0 ∨ carry_out = 1)

set_option linter.unusedVariables false

def circuit : FormalCircuit (F p) (Inputs p) (Outputs p) where
main := add32_full
assumptions := assumptions
spec := spec
soundness := by
sorry
rintro ctx env ⟨ x, y, carry_in ⟩ ⟨ x_var, y_var, carry_in_var ⟩ h_inputs as h
let ⟨ x0, x1, x2, x3 ⟩ := x
let ⟨ y0, y1, y2, y3 ⟩ := y
let ⟨ x0_var, x1_var, x2_var, x3_var ⟩ := x_var
let ⟨ y0_var, y1_var, y2_var, y3_var ⟩ := y_var
have : x0_var.eval_env env = x0 := by injections
have : x1_var.eval_env env = x1 := by injections
have : x2_var.eval_env env = x2 := by injections
have : x3_var.eval_env env = x3 := by injections
have : y0_var.eval_env env = y0 := by injections
have : y1_var.eval_env env = y1 := by injections
have : y2_var.eval_env env = y2 := by injections
have : y3_var.eval_env env = y3 := by injections
have : carry_in_var.eval_env env = carry_in := by injection h_inputs

-- simplify assumptions
dsimp [assumptions, U32.is_normalized] at as
have ⟨ x_norm, y_norm, carry_in_bool ⟩ := as
have ⟨ x0_byte, x1_byte, x2_byte, x3_byte ⟩ := x_norm
have ⟨ y0_byte, y1_byte, y2_byte, y3_byte ⟩ := y_norm

-- simplify circuit
dsimp [add32_full, Boolean.circuit, Circuit.formal_assertion_to_subcircuit] at h
set i0 := ctx.offset
have : ctx.offset = i0 := rfl
set z0 := env i0
set c0 := env (i0 + 1)
set z1 := env (i0 + 2)
set c1 := env (i0 + 3)
set z2 := env (i0 + 4)
set c2 := env (i0 + 5)
set z3 := env (i0 + 6)
set c3 := env (i0 + 7)
rw [‹x0_var.eval_env env = x0›, ‹y0_var.eval_env env = y0›, ‹carry_in_var.eval_env env = carry_in›] at h
rw [‹x1_var.eval_env env = x1›, ‹y1_var.eval_env env = y1›] at h
rw [‹x2_var.eval_env env = x2›, ‹y2_var.eval_env env = y2›] at h
rw [‹x3_var.eval_env env = x3›, ‹y3_var.eval_env env = y3›] at h
rw [ByteTable.equiv z0, ByteTable.equiv z1, ByteTable.equiv z2, ByteTable.equiv z3] at h
simp only [true_implies] at h
have ⟨ z0_byte, c0_bool, h0, z1_byte, c1_bool, h1, z2_byte, c2_bool, h2, z3_byte, c3_bool, h3 ⟩ := h

-- simplify spec
dsimp [spec, U32.value, U32.is_normalized]
rw [‹ctx.offset = i0›, (by rfl: env (i0 + 7) = c3)]
rw [(by rfl: env i0 = z0), (by rfl: env (i0 + 2) = z1), (by rfl: env (i0 + 4) = z2), (by rfl: env (i0 + 6) = z3)]

-- add up all the equations
let z := z0 + z1*256 + z2*256^2 + z3*256^3
let x := x0 + x1*256 + x2*256^2 + x3*256^3
let y := y0 + y1*256 + y2*256^2 + y3*256^3
let lhs := z + c3*2^32
let rhs₀ := x0 + y0 + carry_in + -1 * z0 + -1 * (c0 * 256) -- h0 expression
let rhs₁ := x1 + y1 + c0 + -1 * z1 + -1 * (c1 * 256) -- h1 expression
let rhs₂ := x2 + y2 + c1 + -1 * z2 + -1 * (c2 * 256) -- h2 expression
let rhs₃ := x3 + y3 + c2 + -1 * z3 + -1 * (c3 * 256) -- h3 expression

have h_add := calc z + c3*2^32
-- substitute equations
_ = lhs + 0 + 256*0 + 256^2*0 + 256^3*0 := by ring
_ = lhs + rhs₀ + 256*rhs₁ + 256^2*rhs₂ + 256^3*rhs₃ := by dsimp [rhs₀, rhs₁, rhs₂, rhs₃]; rw [h0, h1, h2, h3]
-- simplify
_ = x + y + carry_in := by ring

-- move added equation into Nat
let z_nat := z0.val + z1.val*256 + z2.val*256^2 + z3.val*256^3
let x_nat := x0.val + x1.val*256 + x2.val*256^2 + x3.val*256^3
let y_nat := y0.val + y1.val*256 + y2.val*256^2 + y3.val*256^3

have : c3.val < 2 := FieldUtils.boolean_lt_2 c3_bool
have : carry_in.val < 2 := FieldUtils.boolean_lt_2 carry_in_bool

have h_add_nat := calc z_nat + c3.val*2^32
_ = (z + c3*2^32).val := by dsimp only [z_nat]; field_to_nat_u32
_ = (x + y + carry_in).val := congrArg ZMod.val h_add
_ = x_nat + y_nat + carry_in.val := by dsimp only [x_nat, y_nat]; field_to_nat_u32

-- show that lhs splits into low and high 32 bits
have : z_nat < 2^32 := by dsimp only [z_nat]; linarith

have h_low : z_nat = (x_nat + y_nat + carry_in.val) % 2^32 := by
suffices h : z_nat = z_nat % 2^32 by
rw [← h_add_nat, ← Nat.add_mod_mod, ← Nat.mul_mod_mod]
simpa using h
rw [Nat.mod_eq_of_lt ‹z_nat < 2^32›]

have h_high : c3.val = (x_nat + y_nat + carry_in.val) / 2^32 := by
rw [← h_add_nat, Nat.add_mul_div_right _ _ (by norm_num)]
rw [Nat.div_eq_of_lt ‹z_nat < 2^32›, zero_add]

exact ⟨ h_low, h_high, ⟨ z0_byte, z1_byte, z2_byte, z3_byte ⟩, c3_bool ⟩

completeness := by
sorry
rintro ctx ⟨ x, y, carry_in ⟩ ⟨ x_var, y_var, carry_in_var ⟩ h_inputs as
let ⟨ x0, x1, x2, x3 ⟩ := x
let ⟨ y0, y1, y2, y3 ⟩ := y
let ⟨ x0_var, x1_var, x2_var, x3_var ⟩ := x_var
let ⟨ y0_var, y1_var, y2_var, y3_var ⟩ := y_var
have : x0_var.eval = x0 := by injections
have : x1_var.eval = x1 := by injections
have : x2_var.eval = x2 := by injections
have : x3_var.eval = x3 := by injections
have : y0_var.eval = y0 := by injections
have : y1_var.eval = y1 := by injections
have : y2_var.eval = y2 := by injections
have : y3_var.eval = y3 := by injections
have : carry_in_var.eval = carry_in := by injections

-- simplify assumptions
dsimp [assumptions, U32.is_normalized] at as
have ⟨ x_norm, y_norm, carry_in_bool ⟩ := as
have ⟨ x0_byte, x1_byte, x2_byte, x3_byte ⟩ := x_norm
have ⟨ y0_byte, y1_byte, y2_byte, y3_byte ⟩ := y_norm

-- simplify circuit
dsimp [add32_full, Boolean.circuit, Circuit.formal_assertion_to_subcircuit]
rw [‹x0_var.eval = x0›, ‹y0_var.eval = y0›, ‹carry_in_var.eval = carry_in›]
rw [‹x1_var.eval = x1›, ‹y1_var.eval = y1›]
rw [‹x2_var.eval = x2›, ‹y2_var.eval = y2›]
rw [‹x3_var.eval = x3›, ‹y3_var.eval = y3›]
set z0 := FieldUtils.mod_256 (x0 + y0 + carry_in)
set c0 := FieldUtils.floordiv (x0 + y0 + carry_in) 256
set z1 := FieldUtils.mod_256 (x1 + y1 + c0)
set c1 := FieldUtils.floordiv (x1 + y1 + c0) 256
set z2 := FieldUtils.mod_256 (x2 + y2 + c1)
set c2 := FieldUtils.floordiv (x2 + y2 + c1) 256
set z3 := FieldUtils.mod_256 (x3 + y3 + c2)
set c3 := FieldUtils.floordiv (x3 + y3 + c2) 256

simp only [true_and]

-- the add8 completeness proof, four times
have add8_completeness {x y c_in : F p} :
let z := FieldUtils.mod_256 (x + y + c_in);
let c_out := FieldUtils.floordiv (x + y + c_in) 256;
x.val < 256 → y.val < 256 → c_in = 0 ∨ c_in = 1 →
ByteTable.contains (vec [z]) ∧ (c_out = 0 ∨ c_out = 1) ∧ x + y + c_in + -1 * z + -1 * (c_out * 256) = 0
:= by
intro z c_out _ _ hc
have : z.val < 256 := FieldUtils.mod_256_lt (x + y + c_in)
use ByteTable.completeness z this
have : c_in.val < 2 := FieldUtils.boolean_lt_2 hc
have : (x + y + c_in).val < 512 := by field_to_nat_u32
use FieldUtils.floordiv_bool this
rw [FieldUtils.mod_add_div_256 (x + y + c_in)]
ring

have ⟨ z0_byte, c0_bool, h0 ⟩ := add8_completeness x0_byte y0_byte carry_in_bool
have ⟨ z1_byte, c1_bool, h1 ⟩ := add8_completeness x1_byte y1_byte c0_bool
have ⟨ z2_byte, c2_bool, h2 ⟩ := add8_completeness x2_byte y2_byte c1_bool
have ⟨ z3_byte, c3_bool, h3 ⟩ := add8_completeness x3_byte y3_byte c2_bool

exact ⟨ z0_byte, c0_bool, h0, z1_byte, c1_bool, h1, z2_byte, c2_bool, h2, z3_byte, c3_bool, h3 ⟩
end Addition32Full
3 changes: 3 additions & 0 deletions Clean/GadgetsNew/ByteLookup.lean
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ def ByteTable.completeness (x: F p) : x.val < 256 → ByteTable.contains (vec [x
simp [h']
rw [FieldUtils.nat_to_field_of_val_eq_iff]

def ByteTable.equiv (x: F p) : ByteTable.contains (vec [x]) ↔ x.val < 256 :=
⟨ByteTable.soundness x, ByteTable.completeness x⟩

def byte_lookup (x: Expression (F p)) := lookup {
table := ByteTable
entry := vec [x]
Expand Down
58 changes: 56 additions & 2 deletions Clean/Types/U32.lean
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import Clean.GadgetsNew.ByteLookup

section
variable {p : ℕ} [Fact (p ≠ 0)] [Fact p.Prime]
variable [p_large_enough: Fact (p > 512)]
variable {p : ℕ} [Fact p.Prime]
variable [p_large_enough: Fact (p > 2*2^32)]

instance : NeZero p := ⟨‹Fact p.Prime›.elim.ne_zero⟩
instance : Fact (p > 512) := by apply Fact.mk; linarith [p_large_enough.elim]

/--
A 32-bit unsigned integer is represented using four limbs of 8 bits each.
Expand Down Expand Up @@ -44,6 +47,12 @@ def is_normalized (x: U32 (F p)) :=
def value (x: U32 (F p)) :=
x.x0.val + x.x1.val * 256 + x.x2.val * 256^2 + x.x3.val * 256^3

/--
Return the value of a 32-bit unsigned integer as a field element.
-/
def value_field (x: U32 (F p)) : F p :=
x.x0 + x.x1 * 256 + x.x2 * 256^2 + x.x3 * 256^3

/--
Return a 32-bit unsigned integer from a natural number, by decomposing
it into four limbs of 8 bits each.
Expand All @@ -69,5 +78,50 @@ lemma wrapping_add_correct (x y z: U32 (F p)) :
x.wrapping_add y = z ↔ z.value = (x.value + y.value) % 2^32 := by
sorry

-- U32-related tactic and lemmas

lemma val_eq_256 : (256 : F p).val = 256 := FieldUtils.val_lt_p 256 (by linarith [p_large_enough.elim])
lemma val_eq_256p2 : (256^2 : F p).val = 256^2 := by ring_nf; exact FieldUtils.val_lt_p (256^2) (by linarith [p_large_enough.elim])
lemma val_eq_256p3 : (256^3 : F p).val = 256^3 := by ring_nf; exact FieldUtils.val_lt_p (256^3) (by linarith [p_large_enough.elim])
lemma val_eq_256p4 : (256^4 : F p).val = 256^4 := by ring_nf; exact FieldUtils.val_lt_p (256^4) (by linarith [p_large_enough.elim])
lemma val_eq_2p32 : (2^32 : F p).val = 2^32 := by have := val_eq_256p4 (p:=p); ring_nf at *; assumption

/--
tactic script to fully rewrite a ZMod expression to its Nat version, given that
the expression is smaller than the modulus.

```
example (x y : F p) (hx: x.val < 256) (hy: y.val < 256) :
(x + y * 256).val = x.val + y.val * 256 := by field_to_nat_u32
```

expected context:
- the equation to prove as the goal
- size assumptions on variables and a sufficient `p > ...` instance

if no sufficient inequalities are in the context, then the tactic will leave an equation of the form `expr : Nat < p` unsolved.

note: this version is optimized for uint32 arithmetic:
- specifically handles field constants 256, 256^2, 256^3, 256^4 = 2^32
- expects `[Fact (p > 2*256^4)]` in the context
-/
syntax "field_to_nat_u32" : tactic
macro_rules
| `(tactic|field_to_nat_u32) =>
`(tactic|(
repeat rw [ZMod.val_add] -- (a + b).val = (a.val + b.val) % p
repeat rw [ZMod.val_mul] -- (a * b).val = (a.val * b.val) % p
repeat rw [U32.val_eq_256]
repeat rw [U32.val_eq_256p2]
repeat rw [U32.val_eq_256p3]
repeat rw [U32.val_eq_256p4]
repeat rw [U32.val_eq_2p32]
simp only [Nat.reducePow, Nat.add_mod_mod, Nat.mod_add_mod, Nat.mul_mod_mod, Nat.mod_mul_mod]
rw [Nat.mod_eq_of_lt _]
repeat linarith [‹Fact (_ > 2 * 2^32)›.elim]))

lemma value_eq {x0 x1 x2 x3: F p} (h0 : x0.val < 256) (h1 : x1.val < 256) (h2 : x2.val < 256) (h3 : x3.val < 256) :
(x0 + x1 * 256 + x2 * 256^2 + x3 * 256^3).val = x0.val + x1.val * 256 + x2.val * 256^2 + x3.val * 256^3 := by
field_to_nat_u32
end U32
end
Loading
Loading