From 3ba9bc2fde09b3372d4668df81f317e29f78ae81 Mon Sep 17 00:00:00 2001 From: Bartosz Nowak Date: Wed, 6 Nov 2024 15:06:19 -0800 Subject: [PATCH] sha formulas improvements --- crates/core/machine/src/operations/add3.rs | 134 ++++++++++++++++++ crates/core/machine/src/operations/mod.rs | 2 + .../precompiles/sha256/compress/air.rs | 68 ++++----- .../precompiles/sha256/compress/columns.rs | 19 +-- .../precompiles/sha256/compress/trace.rs | 18 +-- 5 files changed, 178 insertions(+), 63 deletions(-) create mode 100644 crates/core/machine/src/operations/add3.rs diff --git a/crates/core/machine/src/operations/add3.rs b/crates/core/machine/src/operations/add3.rs new file mode 100644 index 0000000000..616185ba83 --- /dev/null +++ b/crates/core/machine/src/operations/add3.rs @@ -0,0 +1,134 @@ +use p3_air::AirBuilder; +use p3_field::{AbstractField, Field}; +use sp1_derive::AlignedBorrow; + +use sp1_core_executor::events::ByteRecord; +use sp1_primitives::consts::WORD_SIZE; +use sp1_stark::{air::SP1AirBuilder, Word}; + +use crate::air::WordAirBuilder; + +/// A set of columns needed to compute the add of three words. +#[derive(AlignedBorrow, Default, Debug, Clone, Copy)] +#[repr(C)] +pub struct Add3Operation { + /// The result of `a + b + c`. + pub value: Word, + + /// Indicates if the carry for the `i`th digit is 0. + pub is_carry_0: Word, + + /// Indicates if the carry for the `i`th digit is 1. + pub is_carry_1: Word, + + /// Indicates if the carry for the `i`th digit is 2. The carry when adding 3 words is at most + /// 2 + pub is_carry_2: Word, + + /// The carry for the `i`th digit. + pub carry: Word, +} + +impl Add3Operation { + #[allow(clippy::too_many_arguments)] + pub fn populate( + &mut self, + record: &mut impl ByteRecord, + shard: u32, + a_u32: u32, + b_u32: u32, + c_u32: u32, + ) -> u32 { + let expected = a_u32.wrapping_add(b_u32).wrapping_add(c_u32); + self.value = Word::from(expected); + let a = a_u32.to_le_bytes(); + let b = b_u32.to_le_bytes(); + let c = c_u32.to_le_bytes(); + + let base = 256; + let mut carry = [0u8, 0u8, 0u8, 0u8]; + for i in 0..WORD_SIZE { + let mut res = (a[i] as u32) + (b[i] as u32) + (c[i] as u32); + if i > 0 { + res += carry[i - 1] as u32; + } + carry[i] = (res / base) as u8; + self.is_carry_0[i] = F::from_bool(carry[i] == 0); + self.is_carry_1[i] = F::from_bool(carry[i] == 1); + self.is_carry_2[i] = F::from_bool(carry[i] == 2); + self.carry[i] = F::from_canonical_u8(carry[i]); + debug_assert!(carry[i] <= 2); + debug_assert_eq!(self.value[i], F::from_canonical_u32(res % base)); + } + + // Range check. + { + record.add_u8_range_checks(shard, &a); + record.add_u8_range_checks(shard, &b); + record.add_u8_range_checks(shard, &c); + record.add_u8_range_checks(shard, &expected.to_le_bytes()); + } + expected + } + + #[allow(clippy::too_many_arguments)] + pub fn eval( + builder: &mut AB, + a: Word, + b: Word, + c: Word, + cols: Add3Operation, + is_real: AB::Var, + ) { + // Range check each byte. + { + builder.slice_range_check_u8(&a.0, is_real); + builder.slice_range_check_u8(&b.0, is_real); + builder.slice_range_check_u8(&c.0, is_real); + builder.slice_range_check_u8(&cols.value.0, is_real); + } + + builder.assert_bool(is_real); + let mut builder_is_real = builder.when(is_real); + + // Each value in is_carry_{0,1,2} is 0 or 1, and exactly one of them is 1 per digit. + { + for i in 0..WORD_SIZE { + builder_is_real.assert_bool(cols.is_carry_0[i]); + builder_is_real.assert_bool(cols.is_carry_1[i]); + builder_is_real.assert_bool(cols.is_carry_2[i]); + builder_is_real.assert_eq( + cols.is_carry_0[i] + cols.is_carry_1[i] + cols.is_carry_2[i], + AB::Expr::one(), + ); + } + } + + // Calculates carry from is_carry_{0,1,2}. + { + let one = AB::Expr::one(); + let two = AB::F::from_canonical_u32(2); + + for i in 0..WORD_SIZE { + builder_is_real.assert_eq( + cols.carry[i], + cols.is_carry_1[i] * one.clone() + cols.is_carry_2[i] * two, + ); + } + } + + // Compare the sum and summands by looking at carry. + { + let base = AB::F::from_canonical_u32(256); + // For each limb, assert that difference between the carried result and the non-carried + // result is the product of carry and base. + for i in 0..WORD_SIZE { + let mut overflow = a[i] + b[i] + c[i] - cols.value[i]; + if i > 0 { + overflow = overflow.clone() + cols.carry[i - 1].into(); + } + builder_is_real.assert_eq(cols.carry[i] * base, overflow.clone()); + } + } + } +} diff --git a/crates/core/machine/src/operations/mod.rs b/crates/core/machine/src/operations/mod.rs index 394daf906b..d430b4738f 100644 --- a/crates/core/machine/src/operations/mod.rs +++ b/crates/core/machine/src/operations/mod.rs @@ -5,6 +5,7 @@ //! the constraints. mod add; +mod add3; mod add4; mod add5; mod and; @@ -22,6 +23,7 @@ mod or; mod xor; pub use add::*; +pub use add3::*; pub use add4::*; pub use add5::*; pub use and::*; diff --git a/crates/core/machine/src/syscall/precompiles/sha256/compress/air.rs b/crates/core/machine/src/syscall/precompiles/sha256/compress/air.rs index 2ecb8deb37..8cd10ff97a 100644 --- a/crates/core/machine/src/syscall/precompiles/sha256/compress/air.rs +++ b/crates/core/machine/src/syscall/precompiles/sha256/compress/air.rs @@ -17,7 +17,7 @@ use crate::{ air::{MemoryAirBuilder, WordAirBuilder}, memory::MemoryCols, operations::{ - Add5Operation, AddOperation, AndOperation, FixedRotateRightOperation, NotOperation, + Add3Operation, Add5Operation, AddOperation, AndOperation, FixedRotateRightOperation, XorOperation, }, }; @@ -329,24 +329,22 @@ impl ShaCompressChip { local.is_compression, ); - // Calculate ch := (e and f) xor ((not e) and g). - // Calculate e and f. - AndOperation::::eval(builder, local.e, local.f, local.e_and_f, local.is_compression); - // Calculate not e. - NotOperation::::eval(builder, local.e, local.e_not, local.is_compression); - // Calculate (not e) and g. + // Calculate ch := (e and f) xor ((not e) and g) = g xor (e and (f xor g)). + // Calculate f xor g. + XorOperation::::eval(builder, local.f, local.g, local.f_xor_g, local.is_compression); + // Calculate e and (f xor g). AndOperation::::eval( builder, - local.e_not.value, - local.g, - local.e_not_and_g, + local.e, + local.f_xor_g.value, + local.e_and_f_xor_g, local.is_compression, ); - // Calculate ch := (e and f) xor ((not e) and g). + // Calculate ch := g xor (e and (f xor g)). XorOperation::::eval( builder, - local.e_and_f.value, - local.e_not_and_g.value, + local.g, + local.e_and_f_xor_g.value, local.ch, local.is_compression, ); @@ -401,39 +399,28 @@ impl ShaCompressChip { local.is_compression, ); - // Calculate maj := (a and b) xor (a and c) xor (b and c). - // Calculate a and b. - AndOperation::::eval(builder, local.a, local.b, local.a_and_b, local.is_compression); - // Calculate a and c. - AndOperation::::eval(builder, local.a, local.c, local.a_and_c, local.is_compression); - // Calculate b and c. - AndOperation::::eval(builder, local.b, local.c, local.b_and_c, local.is_compression); - // Calculate (a and b) xor (a and c). - XorOperation::::eval( + // Calculate maj := (a and b) xor (a and c) xor (b and c) = (a and (b xor c)) xor (b and c). + // Calculate b xor c. + XorOperation::::eval(builder, local.b, local.c, local.b_xor_c, local.is_compression); + // Calculate a and (b xor c). + AndOperation::::eval( builder, - local.a_and_b.value, - local.a_and_c.value, - local.maj_intermediate, + local.a, + local.b_xor_c.value, + local.a_and_b_xor_c, local.is_compression, ); - // Calculate maj := ((a and b) xor (a and c)) xor (b and c). + // Calculate b and c. + AndOperation::::eval(builder, local.b, local.c, local.b_and_c, local.is_compression); + // Calculate maj := (a and (b xor c)) xor (b and c). XorOperation::::eval( builder, - local.maj_intermediate.value, + local.a_and_b_xor_c.value, local.b_and_c.value, local.maj, local.is_compression, ); - // Calculate temp2 := s0 + maj. - AddOperation::::eval( - builder, - local.s0.value, - local.maj.value, - local.temp2, - local.is_compression.into(), - ); - // Calculate d + temp1 for the new value of e. AddOperation::::eval( builder, @@ -443,13 +430,14 @@ impl ShaCompressChip { local.is_compression.into(), ); - // Calculate temp1 + temp2 for the new value of a. - AddOperation::::eval( + // Calculate temp1 + S0 + maj for the new value of a. + Add3Operation::::eval( builder, local.temp1.value, - local.temp2.value, + local.s0.value, + local.maj.value, local.temp1_add_temp2, - local.is_compression.into(), + local.is_compression, ); // h := g diff --git a/crates/core/machine/src/syscall/precompiles/sha256/compress/columns.rs b/crates/core/machine/src/syscall/precompiles/sha256/compress/columns.rs index 5d48b9edcc..e2f942d7d3 100644 --- a/crates/core/machine/src/syscall/precompiles/sha256/compress/columns.rs +++ b/crates/core/machine/src/syscall/precompiles/sha256/compress/columns.rs @@ -6,7 +6,7 @@ use sp1_stark::Word; use crate::{ memory::MemoryReadWriteCols, operations::{ - Add5Operation, AddOperation, AndOperation, FixedRotateRightOperation, NotOperation, + Add3Operation, Add5Operation, AddOperation, AndOperation, FixedRotateRightOperation, XorOperation, }, }; @@ -67,9 +67,8 @@ pub struct ShaCompressCols { /// `S1 := (e rightrotate 6) xor (e rightrotate 11) xor (e rightrotate 25)`. pub s1: XorOperation, - pub e_and_f: AndOperation, - pub e_not: NotOperation, - pub e_not_and_g: AndOperation, + pub f_xor_g: XorOperation, + pub e_and_f_xor_g: AndOperation, /// `ch := (e and f) xor ((not e) and g)`. pub ch: XorOperation, @@ -83,20 +82,16 @@ pub struct ShaCompressCols { /// `S0 := (a rightrotate 2) xor (a rightrotate 13) xor (a rightrotate 22)`. pub s0: XorOperation, - pub a_and_b: AndOperation, - pub a_and_c: AndOperation, + pub b_xor_c: XorOperation, + pub a_and_b_xor_c: AndOperation, pub b_and_c: AndOperation, - pub maj_intermediate: XorOperation, /// `maj := (a and b) xor (a and c) xor (b and c)`. pub maj: XorOperation, - /// `temp2 := S0 + maj`. - pub temp2: AddOperation, - /// The next value of `e` is `d + temp1`. pub d_add_temp1: AddOperation, - /// The next value of `a` is `temp1 + temp2`. - pub temp1_add_temp2: AddOperation, + /// The next value of `a` is `temp1 + S0 + maj`. + pub temp1_add_temp2: Add3Operation, /// During finalize, this is one of a-h and is being written into `mem`. pub finalized_operand: Word, diff --git a/crates/core/machine/src/syscall/precompiles/sha256/compress/trace.rs b/crates/core/machine/src/syscall/precompiles/sha256/compress/trace.rs index d6b61b67f2..696be20955 100644 --- a/crates/core/machine/src/syscall/precompiles/sha256/compress/trace.rs +++ b/crates/core/machine/src/syscall/precompiles/sha256/compress/trace.rs @@ -215,10 +215,9 @@ impl ShaCompressChip { let s1_intermediate = cols.s1_intermediate.populate(blu, shard, e_rr_6, e_rr_11); let s1 = cols.s1.populate(blu, shard, s1_intermediate, e_rr_25); - let e_and_f = cols.e_and_f.populate(blu, shard, e, f); - let e_not = cols.e_not.populate(blu, shard, e); - let e_not_and_g = cols.e_not_and_g.populate(blu, shard, e_not, g); - let ch = cols.ch.populate(blu, shard, e_and_f, e_not_and_g); + let f_xor_g = cols.f_xor_g.populate(blu, shard, f, g); + let e_and_f_xor_g = cols.e_and_f_xor_g.populate(blu, shard, a, f_xor_g); + let ch = cols.ch.populate(blu, shard, g, e_and_f_xor_g); let temp1 = cols.temp1.populate(blu, shard, h, s1, ch, event.w[j], SHA_COMPRESS_K[j]); @@ -228,16 +227,13 @@ impl ShaCompressChip { let s0_intermediate = cols.s0_intermediate.populate(blu, shard, a_rr_2, a_rr_13); let s0 = cols.s0.populate(blu, shard, s0_intermediate, a_rr_22); - let a_and_b = cols.a_and_b.populate(blu, shard, a, b); - let a_and_c = cols.a_and_c.populate(blu, shard, a, c); + let b_xor_c = cols.b_xor_c.populate(blu, shard, b, c); + let a_and_b_xor_c = cols.a_and_b_xor_c.populate(blu, shard, a, b_xor_c); let b_and_c = cols.b_and_c.populate(blu, shard, b, c); - let maj_intermediate = cols.maj_intermediate.populate(blu, shard, a_and_b, a_and_c); - let maj = cols.maj.populate(blu, shard, maj_intermediate, b_and_c); - - let temp2 = cols.temp2.populate(blu, shard, s0, maj); + let maj = cols.maj.populate(blu, shard, a_and_b_xor_c, b_and_c); let d_add_temp1 = cols.d_add_temp1.populate(blu, shard, d, temp1); - let temp1_add_temp2 = cols.temp1_add_temp2.populate(blu, shard, temp1, temp2); + let temp1_add_temp2 = cols.temp1_add_temp2.populate(blu, shard, temp1, s0, maj); h_array[7] = g; h_array[6] = f;