diff --git a/Cargo.lock b/Cargo.lock index 1296a001b..95d28b7d4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1199,6 +1199,7 @@ source = "git+https://github.com/scroll-tech/ceno-patch.git?branch=main#bc61dbd5 name = "ceno_zkvm" version = "0.1.0" dependencies = [ + "arrayref", "base64", "bincode 1.3.3", "ceno-examples", diff --git a/ceno_emul/src/syscalls/sha256.rs b/ceno_emul/src/syscalls/sha256.rs index ba0252ff8..10e13afeb 100644 --- a/ceno_emul/src/syscalls/sha256.rs +++ b/ceno_emul/src/syscalls/sha256.rs @@ -1,8 +1,12 @@ -use crate::{Change, EmuContext, Platform, Tracer, VMState, Word, WriteOp, utils::MemoryView}; +use crate::{ + ByteAddr, Change, EmuContext, Platform, Tracer, VMState, WORD_SIZE, Word, WordAddr, WriteOp, + utils::MemoryView, +}; use super::{SyscallEffects, SyscallSpec, SyscallWitness}; pub const SHA_EXTEND_WORDS: usize = 64; // u64 cells +pub const SHA_EXTEND_ROUND_MEM_OPS: usize = 5; pub struct Sha256ExtendSpec; @@ -10,7 +14,7 @@ impl SyscallSpec for Sha256ExtendSpec { const NAME: &'static str = "SHA256_EXTEND"; const REG_OPS_COUNT: usize = 1; - const MEM_OPS_COUNT: usize = SHA_EXTEND_WORDS; + const MEM_OPS_COUNT: usize = SHA_EXTEND_ROUND_MEM_OPS; const CODE: u32 = ceno_syscall::SHA_EXTEND; } @@ -29,18 +33,6 @@ impl From for [Word; SHA_EXTEND_WORDS] { } } -/// Based on: https://github.com/succinctlabs/sp1/blob/2aed8fea16a67a5b2983ffc471b2942c2f2512c8/crates/core/machine/src/syscall/precompiles/sha256/extend/mod.rs#L22 -pub fn sha_extend(w: &mut [u32]) { - for i in 16..64 { - let s0 = w[i - 15].rotate_right(7) ^ w[i - 15].rotate_right(18) ^ (w[i - 15] >> 3); - let s1 = w[i - 2].rotate_right(17) ^ w[i - 2].rotate_right(19) ^ (w[i - 2] >> 10); - w[i] = w[i - 16] - .wrapping_add(s0) - .wrapping_add(w[i - 7]) - .wrapping_add(s1); - } -} - pub fn extend(vm: &VMState) -> SyscallEffects { let state_ptr = vm.peek_register(Platform::reg_arg0()); @@ -48,18 +40,51 @@ pub fn extend(vm: &VMState) -> SyscallEffects { let reg_ops = vec![WriteOp::new_register_op( Platform::reg_arg0(), Change::new(state_ptr, state_ptr), - 0, // Cycle set later in finalize(). + 0, )]; - let mut state_view = MemoryView::<_, SHA_EXTEND_WORDS>::new(vm, state_ptr); - let mut sha_extend_words = ShaExtendWords::from(state_view.words()); - sha_extend(&mut sha_extend_words.0); - let output_words: [Word; SHA_EXTEND_WORDS] = sha_extend_words.into(); + let w_i_minus_2 = MemoryView::<_, 1>::new(vm, state_ptr - 2 * WORD_SIZE as u32).words()[0]; + let w_i_minus_7 = MemoryView::<_, 1>::new(vm, state_ptr - 7 * WORD_SIZE as u32).words()[0]; + let w_i_minus_15 = MemoryView::<_, 1>::new(vm, state_ptr - 15 * WORD_SIZE as u32).words()[0]; + let w_i_minus_16 = MemoryView::<_, 1>::new(vm, state_ptr - 16 * WORD_SIZE as u32).words()[0]; + let old_word = MemoryView::<_, 1>::new(vm, state_ptr).words()[0]; + + let s0 = w_i_minus_15.rotate_right(7) ^ w_i_minus_15.rotate_right(18) ^ (w_i_minus_15 >> 3); + let s1 = w_i_minus_2.rotate_right(17) ^ w_i_minus_2.rotate_right(19) ^ (w_i_minus_2 >> 10); + let new_word = w_i_minus_16 + .wrapping_add(s0) + .wrapping_add(w_i_minus_7) + .wrapping_add(s1); - state_view.write(output_words); - let mem_ops = state_view.mem_ops().to_vec(); + let base = ByteAddr::from(state_ptr).waddr(); + let mem_ops = vec![ + WriteOp { + addr: WordAddr(base.0 - 2), + value: Change::new(w_i_minus_2, w_i_minus_2), + previous_cycle: 0, + }, + WriteOp { + addr: WordAddr(base.0 - 7), + value: Change::new(w_i_minus_7, w_i_minus_7), + previous_cycle: 0, + }, + WriteOp { + addr: WordAddr(base.0 - 15), + value: Change::new(w_i_minus_15, w_i_minus_15), + previous_cycle: 0, + }, + WriteOp { + addr: WordAddr(base.0 - 16), + value: Change::new(w_i_minus_16, w_i_minus_16), + previous_cycle: 0, + }, + WriteOp { + addr: base, + value: Change::new(old_word, new_word), + previous_cycle: 0, + }, + ]; - assert_eq!(mem_ops.len(), SHA_EXTEND_WORDS); SyscallEffects { witness: SyscallWitness::new(mem_ops, reg_ops), next_pc: None, diff --git a/ceno_host/tests/test_elf.rs b/ceno_host/tests/test_elf.rs index 83f56a3e1..ae06f92a4 100644 --- a/ceno_host/tests/test_elf.rs +++ b/ceno_host/tests/test_elf.rs @@ -3,8 +3,8 @@ use std::{collections::BTreeSet, iter::from_fn, sync::Arc}; use anyhow::Result; use ceno_emul::{ BN254_FP_WORDS, BN254_FP2_WORDS, BN254_POINT_WORDS, CENO_PLATFORM, EmuContext, InsnKind, - Platform, Program, SECP256K1_ARG_WORDS, SECP256K1_COORDINATE_WORDS, SHA_EXTEND_WORDS, - StepRecord, UINT256_WORDS_FIELD_ELEMENT, VMState, WORD_SIZE, Word, WordAddr, WriteOp, + Platform, Program, SECP256K1_ARG_WORDS, SECP256K1_COORDINATE_WORDS, StepRecord, + UINT256_WORDS_FIELD_ELEMENT, VMState, WORD_SIZE, Word, WordAddr, WriteOp, host_utils::{read_all_messages, read_all_messages_as_words}, }; use ceno_host::CenoStdin; @@ -466,19 +466,6 @@ fn test_secp256k1_ecrecover() -> Result<()> { #[test] fn test_sha256_extend() -> Result<()> { let program_elf = ceno_examples::sha_extend_syscall; - let mut state = VMState::new_from_elf(unsafe_platform(), program_elf)?; - - let steps = run(&mut state)?; - let syscalls = steps.iter().filter_map(|step| step.syscall()).collect_vec(); - assert_eq!(syscalls.len(), 1); - - let witness = syscalls[0]; - assert_eq!(witness.reg_ops.len(), 1); - assert_eq!(witness.reg_ops[0].register_index(), Platform::reg_arg0()); - - let state_ptr = witness.reg_ops[0].value.after; - assert_eq!(state_ptr, witness.reg_ops[0].value.before); - let state_ptr: WordAddr = state_ptr.into(); let expected = [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 34013193, 67559435, 1711661200, @@ -491,14 +478,31 @@ fn test_sha256_extend() -> Result<()> { 634956631, ]; - assert_eq!(witness.mem_ops.len(), SHA_EXTEND_WORDS); + for round in 0..48 { + let mut state = VMState::new_from_elf(unsafe_platform(), program_elf)?; + let steps = run(&mut state)?; + let syscalls = steps.iter().filter_map(|step| step.syscall()).collect_vec(); + assert_eq!(syscalls.len(), 1); + let witness = syscalls[0]; - for (i, write_op) in witness.mem_ops.iter().enumerate() { - assert_eq!(write_op.addr, state_ptr + i); - assert_eq!(write_op.value.after, expected[i]); - if i < 16 { - // sanity check: first 16 entries remain unchanged - assert_eq!(write_op.value.before, write_op.value.after); + assert_eq!(witness.reg_ops.len(), 1); + assert_eq!(witness.reg_ops[0].register_index(), Platform::reg_arg0()); + + let state_ptr = witness.reg_ops[0].value.before; + assert_eq!(state_ptr + WORD_SIZE as u32, witness.reg_ops[0].value.after); + let state_ptr: WordAddr = state_ptr.into(); + + assert_eq!(witness.mem_ops.len(), 5); + + let offsets = [-2, -7, -15, -16, 0]; + for (i, write_op) in witness.mem_ops.iter().enumerate() { + let mem_round_id = round + (16 + offsets[i]) as usize; + assert_eq!(write_op.addr, state_ptr + mem_round_id as u32); + if i < 4 { + assert_eq!(write_op.value.before, write_op.value.after); + } else { + assert_eq!(write_op.value.after, expected[mem_round_id]); + } } } diff --git a/ceno_recursion/src/zkvm_verifier/verifier.rs b/ceno_recursion/src/zkvm_verifier/verifier.rs index 68c935fdb..6883cdd45 100644 --- a/ceno_recursion/src/zkvm_verifier/verifier.rs +++ b/ceno_recursion/src/zkvm_verifier/verifier.rs @@ -1370,12 +1370,14 @@ pub fn evaluate_selector( (expr, sel) } - SelectorType::OrderedSparse32 { + SelectorType::OrderedSparse { + num_vars, indices, expression, } => { - let out_point_slice = out_point.slice(builder, 0, 5); - let in_point_slice = in_point.slice(builder, 0, 5); + let num_vars = *num_vars; + let out_point_slice = out_point.slice(builder, 0, num_vars); + let in_point_slice = in_point.slice(builder, 0, num_vars); let out_subgroup_eq = build_eq_x_r_vec_sequential(builder, &out_point_slice); let in_subgroup_eq = build_eq_x_r_vec_sequential(builder, &in_point_slice); @@ -1386,8 +1388,8 @@ pub fn evaluate_selector( builder.assign(&eval, eval + out_val * in_val); } - let out_point_slice = out_point.slice(builder, 5, out_point.len()); - let in_point_slice = in_point.slice(builder, 5, in_point.len()); + let out_point_slice = out_point.slice(builder, num_vars, out_point.len()); + let in_point_slice = in_point.slice(builder, num_vars, in_point.len()); let n_bits = builder.get(&ctx.num_instances_bit_decomps, 0); let sel = diff --git a/ceno_zkvm/Cargo.toml b/ceno_zkvm/Cargo.toml index ca4955e4b..42d4e869b 100644 --- a/ceno_zkvm/Cargo.toml +++ b/ceno_zkvm/Cargo.toml @@ -42,6 +42,7 @@ tracing.workspace = true tracing-forest.workspace = true tracing-subscriber.workspace = true +arrayref = "0.3.9" bincode.workspace = true cfg-if.workspace = true clap.workspace = true diff --git a/ceno_zkvm/src/gadgets/add4.rs b/ceno_zkvm/src/gadgets/add4.rs new file mode 100644 index 000000000..d3fcf16b5 --- /dev/null +++ b/ceno_zkvm/src/gadgets/add4.rs @@ -0,0 +1,198 @@ +// The add4 gadget is modified from succinctlabs/sp1 under MIT license + +// The MIT License (MIT) + +// Copyright (c) 2023 Succinct Labs + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: + +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. + +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +use ceno_emul::WORD_SIZE; +use derive::AlignedBorrow; +use ff_ext::{ExtensionField, SmallField}; +use gkr_iop::error::CircuitBuilderError; +use multilinear_extensions::{Expression, ToExpr, WitIn}; + +use crate::{circuit_builder::CircuitBuilder, gadgets::word::Word, witness::LkMultiplicity}; + +/// A set of columns needed to compute the add of four words. +#[derive(AlignedBorrow, Default, Debug, Clone, Copy)] +#[repr(C)] +pub struct Add4Operation { + /// The result of `a + b + c + d`. + pub value: Word, + + /// Indicates if the carry for the `i`th digit is 0. + pub is_carry_0: Word, + + /// Indicates if the carry for the `i`th digit is 1. + pub is_carry_1: Word, + + /// Indicates if the carry for the `i`th digit is 2. + pub is_carry_2: Word, + + /// Indicates if the carry for the `i`th digit is 3. The carry when adding 4 words is at most + /// 3. + pub is_carry_3: Word, + + /// The carry for the `i`th digit. + pub carry: Word, +} + +impl Add4Operation { + pub fn create(cb: &mut CircuitBuilder, name_fn: N) -> Self + where + NR: Into, + N: FnOnce() -> NR, + { + let name: String = name_fn().into(); + Self { + value: Word::create(cb, || format!("{}_value", name)), + is_carry_0: Word::create(cb, || format!("{}_is_carry_0", name)), + is_carry_1: Word::create(cb, || format!("{}_is_carry_1", name)), + is_carry_2: Word::create(cb, || format!("{}_is_carry_2", name)), + is_carry_3: Word::create(cb, || format!("{}_is_carry_3", name)), + carry: Word::create(cb, || format!("{}_carry", name)), + } + } +} + +impl Add4Operation { + #[allow(clippy::too_many_arguments)] + pub fn populate( + &mut self, + record: &mut LkMultiplicity, + a_u32: u32, + b_u32: u32, + c_u32: u32, + d_u32: u32, + ) -> u32 { + let expected = a_u32 + .wrapping_add(b_u32) + .wrapping_add(c_u32) + .wrapping_add(d_u32); + self.value = Word::from(expected); + let a = a_u32.to_le_bytes(); + let b = b_u32.to_le_bytes(); + let c = c_u32.to_le_bytes(); + let d = d_u32.to_le_bytes(); + + let base = 256; + let mut carry = [0u8, 0u8, 0u8, 0u8]; + for i in 0..WORD_SIZE { + let mut res = (a[i] as u32) + (b[i] as u32) + (c[i] as u32) + (d[i] as u32); + if i > 0 { + res += carry[i - 1] as u32; + } + carry[i] = (res / base) as u8; + self.is_carry_0[i] = F::from_bool(carry[i] == 0); + self.is_carry_1[i] = F::from_bool(carry[i] == 1); + self.is_carry_2[i] = F::from_bool(carry[i] == 2); + self.is_carry_3[i] = F::from_bool(carry[i] == 3); + self.carry[i] = F::from_canonical_u8(carry[i]); + debug_assert!(carry[i] <= 3); + debug_assert_eq!(self.value[i], F::from_canonical_u32(res % base)); + } + + // Range check. + { + record.assert_bytes(&a); + record.assert_bytes(&b); + record.assert_bytes(&c); + record.assert_bytes(&d); + record.assert_bytes(&expected.to_le_bytes()); + } + expected + } +} + +impl Add4Operation { + #[allow(clippy::too_many_arguments)] + pub fn eval( + &self, + builder: &mut CircuitBuilder, + a: Word> + Clone>, + b: Word> + Clone>, + c: Word> + Clone>, + d: Word> + Clone>, + ) -> Result<(), CircuitBuilderError> + where + E: ExtensionField, + Expr: ToExpr>, + Expression: From, + { + // Range check each byte. + { + builder.assert_bytes(|| "add4 operation a.0", &a.0)?; + builder.assert_bytes(|| "add4 operation b.0", &b.0)?; + builder.assert_bytes(|| "add4 operation c.0", &c.0)?; + builder.assert_bytes(|| "add4 operation d.0", &d.0)?; + builder.assert_bytes(|| "add4 operation self.value.0", &self.value.0)?; + } + + // Each value in is_carry_{0,1,2,3} is 0 or 1, and exactly one of them is 1 per digit. + { + for i in 0..WORD_SIZE { + builder.assert_bit(|| "add4 is_carry_0", self.is_carry_0[i].expr())?; + builder.assert_bit(|| "add4 is_carry_1", self.is_carry_1[i].expr())?; + builder.assert_bit(|| "add4 is_carry_2", self.is_carry_2[i].expr())?; + builder.assert_bit(|| "add4 is_carry_3", self.is_carry_3[i].expr())?; + builder.require_equal( + || "add4 is_carry sum to 1", + self.is_carry_0[i].expr() + + self.is_carry_1[i].expr() + + self.is_carry_2[i].expr() + + self.is_carry_3[i].expr(), + 1.into(), + )?; + } + } + + // Calculates carry from is_carry_{0,1,2,3}. + { + for i in 0..WORD_SIZE { + builder.require_equal( + || "add4 carry from is_carry", + self.carry[i].expr(), + self.is_carry_1[i].expr() * 1 + + self.is_carry_2[i].expr() * 2 + + self.is_carry_3[i].expr() * 3, + )?; + } + } + + // Compare the sum and summands by looking at carry. + { + // For each limb, assert that difference between the carried result and the non-carried + // result is the product of carry and base. + for i in 0..WORD_SIZE { + let mut overflow = + a[i].expr() + b[i].expr() + c[i].expr() + d[i].expr() - self.value[i].expr(); + if i > 0 { + overflow = overflow.expr() + self.carry[i - 1].expr(); + } + builder.require_equal( + || "add4 carry overflow", + self.carry[i].expr() * 256, + overflow.clone(), + )?; + } + } + Ok(()) + } +} diff --git a/ceno_zkvm/src/gadgets/fixed_rotate_right.rs b/ceno_zkvm/src/gadgets/fixed_rotate_right.rs new file mode 100644 index 000000000..a1fa8499f --- /dev/null +++ b/ceno_zkvm/src/gadgets/fixed_rotate_right.rs @@ -0,0 +1,188 @@ +// The fixed_rotate_right gadget is modified from succinctlabs/sp1 under MIT license + +// The MIT License (MIT) + +// Copyright (c) 2023 Succinct Labs + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: + +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. + +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +use ceno_emul::WORD_SIZE; +use derive::AlignedBorrow; +use ff_ext::{ExtensionField, SmallField}; +use gkr_iop::error::CircuitBuilderError; +use multilinear_extensions::{Expression, ToExpr, WitIn}; + +use crate::{ + circuit_builder::CircuitBuilder, + gadgets::{util::shr_carry, word::Word}, + witness::LkMultiplicity, +}; +/// A set of columns needed to compute `rotateright` of a word with a fixed offset R. +/// +/// Note that we decompose shifts into a byte shift and a bit shift. +#[derive(AlignedBorrow, Default, Debug, Clone, Copy)] +#[repr(C)] +pub struct FixedRotateRightOperation { + /// The output value. + pub value: Word, + + /// The shift output of `shrcarry` on each byte of a word. + pub shift: Word, + + /// The carry output of `shrcarry` on each byte of a word. + pub carry: Word, +} + +impl FixedRotateRightOperation { + pub fn create(cb: &mut CircuitBuilder, name_fn: N) -> Self + where + NR: Into, + N: FnOnce() -> NR, + { + let name: String = name_fn().into(); + Self { + value: Word::create(cb, || format!("{}_value", name)), + shift: Word::create(cb, || format!("{}_shift", name)), + carry: Word::create(cb, || format!("{}_carry", name)), + } + } +} + +impl FixedRotateRightOperation { + pub fn populate(&mut self, record: &mut LkMultiplicity, input: u32, rotation: usize) -> u32 { + let input_bytes = input.to_le_bytes().map(F::from_canonical_u8); + let expected = input.rotate_right(rotation as u32); + + // Compute some constants with respect to the rotation needed for the rotation. + let nb_bytes_to_shift = Self::nb_bytes_to_shift(rotation); + let nb_bits_to_shift = Self::nb_bits_to_shift(rotation); + let carry_multiplier = F::from_canonical_u32(Self::carry_multiplier(rotation)); + + // Perform the byte shift. + let input_bytes_rotated = Word([ + input_bytes[nb_bytes_to_shift % WORD_SIZE], + input_bytes[(1 + nb_bytes_to_shift) % WORD_SIZE], + input_bytes[(2 + nb_bytes_to_shift) % WORD_SIZE], + input_bytes[(3 + nb_bytes_to_shift) % WORD_SIZE], + ]); + + // For each byte, calculate the shift and carry. If it's not the first byte, calculate the + // new byte value using the current shifted byte and the last carry. + let mut first_shift = F::ZERO; + let mut last_carry = F::ZERO; + for i in (0..WORD_SIZE).rev() { + let b = input_bytes_rotated[i].to_string().parse::().unwrap(); + let c = nb_bits_to_shift as u8; + + let (shift, carry) = shr_carry(b, c); + record.lookup_shr_byte(shift as u64, carry as u64, nb_bits_to_shift as u64); + + self.shift[i] = F::from_canonical_u8(shift); + self.carry[i] = F::from_canonical_u8(carry); + + if i == WORD_SIZE - 1 { + first_shift = self.shift[i]; + } else { + self.value[i] = self.shift[i] + last_carry * carry_multiplier; + } + + last_carry = self.carry[i]; + } + + // For the first byte, we didn't know the last carry so compute the rotated byte here. + self.value[WORD_SIZE - 1] = first_shift + last_carry * carry_multiplier; + + // Check that the value is correct. + assert_eq!(self.value.to_u32(), expected); + + expected + } +} + +impl FixedRotateRightOperation { + pub fn eval( + &self, + builder: &mut CircuitBuilder, + input: Word>>, + rotation: usize, + ) -> Result<(), CircuitBuilderError> + where + E: ExtensionField, + Expr: ToExpr>, + Expression: From, + { + // Compute some constants with respect to the rotation needed for the rotation. + let nb_bytes_to_shift = Self::nb_bytes_to_shift(rotation); + let nb_bits_to_shift = Self::nb_bits_to_shift(rotation); + let carry_multiplier = Self::carry_multiplier(rotation); + + // Perform the byte shift. + let input_bytes_rotated = Word([ + input[nb_bytes_to_shift % WORD_SIZE].expr(), + input[(1 + nb_bytes_to_shift) % WORD_SIZE].expr(), + input[(2 + nb_bytes_to_shift) % WORD_SIZE].expr(), + input[(3 + nb_bytes_to_shift) % WORD_SIZE].expr(), + ]); + + // For each byte, calculate the shift and carry. If it's not the first byte, calculate the + // new byte value using the current shifted byte and the last carry. + let mut first_shift = Expression::::ZERO; + let mut last_carry = Expression::::ZERO; + for i in (0..WORD_SIZE).rev() { + builder.lookup_shr_byte( + input_bytes_rotated[i].expr(), + nb_bits_to_shift, + self.shift[i].expr(), + self.carry[i].expr(), + )?; + + if i == WORD_SIZE - 1 { + first_shift = self.shift[i].expr(); + } else { + builder.require_equal( + || "fixed rotate right value calculation", + self.value[i].expr(), + self.shift[i].expr() + last_carry * carry_multiplier, + )?; + } + + last_carry = self.carry[i].expr(); + } + + // For the first byte, we didn't know the last carry so compute the rotated byte here. + builder.require_equal( + || "fixed rotate right first value calculation", + self.value[WORD_SIZE - 1].expr(), + first_shift + last_carry * carry_multiplier, + ) + } + + pub const fn nb_bytes_to_shift(rotation: usize) -> usize { + rotation / 8 + } + + pub const fn nb_bits_to_shift(rotation: usize) -> usize { + rotation % 8 + } + + pub const fn carry_multiplier(rotation: usize) -> u32 { + let nb_bits_to_shift = Self::nb_bits_to_shift(rotation); + 1 << (8 - nb_bits_to_shift) + } +} diff --git a/ceno_zkvm/src/gadgets/fixed_shift_right.rs b/ceno_zkvm/src/gadgets/fixed_shift_right.rs new file mode 100644 index 000000000..d7827e2dc --- /dev/null +++ b/ceno_zkvm/src/gadgets/fixed_shift_right.rs @@ -0,0 +1,191 @@ +// The fixed_shift_right gadget is modified from succinctlabs/sp1 under MIT license + +// The MIT License (MIT) + +// Copyright (c) 2023 Succinct Labs + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: + +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. + +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +use ceno_emul::WORD_SIZE; +use derive::AlignedBorrow; +use ff_ext::SmallField; +use gkr_iop::error::CircuitBuilderError; +use multilinear_extensions::{Expression, ToExpr, WitIn}; + +use crate::{ + circuit_builder::CircuitBuilder, + gadgets::{util::shr_carry, word::Word}, + witness::LkMultiplicity, +}; + +/// A set of columns needed to compute `>>` of a word with a fixed offset R. +/// +/// Note that we decompose shifts into a byte shift and a bit shift. +#[derive(AlignedBorrow, Default, Debug, Clone, Copy)] +#[repr(C)] +pub struct FixedShiftRightOperation { + /// The output value. + pub value: Word, + + /// The shift output of `shrcarry` on each byte of a word. + pub shift: Word, + + /// The carry ouytput of `shrcarry` on each byte of a word. + pub carry: Word, +} + +impl FixedShiftRightOperation { + pub fn create(cb: &mut CircuitBuilder, name_fn: N) -> Self + where + NR: Into, + N: FnOnce() -> NR, + { + let name: String = name_fn().into(); + Self { + value: Word::create(cb, || format!("{}_value", name)), + shift: Word::create(cb, || format!("{}_shift", name)), + carry: Word::create(cb, || format!("{}_carry", name)), + } + } +} + +impl FixedShiftRightOperation { + pub fn populate(&mut self, record: &mut LkMultiplicity, input: u32, rotation: usize) -> u32 { + let input_bytes = input.to_le_bytes().map(F::from_canonical_u8); + let expected = input >> rotation; + + // Compute some constants with respect to the rotation needed for the rotation. + let nb_bytes_to_shift = Self::nb_bytes_to_shift(rotation); + let nb_bits_to_shift = Self::nb_bits_to_shift(rotation); + let carry_multiplier = F::from_canonical_u32(Self::carry_multiplier(rotation)); + + // Perform the byte shift. + let mut word = [F::ZERO; WORD_SIZE]; + for i in 0..WORD_SIZE { + if i + nb_bytes_to_shift < WORD_SIZE { + word[i] = input_bytes[(i + nb_bytes_to_shift) % WORD_SIZE]; + } + } + let input_bytes_rotated = Word(word); + + // For each byte, calculate the shift and carry. If it's not the first byte, calculate the + // new byte value using the current shifted byte and the last carry. + let mut first_shift = F::ZERO; + let mut last_carry = F::ZERO; + for i in (0..WORD_SIZE).rev() { + let b = input_bytes_rotated[i].to_string().parse::().unwrap(); + let c = nb_bits_to_shift as u8; + let (shift, carry) = shr_carry(b, c); + + record.lookup_shr_byte(shift as u64, carry as u64, nb_bits_to_shift as u64); + + self.shift[i] = F::from_canonical_u8(shift); + self.carry[i] = F::from_canonical_u8(carry); + + if i == WORD_SIZE - 1 { + first_shift = self.shift[i]; + } else { + self.value[i] = self.shift[i] + last_carry * carry_multiplier; + } + + last_carry = self.carry[i]; + } + + // For the first byte, we don't move over the carry as this is a shift, not a rotate. + self.value[WORD_SIZE - 1] = first_shift; + + // Assert the answer is correct. + assert_eq!(self.value.to_u32(), expected); + + expected + } +} + +impl FixedShiftRightOperation { + pub fn eval( + &self, + builder: &mut CircuitBuilder, + input: Word>>, + rotation: usize, + ) -> Result<(), CircuitBuilderError> + where + E: ff_ext::ExtensionField, + Expr: multilinear_extensions::ToExpr>, + Expression: From, + { + // Compute some constants with respect to the rotation needed for the rotation. + let nb_bytes_to_shift = Self::nb_bytes_to_shift(rotation); + let nb_bits_to_shift = Self::nb_bits_to_shift(rotation); + let carry_multiplier = Self::carry_multiplier(rotation); + + // Perform the byte shift. + let input_bytes_rotated = Word(std::array::from_fn(|i| { + if i + nb_bytes_to_shift < WORD_SIZE { + input[(i + nb_bytes_to_shift) % WORD_SIZE].expr() + } else { + Expression::::ZERO + } + })); + + // For each byte, calculate the shift and carry. If it's not the first byte, calculate the + // new byte value using the current shifted byte and the last carry. + let mut first_shift = Expression::::ZERO; + let mut last_carry = Expression::::ZERO; + for i in (0..WORD_SIZE).rev() { + builder.lookup_shr_byte( + input_bytes_rotated[i].expr(), + nb_bits_to_shift, + self.shift[i].expr(), + self.carry[i].expr(), + )?; + + if i == WORD_SIZE - 1 { + first_shift = self.shift[i].expr(); + } else { + builder.require_equal( + || "fixed shift right value calculation", + self.value[i].expr(), + self.shift[i].expr() + last_carry * carry_multiplier, + )?; + } + + last_carry = self.carry[i].expr(); + } + + // For the first byte, we don't move over the carry as this is a shift, not a rotate. + builder.require_equal( + || "fixed shift right first value calculation", + self.value[WORD_SIZE - 1].expr(), + first_shift, + ) + } + + pub const fn nb_bytes_to_shift(rotation: usize) -> usize { + rotation / 8 + } + + pub const fn nb_bits_to_shift(rotation: usize) -> usize { + rotation % 8 + } + + pub const fn carry_multiplier(rotation: usize) -> u32 { + let nb_bits_to_shift = Self::nb_bits_to_shift(rotation); + 1 << (8 - nb_bits_to_shift) + } +} diff --git a/ceno_zkvm/src/gadgets/mod.rs b/ceno_zkvm/src/gadgets/mod.rs index 89b3210a1..349b7cf64 100644 --- a/ceno_zkvm/src/gadgets/mod.rs +++ b/ceno_zkvm/src/gadgets/mod.rs @@ -1,5 +1,8 @@ +mod add4; mod div; mod field; +mod fixed_rotate_right; +mod fixed_shift_right; mod is_lt; mod is_zero; mod poseidon2; @@ -8,9 +11,14 @@ mod signed_ext; mod signed_limbs; mod util; mod util_expr; +mod word; +mod xor; +pub use add4::*; pub use div::DivConfig; pub use field::*; +pub use fixed_rotate_right::*; +pub use fixed_shift_right::FixedShiftRightOperation; pub use gkr_iop::gadgets::{ AssertLtConfig, InnerLtConfig, IsEqualConfig, IsLtConfig, IsZeroConfig, cal_lt_diff, }; @@ -20,3 +28,5 @@ pub use poseidon2::{Poseidon2BabyBearConfig, Poseidon2Config}; pub use signed::Signed; pub use signed_ext::SignedExtendConfig; pub use signed_limbs::{UIntLimbsLT, UIntLimbsLTConfig}; +pub use word::*; +pub use xor::*; diff --git a/ceno_zkvm/src/gadgets/util.rs b/ceno_zkvm/src/gadgets/util.rs index 0ea4245ed..121ab2937 100644 --- a/ceno_zkvm/src/gadgets/util.rs +++ b/ceno_zkvm/src/gadgets/util.rs @@ -97,3 +97,15 @@ pub fn split_u16_limbs_to_u8_limbs(slice: &[F]) -> (Vec, Vec (u8, u8) { + let c_mod = rotation & 0x7; + if c_mod != 0 { + let res = input >> c_mod; + let carry = (input << (8 - c_mod)) >> (8 - c_mod); + (res, carry) + } else { + (input, 0u8) + } +} diff --git a/ceno_zkvm/src/gadgets/word.rs b/ceno_zkvm/src/gadgets/word.rs new file mode 100644 index 000000000..a11a80587 --- /dev/null +++ b/ceno_zkvm/src/gadgets/word.rs @@ -0,0 +1,152 @@ +// The word struct is modified from succinctlabs/sp1 under MIT license + +// The MIT License (MIT) + +// Copyright (c) 2023 Succinct Labs + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: + +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. + +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +use std::ops::{Index, IndexMut}; + +use std::array::IntoIter; + +use arrayref::array_ref; +use ceno_emul::WORD_SIZE; +use derive::AlignedBorrow; +use ff_ext::{ExtensionField, SmallField}; +use itertools::Itertools; +use multilinear_extensions::{Expression, ToExpr, WitIn}; +use serde::{Deserialize, Serialize}; + +use crate::circuit_builder::CircuitBuilder; + +/// An array of four bytes to represent a 32-bit value. +/// +/// We use the generic type `T` to represent the different representations of a byte, ranging from +/// a `u8` to a `Expression` or `AB::Expr`. +#[derive( + AlignedBorrow, Clone, Copy, Debug, Default, PartialEq, Eq, Hash, Serialize, Deserialize, +)] +#[repr(C)] +pub struct Word(pub [T; WORD_SIZE]); + +impl Word { + /// Creates a word in the circuit. + pub fn create(cb: &mut CircuitBuilder, name_fn: N) -> Self + where + NR: Into, + N: FnOnce() -> NR, + { + let name: String = name_fn().into(); + Word(std::array::from_fn(|i| { + cb.create_witin(|| format!("{}[{}]", name, i)) + })) + } +} + +impl Word { + /// Converts a word to a u32. + pub fn to_u32(&self) -> u32 { + u32::from_le_bytes(self.0.map(|x| x.to_string().parse::().unwrap())) + } +} + +impl Word { + /// Applies `f` to each element of the word. + pub fn map(self, f: F) -> Word + where + F: FnMut(Expr) -> S, + { + Word(self.0.map(f)) + } + + /// Extends a variable to a word. + pub fn extend(var: Expr) -> Word> + where + Expr: ToExpr>, + { + Word([ + Expression::::ZERO + var.expr(), + Expression::::ZERO, + Expression::::ZERO, + Expression::::ZERO, + ]) + } + + pub fn zero() -> Word> { + Word([ + Expression::::ZERO, + Expression::::ZERO, + Expression::::ZERO, + Expression::::ZERO, + ]) + } + + /// Reduces a word to a single variable. + pub fn reduce(&self) -> Expression + where + E: ExtensionField, + Expr: ToExpr>, + Expression: From, + { + let base = [1, 1 << 8, 1 << 16, 1 << 24]; + self.0 + .iter() + .enumerate() + .map(|(i, x)| x.expr() * base[i]) + .sum() + } +} + +impl Index for Word { + type Output = T; + + fn index(&self, index: usize) -> &Self::Output { + &self.0[index] + } +} + +impl IndexMut for Word { + fn index_mut(&mut self, index: usize) -> &mut Self::Output { + &mut self.0[index] + } +} + +impl From for Word { + fn from(value: u32) -> Self { + Word(value.to_le_bytes().map(F::from_canonical_u8)) + } +} + +impl IntoIterator for Word { + type Item = T; + type IntoIter = IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} + +impl FromIterator for Word { + fn from_iter>(iter: I) -> Self { + let elements = iter.into_iter().take(WORD_SIZE).collect_vec(); + + Word(array_ref![elements, 0, WORD_SIZE].clone()) + } +} diff --git a/ceno_zkvm/src/gadgets/xor.rs b/ceno_zkvm/src/gadgets/xor.rs new file mode 100644 index 000000000..d2157aa10 --- /dev/null +++ b/ceno_zkvm/src/gadgets/xor.rs @@ -0,0 +1,92 @@ +// The xor gadget is modified from succinctlabs/sp1 under MIT license + +// The MIT License (MIT) + +// Copyright (c) 2023 Succinct Labs + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: + +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. + +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +use ceno_emul::WORD_SIZE; +use derive::AlignedBorrow; +use ff_ext::{ExtensionField, SmallField}; +use gkr_iop::error::CircuitBuilderError; +use itertools::izip; +use multilinear_extensions::{Expression, ToExpr, WitIn}; + +use crate::{circuit_builder::CircuitBuilder, gadgets::word::Word, witness::LkMultiplicity}; + +/// A set of columns needed to compute the xor of two words. +#[derive(AlignedBorrow, Default, Debug, Clone, Copy)] +#[repr(C)] +pub struct XorOperation { + /// The result of `x ^ y`. + pub value: Word, +} + +impl XorOperation { + /// Creates an xor operation in the circuit. + pub fn create(cb: &mut CircuitBuilder, name_fn: N) -> Self + where + NR: Into, + N: FnOnce() -> NR, + { + let name: String = name_fn().into(); + Self { + value: Word::create(cb, || format!("{}_value", name)), + } + } +} + +impl XorOperation { + pub fn populate(&mut self, record: &mut LkMultiplicity, x: u32, y: u32) -> u32 { + let expected = x ^ y; + let x_bytes = x.to_le_bytes(); + let y_bytes = y.to_le_bytes(); + for i in 0..WORD_SIZE { + let xor = x_bytes[i] ^ y_bytes[i]; + self.value[i] = F::from_canonical_u8(xor); + + record.lookup_xor_byte(x_bytes[i] as u64, y_bytes[i] as u64); + } + expected + } +} + +impl XorOperation { + #[allow(unused_variables)] + pub fn eval( + &self, + builder: &mut CircuitBuilder, + a: Word>>, + b: Word>>, + ) -> Result<(), CircuitBuilderError> + where + E: ExtensionField, + Expr: ToExpr>, + Expression: From, + { + izip!(a.0, b.0) + .enumerate() + .map(|(i, (a_byte, b_byte))| { + builder.lookup_xor_byte(a_byte.expr(), b_byte.expr(), self.value[i].expr()) + }) + .collect::, _>>()?; + Ok(()) + } +} diff --git a/ceno_zkvm/src/instructions/riscv/ecall.rs b/ceno_zkvm/src/instructions/riscv/ecall.rs index 8f1300bad..84a836bf3 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall.rs @@ -3,6 +3,7 @@ mod fptower_fp2_add; mod fptower_fp2_mul; mod halt; mod keccak; +mod sha_extend; mod uint256; mod weierstrass_add; mod weierstrass_decompress; @@ -12,6 +13,7 @@ pub use fptower_fp::{FpAddInstruction, FpMulInstruction}; pub use fptower_fp2_add::Fp2AddInstruction; pub use fptower_fp2_mul::Fp2MulInstruction; pub use keccak::KeccakInstruction; +pub use sha_extend::ShaExtendInstruction; pub use uint256::{Secp256k1InvInstruction, Secp256r1InvInstruction, Uint256MulInstruction}; pub use weierstrass_add::WeierstrassAddAssignInstruction; pub use weierstrass_decompress::WeierstrassDecompressInstruction; diff --git a/ceno_zkvm/src/instructions/riscv/ecall/sha_extend.rs b/ceno_zkvm/src/instructions/riscv/ecall/sha_extend.rs new file mode 100644 index 000000000..7f8513dc1 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/ecall/sha_extend.rs @@ -0,0 +1,312 @@ +use std::{array, marker::PhantomData}; + +use ceno_emul::{Change, InsnKind, Platform, SHA_EXTEND, StepRecord, WORD_SIZE, WriteOp}; +use ff_ext::{ExtensionField, FieldInto}; +use gkr_iop::{ + ProtocolBuilder, ProtocolWitnessGenerator, + gkr::{GKRCircuit, layer::Layer}, + utils::lk_multiplicity::Multiplicity, +}; +use itertools::{Itertools, izip}; +use multilinear_extensions::{ToExpr, WitIn, util::max_usable_threads}; +use p3::matrix::Matrix; +use rayon::{ + iter::{IndexedParallelIterator, ParallelIterator}, + slice::ParallelSlice, +}; +use witness::{InstancePaddingStrategy, RowMajorMatrix, set_val}; + +use crate::{ + chip_handler::general::InstFetch, + circuit_builder::CircuitBuilder, + e2e::ShardContext, + error::ZKVMError, + instructions::{ + Instruction, + riscv::{ + constants::{LIMB_BITS, LIMB_MASK, MEM_BITS, UINT_LIMBS, UInt}, + ecall_base::OpFixedRS, + insn_base::{MemAddr, StateInOut, WriteMEM}, + }, + }, + precompiles::{ShaExtendInstance, ShaExtendLayout, ShaExtendTrace, ShaExtendWitInstance}, + structs::ProgramParams, + tables::{InsnRecord, RMMCollections}, + witness::LkMultiplicity, +}; + +#[derive(Debug)] +pub struct EcallShaExtendConfig { + pub layout: ShaExtendLayout, + vm_state: StateInOut, + ecall_id: OpFixedRS, + state_ptr: (OpFixedRS, MemAddr), + old_value: [WitIn; UINT_LIMBS], + mem_rw: Vec, +} + +/// ShaExtendInstruction can handle any instruction and produce its side-effects. +pub struct ShaExtendInstruction(PhantomData); + +impl Instruction for ShaExtendInstruction { + type InstructionConfig = EcallShaExtendConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[InsnKind::ECALL] + } + + fn name() -> String { + "Ecall_ShaExtend".to_string() + } + + fn construct_circuit( + _circuit_builder: &mut CircuitBuilder, + _param: &ProgramParams, + ) -> Result { + unimplemented!() + } + + fn build_gkr_iop_circuit( + cb: &mut CircuitBuilder, + _param: &ProgramParams, + ) -> Result<(Self::InstructionConfig, GKRCircuit), ZKVMError> { + let vm_state = StateInOut::construct_circuit(cb, false)?; + + let ecall_id = OpFixedRS::<_, { Platform::reg_ecall() }, false>::construct_circuit( + cb, + UInt::from_const_unchecked(vec![ + SHA_EXTEND & LIMB_MASK, + (SHA_EXTEND >> LIMB_BITS) & LIMB_MASK, + ]) + .register_expr(), + vm_state.ts, + )?; + + let state_ptr_value = MemAddr::construct_with_max_bits(cb, 2, MEM_BITS)?; + let state_ptr = OpFixedRS::<_, { Platform::reg_arg0() }, true>::construct_circuit( + cb, + state_ptr_value.uint_unaligned().register_expr(), + vm_state.ts, + )?; + + // fetch + cb.lk_fetch(&InsnRecord::new( + vm_state.pc.expr(), + InsnKind::ECALL.into(), + None, + 0.into(), + 0.into(), + 0.into(), + #[cfg(feature = "u16limb_circuit")] + 0.into(), + ))?; + + let mut layout = + as gkr_iop::ProtocolBuilder>::build_layer_logic(cb, ())?; + + let old_value = + array::from_fn(|i| cb.create_witin(|| format!("sha256 extend old_mem_value_{}", i))); + let offset = [-2, -7, -15, -16]; + let mut mem_rw = izip!(offset, &layout.input32_exprs) + .map(|(offset, val_before)| { + WriteMEM::construct_circuit( + cb, + state_ptr.prev_value.as_ref().unwrap().value() + offset * WORD_SIZE as i32, + val_before.clone(), + val_before.clone(), + vm_state.ts, + ) + }) + .collect::, ZKVMError>>()?; + + mem_rw.push(WriteMEM::construct_circuit( + cb, + state_ptr.prev_value.as_ref().unwrap().value(), + [old_value[0].expr(), old_value[1].expr()], + layout.output32_expr.clone(), + vm_state.ts, + )?); + + let (out_evals, mut chip) = layout.finalize(cb); + + let layer = Layer::from_circuit_builder(cb, Self::name(), layout.n_challenges, out_evals); + chip.add_layer(layer); + + let circuit = chip.gkr_circuit(); + + Ok(( + EcallShaExtendConfig { + layout, + vm_state, + ecall_id, + state_ptr: (state_ptr, state_ptr_value), + old_value, + mem_rw, + }, + circuit, + )) + } + + fn generate_fixed_traces( + config: &Self::InstructionConfig, + num_fixed: usize, + ) -> Option> { + let fixed = config.layout.fixed_witness_group(); + assert_eq!(fixed.width(), num_fixed); + Some(fixed) + } + + fn assign_instance( + _config: &Self::InstructionConfig, + _shard_ctx: &mut ShardContext, + _instance: &mut [::BaseField], + _lk_multiplicity: &mut LkMultiplicity, + _step: &StepRecord, + ) -> Result<(), ZKVMError> { + unreachable!("we override logic in assign_instances") + } + + fn assign_instances( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + num_witin: usize, + num_structural_witin: usize, + steps: &[StepRecord], + ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { + let mut lk_multiplicity = LkMultiplicity::default(); + let num_structural_witin = config.layout.n_structural_witin.max(num_structural_witin); + if steps.is_empty() { + return Ok(( + [ + RowMajorMatrix::new(0, num_witin, InstancePaddingStrategy::Default), + RowMajorMatrix::new(0, num_structural_witin, InstancePaddingStrategy::Default), + ], + lk_multiplicity.into_finalize_result(), + )); + } + + let num_instances = steps.len(); + let nthreads = max_usable_threads(); + let num_instance_per_batch = num_instances.div_ceil(nthreads).max(1); + + let mut raw_witin = RowMajorMatrix::::new( + num_instances, + num_witin, + InstancePaddingStrategy::Default, + ); + let mut raw_structural_witin = RowMajorMatrix::::new( + num_instances, + num_structural_witin, + InstancePaddingStrategy::Default, + ); + + let raw_witin_iter = raw_witin.par_batch_iter_mut(num_instance_per_batch); + let shard_ctx_vec = shard_ctx.get_forked(); + + raw_witin_iter + .zip_eq(steps.par_chunks(num_instance_per_batch)) + .zip(shard_ctx_vec) + .flat_map(|((instances, steps), mut shard_ctx)| { + let mut lk_multiplicity = lk_multiplicity.clone(); + + instances + .chunks_mut(num_witin) + .zip_eq(steps) + .map(|(instance, step)| { + let ops = step.syscall().expect("syscall step"); + + // vm_state + config + .vm_state + .assign_instance(instance, &shard_ctx, step)?; + + config.ecall_id.assign_op( + instance, + &mut shard_ctx, + &mut lk_multiplicity, + step.cycle(), + &WriteOp::new_register_op( + Platform::reg_ecall(), + Change::new(SHA_EXTEND, SHA_EXTEND), + step.rs1().unwrap().previous_cycle, + ), + )?; + + // assign state_ptr + config.state_ptr.1.assign_instance( + instance, + &mut lk_multiplicity, + ops.reg_ops[0].value.before, + )?; + config.state_ptr.0.assign_op( + instance, + &mut shard_ctx, + &mut lk_multiplicity, + step.cycle(), + &ops.reg_ops[0], + )?; + + let write_op = ops.mem_ops.last().expect("sha_extend write op"); + set_val!( + instance, + config.old_value[0], + (write_op.value.before & LIMB_MASK) as u64 + ); + set_val!( + instance, + config.old_value[1], + (write_op.value.before >> LIMB_BITS) as u64 + ); + + // assign mem_rw + for (writer, op) in config.mem_rw.iter().zip_eq(&ops.mem_ops) { + writer.assign_op( + instance, + &mut shard_ctx, + &mut lk_multiplicity, + step.cycle(), + op, + )?; + } + // fetch + lk_multiplicity.fetch(step.pc().before.0); + Ok(()) + }) + .collect::>() + }) + .collect::>()?; + + let instances = steps + .iter() + .map(|step| -> ShaExtendInstance { + let ops = step.syscall().expect("syscall step"); + let w_i_minus_2 = ops.mem_ops[0].value.before; + let w_i_minus_7 = ops.mem_ops[1].value.before; + let w_i_minus_15 = ops.mem_ops[2].value.before; + let w_i_minus_16 = ops.mem_ops[3].value.before; + ShaExtendInstance { + witin: ShaExtendWitInstance { + w_i_minus_2, + w_i_minus_7, + w_i_minus_15, + w_i_minus_16, + }, + } + }) + .collect_vec(); + + config.layout.phase1_witness_group( + ShaExtendTrace { instances }, + [&mut raw_witin, &mut raw_structural_witin], + &mut lk_multiplicity, + ); + + raw_witin.padding_by_strategy(); + raw_structural_witin.padding_by_strategy(); + Ok(( + [raw_witin, raw_structural_witin], + lk_multiplicity.into_finalize_result(), + )) + } +} diff --git a/ceno_zkvm/src/instructions/riscv/rv32im.rs b/ceno_zkvm/src/instructions/riscv/rv32im.rs index 60c27b984..ea21c4dcb 100644 --- a/ceno_zkvm/src/instructions/riscv/rv32im.rs +++ b/ceno_zkvm/src/instructions/riscv/rv32im.rs @@ -22,7 +22,7 @@ use crate::{ ecall::{ Fp2AddInstruction, Fp2MulInstruction, FpAddInstruction, FpMulInstruction, KeccakInstruction, Secp256k1InvInstruction, Secp256r1InvInstruction, - Uint256MulInstruction, WeierstrassAddAssignInstruction, + ShaExtendInstruction, Uint256MulInstruction, WeierstrassAddAssignInstruction, WeierstrassDecompressInstruction, WeierstrassDoubleAssignInstruction, }, logic::{AndInstruction, OrInstruction, XorInstruction}, @@ -135,6 +135,7 @@ pub struct Rv32imConfig { // Ecall Opcodes pub halt_config: as Instruction>::InstructionConfig, pub keccak_config: as Instruction>::InstructionConfig, + pub sha_extend_config: as Instruction>::InstructionConfig, pub bn254_add_config: > as Instruction>::InstructionConfig, pub bn254_double_config: @@ -380,6 +381,7 @@ impl Rv32imConfig { .is_none() ); let bn254_add_config = register_ecall_circuit!(WeierstrassAddAssignInstruction>, ecall_cells_map); + let sha_extend_config = register_ecall_circuit!(ShaExtendInstruction, ecall_cells_map); let bn254_double_config = register_ecall_circuit!(WeierstrassDoubleAssignInstruction>, ecall_cells_map); let bn254_fp_add_config = register_ecall_circuit!(FpAddInstruction, ecall_cells_map); @@ -467,6 +469,7 @@ impl Rv32imConfig { // ecall opcodes halt_config, keccak_config, + sha_extend_config, bn254_add_config, bn254_double_config, bn254_fp_add_config, @@ -560,6 +563,7 @@ impl Rv32imConfig { // system fixed.register_opcode_circuit::>(cs, &self.halt_config); fixed.register_opcode_circuit::>(cs, &self.keccak_config); + fixed.register_opcode_circuit::>(cs, &self.sha_extend_config); fixed.register_opcode_circuit::>>( cs, &self.bn254_add_config, @@ -669,6 +673,7 @@ impl Rv32imConfig { Secp256r1ScalarInvertSpec::CODE ); log_ecall!("uint256_mul_records", Uint256MulSpec::CODE); + log_ecall!("sha_extend_records", Sha256ExtendSpec::CODE); macro_rules! assign_opcode { ($instruction:ty, $config:ident) => {{ @@ -824,6 +829,11 @@ impl Rv32imConfig { uint256_mul_config, Uint256MulSpec::CODE ); + assign_ecall!( + ShaExtendInstruction, + sha_extend_config, + Sha256ExtendSpec::CODE + ); Ok(()) } @@ -968,20 +978,15 @@ impl InstructionDispatchCtx { } /// Fake version of what is missing in Rv32imConfig, for some tests. pub struct DummyExtraConfig { - sha256_extend_config: - as Instruction>::InstructionConfig, phantom_log_pc_cycle: as Instruction>::InstructionConfig, } impl DummyExtraConfig { pub fn construct_circuits(cs: &mut ZKVMConstraintSystem) -> Self { - let sha256_extend_config = - cs.register_opcode_circuit::>(); let phantom_log_pc_cycle = cs.register_opcode_circuit::>(); Self { - sha256_extend_config, phantom_log_pc_cycle, } } @@ -991,10 +996,6 @@ impl DummyExtraConfig { cs: &ZKVMConstraintSystem, fixed: &mut ZKVMFixedTraces, ) { - fixed.register_opcode_circuit::>( - cs, - &self.sha256_extend_config, - ); fixed.register_opcode_circuit::>( cs, &self.phantom_log_pc_cycle, @@ -1008,19 +1009,9 @@ impl DummyExtraConfig { instrunction_dispatch_ctx: &InstructionDispatchCtx, witness: &mut ZKVMWitnesses, ) -> Result<(), ZKVMError> { - let sha256_extend_records = instrunction_dispatch_ctx - .records_for_ecall_code(Sha256ExtendSpec::CODE) - .unwrap_or(&[]); let phantom_log_pc_cycle_records = instrunction_dispatch_ctx .records_for_ecall_code(LogPcCycleSpec::CODE) .unwrap_or(&[]); - - witness.assign_opcode_circuit::>( - cs, - shard_ctx, - &self.sha256_extend_config, - sha256_extend_records, - )?; witness.assign_opcode_circuit::>( cs, shard_ctx, @@ -1106,9 +1097,12 @@ impl Rv32imConfig { .ecall_cells_map .get(&Uint256MulInstruction::::name()) .expect("unable to find name"), + Sha256ExtendSpec::CODE => *self + .ecall_cells_map + .get(&ShaExtendInstruction::::name()) + .expect("unable to find name"), // phantom LogPcCycleSpec::CODE => 0, - ceno_emul::SHA_EXTEND => 0, _ => panic!("unknown ecall code {code:#x}"), } } diff --git a/ceno_zkvm/src/precompiles/bitwise_keccakf.rs b/ceno_zkvm/src/precompiles/bitwise_keccakf.rs index e1b259501..78c8bfe95 100644 --- a/ceno_zkvm/src/precompiles/bitwise_keccakf.rs +++ b/ceno_zkvm/src/precompiles/bitwise_keccakf.rs @@ -368,7 +368,8 @@ fn output32_layer( let mut keccak_output32_iter = out_evals.iter().map(|x| EvalExpression::Single(*x)); // process keccak output - let sel_type = SelectorType::OrderedSparse32 { + let sel_type = SelectorType::OrderedSparse { + num_vars: 5, indices: vec![CYCLIC_POW2_5[ROUNDS - 1] as usize], expression: layer.sel.expr(), }; @@ -623,7 +624,8 @@ fn keccak_first_layer( // process keccak output let mut out_eval_iter = input32_out_evals.iter().map(|x| EvalExpression::Single(*x)); - let sel_type = SelectorType::OrderedSparse32 { + let sel_type = SelectorType::OrderedSparse { + num_vars: 5, indices: vec![CYCLIC_POW2_5[0] as usize], expression: layer.sel_keccak_out.expr(), }; diff --git a/ceno_zkvm/src/precompiles/fptower/fp.rs b/ceno_zkvm/src/precompiles/fptower/fp.rs index 08427120a..ba7c04308 100644 --- a/ceno_zkvm/src/precompiles/fptower/fp.rs +++ b/ceno_zkvm/src/precompiles/fptower/fp.rs @@ -126,10 +126,9 @@ impl FpOpLayout { let eq = cb.create_placeholder_structural_witin(|| "fp_op_structural_witin"); let sel = SelectorType::Prefix(eq.expr()); let selector_type_layout = SelectorTypeLayout { - sel_mem_read: sel.clone(), - sel_mem_write: sel.clone(), - sel_lookup: sel.clone(), - sel_zero: sel.clone(), + sel_first: None, + sel_last: None, + sel_all: sel.clone(), }; let input32_exprs: [GenericArray,

::WordsFieldElement>; 2] = @@ -233,10 +232,10 @@ impl ProtocolBuilder for FpOpLayout { self.n_structural_witin = cb.cs.num_structural_witin as usize; self.n_challenges = 0; - cb.cs.r_selector = Some(self.selector_type_layout.sel_mem_read.clone()); - cb.cs.w_selector = Some(self.selector_type_layout.sel_mem_write.clone()); - cb.cs.lk_selector = Some(self.selector_type_layout.sel_lookup.clone()); - cb.cs.zero_selector = Some(self.selector_type_layout.sel_zero.clone()); + cb.cs.r_selector = Some(self.selector_type_layout.sel_all.clone()); + cb.cs.w_selector = Some(self.selector_type_layout.sel_all.clone()); + cb.cs.lk_selector = Some(self.selector_type_layout.sel_all.clone()); + cb.cs.zero_selector = Some(self.selector_type_layout.sel_all.clone()); let w_len = cb.cs.w_expressions.len(); let r_len = cb.cs.r_expressions.len(); diff --git a/ceno_zkvm/src/precompiles/fptower/fp2_addsub.rs b/ceno_zkvm/src/precompiles/fptower/fp2_addsub.rs index b865e1567..76d32b31a 100644 --- a/ceno_zkvm/src/precompiles/fptower/fp2_addsub.rs +++ b/ceno_zkvm/src/precompiles/fptower/fp2_addsub.rs @@ -135,10 +135,9 @@ impl Fp2AddSubAssignLayout { let eq = cb.create_placeholder_structural_witin(|| "fp2_addsub_structural_witin"); let sel = SelectorType::Prefix(eq.expr()); let selector_type_layout = SelectorTypeLayout { - sel_mem_read: sel.clone(), - sel_mem_write: sel.clone(), - sel_lookup: sel.clone(), - sel_zero: sel.clone(), + sel_first: None, + sel_last: None, + sel_all: sel.clone(), }; let input32_exprs: [GenericArray,

::WordsCurvePoint>; 2] = @@ -257,10 +256,10 @@ impl ProtocolBuilder for Fp2AddSubAssignLayo self.n_structural_witin = cb.cs.num_structural_witin as usize; self.n_challenges = 0; - cb.cs.r_selector = Some(self.selector_type_layout.sel_mem_read.clone()); - cb.cs.w_selector = Some(self.selector_type_layout.sel_mem_write.clone()); - cb.cs.lk_selector = Some(self.selector_type_layout.sel_lookup.clone()); - cb.cs.zero_selector = Some(self.selector_type_layout.sel_zero.clone()); + cb.cs.r_selector = Some(self.selector_type_layout.sel_all.clone()); + cb.cs.w_selector = Some(self.selector_type_layout.sel_all.clone()); + cb.cs.lk_selector = Some(self.selector_type_layout.sel_all.clone()); + cb.cs.zero_selector = Some(self.selector_type_layout.sel_all.clone()); let w_len = cb.cs.w_expressions.len(); let r_len = cb.cs.r_expressions.len(); diff --git a/ceno_zkvm/src/precompiles/fptower/fp2_mul.rs b/ceno_zkvm/src/precompiles/fptower/fp2_mul.rs index e9c0d42f9..c9160e6d9 100644 --- a/ceno_zkvm/src/precompiles/fptower/fp2_mul.rs +++ b/ceno_zkvm/src/precompiles/fptower/fp2_mul.rs @@ -139,10 +139,9 @@ impl Fp2MulAssignLayout { let eq = cb.create_placeholder_structural_witin(|| "fp2_mul_structural_witin"); let sel = SelectorType::Prefix(eq.expr()); let selector_type_layout = SelectorTypeLayout { - sel_mem_read: sel.clone(), - sel_mem_write: sel.clone(), - sel_lookup: sel.clone(), - sel_zero: sel.clone(), + sel_first: None, + sel_last: None, + sel_all: sel.clone(), }; let input32_exprs: [GenericArray,

::WordsCurvePoint>; 2] = @@ -286,10 +285,10 @@ impl ProtocolBuilder for Fp2MulAssignLayout< self.n_structural_witin = cb.cs.num_structural_witin as usize; self.n_challenges = 0; - cb.cs.r_selector = Some(self.selector_type_layout.sel_mem_read.clone()); - cb.cs.w_selector = Some(self.selector_type_layout.sel_mem_write.clone()); - cb.cs.lk_selector = Some(self.selector_type_layout.sel_lookup.clone()); - cb.cs.zero_selector = Some(self.selector_type_layout.sel_zero.clone()); + cb.cs.r_selector = Some(self.selector_type_layout.sel_all.clone()); + cb.cs.w_selector = Some(self.selector_type_layout.sel_all.clone()); + cb.cs.lk_selector = Some(self.selector_type_layout.sel_all.clone()); + cb.cs.zero_selector = Some(self.selector_type_layout.sel_all.clone()); let w_len = cb.cs.w_expressions.len(); let r_len = cb.cs.r_expressions.len(); diff --git a/ceno_zkvm/src/precompiles/lookup_keccakf.rs b/ceno_zkvm/src/precompiles/lookup_keccakf.rs index 39a32d841..52391267a 100644 --- a/ceno_zkvm/src/precompiles/lookup_keccakf.rs +++ b/ceno_zkvm/src/precompiles/lookup_keccakf.rs @@ -3,9 +3,7 @@ use ff_ext::ExtensionField; use gkr_iop::{ OutEvalGroups, ProtocolBuilder, ProtocolWitnessGenerator, chip::Chip, - circuit_builder::{ - CircuitBuilder, ConstraintSystem, RotationParams, expansion_expr, rotation_split, - }, + circuit_builder::{CircuitBuilder, ConstraintSystem, expansion_expr, rotation_split}, cpu::{CpuBackend, CpuProver}, error::{BackendError, CircuitBuilderError}, gkr::{ @@ -192,8 +190,8 @@ impl KeccakLayout { wits, // fixed, [ - sel_mem_read, - sel_mem_write, + sel_first, + sel_last, eq_zero, eq_rotation_left, eq_rotation_right, @@ -231,22 +229,21 @@ impl KeccakLayout { eq_rotation, }, selector_type_layout: SelectorTypeLayout { - sel_mem_read: SelectorType::OrderedSparse32 { + sel_first: Some(SelectorType::OrderedSparse { + num_vars: 5, indices: vec![CYCLIC_POW2_5[0] as usize], - expression: sel_mem_read.expr(), - }, - sel_mem_write: SelectorType::OrderedSparse32 { + expression: sel_first.expr(), + }), + sel_last: Some(SelectorType::OrderedSparse { + num_vars: 5, indices: vec![CYCLIC_POW2_5[ROUNDS - 1] as usize], - expression: sel_mem_write.expr(), - }, - sel_lookup: SelectorType::OrderedSparse32 { + expression: sel_last.expr(), + }), + sel_all: SelectorType::OrderedSparse { + num_vars: 5, indices: checked_indices.clone(), expression: eq_zero.expr(), }, - sel_zero: SelectorType::OrderedSparse32 { - indices: checked_indices, - expression: eq_zero.expr(), - }, }, input32_exprs: array::from_fn(|_| array::from_fn(|_| Expression::WitIn(0))), output32_exprs: array::from_fn(|_| array::from_fn(|_| Expression::WitIn(0))), @@ -506,15 +503,13 @@ impl ProtocolBuilder for KeccakLayout { // rotation constrain: rotation(keccak_input8).next() == keccak_output8 izip!(keccak_input8, keccak_output8) .for_each(|(input, output)| system.rotate_and_assert_eq(input.expr(), output.expr())); - system.set_rotation_params(RotationParams { - rotation_eqs: Some([ - layout.layer_exprs.eq_rotation_left.expr(), - layout.layer_exprs.eq_rotation_right.expr(), - layout.layer_exprs.eq_rotation.expr(), - ]), - rotation_cyclic_group_log2: ROUNDS_CEIL_LOG2, - rotation_cyclic_subgroup_size: ROUNDS - 1, - }); + system.set_rotation_params( + layout.layer_exprs.eq_rotation_left.expr(), + layout.layer_exprs.eq_rotation_right.expr(), + layout.layer_exprs.eq_rotation.expr(), + ROUNDS_CEIL_LOG2, + ROUNDS - 1, + ); Ok(layout) } @@ -525,10 +520,10 @@ impl ProtocolBuilder for KeccakLayout { self.n_challenges = 0; // register selector to legacy constrain system - cb.cs.r_selector = Some(self.selector_type_layout.sel_mem_read.clone()); - cb.cs.w_selector = Some(self.selector_type_layout.sel_mem_write.clone()); - cb.cs.lk_selector = Some(self.selector_type_layout.sel_lookup.clone()); - cb.cs.zero_selector = Some(self.selector_type_layout.sel_zero.clone()); + cb.cs.r_selector = Some(self.selector_type_layout.sel_first.clone().unwrap()); + cb.cs.w_selector = Some(self.selector_type_layout.sel_last.clone().unwrap()); + cb.cs.lk_selector = Some(self.selector_type_layout.sel_all.clone()); + cb.cs.zero_selector = Some(self.selector_type_layout.sel_all.clone()); let w_len = cb.cs.w_expressions.len(); let r_len = cb.cs.r_expressions.len(); @@ -688,30 +683,25 @@ where let bh = BooleanHypercube::new(ROUNDS_CEIL_LOG2); let mut cyclic_group = bh.into_iter(); - let (mut sel_mem_read_iter, sel_mem_read_structural_witin) = ( - self.selector_type_layout - .sel_mem_read - .sparse32_indices() - .iter(), - self.selector_type_layout.sel_mem_read.selector_expr().id(), - ); - let (mut sel_mem_write_iter, sel_mem_write_structural_witin) = ( - self.selector_type_layout - .sel_mem_write - .sparse32_indices() - .iter(), - self.selector_type_layout.sel_mem_write.selector_expr().id(), + let Some(sel_first) = self.selector_type_layout.sel_first.as_ref() else { + panic!("sel_first must be Some"); + }; + let (mut sel_first_iter, sel_first_structural_witin) = ( + sel_first.sparse_indices().iter(), + sel_first.selector_expr().id(), ); - let (mut sel_lookup_iter, sel_lookup_structural_witin) = ( - self.selector_type_layout - .sel_lookup - .sparse32_indices() - .iter(), - self.selector_type_layout.sel_lookup.selector_expr().id(), + + let Some(sel_last) = self.selector_type_layout.sel_last.as_ref() else { + panic!("sel_last must be Some"); + }; + let (mut sel_last_iter, sel_last_structural_witin) = ( + sel_last.sparse_indices().iter(), + sel_last.selector_expr().id(), ); - let (mut sel_zero_iter, sel_zero_structural_witin) = ( - self.selector_type_layout.sel_zero.sparse32_indices().iter(), - self.selector_type_layout.sel_zero.selector_expr().id(), + + let (mut sel_all_iter, sel_all_structural_witin) = ( + self.selector_type_layout.sel_all.sparse_indices().iter(), + self.selector_type_layout.sel_all.selector_expr().id(), ); #[allow(clippy::needless_range_loop)] @@ -721,24 +711,19 @@ where &mut wits[round_index as usize * self.n_committed..][..self.n_committed]; // set selector - if let Some(index) = sel_mem_read_iter.next() { - structural_wits - [index * self.n_structural_witin + sel_mem_read_structural_witin] = - E::BaseField::ONE; - } - if let Some(index) = sel_mem_write_iter.next() { + if let Some(index) = sel_first_iter.next() { structural_wits - [index * self.n_structural_witin + sel_mem_write_structural_witin] = + [index * self.n_structural_witin + sel_first_structural_witin] = E::BaseField::ONE; } - if let Some(index) = sel_lookup_iter.next() { + if let Some(index) = sel_last_iter.next() { structural_wits - [index * self.n_structural_witin + sel_lookup_structural_witin] = + [index * self.n_structural_witin + sel_last_structural_witin] = E::BaseField::ONE; } - if let Some(index) = sel_zero_iter.next() { + if let Some(index) = sel_all_iter.next() { structural_wits - [index * self.n_structural_witin + sel_zero_structural_witin] = + [index * self.n_structural_witin + sel_all_structural_witin] = E::BaseField::ONE; } diff --git a/ceno_zkvm/src/precompiles/mod.rs b/ceno_zkvm/src/precompiles/mod.rs index 609d25616..3d9a6e545 100644 --- a/ceno_zkvm/src/precompiles/mod.rs +++ b/ceno_zkvm/src/precompiles/mod.rs @@ -1,6 +1,7 @@ mod bitwise_keccakf; mod fptower; mod lookup_keccakf; +mod sha256; mod uint256; mod utils; mod weierstrass; @@ -23,6 +24,9 @@ pub use fptower::{ fp2_mul::{Fp2MulAssignLayout, Fp2MulInstance, Fp2MulTrace}, }; use gkr_iop::selector::SelectorType; +pub use sha256::{ + SHA_EXTEND_ROUNDS, ShaExtendInstance, ShaExtendLayout, ShaExtendTrace, ShaExtendWitInstance, +}; pub use uint256::{ Uint256InvLayout, Uint256InvSpec, Uint256InvTrace, Uint256MulInstance, Uint256MulLayout, Uint256MulTrace, run_uint256_mul, setup_uint256mul_gkr_circuit as setup_uint256_mul_circuit, @@ -46,8 +50,7 @@ pub use weierstrass::{ #[derive(Clone, Debug)] pub struct SelectorTypeLayout { - pub sel_mem_read: SelectorType, - pub sel_mem_write: SelectorType, - pub sel_lookup: SelectorType, - pub sel_zero: SelectorType, + pub sel_first: Option>, + pub sel_last: Option>, + pub sel_all: SelectorType, } diff --git a/ceno_zkvm/src/precompiles/sha256.rs b/ceno_zkvm/src/precompiles/sha256.rs new file mode 100644 index 000000000..f25b961a2 --- /dev/null +++ b/ceno_zkvm/src/precompiles/sha256.rs @@ -0,0 +1,4 @@ +mod extend; + +#[allow(unused_imports)] +pub use extend::*; diff --git a/ceno_zkvm/src/precompiles/sha256/extend.rs b/ceno_zkvm/src/precompiles/sha256/extend.rs new file mode 100644 index 000000000..235e37b95 --- /dev/null +++ b/ceno_zkvm/src/precompiles/sha256/extend.rs @@ -0,0 +1,585 @@ +// The crate sha extend circuit is modified from succinctlabs/sp1 under MIT license + +// The MIT License (MIT) + +// Copyright (c) 2023 Succinct Labs + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: + +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. + +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +use std::{array, borrow::BorrowMut, mem::size_of}; + +use derive::AlignedBorrow; +use ff_ext::{ExtensionField, SmallField}; +use gkr_iop::{ + OutEvalGroups, ProtocolBuilder, ProtocolWitnessGenerator, chip::Chip, + circuit_builder::CircuitBuilder, error::CircuitBuilderError, selector::SelectorType, +}; +use itertools::Itertools; +use multilinear_extensions::{Expression, ToExpr, WitIn, util::max_usable_threads}; +use p3::field::{FieldAlgebra, TwoAdicField}; +use rayon::{ + iter::{IndexedParallelIterator, ParallelIterator}, + prelude::ParallelSlice, +}; +use witness::{InstancePaddingStrategy, RowMajorMatrix}; + +use crate::{ + chip_handler::MemoryExpr, + gadgets::{ + Add4Operation, FixedRotateRightOperation, FixedShiftRightOperation, Word, XorOperation, + }, + precompiles::{SelectorTypeLayout, utils::merge_u8_slice_to_u16_limbs_pairs_and_extend}, + witness::LkMultiplicity, +}; + +pub const SHA_EXTEND_ROUNDS: usize = 48; + +#[derive(Clone, Debug, AlignedBorrow)] +#[repr(C)] +pub struct ShaExtendWitCols { + /// Inputs to `s0`. + pub w_i_minus_15: Word, + pub w_i_minus_15_rr_7: FixedRotateRightOperation, + pub w_i_minus_15_rr_18: FixedRotateRightOperation, + pub w_i_minus_15_rs_3: FixedShiftRightOperation, + pub s0_intermediate: XorOperation, + + /// `s0 := (w[i-15] rightrotate 7) xor (w[i-15] rightrotate 18) xor (w[i-15] rightshift 3)`. + pub s0: XorOperation, + + /// Inputs to `s1`. + pub w_i_minus_2: Word, + pub w_i_minus_2_rr_17: FixedRotateRightOperation, + pub w_i_minus_2_rr_19: FixedRotateRightOperation, + pub w_i_minus_2_rs_10: FixedShiftRightOperation, + pub s1_intermediate: XorOperation, + + /// `s1 := (w[i-2] rightrotate 17) xor (w[i-2] rightrotate 19) xor (w[i-2] rightshift 10)`. + pub s1: XorOperation, + + /// Inputs to `s2`. + pub w_i_minus_16: Word, + pub w_i_minus_7: Word, + + /// `w[i] := w[i-16] + s0 + w[i-7] + s1`. + pub s2: Add4Operation, +} + +impl ShaExtendWitCols { + #[allow(clippy::too_many_arguments)] + pub fn populate(&mut self, instance: &ShaExtendWitInstance, blu: &mut LkMultiplicity) { + // `s0 := (w[i-15] rightrotate 7) xor (w[i-15] rightrotate 18) xor (w[i-15] rightshift + // 3)`. + let w_i_minus_15 = instance.w_i_minus_15; + self.w_i_minus_15 = Word::from(w_i_minus_15); + let w_i_minus_15_rr_7 = self.w_i_minus_15_rr_7.populate(blu, w_i_minus_15, 7); + let w_i_minus_15_rr_18 = self.w_i_minus_15_rr_18.populate(blu, w_i_minus_15, 18); + let w_i_minus_15_rs_3 = self.w_i_minus_15_rs_3.populate(blu, w_i_minus_15, 3); + let s0_intermediate = + self.s0_intermediate + .populate(blu, w_i_minus_15_rr_7, w_i_minus_15_rr_18); + let s0 = self.s0.populate(blu, s0_intermediate, w_i_minus_15_rs_3); + + // `s1 := (w[i-2] rightrotate 17) xor (w[i-2] rightrotate 19) xor (w[i-2] rightshift + // 10)`. + let w_i_minus_2 = instance.w_i_minus_2; + self.w_i_minus_2 = Word::from(w_i_minus_2); + let w_i_minus_2_rr_17 = self.w_i_minus_2_rr_17.populate(blu, w_i_minus_2, 17); + let w_i_minus_2_rr_19 = self.w_i_minus_2_rr_19.populate(blu, w_i_minus_2, 19); + let w_i_minus_2_rs_10 = self.w_i_minus_2_rs_10.populate(blu, w_i_minus_2, 10); + let s1_intermediate = + self.s1_intermediate + .populate(blu, w_i_minus_2_rr_17, w_i_minus_2_rr_19); + let s1 = self.s1.populate(blu, s1_intermediate, w_i_minus_2_rs_10); + + // Compute `s2`. + let w_i_minus_7 = instance.w_i_minus_7; + let w_i_minus_16 = instance.w_i_minus_16; + self.w_i_minus_7 = Word::from(w_i_minus_7); + self.w_i_minus_16 = Word::from(w_i_minus_16); + self.s2.populate(blu, w_i_minus_16, s0, w_i_minus_7, s1); + } +} + +#[derive(Clone, Debug)] +#[repr(C)] +pub struct ShaExtendLayer { + pub wits: ShaExtendWitCols, +} + +#[derive(Clone, Debug)] +pub struct ShaExtendLayout { + pub layer_exprs: ShaExtendLayer, + pub selector_type_layout: SelectorTypeLayout, + pub input32_exprs: [MemoryExpr; 4], + pub output32_expr: MemoryExpr, + pub n_fixed: usize, + pub n_committed: usize, + pub n_structural_witin: usize, + pub n_challenges: usize, +} + +impl ShaExtendLayout { + fn new(cb: &mut CircuitBuilder) -> Self { + let wits = ShaExtendWitCols { + w_i_minus_15: Word::create(cb, || "ShaExtendLayer::w_i_minus_15"), + w_i_minus_15_rr_7: FixedRotateRightOperation::create( + cb, + || "ShaExtendLayer::w_i_minus_15_rr_7", + ), + w_i_minus_15_rr_18: FixedRotateRightOperation::create( + cb, + || "ShaExtendLayer::w_i_minus_15_rr_18", + ), + w_i_minus_15_rs_3: FixedShiftRightOperation::create( + cb, + || "ShaExtendLayer::w_i_minus_15_rs_3", + ), + s0_intermediate: XorOperation::create(cb, || "ShaExtendLayer::s0_intermediate"), + s0: XorOperation::create(cb, || "ShaExtendLayer::s0"), + w_i_minus_2: Word::create(cb, || "ShaExtendLayer::w_i_minus_2"), + w_i_minus_2_rr_17: FixedRotateRightOperation::create( + cb, + || "ShaExtendLayer::w_i_minus_2_rr_17", + ), + w_i_minus_2_rr_19: FixedRotateRightOperation::create( + cb, + || "ShaExtendLayer::w_i_minus_2_rr_19", + ), + w_i_minus_2_rs_10: FixedShiftRightOperation::create( + cb, + || "ShaExtendLayer::w_i_minus_2_rs_10", + ), + s1_intermediate: XorOperation::create(cb, || "ShaExtendLayer::s1_intermediate"), + s1: XorOperation::create(cb, || "ShaExtendLayer::s1"), + w_i_minus_16: Word::create(cb, || "ShaExtendLayer::w_i_minus_16"), + w_i_minus_7: Word::create(cb, || "ShaExtendLayer::w_i_minus_7"), + s2: Add4Operation::create(cb, || "ShaExtendLayer::s2"), + }; + + let sel_all = cb.create_placeholder_structural_witin(|| "sha_extend_sel_all"); + + let selector_type_layout = SelectorTypeLayout { + sel_first: None, + sel_last: None, + sel_all: SelectorType::::Prefix(sel_all.expr()), + }; + + let input32_exprs: [MemoryExpr; 4] = + array::from_fn(|_| array::from_fn(|_| Expression::WitIn(0))); + let output32_expr: MemoryExpr = array::from_fn(|_| Expression::WitIn(0)); + + Self { + layer_exprs: ShaExtendLayer { wits }, + selector_type_layout, + input32_exprs, + output32_expr, + n_fixed: 0, + n_committed: 0, + n_structural_witin: 6, + n_challenges: 0, + } + } +} + +impl ProtocolBuilder for ShaExtendLayout { + type Params = (); + + fn build_layer_logic( + cb: &mut CircuitBuilder, + _params: Self::Params, + ) -> Result { + let mut layout = Self::new(cb); + let wits = &layout.layer_exprs.wits; + + // Compute `s0`. + // w[i-15] rightrotate 7. + wits.w_i_minus_15_rr_7.eval(cb, wits.w_i_minus_15, 7)?; + // w[i-15] rightrotate 18. + wits.w_i_minus_15_rr_18.eval(cb, wits.w_i_minus_15, 18)?; + // w[i-15] rightshift 3. + wits.w_i_minus_15_rs_3.eval(cb, wits.w_i_minus_15, 3)?; + // (w[i-15] rightrotate 7) xor (w[i-15] rightrotate 18) + wits.s0_intermediate.eval( + cb, + wits.w_i_minus_15_rr_7.value, + wits.w_i_minus_15_rr_18.value, + )?; + // s0 := (w[i-15] rightrotate 7) xor (w[i-15] rightrotate 18) xor (w[i-15] rightshift 3) + wits.s0 + .eval(cb, wits.s0_intermediate.value, wits.w_i_minus_15_rs_3.value)?; + + // Compute `s1`. + // w[i-2] rightrotate 17. + wits.w_i_minus_2_rr_17.eval(cb, wits.w_i_minus_2, 17)?; + // w[i-2] rightrotate 19. + wits.w_i_minus_2_rr_19.eval(cb, wits.w_i_minus_2, 19)?; + // w[i-2] rightshift 10. + wits.w_i_minus_2_rs_10.eval(cb, wits.w_i_minus_2, 10)?; + // (w[i-2] rightrotate 17) xor (w[i-2] rightrotate 19) + wits.s1_intermediate.eval( + cb, + wits.w_i_minus_2_rr_17.value, + wits.w_i_minus_2_rr_19.value, + )?; + // s1 := (w[i-2] rightrotate 17) xor (w[i-2] rightrotate 19) xor (w[i-2] rightshift 10) + wits.s1 + .eval(cb, wits.s1_intermediate.value, wits.w_i_minus_2_rs_10.value)?; + + // s2 (or w[i]):= w[i-16] + s0 + w[i-7] + s1. + wits.s2.eval( + cb, + wits.w_i_minus_16, + wits.s0.value, + wits.w_i_minus_7, + wits.s1.value, + )?; + + let mut input32_exprs = Vec::with_capacity(4); + for w in [ + &wits.w_i_minus_2, + &wits.w_i_minus_7, + &wits.w_i_minus_15, + &wits.w_i_minus_16, + ] { + merge_u8_slice_to_u16_limbs_pairs_and_extend(&w.0, &mut input32_exprs); + } + layout.input32_exprs = input32_exprs.try_into().unwrap(); + let mut output32_expr = Vec::with_capacity(1); + merge_u8_slice_to_u16_limbs_pairs_and_extend::(&wits.s2.value.0, &mut output32_expr); + layout.output32_expr = output32_expr.pop().unwrap(); + + Ok(layout) + } + + fn finalize(&mut self, cb: &mut CircuitBuilder) -> (OutEvalGroups, Chip) { + self.n_fixed = cb.cs.num_fixed; + self.n_committed = cb.cs.num_witin as usize; + self.n_structural_witin = cb.cs.num_structural_witin as usize; + self.n_challenges = 0; + + cb.cs.r_selector = Some(self.selector_type_layout.sel_all.clone()); + cb.cs.w_selector = Some(self.selector_type_layout.sel_all.clone()); + cb.cs.lk_selector = Some(self.selector_type_layout.sel_all.clone()); + cb.cs.zero_selector = Some(self.selector_type_layout.sel_all.clone()); + + let w_len = cb.cs.w_expressions.len(); + let r_len = cb.cs.r_expressions.len(); + let lk_len = cb.cs.lk_expressions.len(); + let zero_len = + cb.cs.assert_zero_expressions.len() + cb.cs.assert_zero_sumcheck_expressions.len(); + ( + [ + (0..r_len).collect_vec(), + (r_len..r_len + w_len).collect_vec(), + (r_len + w_len..r_len + w_len + lk_len).collect_vec(), + (0..zero_len).collect_vec(), + ], + Chip::new_from_cb(cb, self.n_challenges), + ) + } +} + +#[derive(Clone, Debug)] +pub struct ShaExtendWitInstance { + pub w_i_minus_2: u32, + pub w_i_minus_7: u32, + pub w_i_minus_15: u32, + pub w_i_minus_16: u32, +} + +#[derive(Clone, Debug)] +pub struct ShaExtendInstance { + pub witin: ShaExtendWitInstance, +} + +#[derive(Clone, Debug, Default)] +pub struct ShaExtendTrace { + pub instances: Vec, +} + +impl ProtocolWitnessGenerator for ShaExtendLayout { + type Trace = ShaExtendTrace; + + fn fixed_witness_group(&self) -> RowMajorMatrix { + RowMajorMatrix::new(0, 0, InstancePaddingStrategy::Default) + } + + fn phase1_witness_group( + &self, + phase1: Self::Trace, + wits: [&mut RowMajorMatrix; 2], + lk_multiplicity: &mut LkMultiplicity, + ) { + let (wits_start, num_wit_cols) = ( + self.layer_exprs.wits.w_i_minus_15.0[0].id as usize, + size_of::>(), + ); + let [wits, structural_wits] = wits; + let num_instances = wits.num_instances(); + let nthreads = max_usable_threads(); + let num_instance_per_batch = num_instances.div_ceil(nthreads).max(1); + let raw_witin_iter = wits.par_batch_iter_mut(num_instance_per_batch); + let raw_structural_wits_iter = structural_wits.par_batch_iter_mut(num_instance_per_batch); + raw_witin_iter + .zip_eq(raw_structural_wits_iter) + .zip_eq(phase1.instances.par_chunks(num_instance_per_batch)) + .for_each(|((rows, eqs), instances)| { + let mut lk_multiplicity = lk_multiplicity.clone(); + rows.chunks_mut(self.n_committed) + .zip_eq(eqs.chunks_mut(self.n_structural_witin)) + .zip_eq(instances.iter()) + .for_each(|((rows, eqs), phase1_instance)| { + let sel_all_structural_witin = + self.selector_type_layout.sel_all.selector_expr().id(); + eqs[sel_all_structural_witin] = E::BaseField::ONE; + + let cols: &mut ShaExtendWitCols = + rows[wits_start..][..num_wit_cols].borrow_mut(); + cols.populate(&phase1_instance.witin, &mut lk_multiplicity); + }); + }); + } +} + +#[cfg(test)] +mod tests { + use crate::circuit_builder::ConstraintSystem; + + use super::*; + + use std::sync::Arc; + + use ceno_emul::{SHA_EXTEND_WORDS, WORD_SIZE}; + use ff_ext::BabyBearExt4; + use gkr_iop::{ + cpu::{CpuBackend, CpuProver}, + gkr::{GKRProverOutput, layer::Layer}, + selector::SelectorContext, + }; + use itertools::Itertools; + use mpcs::BasefoldDefault; + use multilinear_extensions::{mle::PointAndEval, util::ceil_log2}; + use p3::{babybear::BabyBear, matrix::Matrix}; + use rand::{RngCore, SeedableRng, rngs::StdRng}; + use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; + use sumcheck::util::optimal_sumcheck_threads; + use transcript::{BasicTranscript, Transcript}; + use witness::next_pow2_instance_padding; + + fn test_sha_extend_helper(num_instances: usize) { + type E = BabyBearExt4; + type Pcs = BasefoldDefault; + + let mut cs = ConstraintSystem::::new(|| "sha_extend_test"); + let mut cb = CircuitBuilder::::new(&mut cs); + let mut layout = + ShaExtendLayout::::build_layer_logic(&mut cb, ()).expect("build_layer_logic failed"); + let (out_evals, mut chip) = layout.finalize(&mut cb); + let layer = Layer::from_circuit_builder( + &cb, + "sha_extend".to_string(), + layout.n_challenges, + out_evals, + ); + chip.add_layer(layer); + let gkr_circuit = chip.gkr_circuit(); + + let mut rng = StdRng::seed_from_u64(1); + let mut instances = Vec::new(); + let mut expected_outputs = Vec::new(); + + for _ in 0..num_instances { + let mut words = [0u32; SHA_EXTEND_WORDS]; + for word in words.iter_mut().take(16) { + *word = rng.next_u32(); + } + sha_extend(&mut words); + for j in 16..SHA_EXTEND_ROUNDS + 16 { + instances.push(ShaExtendInstance { + witin: ShaExtendWitInstance { + w_i_minus_2: words[j - 2], + w_i_minus_7: words[j - 7], + w_i_minus_15: words[j - 15], + w_i_minus_16: words[j - 16], + }, + }); + } + expected_outputs.push(words[16..].to_vec()); + } + + let num_instances = num_instances * SHA_EXTEND_ROUNDS; + let mut phase1 = RowMajorMatrix::new( + num_instances, + layout.n_committed, + InstancePaddingStrategy::Default, + ); + let mut structural = RowMajorMatrix::new( + num_instances, + layout.n_structural_witin, + InstancePaddingStrategy::Default, + ); + let mut lk_multiplicity = LkMultiplicity::default(); + layout.phase1_witness_group( + ShaExtendTrace { instances }, + [&mut phase1, &mut structural], + &mut lk_multiplicity, + ); + + let out_index = layout.layer_exprs.wits.s2.value.0[0].id as usize; + for (instance_idx, expected_output) in expected_outputs.iter().enumerate() { + for (round_idx, expected_word_u32) in + expected_output.iter().take(SHA_EXTEND_ROUNDS).enumerate() + { + let row_idx = instance_idx * SHA_EXTEND_ROUNDS + round_idx; + let output_word: [_; WORD_SIZE] = phase1.row_slice(row_idx) + [out_index..out_index + 4] + .to_vec() + .try_into() + .unwrap(); + let expected_word = Word::::from(*expected_word_u32); + assert_eq!( + output_word, expected_word.0, + "mismatch at instance {}, round {}", + instance_idx, round_idx + ); + } + } + + let num_instances_rounds = next_pow2_instance_padding(num_instances); + let log2_num_instance_rounds = ceil_log2(num_instances_rounds); + let num_threads = optimal_sumcheck_threads(log2_num_instance_rounds); + let mut prover_transcript = BasicTranscript::::new(b"protocol"); + let challenges = [ + prover_transcript.read_challenge().elements, + prover_transcript.read_challenge().elements, + ]; + + let phase1_witness_group = phase1.to_mles().into_iter().map(Arc::new).collect_vec(); + let structural_witness = structural.to_mles().into_iter().map(Arc::new).collect_vec(); + let fixed = layout + .fixed_witness_group() + .to_mles() + .into_iter() + .map(Arc::new) + .collect_vec(); + + let (gkr_witness, gkr_output) = + crate::scheme::utils::gkr_witness::, CpuProver<_>>( + &gkr_circuit, + &phase1_witness_group, + &structural_witness, + &fixed, + &[], + &[], + &challenges, + ); + + let out_evals = { + let mut point = Vec::with_capacity(log2_num_instance_rounds); + point.extend( + prover_transcript + .sample_vec(log2_num_instance_rounds) + .to_vec(), + ); + + let out_evals = gkr_output + .0 + .par_iter() + .map(|wit| { + let point = point[point.len() - wit.num_vars()..point.len()].to_vec(); + PointAndEval { + point: point.clone(), + eval: wit.evaluate(&point), + } + }) + .collect::>(); + + if out_evals.is_empty() { + vec![PointAndEval { + point: point[point.len() - log2_num_instance_rounds..point.len()].to_vec(), + eval: E::ZERO, + }] + } else { + out_evals + } + }; + + let selector_ctxs = + vec![SelectorContext::new(0, num_instances, log2_num_instance_rounds); 1]; + let GKRProverOutput { gkr_proof, .. } = gkr_circuit + .prove::, CpuProver<_>>( + num_threads, + log2_num_instance_rounds, + gkr_witness, + &out_evals, + &[], + &challenges, + &mut prover_transcript, + &selector_ctxs, + ) + .expect("sha extend prove failed"); + + let mut verifier_transcript = BasicTranscript::::new(b"protocol"); + let challenges = [ + verifier_transcript.read_challenge().elements, + verifier_transcript.read_challenge().elements, + ]; + let mut point = Vec::with_capacity(log2_num_instance_rounds); + point.extend( + verifier_transcript + .sample_vec(log2_num_instance_rounds) + .to_vec(), + ); + + gkr_circuit + .verify( + log2_num_instance_rounds, + gkr_proof, + &out_evals, + &[], + &[], + &challenges, + &mut verifier_transcript, + &selector_ctxs, + ) + .expect("sha extend verify failed"); + } + + #[test] + fn test_sha_extend() { + test_sha_extend_helper(4); + } + + #[test] + fn test_sha_extend_non_pow2() { + test_sha_extend_helper(5); + } + + fn sha_extend(w: &mut [u32; SHA_EXTEND_WORDS]) { + for i in 16..SHA_EXTEND_WORDS { + let s0 = w[i - 15].rotate_right(7) ^ w[i - 15].rotate_right(18) ^ (w[i - 15] >> 3); + let s1 = w[i - 2].rotate_right(17) ^ w[i - 2].rotate_right(19) ^ (w[i - 2] >> 10); + w[i] = w[i - 16] + .wrapping_add(s0) + .wrapping_add(w[i - 7]) + .wrapping_add(s1); + } + } +} diff --git a/ceno_zkvm/src/precompiles/uint256.rs b/ceno_zkvm/src/precompiles/uint256.rs index fcded1e80..e59e97c55 100644 --- a/ceno_zkvm/src/precompiles/uint256.rs +++ b/ceno_zkvm/src/precompiles/uint256.rs @@ -130,10 +130,9 @@ impl Uint256MulLayout { let eq = cb.create_placeholder_structural_witin(|| "uint256_mul_structural_witin"); let sel = SelectorType::Prefix(eq.expr()); let selector_type_layout = SelectorTypeLayout { - sel_mem_read: sel.clone(), - sel_mem_write: sel.clone(), - sel_lookup: sel.clone(), - sel_zero: sel.clone(), + sel_first: None, + sel_last: None, + sel_all: sel.clone(), }; // Default expression, will be updated in build_layer_logic @@ -289,10 +288,10 @@ impl ProtocolBuilder for Uint256MulLayout { self.n_challenges = 0; // register selector to legacy constrain system - cb.cs.r_selector = Some(self.selector_type_layout.sel_mem_read.clone()); - cb.cs.w_selector = Some(self.selector_type_layout.sel_mem_write.clone()); - cb.cs.lk_selector = Some(self.selector_type_layout.sel_lookup.clone()); - cb.cs.zero_selector = Some(self.selector_type_layout.sel_zero.clone()); + cb.cs.r_selector = Some(self.selector_type_layout.sel_all.clone()); + cb.cs.w_selector = Some(self.selector_type_layout.sel_all.clone()); + cb.cs.lk_selector = Some(self.selector_type_layout.sel_all.clone()); + cb.cs.zero_selector = Some(self.selector_type_layout.sel_all.clone()); let w_len = cb.cs.w_expressions.len(); let r_len = cb.cs.r_expressions.len(); @@ -433,10 +432,9 @@ impl Uint256InvLayout { let eq = cb.create_placeholder_structural_witin(|| "uint256_mul_structural_witin"); let sel = SelectorType::Prefix(eq.expr()); let selector_type_layout = SelectorTypeLayout { - sel_mem_read: sel.clone(), - sel_mem_write: sel.clone(), - sel_lookup: sel.clone(), - sel_zero: sel.clone(), + sel_first: None, + sel_last: None, + sel_all: sel.clone(), }; // Default expression, will be updated in build_layer_logic @@ -541,10 +539,10 @@ impl ProtocolBuilder for Uint256InvL self.n_challenges = 0; // register selector to legacy constrain system - cb.cs.r_selector = Some(self.selector_type_layout.sel_mem_read.clone()); - cb.cs.w_selector = Some(self.selector_type_layout.sel_mem_write.clone()); - cb.cs.lk_selector = Some(self.selector_type_layout.sel_lookup.clone()); - cb.cs.zero_selector = Some(self.selector_type_layout.sel_zero.clone()); + cb.cs.r_selector = Some(self.selector_type_layout.sel_all.clone()); + cb.cs.w_selector = Some(self.selector_type_layout.sel_all.clone()); + cb.cs.lk_selector = Some(self.selector_type_layout.sel_all.clone()); + cb.cs.zero_selector = Some(self.selector_type_layout.sel_all.clone()); let w_len = cb.cs.w_expressions.len(); let r_len = cb.cs.r_expressions.len(); diff --git a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_add.rs b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_add.rs index 29155df42..012dcab80 100644 --- a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_add.rs +++ b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_add.rs @@ -135,10 +135,9 @@ impl WeierstrassAddAssignLayout { let eq = cb.create_placeholder_structural_witin(|| "weierstrass_add_eq"); let sel = SelectorType::Prefix(eq.expr()); let selector_type_layout = SelectorTypeLayout { - sel_mem_read: sel.clone(), - sel_mem_write: sel.clone(), - sel_lookup: sel.clone(), - sel_zero: sel.clone(), + sel_first: None, + sel_last: None, + sel_all: sel.clone(), }; // Default expression, will be updated in build_layer_logic @@ -334,10 +333,10 @@ impl ProtocolBuilder self.n_challenges = 0; // register selector to legacy constrain system - cb.cs.r_selector = Some(self.selector_type_layout.sel_mem_read.clone()); - cb.cs.w_selector = Some(self.selector_type_layout.sel_mem_write.clone()); - cb.cs.lk_selector = Some(self.selector_type_layout.sel_lookup.clone()); - cb.cs.zero_selector = Some(self.selector_type_layout.sel_zero.clone()); + cb.cs.r_selector = Some(self.selector_type_layout.sel_all.clone()); + cb.cs.w_selector = Some(self.selector_type_layout.sel_all.clone()); + cb.cs.lk_selector = Some(self.selector_type_layout.sel_all.clone()); + cb.cs.zero_selector = Some(self.selector_type_layout.sel_all.clone()); let w_len = cb.cs.w_expressions.len(); let r_len = cb.cs.r_expressions.len(); diff --git a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_decompress.rs b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_decompress.rs index 92edb98c3..d6400a2d7 100644 --- a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_decompress.rs +++ b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_decompress.rs @@ -153,10 +153,9 @@ impl let eq = cb.create_placeholder_structural_witin(|| "weierstrass_decompress_eq"); let sel = SelectorType::Prefix(eq.expr()); let selector_type_layout = SelectorTypeLayout { - sel_mem_read: sel.clone(), - sel_mem_write: sel.clone(), - sel_lookup: sel.clone(), - sel_zero: sel.clone(), + sel_first: None, + sel_last: None, + sel_all: sel.clone(), }; let input32_exprs: GenericArray< @@ -344,10 +343,10 @@ impl ProtocolBuild self.n_challenges = 0; // register selector to legacy constrain system - cb.cs.r_selector = Some(self.selector_type_layout.sel_mem_read.clone()); - cb.cs.w_selector = Some(self.selector_type_layout.sel_mem_write.clone()); - cb.cs.lk_selector = Some(self.selector_type_layout.sel_lookup.clone()); - cb.cs.zero_selector = Some(self.selector_type_layout.sel_zero.clone()); + cb.cs.r_selector = Some(self.selector_type_layout.sel_all.clone()); + cb.cs.w_selector = Some(self.selector_type_layout.sel_all.clone()); + cb.cs.lk_selector = Some(self.selector_type_layout.sel_all.clone()); + cb.cs.zero_selector = Some(self.selector_type_layout.sel_all.clone()); let w_len = cb.cs.w_expressions.len(); let r_len = cb.cs.r_expressions.len(); diff --git a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_double.rs b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_double.rs index aaaafa748..686baa397 100644 --- a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_double.rs +++ b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_double.rs @@ -137,10 +137,9 @@ impl let eq = cb.create_placeholder_structural_witin(|| "weierstrass_double_eq"); let sel = SelectorType::Prefix(eq.expr()); let selector_type_layout = SelectorTypeLayout { - sel_mem_read: sel.clone(), - sel_mem_write: sel.clone(), - sel_lookup: sel.clone(), - sel_zero: sel.clone(), + sel_first: None, + sel_last: None, + sel_all: sel.clone(), }; let input32_exprs: GenericArray< @@ -362,10 +361,10 @@ impl ProtocolBuild self.n_challenges = 0; // register selector to legacy constrain system - cb.cs.r_selector = Some(self.selector_type_layout.sel_mem_read.clone()); - cb.cs.w_selector = Some(self.selector_type_layout.sel_mem_write.clone()); - cb.cs.lk_selector = Some(self.selector_type_layout.sel_lookup.clone()); - cb.cs.zero_selector = Some(self.selector_type_layout.sel_zero.clone()); + cb.cs.r_selector = Some(self.selector_type_layout.sel_all.clone()); + cb.cs.w_selector = Some(self.selector_type_layout.sel_all.clone()); + cb.cs.lk_selector = Some(self.selector_type_layout.sel_all.clone()); + cb.cs.zero_selector = Some(self.selector_type_layout.sel_all.clone()); let w_len = cb.cs.w_expressions.len(); let r_len = cb.cs.r_expressions.len(); diff --git a/examples/examples/sha_extend_syscall.rs b/examples/examples/sha_extend_syscall.rs index 0ccd586eb..16d4290c2 100644 --- a/examples/examples/sha_extend_syscall.rs +++ b/examples/examples/sha_extend_syscall.rs @@ -5,7 +5,7 @@ use std::array; use ceno_syscall::syscall_sha256_extend; fn main() { - let mut words: [u32; 64] = array::from_fn(|i| i as u32); + let mut words = Box::new(array::from_fn(|i| i as u32)); let expected = [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 34013193, 67559435, 1711661200, @@ -18,6 +18,8 @@ fn main() { 634956631, ]; - syscall_sha256_extend(&mut words); - assert_eq!(words, expected); + for i in 16..64 { + syscall_sha256_extend(&mut words[i]); + } + assert_eq!(*words, expected); } diff --git a/examples/examples/syscalls.rs b/examples/examples/syscalls.rs index ebbd35de7..4594214eb 100644 --- a/examples/examples/syscalls.rs +++ b/examples/examples/syscalls.rs @@ -152,7 +152,9 @@ pub fn test_syscalls() { 3334668051, 3166416553, 634956631, ]; - syscall_sha256_extend(&mut words); + for word in &mut words[16..64] { + syscall_sha256_extend(word); + } assert_eq!(words, expected); } diff --git a/gkr_iop/src/circuit_builder.rs b/gkr_iop/src/circuit_builder.rs index 088f1f42c..d84cb4d55 100644 --- a/gkr_iop/src/circuit_builder.rs +++ b/gkr_iop/src/circuit_builder.rs @@ -1149,6 +1149,31 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { self.logic_u8(LookupTable::Ltu, a, b, c) } + /// Assert that `a >> b = (shift, carry)`, where `a` is an 8-bit unsigned integer, `shift` and `carry` are 8-bit unsigned integers. b is in [0, 8). + pub fn lookup_shr_byte( + &mut self, + a: Expression, + b: usize, + shift: Expression, + carry: Expression, + ) -> Result<(), CircuitBuilderError> { + self.assert_double_u8( + || "lookup_shr_byte shift range check", + shift.expr(), + shift.expr() * (1 << b), + )?; + self.assert_double_u8( + || "lookup_shr_byte carry range check", + carry.expr(), + carry.expr() * (1 << (8 - b)), + )?; + self.require_equal( + || "lookup_shr_byte a == shift << b + carry", + a, + shift * (1 << b) + carry, + ) + } + // Assert that `2^b = c` and that `b` is a 5-bit unsigned integer. pub fn lookup_pow2( &mut self, @@ -1305,7 +1330,19 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { Ok(()) } - pub fn set_rotation_params(&mut self, params: RotationParams) { + pub fn set_rotation_params( + &mut self, + eq_rotation_left: Expression, + eq_rotation_right: Expression, + eq_rotation: Expression, + rotation_cyclic_group_log2: usize, + rotation_cyclic_subgroup_size: usize, + ) { + let params = RotationParams { + rotation_eqs: Some([eq_rotation_left, eq_rotation_right, eq_rotation]), + rotation_cyclic_group_log2, + rotation_cyclic_subgroup_size, + }; assert!(self.cs.rotation_params.is_none()); self.cs.rotation_params = Some(params); } diff --git a/gkr_iop/src/gkr/booleanhypercube.rs b/gkr_iop/src/gkr/booleanhypercube.rs index d6ba4c7c0..8bd89f59d 100644 --- a/gkr_iop/src/gkr/booleanhypercube.rs +++ b/gkr_iop/src/gkr/booleanhypercube.rs @@ -2,8 +2,6 @@ use ff_ext::ExtensionField; use itertools::Itertools; use multilinear_extensions::mle::Point; -const BH_MAX_NUM_VAR: usize = 5; - pub struct BooleanHypercube { num_vars: usize, } @@ -46,10 +44,80 @@ pub const CYCLIC_POW2_5: [u64; 32] = [ #[allow(dead_code)] const CYCLIC_POW2_5_MODULUS: u8 = 0b100101; // X^5 + X^2 + 1 +// 2^6-1 cyclic group +pub const CYCLIC_POW2_6: [u64; 64] = [ + 0b000001, // 0 = decimal 1 + 0b000010, // 1 = decimal 2 + 0b000100, // 2 = decimal 4 + 0b001000, // 3 = decimal 8 + 0b010000, // 4 = decimal 16 + 0b100000, // 5 = decimal 32 + 0b000011, // 6 = decimal 3 + 0b000110, // 7 = decimal 6 + 0b001100, // 8 = decimal 12 + 0b011000, // 9 = decimal 24 + 0b110000, // 10 = decimal 48 + 0b100011, // 11 = decimal 35 + 0b000101, // 12 = decimal 5 + 0b001010, // 13 = decimal 10 + 0b010100, // 14 = decimal 20 + 0b101000, // 15 = decimal 40 + 0b010011, // 16 = decimal 19 + 0b100110, // 17 = decimal 38 + 0b001111, // 18 = decimal 15 + 0b011110, // 19 = decimal 30 + 0b111100, // 20 = decimal 60 + 0b111011, // 21 = decimal 59 + 0b110101, // 22 = decimal 53 + 0b101001, // 23 = decimal 41 + 0b010001, // 24 = decimal 17 + 0b100010, // 25 = decimal 34 + 0b000111, // 26 = decimal 7 + 0b001110, // 27 = decimal 14 + 0b011100, // 28 = decimal 28 + 0b111000, // 29 = decimal 56 + 0b110011, // 30 = decimal 51 + 0b100101, // 31 = decimal 37 + 0b001001, // 32 = decimal 9 + 0b010010, // 33 = decimal 18 + 0b100100, // 34 = decimal 36 + 0b001011, // 35 = decimal 11 + 0b010110, // 36 = decimal 22 + 0b101100, // 37 = decimal 44 + 0b011011, // 38 = decimal 27 + 0b110110, // 39 = decimal 54 + 0b101111, // 40 = decimal 47 + 0b011101, // 41 = decimal 29 + 0b111010, // 42 = decimal 58 + 0b110111, // 43 = decimal 55 + 0b101101, // 44 = decimal 45 + 0b011001, // 45 = decimal 25 + 0b110010, // 46 = decimal 50 + 0b100111, // 47 = decimal 39 + 0b001101, // 48 = decimal 13 + 0b011010, // 49 = decimal 26 + 0b110100, // 50 = decimal 52 + 0b101011, // 51 = decimal 43 + 0b010101, // 52 = decimal 21 + 0b101010, // 53 = decimal 42 + 0b010111, // 54 = decimal 23 + 0b101110, // 55 = decimal 46 + 0b011111, // 56 = decimal 31 + 0b111110, // 57 = decimal 62 + 0b111111, // 58 = decimal 63 + 0b111101, // 59 = decimal 61 + 0b111001, // 60 = decimal 57 + 0b110001, // 61 = decimal 49 + 0b100001, // 62 = decimal 33 + 0b000001, // 63 = decimal 1 +]; +#[allow(dead_code)] +const CYCLIC_POW2_6_MODULUS: u8 = 0b1000011; // X^6 + X + 1 + impl BooleanHypercube { // giving num_vars, cyclic group size is 2^num_vars - 1, as excluding 0 pub fn new(num_vars: usize) -> Self { - assert!(num_vars <= BH_MAX_NUM_VAR); + assert!(num_vars == 5 || num_vars == 6); Self { num_vars } } @@ -72,6 +140,23 @@ impl BooleanHypercube { .take(point.len()) .collect_vec(), ), + 6 => ( + // derive from CYCLIC_POW2_6_MODULUS + // left: (0, r0, r1, r2, r3, r4, r6, r7, ....) + std::iter::once(E::ZERO) + .chain(point[..5].iter().copied()) + .chain(point[6..].iter().copied()) + .take(point.len()) + .collect_vec(), + // right: (1, 1 - r0, r1, r2, r3, r4, r6, r7, ....) + std::iter::once(E::ONE) + .chain(std::iter::once(E::ONE - point[0])) + .chain(std::iter::once(point[1])) + .chain(point[2..5].iter().copied()) + .chain(point[6..].iter().copied()) + .take(point.len()) + .collect_vec(), + ), num_vars => unimplemented!("not support {num_vars}"), } } @@ -86,6 +171,9 @@ impl BooleanHypercube { // rotated_eval = (1-r4) * left_eval + r4 * right_eval // right_eval = (rotated_eval - (1-r4) * left_eval) / r4 5 => (rotated_eval - (E::ONE - point[4]) * left_eval) / point[4], + // rotated_eval = (1-r5) * left_eval + r5 * right_eval + // right_eval = (rotated_eval - (1-r5) * left_eval) / r5 + 6 => (rotated_eval - (E::ONE - point[5]) * left_eval) / point[5], num_vars => unimplemented!("not support {num_vars}"), } } @@ -93,19 +181,22 @@ impl BooleanHypercube { impl IntoIterator for &BooleanHypercube { type Item = u64; - type IntoIter = std::array::IntoIter; + type IntoIter = std::iter::Copied>; fn into_iter(self) -> Self::IntoIter { match self.num_vars { - 5 => CYCLIC_POW2_5.into_iter(), - _ => unimplemented!(), + 5 => CYCLIC_POW2_5.as_slice().iter().copied(), + 6 => CYCLIC_POW2_6.as_slice().iter().copied(), + _ => panic!("not support {}", self.num_vars), } } } #[cfg(test)] mod tests { - use crate::gkr::booleanhypercube::CYCLIC_POW2_5_MODULUS; + use std::collections::HashSet; + + use crate::gkr::booleanhypercube::{CYCLIC_POW2_5_MODULUS, CYCLIC_POW2_6_MODULUS}; #[test] fn test_generate_f_31_cyclic_group_element() { @@ -124,8 +215,28 @@ mod tests { powers.push(current); } - for &_val in powers.iter() { - // println!("0b{:05b}, // {} = decimal {} ", val, i, val); + let set = powers.iter().cloned().collect::>(); + assert_eq!(set.len(), 31); // all elements are unique + } + + #[test] + fn test_generate_f_63_cyclic_group_element() { + let _x = 0b000010; // generator x = X + let mut powers = Vec::with_capacity(63); + powers.push(1); // x^0 = 1 + + let mut current = 1u8; + + for _ in 1..64 { + current <<= 1; // multiply by x (shift left) + if current & 0b1000000 != 0 { + // degree 6 overflow + current ^= CYCLIC_POW2_6_MODULUS; // reduce modulo polynomial + } + powers.push(current); } + + let set = powers.iter().cloned().collect::>(); + assert_eq!(set.len(), 63); // all elements are unique } } diff --git a/gkr_iop/src/gkr/layer/gpu/utils.rs b/gkr_iop/src/gkr/layer/gpu/utils.rs index e7bd95050..dc0942074 100644 --- a/gkr_iop/src/gkr/layer/gpu/utils.rs +++ b/gkr_iop/src/gkr/layer/gpu/utils.rs @@ -78,6 +78,9 @@ pub fn build_eq_x_r_with_sel_gpu( SelectorType::OrderedSparse32 { indices, .. } => { (selector_ctx.num_instances, true, indices.clone()) } + SelectorType::OrderedSparse64 { .. } => { + unimplemented!("OrderedSparse64 is not supported in GPU selector path") + } SelectorType::QuarkBinaryTreeLessThan(..) => unimplemented!(), }; diff --git a/gkr_iop/src/gkr/layer/zerocheck_layer.rs b/gkr_iop/src/gkr/layer/zerocheck_layer.rs index 55d63e2d4..fadb6e690 100644 --- a/gkr_iop/src/gkr/layer/zerocheck_layer.rs +++ b/gkr_iop/src/gkr/layer/zerocheck_layer.rs @@ -542,7 +542,7 @@ pub fn extend_exprs_with_rotation( SelectorType::None => zero_check_expr, SelectorType::Whole(sel) | SelectorType::Prefix(sel) - | SelectorType::OrderedSparse32 { + | SelectorType::OrderedSparse { expression: sel, .. } | SelectorType::QuarkBinaryTreeLessThan(sel) => match_expr(sel) * zero_check_expr, diff --git a/gkr_iop/src/selector.rs b/gkr_iop/src/selector.rs index 703386417..857dbd588 100644 --- a/gkr_iop/src/selector.rs +++ b/gkr_iop/src/selector.rs @@ -2,7 +2,7 @@ use std::iter::repeat_n; use rayon::iter::IndexedParallelIterator; -use crate::{gkr::booleanhypercube::CYCLIC_POW2_5, utils::eq_eval_less_or_equal_than}; +use crate::utils::eq_eval_less_or_equal_than; use ff_ext::ExtensionField; use multilinear_extensions::{ Expression, WitnessId, @@ -48,8 +48,10 @@ pub enum SelectorType { /// Select part of the instances, other parts padded with a field element. Prefix(Expression), /// selector activates on the specified `indices`, which are assumed to be in ascending order. - /// each index corresponds to a position within a fixed-size chunk (e.g., size 32), - OrderedSparse32 { + /// selector activates on the specified `indices`, which are assumed to be in ascending order. + /// each index corresponds to a position within a fixed-size chunk (e.g., size 64), + OrderedSparse { + num_vars: usize, indices: Vec, expression: Expression, }, @@ -90,19 +92,20 @@ impl SelectorType { .into_mle(), ) } - SelectorType::OrderedSparse32 { + SelectorType::OrderedSparse { + num_vars, indices, expression: _, } => { - assert_eq!(ceil_log2(ctx.num_instances) + 5, ctx.num_vars); + assert_eq!(ceil_log2(ctx.num_instances) + num_vars, ctx.num_vars); Some( - (0..(1 << (ctx.num_vars - 5))) + (0..(1 << (ctx.num_vars - num_vars))) .into_par_iter() .flat_map(|chunk_index| { if chunk_index >= ctx.num_instances { - vec![E::ZERO; 32] + vec![E::ZERO; 1 << num_vars] } else { - let mut chunk = vec![E::ZERO; 32]; + let mut chunk = vec![E::ZERO; 1 << num_vars]; let mut indices_iter = indices.iter().copied(); let mut next_keep = indices_iter.next(); @@ -125,7 +128,6 @@ impl SelectorType { } } - /// Compute true and false mle eq(1; b[..5]) * sel(y; b[5..]), and eq(1; b[..5]) * (eq() - sel(y; b[5..])) pub fn compute( &self, out_point: &Point, @@ -152,15 +154,17 @@ impl SelectorType { sel.splice(end..sel.len(), repeat_n(E::ZERO, sel.len() - end)); Some(sel.into_mle()) } - // compute true and false mle eq(1; b[..5]) * sel(y; b[5..]), and eq(1; b[..5]) * (eq() - sel(y; b[5..])) - SelectorType::OrderedSparse32 { indices, .. } => { + // compute true and false mle eq(1; b[..num_vars]) * sel(y; b[num_vars..]), and eq(1; b[..num_vars]) * (eq() - sel(y; b[num_vars..])) + SelectorType::OrderedSparse { + indices, num_vars, .. + } => { assert_eq!( out_point.len(), - next_pow2_instance_padding(ctx.num_instances).ilog2() as usize + 5 + next_pow2_instance_padding(ctx.num_instances).ilog2() as usize + num_vars ); let mut sel = build_eq_x_r_vec(out_point); - sel.par_chunks_exact_mut(CYCLIC_POW2_5.len()) + sel.par_chunks_exact_mut(1 << num_vars) .enumerate() .for_each(|(chunk_index, chunk)| { if chunk_index >= ctx.num_instances { @@ -281,21 +285,22 @@ impl SelectorType { (expression, sel) } } - // evaluate true and false mle eq(CYCLIC_POW2_5[round]; b[..5]) * sel(y; b[5..]), and eq(1; b[..5]) * (1 - sel(y; b[5..])) - SelectorType::OrderedSparse32 { + // evaluate true mle \sum_{round} (eq(CYCLIC_POW2[round]; b[..num_vars])) * sel(y; b[num_vars..]) + SelectorType::OrderedSparse { + num_vars, indices, expression, } => { - let out_subgroup_eq = build_eq_x_r_vec(&out_point[..5]); - let in_subgroup_eq = build_eq_x_r_vec(&in_point[..5]); + let out_subgroup_eq = build_eq_x_r_vec(&out_point[..*num_vars]); + let in_subgroup_eq = build_eq_x_r_vec(&in_point[..*num_vars]); let mut eval = E::ZERO; for index in indices { eval += out_subgroup_eq[*index] * in_subgroup_eq[*index]; } let sel = eq_eval_less_or_equal_than( ctx.num_instances - 1, - &out_point[5..], - &in_point[5..], + &out_point[*num_vars..], + &in_point[*num_vars..], ); (expression, eval * sel) } @@ -358,16 +363,16 @@ impl SelectorType { } /// return ordered indices of OrderedSparse32 - pub fn sparse32_indices(&self) -> &[usize] { + pub fn sparse_indices(&self) -> &[usize] { match self { - Self::OrderedSparse32 { indices, .. } => indices, + Self::OrderedSparse { indices, .. } => indices, _ => panic!("invalid calling on non sparse type"), } } pub fn selector_expr(&self) -> &Expression { match self { - Self::OrderedSparse32 { expression, .. } + Self::OrderedSparse { expression, .. } | Self::Whole(expression) | Self::Prefix(expression) => expression, e => unimplemented!("no selector expression in {:?}", e), diff --git a/gkr_iop/src/utils.rs b/gkr_iop/src/utils.rs index 5d1970297..e1c8d7453 100644 --- a/gkr_iop/src/utils.rs +++ b/gkr_iop/src/utils.rs @@ -16,6 +16,7 @@ use rayon::{ use crate::gkr::booleanhypercube::BooleanHypercube; +/// Given an mle defined over a boolean hypercube, left-rotate its base mle by one step according to the given boolean hypercube. pub fn rotation_next_base_mle<'a, E: ExtensionField>( bh: &BooleanHypercube, mle: &ArcMultilinearExtension<'a, E>, diff --git a/gkr_iop/src/utils/lk_multiplicity.rs b/gkr_iop/src/utils/lk_multiplicity.rs index 55ff71f60..7dded4e70 100644 --- a/gkr_iop/src/utils/lk_multiplicity.rs +++ b/gkr_iop/src/utils/lk_multiplicity.rs @@ -202,6 +202,17 @@ impl LkMultiplicity { self.increment(LookupTable::DoubleU8, (a << 8) + b); } + pub fn assert_bytes(&mut self, vs: &[u8]) { + // process in pairs + for pair in vs.chunks(2) { + match pair { + [a, b] => self.assert_double_u8(*a as u64, *b as u64), + [a] => self.assert_double_u8(*a as u64, 0), + _ => {} + } + } + } + /// assert slices of field elements within range #[inline] pub fn assert_byte_fields(&mut self, vs: &[F]) { @@ -240,6 +251,11 @@ impl LkMultiplicity { self.logic_u8::(a, b) } + pub fn lookup_shr_byte(&mut self, shift: u64, carry: u64, nb_bits_to_shift: u64) { + self.assert_double_u8(shift, shift << nb_bits_to_shift); + self.assert_double_u8(carry, carry << (8 - nb_bits_to_shift)); + } + pub fn lookup_pow2(&mut self, v: u64) { self.logic_u8::(2, v) } diff --git a/guest_libs/sha2/src/sha256.rs b/guest_libs/sha2/src/sha256.rs index ab0d5ccc2..5cbde64b5 100644 --- a/guest_libs/sha2/src/sha256.rs +++ b/guest_libs/sha2/src/sha256.rs @@ -18,8 +18,10 @@ fn compress_u32(state: &mut [u32; 8], block: [u32; 16]) { let mut w = [0; 64]; w[..16].copy_from_slice(&block); - // Replace extend with a syscall - syscall_sha256_extend(&mut w); + // Replace extend with syscalls per round + for word in &mut w[16..64] { + syscall_sha256_extend(word); + } for i in 0..64 { let s1 = e.rotate_right(6) ^ e.rotate_right(11) ^ e.rotate_right(25);