Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sha2 air formulas optimizations #7

Open
wants to merge 1 commit into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 134 additions & 0 deletions crates/core/machine/src/operations/add3.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
use p3_air::AirBuilder;
use p3_field::{AbstractField, Field};
use sp1_derive::AlignedBorrow;

use sp1_core_executor::events::ByteRecord;
use sp1_primitives::consts::WORD_SIZE;
use sp1_stark::{air::SP1AirBuilder, Word};

use crate::air::WordAirBuilder;

/// A set of columns needed to compute the add of three words.
#[derive(AlignedBorrow, Default, Debug, Clone, Copy)]
#[repr(C)]
pub struct Add3Operation<T> {
/// The result of `a + b + c`.
pub value: Word<T>,

/// Indicates if the carry for the `i`th digit is 0.
pub is_carry_0: Word<T>,

/// Indicates if the carry for the `i`th digit is 1.
pub is_carry_1: Word<T>,

/// Indicates if the carry for the `i`th digit is 2. The carry when adding 3 words is at most
/// 2
pub is_carry_2: Word<T>,

/// The carry for the `i`th digit.
pub carry: Word<T>,
}

impl<F: Field> Add3Operation<F> {
#[allow(clippy::too_many_arguments)]
pub fn populate(
&mut self,
record: &mut impl ByteRecord,
shard: u32,
a_u32: u32,
b_u32: u32,
c_u32: u32,
) -> u32 {
let expected = a_u32.wrapping_add(b_u32).wrapping_add(c_u32);
self.value = Word::from(expected);
let a = a_u32.to_le_bytes();
let b = b_u32.to_le_bytes();
let c = c_u32.to_le_bytes();

let base = 256;
let mut carry = [0u8, 0u8, 0u8, 0u8];
for i in 0..WORD_SIZE {
let mut res = (a[i] as u32) + (b[i] as u32) + (c[i] as u32);
if i > 0 {
res += carry[i - 1] as u32;
}
carry[i] = (res / base) as u8;
self.is_carry_0[i] = F::from_bool(carry[i] == 0);
self.is_carry_1[i] = F::from_bool(carry[i] == 1);
self.is_carry_2[i] = F::from_bool(carry[i] == 2);
self.carry[i] = F::from_canonical_u8(carry[i]);
debug_assert!(carry[i] <= 2);
debug_assert_eq!(self.value[i], F::from_canonical_u32(res % base));
}

// Range check.
{
record.add_u8_range_checks(shard, &a);
record.add_u8_range_checks(shard, &b);
record.add_u8_range_checks(shard, &c);
record.add_u8_range_checks(shard, &expected.to_le_bytes());
}
expected
}

#[allow(clippy::too_many_arguments)]
pub fn eval<AB: SP1AirBuilder>(
builder: &mut AB,
a: Word<AB::Var>,
b: Word<AB::Var>,
c: Word<AB::Var>,
cols: Add3Operation<AB::Var>,
is_real: AB::Var,
) {
// Range check each byte.
{
builder.slice_range_check_u8(&a.0, is_real);
builder.slice_range_check_u8(&b.0, is_real);
builder.slice_range_check_u8(&c.0, is_real);
builder.slice_range_check_u8(&cols.value.0, is_real);
}

builder.assert_bool(is_real);
let mut builder_is_real = builder.when(is_real);

// Each value in is_carry_{0,1,2} is 0 or 1, and exactly one of them is 1 per digit.
{
for i in 0..WORD_SIZE {
builder_is_real.assert_bool(cols.is_carry_0[i]);
builder_is_real.assert_bool(cols.is_carry_1[i]);
builder_is_real.assert_bool(cols.is_carry_2[i]);
builder_is_real.assert_eq(
cols.is_carry_0[i] + cols.is_carry_1[i] + cols.is_carry_2[i],
AB::Expr::one(),
);
}
}

// Calculates carry from is_carry_{0,1,2}.
{
let one = AB::Expr::one();
let two = AB::F::from_canonical_u32(2);

for i in 0..WORD_SIZE {
builder_is_real.assert_eq(
cols.carry[i],
cols.is_carry_1[i] * one.clone() + cols.is_carry_2[i] * two,
);
}
}

// Compare the sum and summands by looking at carry.
{
let base = AB::F::from_canonical_u32(256);
// For each limb, assert that difference between the carried result and the non-carried
// result is the product of carry and base.
for i in 0..WORD_SIZE {
let mut overflow = a[i] + b[i] + c[i] - cols.value[i];
if i > 0 {
overflow = overflow.clone() + cols.carry[i - 1].into();
}
builder_is_real.assert_eq(cols.carry[i] * base, overflow.clone());
}
}
}
}
2 changes: 2 additions & 0 deletions crates/core/machine/src/operations/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
//! the constraints.

