Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

69 changes: 47 additions & 22 deletions ceno_emul/src/syscalls/sha256.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
use crate::{Change, EmuContext, Platform, Tracer, VMState, Word, WriteOp, utils::MemoryView};
use crate::{
ByteAddr, Change, EmuContext, Platform, Tracer, VMState, WORD_SIZE, Word, WordAddr, WriteOp,
utils::MemoryView,
};

use super::{SyscallEffects, SyscallSpec, SyscallWitness};

pub const SHA_EXTEND_WORDS: usize = 64; // u64 cells
pub const SHA_EXTEND_ROUND_MEM_OPS: usize = 5;

pub struct Sha256ExtendSpec;

impl SyscallSpec for Sha256ExtendSpec {
const NAME: &'static str = "SHA256_EXTEND";

const REG_OPS_COUNT: usize = 1;
const MEM_OPS_COUNT: usize = SHA_EXTEND_WORDS;
const MEM_OPS_COUNT: usize = SHA_EXTEND_ROUND_MEM_OPS;
const CODE: u32 = ceno_syscall::SHA_EXTEND;
}

Expand All @@ -29,37 +33,58 @@ impl From<ShaExtendWords> for [Word; SHA_EXTEND_WORDS] {
}
}

/// Based on: https://github.com/succinctlabs/sp1/blob/2aed8fea16a67a5b2983ffc471b2942c2f2512c8/crates/core/machine/src/syscall/precompiles/sha256/extend/mod.rs#L22
pub fn sha_extend(w: &mut [u32]) {
for i in 16..64 {
let s0 = w[i - 15].rotate_right(7) ^ w[i - 15].rotate_right(18) ^ (w[i - 15] >> 3);
let s1 = w[i - 2].rotate_right(17) ^ w[i - 2].rotate_right(19) ^ (w[i - 2] >> 10);
w[i] = w[i - 16]
.wrapping_add(s0)
.wrapping_add(w[i - 7])
.wrapping_add(s1);
}
}

