Skip to content

Commit 3ba9bc2

Browse files
committed
sha formulas improvements
1 parent c804db2 commit 3ba9bc2

File tree

5 files changed

+178
-63
lines changed

5 files changed

+178
-63
lines changed
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
use p3_air::AirBuilder;
2+
use p3_field::{AbstractField, Field};
3+
use sp1_derive::AlignedBorrow;
4+
5+
use sp1_core_executor::events::ByteRecord;
6+
use sp1_primitives::consts::WORD_SIZE;
7+
use sp1_stark::{air::SP1AirBuilder, Word};
8+
9+
use crate::air::WordAirBuilder;
10+
11+
/// A set of columns needed to compute the add of three words.
12+
#[derive(AlignedBorrow, Default, Debug, Clone, Copy)]
13+
#[repr(C)]
14+
pub struct Add3Operation<T> {
15+
/// The result of `a + b + c`.
16+
pub value: Word<T>,
17+
18+
/// Indicates if the carry for the `i`th digit is 0.
19+
pub is_carry_0: Word<T>,
20+
21+
/// Indicates if the carry for the `i`th digit is 1.
22+
pub is_carry_1: Word<T>,
23+
24+
/// Indicates if the carry for the `i`th digit is 2. The carry when adding 3 words is at most
25+
/// 2
26+
pub is_carry_2: Word<T>,
27+
28+
/// The carry for the `i`th digit.
29+
pub carry: Word<T>,
30+
}
31+
32+
impl<F: Field> Add3Operation<F> {
33+
#[allow(clippy::too_many_arguments)]
34+
pub fn populate(
35+
&mut self,
36+
record: &mut impl ByteRecord,
37+
shard: u32,
38+
a_u32: u32,
39+
b_u32: u32,
40+
c_u32: u32,
41+
) -> u32 {
42+
let expected = a_u32.wrapping_add(b_u32).wrapping_add(c_u32);
43+
self.value = Word::from(expected);
44+
let a = a_u32.to_le_bytes();
45+
let b = b_u32.to_le_bytes();
46+
let c = c_u32.to_le_bytes();
47+
48+
let base = 256;
49+
let mut carry = [0u8, 0u8, 0u8, 0u8];
50+
for i in 0..WORD_SIZE {
51+
let mut res = (a[i] as u32) + (b[i] as u32) + (c[i] as u32);
52+
if i > 0 {
53+
res += carry[i - 1] as u32;
54+
}
55+
carry[i] = (res / base) as u8;
56+
self.is_carry_0[i] = F::from_bool(carry[i] == 0);
57+
self.is_carry_1[i] = F::from_bool(carry[i] == 1);
58+
self.is_carry_2[i] = F::from_bool(carry[i] == 2);
59+
self.carry[i] = F::from_canonical_u8(carry[i]);
60+
debug_assert!(carry[i] <= 2);
61+
debug_assert_eq!(self.value[i], F::from_canonical_u32(res % base));
62+
}
63+
64+
// Range check.
65+
{
66+
record.add_u8_range_checks(shard, &a);
67+
record.add_u8_range_checks(shard, &b);
68+
record.add_u8_range_checks(shard, &c);
69+
record.add_u8_range_checks(shard, &expected.to_le_bytes());
70+
}
71+
expected
72+
}
73+
74+
#[allow(clippy::too_many_arguments)]
75+
pub fn eval<AB: SP1AirBuilder>(
76+
builder: &mut AB,
77+
a: Word<AB::Var>,
78+
b: Word<AB::Var>,
79+
c: Word<AB::Var>,
80+
cols: Add3Operation<AB::Var>,
81+
is_real: AB::Var,
82+
) {
83+
// Range check each byte.
84+
{
85+
builder.slice_range_check_u8(&a.0, is_real);
86+
builder.slice_range_check_u8(&b.0, is_real);
87+
builder.slice_range_check_u8(&c.0, is_real);
88+
builder.slice_range_check_u8(&cols.value.0, is_real);
89+
}
90+
91+
builder.assert_bool(is_real);
92+
let mut builder_is_real = builder.when(is_real);
93+
94+
// Each value in is_carry_{0,1,2} is 0 or 1, and exactly one of them is 1 per digit.
95+
{
96+
for i in 0..WORD_SIZE {
97+
builder_is_real.assert_bool(cols.is_carry_0[i]);
98+
builder_is_real.assert_bool(cols.is_carry_1[i]);
99+
builder_is_real.assert_bool(cols.is_carry_2[i]);
100+
builder_is_real.assert_eq(
101+
cols.is_carry_0[i] + cols.is_carry_1[i] + cols.is_carry_2[i],
102+
AB::Expr::one(),
103+
);
104+
}
105+
}
106+
107+
// Calculates carry from is_carry_{0,1,2}.
108+
{
109+
let one = AB::Expr::one();
110+
let two = AB::F::from_canonical_u32(2);
111+
112+
for i in 0..WORD_SIZE {
113+
builder_is_real.assert_eq(
114+
cols.carry[i],
115+
cols.is_carry_1[i] * one.clone() + cols.is_carry_2[i] * two,
116+
);
117+
}
118+
}
119+
120+
// Compare the sum and summands by looking at carry.
121+
{
122+
let base = AB::F::from_canonical_u32(256);
123+
// For each limb, assert that difference between the carried result and the non-carried
124+
// result is the product of carry and base.
125+
for i in 0..WORD_SIZE {
126+
let mut overflow = a[i] + b[i] + c[i] - cols.value[i];
127+
if i > 0 {
128+
overflow = overflow.clone() + cols.carry[i - 1].into();
129+
}
130+
builder_is_real.assert_eq(cols.carry[i] * base, overflow.clone());
131+
}
132+
}
133+
}
134+
}