mod add;
mod add3;
mod add4;
mod add5;
mod and;
Expand All @@ -22,6 +23,7 @@ mod or;
mod xor;

pub use add::*;
pub use add3::*;
pub use add4::*;
pub use add5::*;
pub use and::*;
Expand Down
68 changes: 28 additions & 40 deletions crates/core/machine/src/syscall/precompiles/sha256/compress/air.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use crate::{
air::{MemoryAirBuilder, WordAirBuilder},
memory::MemoryCols,
operations::{
Add5Operation, AddOperation, AndOperation, FixedRotateRightOperation, NotOperation,
Add3Operation, Add5Operation, AddOperation, AndOperation, FixedRotateRightOperation,
XorOperation,
},
};
Expand Down Expand Up @@ -329,24 +329,22 @@ impl ShaCompressChip {
local.is_compression,
);

// Calculate ch := (e and f) xor ((not e) and g).
// Calculate e and f.
AndOperation::<AB::F>::eval(builder, local.e, local.f, local.e_and_f, local.is_compression);
// Calculate not e.
NotOperation::<AB::F>::eval(builder, local.e, local.e_not, local.is_compression);
// Calculate (not e) and g.
// Calculate ch := (e and f) xor ((not e) and g) = g xor (e and (f xor g)).
// Calculate f xor g.
XorOperation::<AB::F>::eval(builder, local.f, local.g, local.f_xor_g, local.is_compression);
// Calculate e and (f xor g).
AndOperation::<AB::F>::eval(
builder,
local.e_not.value,
local.g,
local.e_not_and_g,
local.e,
local.f_xor_g.value,
local.e_and_f_xor_g,
local.is_compression,
);
// Calculate ch := (e and f) xor ((not e) and g).
// Calculate ch := g xor (e and (f xor g)).
XorOperation::<AB::F>::eval(
builder,
local.e_and_f.value,
local.e_not_and_g.value,
local.g,
local.e_and_f_xor_g.value,
local.ch,
local.is_compression,
);
Expand Down Expand Up @@ -401,39 +399,28 @@ impl ShaCompressChip {
local.is_compression,
);