pub fn extend<T: Tracer>(vm: &VMState<T>) -> SyscallEffects {
let state_ptr = vm.peek_register(Platform::reg_arg0());

// Read the argument `state_ptr`.
let reg_ops = vec![WriteOp::new_register_op(
Platform::reg_arg0(),
Change::new(state_ptr, state_ptr),
0, // Cycle set later in finalize().
0,
)];

let mut state_view = MemoryView::<_, SHA_EXTEND_WORDS>::new(vm, state_ptr);
let mut sha_extend_words = ShaExtendWords::from(state_view.words());
sha_extend(&mut sha_extend_words.0);
let output_words: [Word; SHA_EXTEND_WORDS] = sha_extend_words.into();
let w_i_minus_2 = MemoryView::<_, 1>::new(vm, state_ptr - 2 * WORD_SIZE as u32).words()[0];
let w_i_minus_7 = MemoryView::<_, 1>::new(vm, state_ptr - 7 * WORD_SIZE as u32).words()[0];
let w_i_minus_15 = MemoryView::<_, 1>::new(vm, state_ptr - 15 * WORD_SIZE as u32).words()[0];
let w_i_minus_16 = MemoryView::<_, 1>::new(vm, state_ptr - 16 * WORD_SIZE as u32).words()[0];
let old_word = MemoryView::<_, 1>::new(vm, state_ptr).words()[0];

let s0 = w_i_minus_15.rotate_right(7) ^ w_i_minus_15.rotate_right(18) ^ (w_i_minus_15 >> 3);
let s1 = w_i_minus_2.rotate_right(17) ^ w_i_minus_2.rotate_right(19) ^ (w_i_minus_2 >> 10);
let new_word = w_i_minus_16
.wrapping_add(s0)
.wrapping_add(w_i_minus_7)
.wrapping_add(s1);

state_view.write(output_words);
let mem_ops = state_view.mem_ops().to_vec();
let base = ByteAddr::from(state_ptr).waddr();
let mem_ops = vec![
WriteOp {
addr: WordAddr(base.0 - 2),
value: Change::new(w_i_minus_2, w_i_minus_2),
previous_cycle: 0,
},
WriteOp {
addr: WordAddr(base.0 - 7),
value: Change::new(w_i_minus_7, w_i_minus_7),
previous_cycle: 0,
},
WriteOp {
addr: WordAddr(base.0 - 15),
value: Change::new(w_i_minus_15, w_i_minus_15),
previous_cycle: 0,
},
WriteOp {
addr: WordAddr(base.0 - 16),
value: Change::new(w_i_minus_16, w_i_minus_16),
previous_cycle: 0,
},
WriteOp {
addr: base,
value: Change::new(old_word, new_word),
previous_cycle: 0,
},
];

assert_eq!(mem_ops.len(), SHA_EXTEND_WORDS);
SyscallEffects {
witness: SyscallWitness::new(mem_ops, reg_ops),
next_pc: None,
Expand Down
48 changes: 26 additions & 22 deletions ceno_host/tests/test_elf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ use std::{collections::BTreeSet, iter::from_fn, sync::Arc};
use anyhow::Result;
use ceno_emul::{
BN254_FP_WORDS, BN254_FP2_WORDS, BN254_POINT_WORDS, CENO_PLATFORM, EmuContext, InsnKind,
Platform, Program, SECP256K1_ARG_WORDS, SECP256K1_COORDINATE_WORDS, SHA_EXTEND_WORDS,
StepRecord, UINT256_WORDS_FIELD_ELEMENT, VMState, WORD_SIZE, Word, WordAddr, WriteOp,
Platform, Program, SECP256K1_ARG_WORDS, SECP256K1_COORDINATE_WORDS, StepRecord,
UINT256_WORDS_FIELD_ELEMENT, VMState, WORD_SIZE, Word, WordAddr, WriteOp,
host_utils::{read_all_messages, read_all_messages_as_words},
};
use ceno_host::CenoStdin;
Expand Down Expand Up @@ -466,19 +466,6 @@ fn test_secp256k1_ecrecover() -> Result<()> {
#[test]
fn test_sha256_extend() -> Result<()> {
let program_elf = ceno_examples::sha_extend_syscall;
let mut state = VMState::new_from_elf(unsafe_platform(), program_elf)?;

let steps = run(&mut state)?;
let syscalls = steps.iter().filter_map(|step| step.syscall()).collect_vec();
assert_eq!(syscalls.len(), 1);

let witness = syscalls[0];
assert_eq!(witness.reg_ops.len(), 1);
assert_eq!(witness.reg_ops[0].register_index(), Platform::reg_arg0());

let state_ptr = witness.reg_ops[0].value.after;
assert_eq!(state_ptr, witness.reg_ops[0].value.before);
let state_ptr: WordAddr = state_ptr.into();

let expected = [
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 34013193, 67559435, 1711661200,
Expand All @@ -491,14 +478,31 @@ fn test_sha256_extend() -> Result<()> {
634956631,
];

assert_eq!(witness.mem_ops.len(), SHA_EXTEND_WORDS);
for round in 0..48 {
let mut state = VMState::new_from_elf(unsafe_platform(), program_elf)?;
let steps = run(&mut state)?;
let syscalls = steps.iter().filter_map(|step| step.syscall()).collect_vec();
assert_eq!(syscalls.len(), 1);
let witness = syscalls[0];

for (i, write_op) in witness.mem_ops.iter().enumerate() {
assert_eq!(write_op.addr, state_ptr + i);
assert_eq!(write_op.value.after, expected[i]);
if i < 16 {
// sanity check: first 16 entries remain unchanged
assert_eq!(write_op.value.before, write_op.value.after);
assert_eq!(witness.reg_ops.len(), 1);
assert_eq!(witness.reg_ops[0].register_index(), Platform::reg_arg0());

let state_ptr = witness.reg_ops[0].value.before;
assert_eq!(state_ptr + WORD_SIZE as u32, witness.reg_ops[0].value.after);
let state_ptr: WordAddr = state_ptr.into();

assert_eq!(witness.mem_ops.len(), 5);

let offsets = [-2, -7, -15, -16, 0];
for (i, write_op) in witness.mem_ops.iter().enumerate() {
let mem_round_id = round + (16 + offsets[i]) as usize;
assert_eq!(write_op.addr, state_ptr + mem_round_id as u32);
if i < 4 {
assert_eq!(write_op.value.before, write_op.value.after);
} else {
assert_eq!(write_op.value.after, expected[mem_round_id]);
}
}
}

Expand Down
12 changes: 7 additions & 5 deletions ceno_recursion/src/zkvm_verifier/verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1370,12 +1370,14 @@ pub fn evaluate_selector<C: Config>(

(expr, sel)
}
SelectorType::OrderedSparse32 {
SelectorType::OrderedSparse {
num_vars,
indices,
expression,
} => {
let out_point_slice = out_point.slice(builder, 0, 5);
let in_point_slice = in_point.slice(builder, 0, 5);
let num_vars = *num_vars;
let out_point_slice = out_point.slice(builder, 0, num_vars);
let in_point_slice = in_point.slice(builder, 0, num_vars);
let out_subgroup_eq = build_eq_x_r_vec_sequential(builder, &out_point_slice);
let in_subgroup_eq = build_eq_x_r_vec_sequential(builder, &in_point_slice);

Expand All @@ -1386,8 +1388,8 @@ pub fn evaluate_selector<C: Config>(
builder.assign(&eval, eval + out_val * in_val);
}

let out_point_slice = out_point.slice(builder, 5, out_point.len());
let in_point_slice = in_point.slice(builder, 5, in_point.len());
let out_point_slice = out_point.slice(builder, num_vars, out_point.len());
let in_point_slice = in_point.slice(builder, num_vars, in_point.len());
let n_bits = builder.get(&ctx.num_instances_bit_decomps, 0);

let sel =
Expand Down
1 change: 1 addition & 0 deletions ceno_zkvm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ tracing.workspace = true
tracing-forest.workspace = true
tracing-subscriber.workspace = true

arrayref = "0.3.9"
bincode.workspace = true
cfg-if.workspace = true
clap.workspace = true
Expand Down
Loading
Loading