crates/core/machine/src/operations/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
//! the constraints.
66
77
mod add;
8+
mod add3;
89
mod add4;
910
mod add5;
1011
mod and;
@@ -22,6 +23,7 @@ mod or;
2223
mod xor;
2324

2425
pub use add::*;
26+
pub use add3::*;
2527
pub use add4::*;
2628
pub use add5::*;
2729
pub use and::*;

crates/core/machine/src/syscall/precompiles/sha256/compress/air.rs

Lines changed: 28 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ use crate::{
1717
air::{MemoryAirBuilder, WordAirBuilder},
1818
memory::MemoryCols,
1919
operations::{
20-
Add5Operation, AddOperation, AndOperation, FixedRotateRightOperation, NotOperation,
20+
Add3Operation, Add5Operation, AddOperation, AndOperation, FixedRotateRightOperation,
2121
XorOperation,
2222
},
2323
};
@@ -329,24 +329,22 @@ impl ShaCompressChip {
329329
local.is_compression,
330330
);
331331

332-
// Calculate ch := (e and f) xor ((not e) and g).
333-
// Calculate e and f.
334-
AndOperation::<AB::F>::eval(builder, local.e, local.f, local.e_and_f, local.is_compression);
335-
// Calculate not e.
336-
NotOperation::<AB::F>::eval(builder, local.e, local.e_not, local.is_compression);
337-
// Calculate (not e) and g.
332+
// Calculate ch := (e and f) xor ((not e) and g) = g xor (e and (f xor g)).
333+
// Calculate f xor g.
334+
XorOperation::<AB::F>::eval(builder, local.f, local.g, local.f_xor_g, local.is_compression);
335+
// Calculate e and (f xor g).
338336
AndOperation::<AB::F>::eval(
339337
builder,
340-
local.e_not.value,
341-
local.g,
342-
local.e_not_and_g,
338+
local.e,
339+
local.f_xor_g.value,
340+
local.e_and_f_xor_g,
343341
local.is_compression,
344342
);
345-
// Calculate ch := (e and f) xor ((not e) and g).
343+
// Calculate ch := g xor (e and (f xor g)).
346344
XorOperation::<AB::F>::eval(
347345
builder,
348-
local.e_and_f.value,
349-
local.e_not_and_g.value,
346+
local.g,
347+
local.e_and_f_xor_g.value,
350348
local.ch,
351349
local.is_compression,
352350
);
@@ -401,39 +399,28 @@ impl ShaCompressChip {
401399
local.is_compression,
402400
);
403401