// Calculate maj := (a and b) xor (a and c) xor (b and c).
// Calculate a and b.
AndOperation::<AB::F>::eval(builder, local.a, local.b, local.a_and_b, local.is_compression);
// Calculate a and c.
AndOperation::<AB::F>::eval(builder, local.a, local.c, local.a_and_c, local.is_compression);
// Calculate b and c.
AndOperation::<AB::F>::eval(builder, local.b, local.c, local.b_and_c, local.is_compression);
// Calculate (a and b) xor (a and c).
XorOperation::<AB::F>::eval(
// Calculate maj := (a and b) xor (a and c) xor (b and c) = (a and (b xor c)) xor (b and c).
// Calculate b xor c.
XorOperation::<AB::F>::eval(builder, local.b, local.c, local.b_xor_c, local.is_compression);
// Calculate a and (b xor c).
AndOperation::<AB::F>::eval(
builder,
local.a_and_b.value,
local.a_and_c.value,
local.maj_intermediate,
local.a,
local.b_xor_c.value,
local.a_and_b_xor_c,
local.is_compression,
);
// Calculate maj := ((a and b) xor (a and c)) xor (b and c).
// Calculate b and c.
AndOperation::<AB::F>::eval(builder, local.b, local.c, local.b_and_c, local.is_compression);
// Calculate maj := (a and (b xor c)) xor (b and c).
XorOperation::<AB::F>::eval(
builder,
local.maj_intermediate.value,
local.a_and_b_xor_c.value,
local.b_and_c.value,
local.maj,
local.is_compression,
);

// Calculate temp2 := s0 + maj.
AddOperation::<AB::F>::eval(
builder,
local.s0.value,
local.maj.value,
local.temp2,
local.is_compression.into(),
);

// Calculate d + temp1 for the new value of e.
AddOperation::<AB::F>::eval(
builder,
Expand All @@ -443,13 +430,14 @@ impl ShaCompressChip {
local.is_compression.into(),
);

// Calculate temp1 + temp2 for the new value of a.
AddOperation::<AB::F>::eval(
// Calculate temp1 + S0 + maj for the new value of a.
Add3Operation::<AB::F>::eval(
builder,
local.temp1.value,
local.temp2.value,
local.s0.value,
local.maj.value,
local.temp1_add_temp2,
local.is_compression.into(),
local.is_compression,
);

// h := g
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use sp1_stark::Word;
use crate::{
memory::MemoryReadWriteCols,
operations::{
Add5Operation, AddOperation, AndOperation, FixedRotateRightOperation, NotOperation,
Add3Operation, Add5Operation, AddOperation, AndOperation, FixedRotateRightOperation,
XorOperation,
},
};
Expand Down Expand Up @@ -67,9 +67,8 @@ pub struct ShaCompressCols<T> {
/// `S1 := (e rightrotate 6) xor (e rightrotate 11) xor (e rightrotate 25)`.
pub s1: XorOperation<T>,

pub e_and_f: AndOperation<T>,
pub e_not: NotOperation<T>,
pub e_not_and_g: AndOperation<T>,
pub f_xor_g: XorOperation<T>,
pub e_and_f_xor_g: AndOperation<T>,
/// `ch := (e and f) xor ((not e) and g)`.
pub ch: XorOperation<T>,

Expand All @@ -83,20 +82,16 @@ pub struct ShaCompressCols<T> {
/// `S0 := (a rightrotate 2) xor (a rightrotate 13) xor (a rightrotate 22)`.
pub s0: XorOperation<T>,

pub a_and_b: AndOperation<T>,
pub a_and_c: AndOperation<T>,
pub b_xor_c: XorOperation<T>,
pub a_and_b_xor_c: AndOperation<T>,
pub b_and_c: AndOperation<T>,
pub maj_intermediate: XorOperation<T>,
/// `maj := (a and b) xor (a and c) xor (b and c)`.
pub maj: XorOperation<T>,

/// `temp2 := S0 + maj`.
pub temp2: AddOperation<T>,

/// The next value of `e` is `d + temp1`.
pub d_add_temp1: AddOperation<T>,
/// The next value of `a` is `temp1 + temp2`.
pub temp1_add_temp2: AddOperation<T>,
/// The next value of `a` is `temp1 + S0 + maj`.
pub temp1_add_temp2: Add3Operation<T>,

/// During finalize, this is one of a-h and is being written into `mem`.
pub finalized_operand: Word<T>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,10 +215,9 @@ impl ShaCompressChip {
let s1_intermediate = cols.s1_intermediate.populate(blu, shard, e_rr_6, e_rr_11);
let s1 = cols.s1.populate(blu, shard, s1_intermediate, e_rr_25);

let e_and_f = cols.e_and_f.populate(blu, shard, e, f);
let e_not = cols.e_not.populate(blu, shard, e);
let e_not_and_g = cols.e_not_and_g.populate(blu, shard, e_not, g);
let ch = cols.ch.populate(blu, shard, e_and_f, e_not_and_g);
let f_xor_g = cols.f_xor_g.populate(blu, shard, f, g);
let e_and_f_xor_g = cols.e_and_f_xor_g.populate(blu, shard, a, f_xor_g);
let ch = cols.ch.populate(blu, shard, g, e_and_f_xor_g);

let temp1 = cols.temp1.populate(blu, shard, h, s1, ch, event.w[j], SHA_COMPRESS_K[j]);

Expand All @@ -228,16 +227,13 @@ impl ShaCompressChip {
let s0_intermediate = cols.s0_intermediate.populate(blu, shard, a_rr_2, a_rr_13);
let s0 = cols.s0.populate(blu, shard, s0_intermediate, a_rr_22);

let a_and_b = cols.a_and_b.populate(blu, shard, a, b);
let a_and_c = cols.a_and_c.populate(blu, shard, a, c);
let b_xor_c = cols.b_xor_c.populate(blu, shard, b, c);
let a_and_b_xor_c = cols.a_and_b_xor_c.populate(blu, shard, a, b_xor_c);
let b_and_c = cols.b_and_c.populate(blu, shard, b, c);
let maj_intermediate = cols.maj_intermediate.populate(blu, shard, a_and_b, a_and_c);
let maj = cols.maj.populate(blu, shard, maj_intermediate, b_and_c);

let temp2 = cols.temp2.populate(blu, shard, s0, maj);
let maj = cols.maj.populate(blu, shard, a_and_b_xor_c, b_and_c);

let d_add_temp1 = cols.d_add_temp1.populate(blu, shard, d, temp1);
let temp1_add_temp2 = cols.temp1_add_temp2.populate(blu, shard, temp1, temp2);
let temp1_add_temp2 = cols.temp1_add_temp2.populate(blu, shard, temp1, s0, maj);

h_array[7] = g;
h_array[6] = f;
Expand Down
Loading