diff --git a/Clean/Circuit/Basic.lean b/Clean/Circuit/Basic.lean index b39f187..9db30ff 100644 --- a/Clean/Circuit/Basic.lean +++ b/Clean/Circuit/Basic.lean @@ -73,6 +73,7 @@ def constraints_hold_default : List (PreOperation F) → Prop table.contains (entry.map (fun e => e.eval)) ∧ constraints_hold_default ops | _ => constraints_hold_default ops +@[simp] def witness_length : List (PreOperation F) → ℕ | [] => 0 | (Witness _) :: ops => witness_length ops + 1 diff --git a/Clean/GadgetsNew/Add8/Addition8Full.lean b/Clean/GadgetsNew/Add8/Addition8Full.lean index 7c1342d..f359db8 100644 --- a/Clean/GadgetsNew/Add8/Addition8Full.lean +++ b/Clean/GadgetsNew/Add8/Addition8Full.lean @@ -74,7 +74,6 @@ def circuit : FormalCircuit (F p) (Inputs p) (field (F p)) where let ⟨ asx, asy, as_carry_in ⟩ := as have as': Add8FullCarry.circuit.assumptions { x, y, carry_in } := ⟨asx, asy, as_carry_in⟩ specialize h_holds (by assumption) - dsimp [ProvableType.from_values] at h_holds guard_hyp h_holds : Add8FullCarry.circuit.spec { x, y, carry_in } diff --git a/Clean/GadgetsNew/Add8/Addition8FullCarry.lean b/Clean/GadgetsNew/Add8/Addition8FullCarry.lean index a8a8fdf..60efe7f 100644 --- a/Clean/GadgetsNew/Add8/Addition8FullCarry.lean +++ b/Clean/GadgetsNew/Add8/Addition8FullCarry.lean @@ -19,6 +19,7 @@ def Inputs (p : ℕ) : TypePair := ⟨ InputStruct (F p) ⟩ +@[simp] instance : ProvableType (F p) (Inputs p) where size := 3 to_vars s := vec [s.x, s.y, s.carry_in] @@ -40,6 +41,7 @@ def Outputs (p : ℕ) : TypePair := ⟨ OutputStruct (F p) ⟩ +@[simp] instance : ProvableType (F p) (Outputs p) where size := 2 to_vars s := vec [s.z, s.carry_out] @@ -169,7 +171,7 @@ def circuit : FormalCircuit (F p) (Inputs p) (Outputs p) where linarith) have ⟨as_x, as_y, as_carry_in⟩ := as - have carry_in_bound := FieldUtils.boolean_le_2 carry_in as_carry_in + have carry_in_bound := FieldUtils.boolean_lt_2 as_carry_in have completeness2 : goal_bool := by apply Add8Theorems.completeness_bool diff --git a/Clean/GadgetsNew/Add8/Theorems.lean b/Clean/GadgetsNew/Add8/Theorems.lean index 2a8b857..5d038db 100644 --- a/Clean/GadgetsNew/Add8/Theorems.lean +++ b/Clean/GadgetsNew/Add8/Theorems.lean @@ -91,7 +91,7 @@ theorem soundness (x y out carry_in carry_out: F p): (out.val = (x.val + y.val + carry_in.val) % 256 ∧ carry_out.val = (x.val + y.val + carry_in.val) / 256):= by intros hx hy hout carry_in_bool carry_out_bool h - have carry_in_bound := FieldUtils.boolean_le_2 carry_in carry_in_bool + have carry_in_bound := FieldUtils.boolean_lt_2 carry_in_bool rcases carry_out_bool with zero_carry | one_carry -- case with zero carry diff --git a/Clean/GadgetsNew/Addition32Full.lean b/Clean/GadgetsNew/Addition32Full.lean index 6a805b2..6424a01 100644 --- a/Clean/GadgetsNew/Addition32Full.lean +++ b/Clean/GadgetsNew/Addition32Full.lean @@ -3,7 +3,7 @@ import Clean.Types.U32 namespace Addition32Full variable {p : ℕ} [Fact (p ≠ 0)] [Fact p.Prime] -variable [p_large_enough: Fact (p > 512)] +variable [p_large_enough: Fact (p > 2*2^32)] open Provable (field field2 fields) @@ -17,6 +17,7 @@ def Inputs (p : ℕ) : TypePair := ⟨ InputStruct (F p) ⟩ +@[simp] instance : ProvableType (F p) (Inputs p) where size := 9 -- 4 + 4 + 1 to_vars s := vec [s.x.x0, s.x.x1, s.x.x2, s.x.x3, s.y.x0, s.y.x1, s.y.x2, s.y.x3, s.carry_in] @@ -49,33 +50,15 @@ instance : ProvableType (F p) (Outputs p) where let ⟨ [z0, z1, z2, z3, carry_out], _ ⟩ := v ⟨ ⟨ z0, z1, z2, z3 ⟩, carry_out ⟩ +open Add8FullCarry (add8_full_carry) + def add32_full (input : (Inputs p).var) : Circuit (F p) (Outputs p).var := do let ⟨x, y, carry_in⟩ := input - - let { - z := z0, - carry_out := c0 - } ← subcircuit Add8FullCarry.circuit ⟨ x.x0, y.x0, carry_in ⟩ - - let { - z := z1, - carry_out := c1 - } ← subcircuit Add8FullCarry.circuit ⟨ x.x1, y.x1, c0 ⟩ - - let { - z := z2, - carry_out := c2 - } ← subcircuit Add8FullCarry.circuit ⟨ x.x2, y.x2, c1 ⟩ - - let { - z := z3, - carry_out := c3 - } ← subcircuit Add8FullCarry.circuit ⟨ x.x3, y.x3, c2 ⟩ - - return { - z := U32.mk z0 z1 z2 z3, - carry_out := c3 - } + let { z := z0, carry_out := c0 } ← add8_full_carry ⟨ x.x0, y.x0, carry_in ⟩ + let { z := z1, carry_out := c1 } ← add8_full_carry ⟨ x.x1, y.x1, c0 ⟩ + let { z := z2, carry_out := c2 } ← add8_full_carry ⟨ x.x2, y.x2, c1 ⟩ + let { z := z3, carry_out := c3 } ← add8_full_carry ⟨ x.x3, y.x3, c2 ⟩ + return { z := U32.mk z0 z1 z2 z3, carry_out := c3 } def assumptions (input : (Inputs p).value) := let ⟨x, y, carry_in⟩ := input @@ -84,17 +67,167 @@ def assumptions (input : (Inputs p).value) := def spec (input : (Inputs p).value) (out: (Outputs p).value) := let ⟨x, y, carry_in⟩ := input let ⟨z, carry_out⟩ := out - z.value = (x.value + y.value + carry_in.val) % 2^32 ∧ - z.is_normalized ∧ - carry_out.val = (x.value + y.value + carry_in.val) / 2^32 + z.value = (x.value + y.value + carry_in.val) % 2^32 + ∧ carry_out.val = (x.value + y.value + carry_in.val) / 2^32 + ∧ z.is_normalized ∧ (carry_out = 0 ∨ carry_out = 1) + +set_option linter.unusedVariables false def circuit : FormalCircuit (F p) (Inputs p) (Outputs p) where main := add32_full assumptions := assumptions spec := spec soundness := by - sorry + rintro ctx env ⟨ x, y, carry_in ⟩ ⟨ x_var, y_var, carry_in_var ⟩ h_inputs as h + let ⟨ x0, x1, x2, x3 ⟩ := x + let ⟨ y0, y1, y2, y3 ⟩ := y + let ⟨ x0_var, x1_var, x2_var, x3_var ⟩ := x_var + let ⟨ y0_var, y1_var, y2_var, y3_var ⟩ := y_var + have : x0_var.eval_env env = x0 := by injections + have : x1_var.eval_env env = x1 := by injections + have : x2_var.eval_env env = x2 := by injections + have : x3_var.eval_env env = x3 := by injections + have : y0_var.eval_env env = y0 := by injections + have : y1_var.eval_env env = y1 := by injections + have : y2_var.eval_env env = y2 := by injections + have : y3_var.eval_env env = y3 := by injections + have : carry_in_var.eval_env env = carry_in := by injection h_inputs + + -- simplify assumptions + dsimp [assumptions, U32.is_normalized] at as + have ⟨ x_norm, y_norm, carry_in_bool ⟩ := as + have ⟨ x0_byte, x1_byte, x2_byte, x3_byte ⟩ := x_norm + have ⟨ y0_byte, y1_byte, y2_byte, y3_byte ⟩ := y_norm + + -- simplify circuit + dsimp [add32_full, Boolean.circuit, Circuit.formal_assertion_to_subcircuit] at h + set i0 := ctx.offset + have : ctx.offset = i0 := rfl + set z0 := env i0 + set c0 := env (i0 + 1) + set z1 := env (i0 + 2) + set c1 := env (i0 + 3) + set z2 := env (i0 + 4) + set c2 := env (i0 + 5) + set z3 := env (i0 + 6) + set c3 := env (i0 + 7) + rw [‹x0_var.eval_env env = x0›, ‹y0_var.eval_env env = y0›, ‹carry_in_var.eval_env env = carry_in›] at h + rw [‹x1_var.eval_env env = x1›, ‹y1_var.eval_env env = y1›] at h + rw [‹x2_var.eval_env env = x2›, ‹y2_var.eval_env env = y2›] at h + rw [‹x3_var.eval_env env = x3›, ‹y3_var.eval_env env = y3›] at h + rw [ByteTable.equiv z0, ByteTable.equiv z1, ByteTable.equiv z2, ByteTable.equiv z3] at h + simp only [true_implies] at h + have ⟨ z0_byte, c0_bool, h0, z1_byte, c1_bool, h1, z2_byte, c2_bool, h2, z3_byte, c3_bool, h3 ⟩ := h + + -- simplify spec + dsimp [spec, U32.value, U32.is_normalized] + rw [‹ctx.offset = i0›, (by rfl: env (i0 + 7) = c3)] + rw [(by rfl: env i0 = z0), (by rfl: env (i0 + 2) = z1), (by rfl: env (i0 + 4) = z2), (by rfl: env (i0 + 6) = z3)] + + -- add up all the equations + let z := z0 + z1*256 + z2*256^2 + z3*256^3 + let x := x0 + x1*256 + x2*256^2 + x3*256^3 + let y := y0 + y1*256 + y2*256^2 + y3*256^3 + let lhs := z + c3*2^32 + let rhs₀ := x0 + y0 + carry_in + -1 * z0 + -1 * (c0 * 256) -- h0 expression + let rhs₁ := x1 + y1 + c0 + -1 * z1 + -1 * (c1 * 256) -- h1 expression + let rhs₂ := x2 + y2 + c1 + -1 * z2 + -1 * (c2 * 256) -- h2 expression + let rhs₃ := x3 + y3 + c2 + -1 * z3 + -1 * (c3 * 256) -- h3 expression + + have h_add := calc z + c3*2^32 + -- substitute equations + _ = lhs + 0 + 256*0 + 256^2*0 + 256^3*0 := by ring + _ = lhs + rhs₀ + 256*rhs₁ + 256^2*rhs₂ + 256^3*rhs₃ := by dsimp [rhs₀, rhs₁, rhs₂, rhs₃]; rw [h0, h1, h2, h3] + -- simplify + _ = x + y + carry_in := by ring + + -- move added equation into Nat + let z_nat := z0.val + z1.val*256 + z2.val*256^2 + z3.val*256^3 + let x_nat := x0.val + x1.val*256 + x2.val*256^2 + x3.val*256^3 + let y_nat := y0.val + y1.val*256 + y2.val*256^2 + y3.val*256^3 + + have : c3.val < 2 := FieldUtils.boolean_lt_2 c3_bool + have : carry_in.val < 2 := FieldUtils.boolean_lt_2 carry_in_bool + + have h_add_nat := calc z_nat + c3.val*2^32 + _ = (z + c3*2^32).val := by dsimp only [z_nat]; field_to_nat_u32 + _ = (x + y + carry_in).val := congrArg ZMod.val h_add + _ = x_nat + y_nat + carry_in.val := by dsimp only [x_nat, y_nat]; field_to_nat_u32 + + -- show that lhs splits into low and high 32 bits + have : z_nat < 2^32 := by dsimp only [z_nat]; linarith + + have h_low : z_nat = (x_nat + y_nat + carry_in.val) % 2^32 := by + suffices h : z_nat = z_nat % 2^32 by + rw [← h_add_nat, ← Nat.add_mod_mod, ← Nat.mul_mod_mod] + simpa using h + rw [Nat.mod_eq_of_lt ‹z_nat < 2^32›] + + have h_high : c3.val = (x_nat + y_nat + carry_in.val) / 2^32 := by + rw [← h_add_nat, Nat.add_mul_div_right _ _ (by norm_num)] + rw [Nat.div_eq_of_lt ‹z_nat < 2^32›, zero_add] + + exact ⟨ h_low, h_high, ⟨ z0_byte, z1_byte, z2_byte, z3_byte ⟩, c3_bool ⟩ completeness := by - sorry + rintro ctx ⟨ x, y, carry_in ⟩ ⟨ x_var, y_var, carry_in_var ⟩ h_inputs as + let ⟨ x0, x1, x2, x3 ⟩ := x + let ⟨ y0, y1, y2, y3 ⟩ := y + let ⟨ x0_var, x1_var, x2_var, x3_var ⟩ := x_var + let ⟨ y0_var, y1_var, y2_var, y3_var ⟩ := y_var + have : x0_var.eval = x0 := by injections + have : x1_var.eval = x1 := by injections + have : x2_var.eval = x2 := by injections + have : x3_var.eval = x3 := by injections + have : y0_var.eval = y0 := by injections + have : y1_var.eval = y1 := by injections + have : y2_var.eval = y2 := by injections + have : y3_var.eval = y3 := by injections + have : carry_in_var.eval = carry_in := by injections + + -- simplify assumptions + dsimp [assumptions, U32.is_normalized] at as + have ⟨ x_norm, y_norm, carry_in_bool ⟩ := as + have ⟨ x0_byte, x1_byte, x2_byte, x3_byte ⟩ := x_norm + have ⟨ y0_byte, y1_byte, y2_byte, y3_byte ⟩ := y_norm + + -- simplify circuit + dsimp [add32_full, Boolean.circuit, Circuit.formal_assertion_to_subcircuit] + rw [‹x0_var.eval = x0›, ‹y0_var.eval = y0›, ‹carry_in_var.eval = carry_in›] + rw [‹x1_var.eval = x1›, ‹y1_var.eval = y1›] + rw [‹x2_var.eval = x2›, ‹y2_var.eval = y2›] + rw [‹x3_var.eval = x3›, ‹y3_var.eval = y3›] + set z0 := FieldUtils.mod_256 (x0 + y0 + carry_in) + set c0 := FieldUtils.floordiv (x0 + y0 + carry_in) 256 + set z1 := FieldUtils.mod_256 (x1 + y1 + c0) + set c1 := FieldUtils.floordiv (x1 + y1 + c0) 256 + set z2 := FieldUtils.mod_256 (x2 + y2 + c1) + set c2 := FieldUtils.floordiv (x2 + y2 + c1) 256 + set z3 := FieldUtils.mod_256 (x3 + y3 + c2) + set c3 := FieldUtils.floordiv (x3 + y3 + c2) 256 + + simp only [true_and] + + -- the add8 completeness proof, four times + have add8_completeness {x y c_in : F p} : + let z := FieldUtils.mod_256 (x + y + c_in); + let c_out := FieldUtils.floordiv (x + y + c_in) 256; + x.val < 256 → y.val < 256 → c_in = 0 ∨ c_in = 1 → + ByteTable.contains (vec [z]) ∧ (c_out = 0 ∨ c_out = 1) ∧ x + y + c_in + -1 * z + -1 * (c_out * 256) = 0 + := by + intro z c_out _ _ hc + have : z.val < 256 := FieldUtils.mod_256_lt (x + y + c_in) + use ByteTable.completeness z this + have : c_in.val < 2 := FieldUtils.boolean_lt_2 hc + have : (x + y + c_in).val < 512 := by field_to_nat_u32 + use FieldUtils.floordiv_bool this + rw [FieldUtils.mod_add_div_256 (x + y + c_in)] + ring + + have ⟨ z0_byte, c0_bool, h0 ⟩ := add8_completeness x0_byte y0_byte carry_in_bool + have ⟨ z1_byte, c1_bool, h1 ⟩ := add8_completeness x1_byte y1_byte c0_bool + have ⟨ z2_byte, c2_bool, h2 ⟩ := add8_completeness x2_byte y2_byte c1_bool + have ⟨ z3_byte, c3_bool, h3 ⟩ := add8_completeness x3_byte y3_byte c2_bool + + exact ⟨ z0_byte, c0_bool, h0, z1_byte, c1_bool, h1, z2_byte, c2_bool, h2, z3_byte, c3_bool, h3 ⟩ end Addition32Full diff --git a/Clean/GadgetsNew/ByteLookup.lean b/Clean/GadgetsNew/ByteLookup.lean index 8794c09..5ac112e 100644 --- a/Clean/GadgetsNew/ByteLookup.lean +++ b/Clean/GadgetsNew/ByteLookup.lean @@ -32,6 +32,9 @@ def ByteTable.completeness (x: F p) : x.val < 256 → ByteTable.contains (vec [x simp [h'] rw [FieldUtils.nat_to_field_of_val_eq_iff] +def ByteTable.equiv (x: F p) : ByteTable.contains (vec [x]) ↔ x.val < 256 := + ⟨ByteTable.soundness x, ByteTable.completeness x⟩ + def byte_lookup (x: Expression (F p)) := lookup { table := ByteTable entry := vec [x] diff --git a/Clean/Types/U32.lean b/Clean/Types/U32.lean index 86e7c51..b9e2415 100644 --- a/Clean/Types/U32.lean +++ b/Clean/Types/U32.lean @@ -1,8 +1,11 @@ import Clean.GadgetsNew.ByteLookup section -variable {p : ℕ} [Fact (p ≠ 0)] [Fact p.Prime] -variable [p_large_enough: Fact (p > 512)] +variable {p : ℕ} [Fact p.Prime] +variable [p_large_enough: Fact (p > 2*2^32)] + +instance : NeZero p := ⟨‹Fact p.Prime›.elim.ne_zero⟩ +instance : Fact (p > 512) := by apply Fact.mk; linarith [p_large_enough.elim] /-- A 32-bit unsigned integer is represented using four limbs of 8 bits each. @@ -44,6 +47,12 @@ def is_normalized (x: U32 (F p)) := def value (x: U32 (F p)) := x.x0.val + x.x1.val * 256 + x.x2.val * 256^2 + x.x3.val * 256^3 +/-- +Return the value of a 32-bit unsigned integer as a field element. +-/ +def value_field (x: U32 (F p)) : F p := + x.x0 + x.x1 * 256 + x.x2 * 256^2 + x.x3 * 256^3 + /-- Return a 32-bit unsigned integer from a natural number, by decomposing it into four limbs of 8 bits each. @@ -69,5 +78,50 @@ lemma wrapping_add_correct (x y z: U32 (F p)) : x.wrapping_add y = z ↔ z.value = (x.value + y.value) % 2^32 := by sorry +-- U32-related tactic and lemmas + +lemma val_eq_256 : (256 : F p).val = 256 := FieldUtils.val_lt_p 256 (by linarith [p_large_enough.elim]) +lemma val_eq_256p2 : (256^2 : F p).val = 256^2 := by ring_nf; exact FieldUtils.val_lt_p (256^2) (by linarith [p_large_enough.elim]) +lemma val_eq_256p3 : (256^3 : F p).val = 256^3 := by ring_nf; exact FieldUtils.val_lt_p (256^3) (by linarith [p_large_enough.elim]) +lemma val_eq_256p4 : (256^4 : F p).val = 256^4 := by ring_nf; exact FieldUtils.val_lt_p (256^4) (by linarith [p_large_enough.elim]) +lemma val_eq_2p32 : (2^32 : F p).val = 2^32 := by have := val_eq_256p4 (p:=p); ring_nf at *; assumption + +/-- +tactic script to fully rewrite a ZMod expression to its Nat version, given that +the expression is smaller than the modulus. + +``` +example (x y : F p) (hx: x.val < 256) (hy: y.val < 256) : + (x + y * 256).val = x.val + y.val * 256 := by field_to_nat_u32 +``` + +expected context: +- the equation to prove as the goal +- size assumptions on variables and a sufficient `p > ...` instance + +if no sufficient inequalities are in the context, then the tactic will leave an equation of the form `expr : Nat < p` unsolved. + +note: this version is optimized for uint32 arithmetic: +- specifically handles field constants 256, 256^2, 256^3, 256^4 = 2^32 +- expects `[Fact (p > 2*256^4)]` in the context +-/ +syntax "field_to_nat_u32" : tactic +macro_rules + | `(tactic|field_to_nat_u32) => + `(tactic|( + repeat rw [ZMod.val_add] -- (a + b).val = (a.val + b.val) % p + repeat rw [ZMod.val_mul] -- (a * b).val = (a.val * b.val) % p + repeat rw [U32.val_eq_256] + repeat rw [U32.val_eq_256p2] + repeat rw [U32.val_eq_256p3] + repeat rw [U32.val_eq_256p4] + repeat rw [U32.val_eq_2p32] + simp only [Nat.reducePow, Nat.add_mod_mod, Nat.mod_add_mod, Nat.mul_mod_mod, Nat.mod_mul_mod] + rw [Nat.mod_eq_of_lt _] + repeat linarith [‹Fact (_ > 2 * 2^32)›.elim])) + +lemma value_eq {x0 x1 x2 x3: F p} (h0 : x0.val < 256) (h1 : x1.val < 256) (h2 : x2.val < 256) (h3 : x3.val < 256) : + (x0 + x1 * 256 + x2 * 256^2 + x3 * 256^3).val = x0.val + x1.val * 256 + x2.val * 256^2 + x3.val * 256^3 := by + field_to_nat_u32 end U32 end diff --git a/Clean/Utils/Field.lean b/Clean/Utils/Field.lean index 1e4c390..b122805 100644 --- a/Clean/Utils/Field.lean +++ b/Clean/Utils/Field.lean @@ -10,9 +10,13 @@ instance (p : ℕ) : CommRing (F p) := ZMod.commRing p namespace FieldUtils variable {p : ℕ} [p_prime: Fact p.Prime] +instance : NeZero p := ⟨p_prime.elim.ne_zero⟩ -theorem p_neq_zero : p ≠ 0 := - Nat.Prime.ne_zero p_prime.elim +theorem p_neq_zero : p ≠ 0 := p_prime.elim.ne_zero + +theorem ext {x y : F p} (h : x.val = y.val) : x = y := by + cases p; cases p_neq_zero rfl + exact Fin.ext h theorem sum_do_not_wrap_around (x y: F p) : x.val + y.val < p -> (x + y).val = x.val + y.val := by @@ -98,7 +102,7 @@ theorem val_lt_p (x: ℕ) : (x < p) -> (x : F p).val = x := by assumption -theorem boolean_le_2 (b : F p) (hb : b = 0 ∨ b = 1) : b.val < 2 := by +theorem boolean_lt_2 {b : F p} (hb : b = 0 ∨ b = 1) : b.val < 2 := by rcases hb with h0 | h1 · rw [h0]; simp · rw [h1]; simp [ZMod.val_one] @@ -123,9 +127,9 @@ theorem val_of_nat_to_field_eq {n: ℕ} {lt: n < p} : (nat_to_field n lt).val = · exact False.elim (Nat.not_lt_zero n lt) · rfl -def less_than_p [p_pos: Fact (p ≠ 0)] (x: F p) : x.val < p := by +def less_than_p [p_pos: NeZero p] (x: F p) : x.val < p := by rcases p - · have : 0 ≠ 0 := p_pos.elim; tauto + · have : 0 ≠ 0 := p_pos.out; contradiction · exact x.is_lt def mod (x: F p) (c: ℕ+) (lt: c < p) : F p := @@ -134,7 +138,32 @@ def mod (x: F p) (c: ℕ+) (lt: c < p) : F p := def mod_256 (x: F p) [p_large_enough: Fact (p > 512)] : F p := mod x 256 (by linarith [p_large_enough.elim]) -def floordiv [Fact (p ≠ 0)] (x: F p) (c: ℕ+) : F p := +def floordiv [NeZero p] (x: F p) (c: ℕ+) : F p := FieldUtils.nat_to_field (x.val / c) (by linarith [Nat.div_le_self x.val c, less_than_p x]) +theorem mod_256_lt [Fact (p > 512)] (x : F p) : (mod_256 x).val < 256 := by + rcases p with _ | n; cases p_neq_zero rfl + show (x.val % 256) < 256 + exact Nat.mod_lt x.val (by norm_num) + +theorem floordiv_bool [Fact (p > 512)] {x: F p} (h : x.val < 512) : + floordiv x 256 = 0 ∨ floordiv x 256 = 1 := by + rcases p with _ | n; cases p_neq_zero rfl + let z := x.val / 256 + have : z < 2 := Nat.div_lt_of_lt_mul h + -- show z = 0 ∨ z = 1 + rcases (Nat.lt_trichotomy z 1) with _ | h1 | _ + · left; apply ext; show z = 0; linarith + · right; apply ext; show z = ZMod.val 1; rw [h1, ZMod.val_one] + · linarith -- contradiction + +theorem mod_add_div_256 [Fact (p > 512)] (x : F p) : x = mod_256 x + 256 * (floordiv x 256) := by + rcases p with _ | n; cases p_neq_zero rfl + let p := n + 1 + apply ext + rw [ZMod.val_add, ZMod.val_mul] + have : ZMod.val 256 = 256 := val_lt_p (p:=p) 256 (by linarith [‹Fact (p > 512)›.elim]) + rw [this, Nat.add_mod_mod] + show x.val = (x.val % 256 + 256 * (x.val / 256)) % p + rw [Nat.mod_add_div, (Nat.mod_eq_of_lt x.is_lt : x.val % p = x.val)] end FieldUtils