404-
// Calculate maj := (a and b) xor (a and c) xor (b and c).
405-
// Calculate a and b.
406-
AndOperation::<AB::F>::eval(builder, local.a, local.b, local.a_and_b, local.is_compression);
407-
// Calculate a and c.
408-
AndOperation::<AB::F>::eval(builder, local.a, local.c, local.a_and_c, local.is_compression);
409-
// Calculate b and c.
410-
AndOperation::<AB::F>::eval(builder, local.b, local.c, local.b_and_c, local.is_compression);
411-
// Calculate (a and b) xor (a and c).
412-
XorOperation::<AB::F>::eval(
402+
// Calculate maj := (a and b) xor (a and c) xor (b and c) = (a and (b xor c)) xor (b and c).
403+
// Calculate b xor c.
404+
XorOperation::<AB::F>::eval(builder, local.b, local.c, local.b_xor_c, local.is_compression);
405+
// Calculate a and (b xor c).
406+
AndOperation::<AB::F>::eval(
413407
builder,
414-
local.a_and_b.value,
415-
local.a_and_c.value,
416-
local.maj_intermediate,
408+
local.a,
409+
local.b_xor_c.value,
410+
local.a_and_b_xor_c,
417411
local.is_compression,
418412
);
419-
// Calculate maj := ((a and b) xor (a and c)) xor (b and c).
413+
// Calculate b and c.
414+
AndOperation::<AB::F>::eval(builder, local.b, local.c, local.b_and_c, local.is_compression);
415+
// Calculate maj := (a and (b xor c)) xor (b and c).
420416
XorOperation::<AB::F>::eval(
421417
builder,
422-
local.maj_intermediate.value,
418+
local.a_and_b_xor_c.value,
423419
local.b_and_c.value,
424420
local.maj,
425421
local.is_compression,
426422
);
427423

428-
// Calculate temp2 := s0 + maj.
429-
AddOperation::<AB::F>::eval(
430-
builder,
431-
local.s0.value,
432-
local.maj.value,
433-
local.temp2,
434-
local.is_compression.into(),
435-
);
436-
437424
// Calculate d + temp1 for the new value of e.
438425
AddOperation::<AB::F>::eval(
439426
builder,
@@ -443,13 +430,14 @@ impl ShaCompressChip {
443430
local.is_compression.into(),
444431
);
445432

446-
// Calculate temp1 + temp2 for the new value of a.
447-
AddOperation::<AB::F>::eval(
433+
// Calculate temp1 + S0 + maj for the new value of a.
434+
Add3Operation::<AB::F>::eval(
448435
builder,
449436
local.temp1.value,
450-
local.temp2.value,
437+
local.s0.value,
438+
local.maj.value,
451439
local.temp1_add_temp2,
452-
local.is_compression.into(),
440+
local.is_compression,
453441
);
454442

455443
// h := g

crates/core/machine/src/syscall/precompiles/sha256/compress/columns.rs

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use sp1_stark::Word;
66
use crate::{
77
memory::MemoryReadWriteCols,
88
operations::{
9-
Add5Operation, AddOperation, AndOperation, FixedRotateRightOperation, NotOperation,
9+
Add3Operation, Add5Operation, AddOperation, AndOperation, FixedRotateRightOperation,
1010
XorOperation,
1111
},
1212
};
@@ -67,9 +67,8 @@ pub struct ShaCompressCols<T> {
6767
/// `S1 := (e rightrotate 6) xor (e rightrotate 11) xor (e rightrotate 25)`.
6868
pub s1: XorOperation<T>,
6969

70-
pub e_and_f: AndOperation<T>,
71-
pub e_not: NotOperation<T>,
72-
pub e_not_and_g: AndOperation<T>,
70+
pub f_xor_g: XorOperation<T>,
71+
pub e_and_f_xor_g: AndOperation<T>,
7372
/// `ch := (e and f) xor ((not e) and g)`.
7473
pub ch: XorOperation<T>,
7574

@@ -83,20 +82,16 @@ pub struct ShaCompressCols<T> {
8382
/// `S0 := (a rightrotate 2) xor (a rightrotate 13) xor (a rightrotate 22)`.
8483
pub s0: XorOperation<T>,
8584

86-
pub a_and_b: AndOperation<T>,
87-
pub a_and_c: AndOperation<T>,
85+
pub b_xor_c: XorOperation<T>,
86+
pub a_and_b_xor_c: AndOperation<T>,
8887
pub b_and_c: AndOperation<T>,
89-
pub maj_intermediate: XorOperation<T>,
9088
/// `maj := (a and b) xor (a and c) xor (b and c)`.
9189
pub maj: XorOperation<T>,
9290

93-
/// `temp2 := S0 + maj`.
94-
pub temp2: AddOperation<T>,
95-
9691
/// The next value of `e` is `d + temp1`.
9792
pub d_add_temp1: AddOperation<T>,
98-
/// The next value of `a` is `temp1 + temp2`.
99-
pub temp1_add_temp2: AddOperation<T>,
93+
/// The next value of `a` is `temp1 + S0 + maj`.
94+
pub temp1_add_temp2: Add3Operation<T>,
10095

10196
/// During finalize, this is one of a-h and is being written into `mem`.
10297
pub finalized_operand: Word<T>,

crates/core/machine/src/syscall/precompiles/sha256/compress/trace.rs

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -215,10 +215,9 @@ impl ShaCompressChip {
215215
let s1_intermediate = cols.s1_intermediate.populate(blu, shard, e_rr_6, e_rr_11);
216216
let s1 = cols.s1.populate(blu, shard, s1_intermediate, e_rr_25);
217217

218-
let e_and_f = cols.e_and_f.populate(blu, shard, e, f);
219-
let e_not = cols.e_not.populate(blu, shard, e);
220-
let e_not_and_g = cols.e_not_and_g.populate(blu, shard, e_not, g);
221-
let ch = cols.ch.populate(blu, shard, e_and_f, e_not_and_g);
218+
let f_xor_g = cols.f_xor_g.populate(blu, shard, f, g);
219+
let e_and_f_xor_g = cols.e_and_f_xor_g.populate(blu, shard, a, f_xor_g);
220+
let ch = cols.ch.populate(blu, shard, g, e_and_f_xor_g);
222221

223222
let temp1 = cols.temp1.populate(blu, shard, h, s1, ch, event.w[j], SHA_COMPRESS_K[j]);
224223

@@ -228,16 +227,13 @@ impl ShaCompressChip {
228227
let s0_intermediate = cols.s0_intermediate.populate(blu, shard, a_rr_2, a_rr_13);
229228
let s0 = cols.s0.populate(blu, shard, s0_intermediate, a_rr_22);
230229

231-
let a_and_b = cols.a_and_b.populate(blu, shard, a, b);
232-
let a_and_c = cols.a_and_c.populate(blu, shard, a, c);
230+
let b_xor_c = cols.b_xor_c.populate(blu, shard, b, c);
231+
let a_and_b_xor_c = cols.a_and_b_xor_c.populate(blu, shard, a, b_xor_c);
233232
let b_and_c = cols.b_and_c.populate(blu, shard, b, c);
234-
let maj_intermediate = cols.maj_intermediate.populate(blu, shard, a_and_b, a_and_c);
235-
let maj = cols.maj.populate(blu, shard, maj_intermediate, b_and_c);
236-
237-
let temp2 = cols.temp2.populate(blu, shard, s0, maj);
233+
let maj = cols.maj.populate(blu, shard, a_and_b_xor_c, b_and_c);
238234

239235
let d_add_temp1 = cols.d_add_temp1.populate(blu, shard, d, temp1);
240-
let temp1_add_temp2 = cols.temp1_add_temp2.populate(blu, shard, temp1, temp2);
236+
let temp1_add_temp2 = cols.temp1_add_temp2.populate(blu, shard, temp1, s0, maj);
241237

242238
h_array[7] = g;
243239
h_array[6] = f;

0 commit comments

Comments
 (0)