diff --git a/ceno_emul/src/chunked_vec.rs b/ceno_emul/src/chunked_vec.rs deleted file mode 100644 index 7355aa937..000000000 --- a/ceno_emul/src/chunked_vec.rs +++ /dev/null @@ -1,99 +0,0 @@ -use rayon::iter::{IntoParallelIterator, ParallelIterator}; -use std::ops::{Index, IndexMut}; - -/// A growable vector divided into fixed-size chunks. -/// -/// This structure behaves similarly to a `Vec`, but allocates memory -/// in discrete chunks of a fixed size rather than continuously. -/// It is especially useful when the total number of elements is large -/// or not known in advance, as it avoids repeated reallocations. -/// -/// Conceptually, it can be seen as a "DENSE" map-like container where accessed -/// keys are comparable indices and values are stored in chunked segments. -/// -/// This layout is more cache-friendly when keys are accessed in increasing order. -#[derive(Default, Debug, Clone)] -pub struct ChunkedVec { - chunks: Vec>, - chunk_size: usize, - len: usize, -} - -impl ChunkedVec { - /// create a new ChunkedVec with a given chunk size. - pub fn new(chunk_size: usize) -> Self { - assert!(chunk_size > 0, "chunk_size must be > 0"); - Self { - chunks: Vec::new(), - chunk_size, - len: 0, - } - } - - /// get the current number of elements. - pub fn len(&self) -> usize { - self.len - } - - /// returns true if the vector is empty. - pub fn is_empty(&self) -> bool { - self.len == 0 - } - - /// access element by index (immutable). - pub fn get(&self, index: usize) -> Option<&T> { - if index >= self.len { - return None; - } - let chunk_idx = index / self.chunk_size; - let within_idx = index % self.chunk_size; - self.chunks.get(chunk_idx)?.get(within_idx) - } - - /// access element by index (mutable). - /// get mutable reference to element at index, auto-creating chunks as needed - pub fn get_or_create(&mut self, index: usize) -> &mut T { - let chunk_idx = index / self.chunk_size; - let within_idx = index % self.chunk_size; - - // Ensure enough chunks exist - if chunk_idx >= self.chunks.len() { - let to_create = chunk_idx + 1 - self.chunks.len(); - - // Use rayon to create all missing chunks in parallel - let mut new_chunks: Vec> = (0..to_create) - .map(|_| { - (0..self.chunk_size) - .into_par_iter() - .map(|_| Default::default()) - .collect::>() - }) - .collect(); - - self.chunks.append(&mut new_chunks); - } - - let chunk = &mut self.chunks[chunk_idx]; - - // Update the overall length - if index >= self.len { - self.len = index + 1; - } - - &mut chunk[within_idx] - } -} - -impl Index for ChunkedVec { - type Output = T; - - fn index(&self, index: usize) -> &Self::Output { - self.get(index).expect("index out of bounds") - } -} - -impl IndexMut for ChunkedVec { - fn index_mut(&mut self, index: usize) -> &mut Self::Output { - self.get_or_create(index) - } -} diff --git a/ceno_emul/src/lib.rs b/ceno_emul/src/lib.rs index c5d901d40..1e49ced29 100644 --- a/ceno_emul/src/lib.rs +++ b/ceno_emul/src/lib.rs @@ -11,7 +11,8 @@ pub use platform::{CENO_PLATFORM, Platform}; mod tracer; pub use tracer::{ Change, FullTracer, LatestAccesses, MemOp, NextAccessPair, NextCycleAccess, PreflightTracer, - ReadOp, StepRecord, Tracer, WriteOp, + PreflightTracerConfig, ReadOp, ShardPlanBuilder, StepCellExtractor, StepRecord, Tracer, + WriteOp, }; mod vm_state; @@ -54,7 +55,5 @@ pub use syscalls::{ pub mod utils; -pub mod test_utils; - -mod chunked_vec; pub mod host_utils; +pub mod test_utils; diff --git a/ceno_emul/src/tracer.rs b/ceno_emul/src/tracer.rs index 22ac309af..64b7aa746 100644 --- a/ceno_emul/src/tracer.rs +++ b/ceno_emul/src/tracer.rs @@ -1,12 +1,12 @@ use crate::{ CENO_PLATFORM, InsnKind, Instruction, PC_STEP_SIZE, Platform, addr::{ByteAddr, Cycle, RegIdx, Word, WordAddr}, - chunked_vec::ChunkedVec, dense_addr_space::DenseAddrSpace, encode_rv32, syscalls::{SyscallEffects, SyscallWitness}, }; use ceno_rt::WORD_SIZE; +use rustc_hash::FxHashMap; use smallvec::SmallVec; use std::{collections::BTreeMap, fmt, mem, sync::Arc}; @@ -39,9 +39,17 @@ pub struct StepRecord { syscall: Option, } +pub trait StepCellExtractor { + fn cells_for_kind(&self, kind: InsnKind, rs1_value: Option) -> u64; + + #[inline(always)] + fn extract_cells(&self, step: &StepRecord) -> u64 { + self.cells_for_kind(step.insn().kind, step.rs1().map(|op| op.value)) + } +} + pub type NextAccessPair = SmallVec<[(WordAddr, Cycle); 1]>; -pub type NextCycleAccess = ChunkedVec; -const ACCESSED_CHUNK_SIZE: usize = 1 << 20; +pub type NextCycleAccess = FxHashMap; fn init_mmio_min_max_access( platform: &Platform, @@ -79,6 +87,7 @@ fn init_mmio_min_max_access( pub trait Tracer { type Record; + type Config; const SUBCYCLE_RS1: Cycle = 0; const SUBCYCLE_RS2: Cycle = 1; @@ -86,14 +95,18 @@ pub trait Tracer { const SUBCYCLE_MEM: Cycle = 3; const SUBCYCLES_PER_INSN: Cycle = 4; - fn new(platform: &Platform) -> Self; + fn new(platform: &Platform, config: Self::Config) -> Self; - fn with_next_accesses(platform: &Platform, next_accesses: Option>) -> Self + fn with_next_accesses( + platform: &Platform, + config: Self::Config, + next_accesses: Option>, + ) -> Self where Self: Sized, { let _ = next_accesses; - Self::new(platform) + Self::new(platform, config) } fn advance(&mut self) -> Self::Record; @@ -235,6 +248,111 @@ impl LatestAccesses { } } +#[derive(Clone, Debug)] +pub struct ShardPlanBuilder { + shard_cycle_boundaries: Vec, + max_cell_per_shard: u64, + target_cell_first_shard: u64, + max_cycle_per_shard: Cycle, + current_shard_start_cycle: Cycle, + cur_cells: u64, + cur_cycle_in_shard: Cycle, + shard_id: usize, + finalized: bool, +} + +impl ShardPlanBuilder { + pub fn new(max_cell_per_shard: u64, max_cycle_per_shard: Cycle) -> Self { + let initial_cycle = FullTracer::SUBCYCLES_PER_INSN; + ShardPlanBuilder { + shard_cycle_boundaries: vec![initial_cycle], + max_cell_per_shard, + target_cell_first_shard: max_cell_per_shard, + max_cycle_per_shard, + current_shard_start_cycle: initial_cycle, + cur_cells: 0, + cur_cycle_in_shard: 0, + shard_id: 0, + finalized: false, + } + } + + pub fn current_shard_start_cycle(&self) -> Cycle { + self.current_shard_start_cycle + } + + pub fn shard_cycle_boundaries(&self) -> &[Cycle] { + &self.shard_cycle_boundaries + } + + pub fn max_cycle(&self) -> Cycle { + assert!(self.finalized, "shard plan not finalized yet"); + *self + .shard_cycle_boundaries + .last() + .expect("shard boundaries must contain at least one entry") + } + + pub fn into_cycle_boundaries(self) -> Vec { + assert!(self.finalized, "shard plan not finalized yet"); + self.shard_cycle_boundaries + } + + pub fn observe_step(&mut self, step_cycle: Cycle, step_cells: u64) { + assert!( + !self.finalized, + "shard plan cannot be extended after finalization" + ); + let target = if self.shard_id == 0 { + self.target_cell_first_shard + } else { + self.max_cell_per_shard + }; + + // always include step in current shard to simplify overall logic + self.cur_cells = self.cur_cells.saturating_add(step_cells); + self.cur_cycle_in_shard = self + .cur_cycle_in_shard + .saturating_add(FullTracer::SUBCYCLES_PER_INSN); + + let cycle_limit_hit = self.cur_cycle_in_shard >= self.max_cycle_per_shard; + let should_split = self.cur_cells >= target || cycle_limit_hit; + if should_split { + assert!( + self.cur_cells > 0 || self.cur_cycle_in_shard > 0, + "shard split before accumulating any steps" + ); + let next_shard_cycle = step_cycle + FullTracer::SUBCYCLES_PER_INSN; + self.push_boundary(next_shard_cycle); + self.shard_id += 1; + self.current_shard_start_cycle = next_shard_cycle; + self.cur_cells = 0; + self.cur_cycle_in_shard = 0; + } + } + + pub fn finalize(&mut self, max_cycle: Cycle) { + assert!( + !self.finalized, + "shard plan cannot be finalized multiple times" + ); + self.push_boundary(max_cycle); + self.finalized = true; + } + + fn push_boundary(&mut self, cycle: Cycle) { + if self + .shard_cycle_boundaries + .last() + .copied() + .unwrap_or_default() + != cycle + { + self.shard_cycle_boundaries.push(cycle); + } + } +} + #[cfg(any(test, debug_assertions))] pub struct LatestAccessIter<'a> { accesses: &'a LatestAccesses, @@ -760,14 +878,102 @@ impl FullTracer { } } -#[derive(Debug)] pub struct PreflightTracer { cycle: Cycle, pc: Change, + last_kind: InsnKind, + last_rs1: Option, mmio_min_max_access: Option>, latest_accesses: LatestAccesses, next_accesses: NextCycleAccess, register_reads_tracked: u8, + planner: Option, + current_shard_start_cycle: Cycle, + config: PreflightTracerConfig, +} + +#[derive(Clone)] +pub struct PreflightTracerConfig { + record_next_accesses: bool, + max_cell_per_shard: u64, + max_cycle_per_shard: Cycle, + step_cell_extractor: Option>, +} + +impl fmt::Debug for PreflightTracer { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("PreflightTracer") + .field("cycle", &self.cycle) + .field("pc", &self.pc) + .field("last_kind", &self.last_kind) + .field("last_rs1", &self.last_rs1) + .field("mmio_min_max_access", &self.mmio_min_max_access) + .field("latest_accesses", &self.latest_accesses) + .field("next_accesses", &self.next_accesses) + .field("register_reads_tracked", &self.register_reads_tracked) + .field("planner", &self.planner) + .field("current_shard_start_cycle", &self.current_shard_start_cycle) + .field("config", &self.config) + .finish() + } +} + +impl fmt::Debug for PreflightTracerConfig { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("PreflightTracerConfig") + .field("record_next_accesses", &self.record_next_accesses) + .field("max_cell_per_shard", &self.max_cell_per_shard) + .field("max_cycle_per_shard", &self.max_cycle_per_shard) + .field("step_cell_extractor", &self.step_cell_extractor.is_some()) + .finish() + } +} + +impl PreflightTracerConfig { + pub fn new( + record_next_accesses: bool, + max_cell_per_shard: u64, + max_cycle_per_shard: Cycle, + ) -> Self { + Self { + record_next_accesses, + max_cell_per_shard, + max_cycle_per_shard, + step_cell_extractor: None, + } + } + + pub fn record_next_accesses(&self) -> bool { + self.record_next_accesses + } + + pub fn max_cell_per_shard(&self) -> u64 { + self.max_cell_per_shard + } + + pub fn max_cycle_per_shard(&self) -> Cycle { + self.max_cycle_per_shard + } + + pub fn with_step_cell_extractor(mut self, extractor: Arc) -> Self { + self.step_cell_extractor = Some(extractor); + self + } + + pub fn step_cell_extractor(&self) -> Option> { + self.step_cell_extractor.clone() + } +} + +impl Default for PreflightTracerConfig { + fn default() -> Self { + Self { + record_next_accesses: true, + max_cell_per_shard: u64::MAX, + max_cycle_per_shard: Cycle::MAX, + step_cell_extractor: None, + } + } } impl PreflightTracer { @@ -777,19 +983,52 @@ impl PreflightTracer { pub const SUBCYCLE_MEM: Cycle = ::SUBCYCLE_MEM; pub const SUBCYCLES_PER_INSN: Cycle = ::SUBCYCLES_PER_INSN; - pub fn new(platform: &Platform) -> Self { + pub fn last_insn_kind(&self) -> InsnKind { + self.last_kind + } + + pub fn last_rs1_value(&self) -> Option { + self.last_rs1 + } + + pub fn new(platform: &Platform, config: PreflightTracerConfig) -> Self { + let mut planner_cycle_limit = config.max_cycle_per_shard(); + if planner_cycle_limit != Cycle::MAX { + // Observe-step already accounts for the current instruction, so shrink the + // limit by one instruction to keep shard boundaries aligned with callers. + planner_cycle_limit = planner_cycle_limit.saturating_sub(Self::SUBCYCLES_PER_INSN); + } + let max_cell_per_shard = config.max_cell_per_shard(); let mut tracer = PreflightTracer { cycle: ::SUBCYCLES_PER_INSN, pc: Default::default(), + last_kind: InsnKind::INVALID, + last_rs1: None, mmio_min_max_access: Some(init_mmio_min_max_access(platform)), latest_accesses: LatestAccesses::new(platform), - next_accesses: NextCycleAccess::new(ACCESSED_CHUNK_SIZE), + next_accesses: FxHashMap::default(), register_reads_tracked: 0, + planner: Some(ShardPlanBuilder::new( + max_cell_per_shard, + planner_cycle_limit, + )), + current_shard_start_cycle: ::SUBCYCLES_PER_INSN, + config, }; tracer.reset_register_tracking(); tracer } + pub fn into_shard_plan(self) -> (ShardPlanBuilder, NextCycleAccess) { + let Some(mut planner) = self.planner else { + panic!("shard planner missing") + }; + if !planner.finalized { + planner.finalize(self.cycle); + } + (planner, self.next_accesses) + } + #[inline(always)] fn update_mmio_bounds(&mut self, addr: WordAddr) { if let Some((_, (_, end_addr, min_addr, max_addr))) = self @@ -817,13 +1056,25 @@ impl PreflightTracer { impl Tracer for PreflightTracer { type Record = (); + type Config = PreflightTracerConfig; - fn new(platform: &Platform) -> Self { - PreflightTracer::new(platform) + fn new(platform: &Platform, config: Self::Config) -> Self { + PreflightTracer::new(platform, config) } #[inline(always)] fn advance(&mut self) -> Self::Record { + if let Some(planner) = self.planner.as_mut() { + // compute whether next step should bump the cycle + let step_cells = self + .config + .step_cell_extractor + .as_ref() + .map(|extractor| extractor.cells_for_kind(self.last_kind, self.last_rs1)) + .unwrap_or(0); + planner.observe_step(self.cycle, step_cells); + self.current_shard_start_cycle = planner.current_shard_start_cycle(); + } self.cycle += Self::SUBCYCLES_PER_INSN; self.reset_register_tracking(); } @@ -838,8 +1089,10 @@ impl Tracer for PreflightTracer { } #[inline(always)] - fn fetch(&mut self, pc: WordAddr, _value: Instruction) { + fn fetch(&mut self, pc: WordAddr, value: Instruction) { self.pc.before = pc.baddr(); + self.last_kind = value.kind; + self.last_rs1 = None; } #[inline(always)] @@ -849,7 +1102,7 @@ impl Tracer for PreflightTracer { fn track_mmu_maxtouch_after(&mut self) {} #[inline(always)] - fn load_register(&mut self, idx: RegIdx, _value: Word) { + fn load_register(&mut self, idx: RegIdx, value: Word) { let addr = Platform::register_vma(idx).into(); let subcycle = match self.register_reads_tracked { 0 => Self::SUBCYCLE_RS1, @@ -857,6 +1110,9 @@ impl Tracer for PreflightTracer { _ => unimplemented!("Only two register reads are supported"), }; self.register_reads_tracked += 1; + if matches!(self.last_kind, InsnKind::ECALL) && idx == Platform::reg_ecall() { + self.last_rs1 = Some(value); + } self.track_access(addr, subcycle); } @@ -886,9 +1142,12 @@ impl Tracer for PreflightTracer { fn track_access(&mut self, addr: WordAddr, subcycle: Cycle) -> Cycle { let cur_cycle = self.cycle + subcycle; let prev_cycle = self.latest_accesses.track(addr, cur_cycle); - self.next_accesses - .get_or_create(prev_cycle as usize) - .push((addr, cur_cycle)); + if self.config.record_next_accesses && prev_cycle < self.current_shard_start_cycle { + self.next_accesses + .entry(prev_cycle) + .or_default() + .push((addr, cur_cycle)); + } prev_cycle } @@ -937,8 +1196,9 @@ impl Tracer for PreflightTracer { impl Tracer for FullTracer { type Record = StepRecord; + type Config = (); - fn new(platform: &Platform) -> Self { + fn new(platform: &Platform, _config: Self::Config) -> Self { FullTracer::new(platform) } diff --git a/ceno_emul/src/vm_state.rs b/ceno_emul/src/vm_state.rs index b7e66688f..514d6435b 100644 --- a/ceno_emul/src/vm_state.rs +++ b/ceno_emul/src/vm_state.rs @@ -6,7 +6,7 @@ use crate::{ platform::Platform, rv32im::{Instruction, TrapCause}, syscalls::{SyscallEffects, handle_syscall}, - tracer::{Change, FullTracer, NextCycleAccess, Tracer}, + tracer::{Change, FullTracer, Tracer}, }; use anyhow::{Result, anyhow}; use std::{iter::from_fn, ops::Deref, sync::Arc}; @@ -46,14 +46,17 @@ impl VMState { /// 32 architectural registers + 1 register RD_NULL for dark writes to x0. pub const REG_COUNT: usize = VM_REG_COUNT; - pub fn new_with_tracer(platform: Platform, program: Arc) -> Self { - Self::new_with_tracer_and_next_accesses(platform, program, None) + pub fn new_with_tracer(platform: Platform, program: Arc) -> Self + where + T::Config: Default, + { + Self::new_with_tracer_config(platform, program, T::Config::default()) } - pub fn new_with_tracer_and_next_accesses( + pub fn new_with_tracer_config( platform: Platform, program: Arc, - next_accesses: Option>, + config: T::Config, ) -> Self { let pc = program.entry; @@ -67,7 +70,7 @@ impl VMState { ), registers: [0; VM_REG_COUNT], halt_state: None, - tracer: T::with_next_accesses(&platform, next_accesses), + tracer: T::new(&platform, config), }; for (&addr, &value) in &program.image { @@ -77,7 +80,10 @@ impl VMState { vm } - pub fn new_from_elf_with_tracer(platform: Platform, elf: &[u8]) -> Result { + pub fn new_from_elf_with_tracer(platform: Platform, elf: &[u8]) -> Result + where + T::Config: Default, + { let program = Arc::new(Program::load_elf(elf, u32::MAX)?); let platform = Platform { prog_data: Arc::new(program.image.keys().copied().collect()), diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index 8813bbec1..8fb599f27 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -1,6 +1,8 @@ use crate::{ error::ZKVMError, - instructions::riscv::{DummyExtraConfig, MemPadder, MmuConfig, Rv32imConfig}, + instructions::riscv::{ + DummyExtraConfig, InstructionDispatchBuilder, MemPadder, MmuConfig, Rv32imConfig, + }, scheme::{ PublicValues, ZKVMProof, constants::SEPTIC_EXTENSION_DEGREE, @@ -22,8 +24,8 @@ use crate::{ }; use ceno_emul::{ Addr, ByteAddr, CENO_PLATFORM, Cycle, EmuContext, FullTracer, IterAddresses, NextCycleAccess, - Platform, PreflightTracer, Program, StepRecord, Tracer, VM_REG_COUNT, VMState, WORD_SIZE, Word, - WordAddr, host_utils::read_all_messages, + Platform, PreflightTracer, PreflightTracerConfig, Program, StepCellExtractor, StepRecord, + Tracer, VM_REG_COUNT, VMState, WORD_SIZE, Word, WordAddr, host_utils::read_all_messages, }; use clap::ValueEnum; use either::Either; @@ -55,8 +57,6 @@ pub const DEFAULT_MAX_CELLS_PER_SHARDS: u64 = (1 << 30) * 16 / 4 / 2; pub const DEFAULT_MAX_CYCLE_PER_SHARDS: Cycle = 1 << 29; pub const DEFAULT_CROSS_SHARD_ACCESS_LIMIT: usize = 1 << 20; // define a relative small number to make first shard handle much less instruction -pub const DEFAULT_MAX_CELL_FIRST_SHARD: u64 = 1 << 20; - /// The polynomial commitment scheme kind #[derive( Default, @@ -117,6 +117,7 @@ pub struct EmulationResult<'a> { pub final_mem_state: FinalMemState, pub pi: PublicValues, pub shard_ctx_builder: ShardContextBuilder, + pub shard_cycle_boundaries: Arc>, pub executed_steps: usize, pub phantom: PhantomData<&'a ()>, // pub shard_ctxs: Vec>, @@ -371,19 +372,17 @@ impl<'a> ShardContext<'a> { /// then `find_future_next_access(0xabc, 4)` returns `8`. #[inline(always)] pub fn find_future_next_access(&self, cycle: Cycle, addr: WordAddr) -> Option { - self.addr_future_accesses - .get(cycle as usize) - .and_then(|res| { - if res.len() == 1 { - Some(res[0].1) - } else if res.len() > 1 { - res.iter() - .find(|(m_addr, _)| *m_addr == addr) - .map(|(_, cycle)| *cycle) - } else { - None - } - }) + self.addr_future_accesses.get(&cycle).and_then(|res| { + if res.len() == 1 && res[0].0 == addr { + Some(res[0].1) + } else if res.len() > 1 { + res.iter() + .find(|(m_addr, _)| *m_addr == addr) + .map(|(_, cycle)| *cycle) + } else { + None + } + }) } #[inline(always)] @@ -571,24 +570,44 @@ impl<'a> ShardContext<'a> { } } -pub trait StepCellExtractor { - fn extract_cells(&self, step: &StepRecord) -> u64; +#[derive(Clone, Copy, Debug, Default)] +pub struct ShardStepSummary { + pub step_count: usize, + pub first_cycle: Cycle, + pub last_cycle: Cycle, + pub first_pc_before: Addr, + pub last_pc_after: Addr, + pub first_heap_before: Addr, + pub last_heap_after: Addr, + pub first_hint_before: Addr, + pub last_hint_after: Addr, +} + +impl ShardStepSummary { + fn update(&mut self, step: &StepRecord) { + if self.step_count == 0 { + self.first_cycle = step.cycle(); + self.first_pc_before = step.pc().before.0; + self.first_heap_before = step.heap_maxtouch_addr.before.0; + self.first_hint_before = step.hint_maxtouch_addr.before.0; + } + self.step_count += 1; + self.last_cycle = step.cycle(); + self.last_pc_after = step.pc().after.0; + self.last_heap_after = step.heap_maxtouch_addr.after.0; + self.last_hint_after = step.hint_maxtouch_addr.after.0; + } } pub struct ShardContextBuilder { pub cur_shard_id: usize, addr_future_accesses: Arc, - cur_cells: u64, - cur_acc_cycle: Cycle, - max_cell_per_shard: u64, - max_cycle_per_shard: Cycle, - target_cell_first_shard: u64, prev_shard_cycle_range: Vec, prev_shard_heap_range: Vec, prev_shard_hint_range: Vec, - // holds the first step for the next shard once the current shard hits its limit - pending_step: Option, platform: Platform, + shard_cycle_boundaries: Arc>, + max_cycle: Cycle, } impl Default for ShardContextBuilder { @@ -596,148 +615,116 @@ impl Default for ShardContextBuilder { ShardContextBuilder { cur_shard_id: 0, addr_future_accesses: Arc::new(Default::default()), - cur_cells: 0, - cur_acc_cycle: 0, - max_cell_per_shard: 0, - max_cycle_per_shard: 0, - target_cell_first_shard: 0, prev_shard_cycle_range: vec![], prev_shard_heap_range: vec![], prev_shard_hint_range: vec![], - pending_step: None, platform: CENO_PLATFORM.clone(), + shard_cycle_boundaries: Arc::new(vec![FullTracer::SUBCYCLES_PER_INSN]), + max_cycle: 0, } } } impl ShardContextBuilder { - /// set max_cell_per_shard == u64::MAX if target for single shard - pub fn new( + pub fn from_plan( multi_prover: &MultiProver, platform: Platform, + shard_cycle_boundaries: Arc>, + max_cycle: Cycle, addr_future_accesses: NextCycleAccess, ) -> Self { assert_eq!(multi_prover.max_provers, 1); assert_eq!(multi_prover.prover_id, 0); ShardContextBuilder { cur_shard_id: 0, - cur_cells: 0, - cur_acc_cycle: 0, - max_cell_per_shard: multi_prover.max_cell_per_shard, - max_cycle_per_shard: multi_prover.max_cycle_per_shard, - target_cell_first_shard: { - if multi_prover.max_cell_per_shard == u64::MAX { - u64::MAX - } else { - multi_prover.max_cell_per_shard - } - }, addr_future_accesses: Arc::new(addr_future_accesses), prev_shard_cycle_range: vec![0], prev_shard_heap_range: vec![0], prev_shard_hint_range: vec![0], - pending_step: None, platform, + shard_cycle_boundaries, + max_cycle, } } + pub fn shard_cycle_boundaries(&self) -> Arc> { + self.shard_cycle_boundaries.clone() + } + + pub fn total_shards(&self) -> usize { + self.shard_cycle_boundaries.len().saturating_sub(1) + } + pub fn position_next_shard<'a>( &mut self, steps_iter: &mut impl Iterator, - step_cell_extractor: impl StepCellExtractor, - steps: &mut Vec, - ) -> Option> { - steps.clear(); - let target_cost_current_shard = if self.cur_shard_id == 0 { - self.target_cell_first_shard - } else { - self.max_cell_per_shard - }; - loop { - let step = if let Some(step) = self.pending_step.take() { - step - } else { - match steps_iter.next() { - Some(step) => step, - None => break, - } - }; - let next_cells = self.cur_cells + step_cell_extractor.extract_cells(&step); - let next_cycle = self.cur_acc_cycle + FullTracer::SUBCYCLES_PER_INSN; - if next_cells >= target_cost_current_shard || next_cycle >= self.max_cycle_per_shard { - assert!( - !steps.is_empty(), - "empty record match when splitting shards" - ); - self.pending_step = Some(step); + mut on_step: impl FnMut(StepRecord), + ) -> Option<(ShardContext<'a>, ShardStepSummary)> { + if self.cur_shard_id >= self.total_shards() { + return None; + } + let expected_end_cycle = self + .shard_cycle_boundaries + .get(self.cur_shard_id + 1) + .copied() + .expect("missing shard boundary for shard"); + let mut summary = ShardStepSummary::default(); + for step in steps_iter.by_ref() { + summary.update(&step); + on_step(step); + if summary.last_cycle + FullTracer::SUBCYCLES_PER_INSN == expected_end_cycle { break; } - self.cur_cells = next_cells; - self.cur_acc_cycle = next_cycle; - steps.push(step); } - if steps.is_empty() { + if summary.step_count == 0 { return None; } + assert_eq!( + summary.last_cycle + FullTracer::SUBCYCLES_PER_INSN, + expected_end_cycle, + "shard {} did not end on expected boundary", + self.cur_shard_id + ); + if self.cur_shard_id > 0 { assert_eq!( - steps.first().map(|step| step.cycle()).unwrap_or_default(), + summary.first_cycle, self.prev_shard_cycle_range .last() .copied() .unwrap_or(FullTracer::SUBCYCLES_PER_INSN) ); assert_eq!( - steps - .first() - .map(|step| step.heap_maxtouch_addr.before) - .unwrap_or_default(), + summary.first_heap_before, self.prev_shard_heap_range .last() .copied() .unwrap_or(self.platform.heap.start) - .into() ); assert_eq!( - steps - .first() - .map(|step| step.hint_maxtouch_addr.before) - .unwrap_or_default(), + summary.first_hint_before, self.prev_shard_hint_range .last() .copied() .unwrap_or(self.platform.hints.start) - .into() ); } let shard_ctx = ShardContext { shard_id: self.cur_shard_id, - cur_shard_cycle_range: steps.first().map(|step| step.cycle() as usize).unwrap() - ..(steps.last().unwrap().cycle() + FullTracer::SUBCYCLES_PER_INSN) as usize, + num_shards: self.total_shards(), + max_cycle: self.max_cycle, + cur_shard_cycle_range: summary.first_cycle as usize + ..(summary.last_cycle + FullTracer::SUBCYCLES_PER_INSN) as usize, addr_future_accesses: self.addr_future_accesses.clone(), prev_shard_cycle_range: self.prev_shard_cycle_range.clone(), prev_shard_heap_range: self.prev_shard_heap_range.clone(), prev_shard_hint_range: self.prev_shard_hint_range.clone(), platform: self.platform.clone(), - shard_heap_addr_range: steps - .first() - .map(|step| step.heap_maxtouch_addr.before.0) - .unwrap_or_default() - ..steps - .last() - .map(|step| step.heap_maxtouch_addr.after.0) - .unwrap_or_default(), - shard_hint_addr_range: steps - .first() - .map(|step| step.hint_maxtouch_addr.before.0) - .unwrap_or_default() - ..steps - .last() - .map(|step| step.hint_maxtouch_addr.after.0) - .unwrap_or_default(), + shard_heap_addr_range: summary.first_heap_before..summary.last_heap_after, + shard_hint_addr_range: summary.first_hint_before..summary.last_hint_after, ..Default::default() }; self.prev_shard_cycle_range @@ -746,11 +733,9 @@ impl ShardContextBuilder { .push(shard_ctx.shard_heap_addr_range.end); self.prev_shard_hint_range .push(shard_ctx.shard_hint_addr_range.end); - self.cur_cells = 0; - self.cur_acc_cycle = 0; self.cur_shard_id += 1; - Some(shard_ctx) + Some((shard_ctx, summary)) } } @@ -807,6 +792,7 @@ pub fn emulate_program<'a>( init_mem_state: &InitMemState, platform: &Platform, multi_prover: &MultiProver, + step_cell_extractor: Arc, ) -> EmulationResult<'a> { let InitMemState { mem: mem_init, @@ -817,20 +803,40 @@ pub fn emulate_program<'a>( heap: _, } = init_mem_state; - let mut vm: VMState = VMState::new_with_tracer(platform.clone(), program); + let tracer_config = PreflightTracerConfig::new( + true, + multi_prover.max_cell_per_shard, + multi_prover.max_cycle_per_shard, + ) + .with_step_cell_extractor(step_cell_extractor); + let mut vm: VMState = info_span!("[ceno] emulator.new-preflight-tracer") + .in_scope(|| { + VMState::new_with_tracer_config(platform.clone(), program.clone(), tracer_config) + }); - for record in chain!(hints_init, io_init) { - vm.init_memory(record.addr.into(), record.value); - } + info_span!("[ceno] emulator.init_mem").in_scope(|| { + for record in chain!(hints_init, io_init) { + vm.init_memory(record.addr.into(), record.value); + } + }); - let exit_code = info_span!("[ceno] emulator.preflight-execute").in_scope(|| { - vm.iter_until_halt() - .take(max_steps) - .try_for_each(|step| step.map(|_| ())) - .unwrap_or_else(|err| panic!("emulator trapped before halt: {err}")); + let exit_code = info_span!("[ceno] preflight-execute").in_scope(|| { + let mut steps = 0usize; + loop { + if steps >= max_steps { + break; + } + match vm.next_step_record() { + Ok(Some(_)) => { + steps += 1; + } + Ok(None) => break, + Err(err) => panic!("emulator trapped before halt: {err}"), + } + } vm.halted_state().map(|halt_state| halt_state.exit_code) }); - + let max_cycle = vm.tracer().cycle(); if platform.is_debug { let all_messages = read_all_messages(&vm) .iter() @@ -994,16 +1000,28 @@ pub fn emulate_program<'a>( ); } - let shard_ctx_builder = ShardContextBuilder::new( + let tracer = vm.take_tracer(); + let (plan_builder, next_accesses) = tracer.into_shard_plan(); + let shard_cycle_boundaries = Arc::new(plan_builder.into_cycle_boundaries()); + let shard_ctx_builder = ShardContextBuilder::from_plan( multi_prover, platform.clone(), - vm.take_tracer().into_next_accesses(), + shard_cycle_boundaries.clone(), + max_cycle, + next_accesses, + ); + tracing::info!( + "num_shards: {}, max_cycle {}, shard_cycle_boundaries {:?}", + shard_ctx_builder.total_shards(), + max_cycle, + shard_cycle_boundaries.as_ref() ); EmulationResult { pi, exit_code, shard_ctx_builder, + shard_cycle_boundaries: shard_cycle_boundaries.clone(), executed_steps: insts, final_mem_state: FinalMemState { reg: reg_final, @@ -1123,7 +1141,8 @@ pub fn init_static_addrs(program: &Program) -> Vec { pub struct ConstraintSystemConfig { pub zkvm_cs: ZKVMConstraintSystem, - pub config: Rv32imConfig, + pub config: Arc>, + pub inst_dispatch_builder: InstructionDispatchBuilder, pub mmu_config: MmuConfig, pub dummy_config: DummyExtraConfig, pub prog_config: ProgramTableConfig, @@ -1134,14 +1153,15 @@ pub fn construct_configs( ) -> ConstraintSystemConfig { let mut zkvm_cs = ZKVMConstraintSystem::new_with_platform(program_params); - let config = Rv32imConfig::::construct_circuits(&mut zkvm_cs); + let (config, inst_dispatch_builder) = Rv32imConfig::::construct_circuits(&mut zkvm_cs); let mmu_config = MmuConfig::::construct_circuits(&mut zkvm_cs); let dummy_config = DummyExtraConfig::::construct_circuits(&mut zkvm_cs); let prog_config = zkvm_cs.register_table_circuit::>(); zkvm_cs.register_global_state::(); ConstraintSystemConfig { zkvm_cs, - config, + config: Arc::new(config), + inst_dispatch_builder, mmu_config, dummy_config, prog_config, @@ -1195,6 +1215,7 @@ pub fn generate_witness<'a, E: ExtensionField>( "execution trace must contain at least one step" ); + let mut instrunction_dispatch_ctx = system_config.inst_dispatch_builder.to_dispatch_ctx(); let pi_template = emul_result.pi.clone(); let mut step_iter = StepReplay::new( platform.clone(), @@ -1202,29 +1223,28 @@ pub fn generate_witness<'a, E: ExtensionField>( init_mem_state, emul_result.executed_steps, ); - let mut shard_steps = Vec::new(); - std::iter::from_fn(move || { info_span!( "[ceno] app_prove.generate_witness", shard_id = shard_ctx_builder.cur_shard_id ) .in_scope(|| { - let mut shard_ctx = match shard_ctx_builder.position_next_shard( + instrunction_dispatch_ctx.begin_shard(); + let (mut shard_ctx, shard_summary) = match shard_ctx_builder.position_next_shard( &mut step_iter, - &system_config.config, - &mut shard_steps, + |step| instrunction_dispatch_ctx.ingest_step(step), ) { - Some(ctx) => ctx, + Some(result) => result, None => return None, }; let mut zkvm_witness = ZKVMWitnesses::default(); let mut pi = pi_template.clone(); tracing::debug!( - "{}th shard collect {} steps, heap_addr_range {:x} - {:x}, hint_addr_range {:x} - {:x}", + "{}th shard collect {} steps, cycles range {:?}, heap_addr_range {:x} - {:x}, hint_addr_range {:x} - {:x}", shard_ctx.shard_id, - shard_steps.len(), + shard_summary.step_count, + shard_ctx.cur_shard_cycle_range, shard_ctx.shard_heap_addr_range.start, shard_ctx.shard_heap_addr_range.end, shard_ctx.shard_hint_addr_range.start, @@ -1232,15 +1252,14 @@ pub fn generate_witness<'a, E: ExtensionField>( ); let current_shard_offset_cycle = shard_ctx.current_shard_offset_cycle(); - let last_step = shard_steps.last().expect("shard must contain steps"); - let current_shard_end_cycle = - last_step.cycle() + FullTracer::SUBCYCLES_PER_INSN - current_shard_offset_cycle; + let current_shard_end_cycle = shard_summary.last_cycle + FullTracer::SUBCYCLES_PER_INSN + - current_shard_offset_cycle; let current_shard_init_pc = if shard_ctx.is_first_shard() { program.entry } else { - shard_steps.first().unwrap().pc().before.0 + shard_summary.first_pc_before }; - let current_shard_end_pc = last_step.pc().after.0; + let current_shard_end_pc = shard_summary.last_pc_after; pi.init_pc = current_shard_init_pc; pi.init_cycle = FullTracer::SUBCYCLES_PER_INSN; @@ -1267,13 +1286,13 @@ pub fn generate_witness<'a, E: ExtensionField>( } let time = std::time::Instant::now(); - let dummy_records = system_config + system_config .config .assign_opcode_circuit( &system_config.zkvm_cs, &mut shard_ctx, + &mut instrunction_dispatch_ctx, &mut zkvm_witness, - &shard_steps, ) .unwrap(); tracing::debug!("assign_opcode_circuit finish in {:?}", time.elapsed()); @@ -1283,8 +1302,8 @@ pub fn generate_witness<'a, E: ExtensionField>( .assign_opcode_circuit( &system_config.zkvm_cs, &mut shard_ctx, + &instrunction_dispatch_ctx, &mut zkvm_witness, - dummy_records, ) .unwrap(); tracing::debug!("assign_dummy_config finish in {:?}", time.elapsed()); @@ -1375,7 +1394,6 @@ pub fn generate_witness<'a, E: ExtensionField>( "assign_dynamic_init_table_circuit finish in {:?}", time.elapsed() ); - let time = std::time::Instant::now(); system_config .mmu_config @@ -1655,12 +1673,16 @@ pub fn run_e2e_with_checkpoint< // Emulate program let start = std::time::Instant::now(); + let raw_step_cell_extractor = + Arc::clone(&prover.pk.program_ctx.as_ref().unwrap().system_config.config); + let step_cell_extractor: Arc = raw_step_cell_extractor; let emul_result = emulate_program( prover.pk.program_ctx.as_ref().unwrap().program.clone(), max_steps, &init_full_mem, &prover.pk.program_ctx.as_ref().unwrap().platform, &prover.pk.program_ctx.as_ref().unwrap().multi_prover, + step_cell_extractor, ); tracing::debug!("emulate done in {:?}", start.elapsed()); @@ -1745,12 +1767,15 @@ pub fn run_e2e_proof< ) -> Vec> { let ctx = prover.pk.program_ctx.as_ref().unwrap(); // Emulate program + let raw_step_cell_extractor = Arc::clone(&ctx.system_config.config); + let step_cell_extractor: Arc = raw_step_cell_extractor; let emul_result = emulate_program( ctx.program.clone(), max_steps, init_full_mem, &ctx.platform, &ctx.multi_prover, + step_cell_extractor, ); create_proofs_streaming( emul_result, @@ -2049,17 +2074,10 @@ pub fn verify + serde::Ser #[cfg(test)] mod tests { - use crate::e2e::{MultiProver, ShardContextBuilder, StepCellExtractor}; + use crate::e2e::{MultiProver, ShardContextBuilder}; use ceno_emul::{CENO_PLATFORM, Cycle, FullTracer, NextCycleAccess, StepRecord}; use itertools::Itertools; - - struct UniformStepExtractor; - - impl StepCellExtractor for &UniformStepExtractor { - fn extract_cells(&self, _step: &StepRecord) -> u64 { - 1 - } - } + use std::sync::Arc; #[test] fn test_single_prover_shard_ctx() { @@ -2087,23 +2105,42 @@ mod tests { executed_instruction: usize, expected_shard: usize, ) { - let mut shard_ctx_builder = ShardContextBuilder::new( + let steps = (0..executed_instruction) + .map(|i| { + StepRecord::new_ecall_any(FullTracer::SUBCYCLES_PER_INSN * (i + 1) as u64, 0.into()) + }) + .collect_vec(); + let max_cycle = steps + .last() + .map(|step| step.cycle() + FullTracer::SUBCYCLES_PER_INSN) + .unwrap_or(FullTracer::SUBCYCLES_PER_INSN); + let shard_cycle_boundaries = { + let mut boundaries = vec![FullTracer::SUBCYCLES_PER_INSN]; + let mut cur_cycle_in_shard = 0; + for step in &steps { + let next_cycle = cur_cycle_in_shard + FullTracer::SUBCYCLES_PER_INSN; + if max_cycle_per_shard < Cycle::MAX && next_cycle >= max_cycle_per_shard { + boundaries.push(step.cycle()); + cur_cycle_in_shard = FullTracer::SUBCYCLES_PER_INSN; + } else { + cur_cycle_in_shard = next_cycle; + } + } + boundaries.push(max_cycle); + Arc::new(boundaries) + }; + let mut shard_ctx_builder = ShardContextBuilder::from_plan( &MultiProver::new(0, 1, u64::MAX, max_cycle_per_shard), CENO_PLATFORM.clone(), + shard_cycle_boundaries, + max_cycle, NextCycleAccess::default(), ); - - let mut steps_iter = (0..executed_instruction).map(|i| { - StepRecord::new_ecall_any(FullTracer::SUBCYCLES_PER_INSN * (i + 1) as u64, 0.into()) - }); - let mut steps = Vec::new(); - + let mut steps_iter = steps.into_iter(); let shard_ctx = std::iter::from_fn(|| { - shard_ctx_builder.position_next_shard( - &mut steps_iter, - &UniformStepExtractor {}, - &mut steps, - ) + shard_ctx_builder + .position_next_shard(&mut steps_iter, |_| {}) + .map(|(ctx, _)| ctx) }) .collect_vec(); diff --git a/ceno_zkvm/src/instructions.rs b/ceno_zkvm/src/instructions.rs index 6afd64b45..df2c24ff9 100644 --- a/ceno_zkvm/src/instructions.rs +++ b/ceno_zkvm/src/instructions.rs @@ -23,11 +23,14 @@ pub mod riscv; pub trait Instruction { type InstructionConfig: Send + Sync; + type InsnType: Clone + Copy; fn padding_strategy() -> InstancePaddingStrategy { InstancePaddingStrategy::Default } + fn inst_kinds() -> &'static [Self::InsnType]; + fn name() -> String; /// construct circuit and manipulate circuit builder, then return the respective config @@ -98,7 +101,7 @@ pub trait Instruction { shard_ctx: &mut ShardContext, num_witin: usize, num_structural_witin: usize, - steps: Vec<&StepRecord>, + steps: &[StepRecord], ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { // TODO: selector is the only structural witness // this is workaround, as call `construct_circuit` will not initialized selector diff --git a/ceno_zkvm/src/instructions/riscv.rs b/ceno_zkvm/src/instructions/riscv.rs index 69c656148..c77b707b4 100644 --- a/ceno_zkvm/src/instructions/riscv.rs +++ b/ceno_zkvm/src/instructions/riscv.rs @@ -2,7 +2,7 @@ use ceno_emul::InsnKind; mod rv32im; pub use rv32im::{ - DummyExtraConfig, Rv32imConfig, + DummyExtraConfig, InstructionDispatchBuilder, InstructionDispatchCtx, Rv32imConfig, mmu::{MemPadder, MmuConfig}, }; diff --git a/ceno_zkvm/src/instructions/riscv/arith.rs b/ceno_zkvm/src/instructions/riscv/arith.rs index 1bc0768d2..260245931 100644 --- a/ceno_zkvm/src/instructions/riscv/arith.rs +++ b/ceno_zkvm/src/instructions/riscv/arith.rs @@ -34,6 +34,11 @@ pub type SubInstruction = ArithInstruction; impl Instruction for ArithInstruction { type InstructionConfig = ArithConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[I::INST_KIND] + } fn name() -> String { format!("{:?}", I::INST_KIND) @@ -190,7 +195,7 @@ mod test { &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, - vec![&StepRecord::new_r_instruction( + &[StepRecord::new_r_instruction( 3, MOCK_PC_START, insn_code, diff --git a/ceno_zkvm/src/instructions/riscv/arith_imm.rs b/ceno_zkvm/src/instructions/riscv/arith_imm.rs index 7afb65d7f..a1c1d4403 100644 --- a/ceno_zkvm/src/instructions/riscv/arith_imm.rs +++ b/ceno_zkvm/src/instructions/riscv/arith_imm.rs @@ -67,7 +67,7 @@ mod test { &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, - vec![&StepRecord::new_i_instruction( + &[StepRecord::new_i_instruction( 3, Change::new(MOCK_PC_START, MOCK_PC_START + PC_STEP_SIZE), insn_code, diff --git a/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit.rs b/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit.rs index 11d93242c..909171986 100644 --- a/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit.rs @@ -11,7 +11,7 @@ use crate::{ tables::InsnRecord, witness::LkMultiplicity, }; -use ceno_emul::StepRecord; +use ceno_emul::{InsnKind, StepRecord}; use ff_ext::ExtensionField; use std::marker::PhantomData; @@ -27,6 +27,11 @@ pub struct InstructionConfig { impl Instruction for AddiInstruction { type InstructionConfig = InstructionConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[InsnKind::ADDI] + } fn name() -> String { format!("{:?}", Self::INST_KIND) diff --git a/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs index 8ed175d58..027483d1e 100644 --- a/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs @@ -11,7 +11,7 @@ use crate::{ utils::{imm_sign_extend, imm_sign_extend_circuit}, witness::LkMultiplicity, }; -use ceno_emul::StepRecord; +use ceno_emul::{InsnKind, StepRecord}; use ff_ext::{ExtensionField, FieldInto}; use multilinear_extensions::{ToExpr, WitIn}; use p3::field::FieldAlgebra; @@ -32,6 +32,11 @@ pub struct InstructionConfig { impl Instruction for AddiInstruction { type InstructionConfig = InstructionConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[InsnKind::ADDI] + } fn name() -> String { format!("{:?}", Self::INST_KIND) diff --git a/ceno_zkvm/src/instructions/riscv/auipc.rs b/ceno_zkvm/src/instructions/riscv/auipc.rs index 1e984546d..ce7c64a95 100644 --- a/ceno_zkvm/src/instructions/riscv/auipc.rs +++ b/ceno_zkvm/src/instructions/riscv/auipc.rs @@ -37,6 +37,11 @@ pub struct AuipcInstruction(PhantomData); impl Instruction for AuipcInstruction { type InstructionConfig = AuipcConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[InsnKind::AUIPC] + } fn name() -> String { format!("{:?}", InsnKind::AUIPC) @@ -245,7 +250,7 @@ mod tests { &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, - vec![&StepRecord::new_i_instruction( + &[StepRecord::new_i_instruction( 3, Change::new(MOCK_PC_START, MOCK_PC_START + PC_STEP_SIZE), insn_code, diff --git a/ceno_zkvm/src/instructions/riscv/branch/branch_circuit.rs b/ceno_zkvm/src/instructions/riscv/branch/branch_circuit.rs index 2c97a12ee..3622dad73 100644 --- a/ceno_zkvm/src/instructions/riscv/branch/branch_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/branch/branch_circuit.rs @@ -35,12 +35,17 @@ pub struct BranchConfig { } impl Instruction for BranchCircuit { + type InstructionConfig = BranchConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[I::INST_KIND] + } + fn name() -> String { format!("{:?}", I::INST_KIND) } - type InstructionConfig = BranchConfig; - fn construct_circuit( circuit_builder: &mut CircuitBuilder, _params: &ProgramParams, diff --git a/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs index a6aa1edc4..85ef6914b 100644 --- a/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs @@ -39,6 +39,11 @@ pub struct BranchConfig { impl Instruction for BranchCircuit { type InstructionConfig = BranchConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[I::INST_KIND] + } fn name() -> String { format!("{:?}", I::INST_KIND) diff --git a/ceno_zkvm/src/instructions/riscv/branch/test.rs b/ceno_zkvm/src/instructions/riscv/branch/test.rs index 286a60432..67f098ff0 100644 --- a/ceno_zkvm/src/instructions/riscv/branch/test.rs +++ b/ceno_zkvm/src/instructions/riscv/branch/test.rs @@ -41,7 +41,7 @@ fn impl_opcode_beq(take_branch: bool, a: u32, b: u32) { &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, - vec![&StepRecord::new_b_instruction( + &[StepRecord::new_b_instruction( 3, Change::new(MOCK_PC_START, MOCK_PC_START + pc_offset), insn_code, @@ -83,7 +83,7 @@ fn impl_opcode_bne(take_branch: bool, a: u32, b: u32) { &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, - vec![&StepRecord::new_b_instruction( + &[StepRecord::new_b_instruction( 3, Change::new(MOCK_PC_START, MOCK_PC_START + pc_offset), insn_code, @@ -127,7 +127,7 @@ fn impl_bltu_circuit(taken: bool, a: u32, b: u32) -> Result<(), ZKVMError> { &mut ShardContext::default(), circuit_builder.cs.num_witin as usize, circuit_builder.cs.num_structural_witin as usize, - vec![&StepRecord::new_b_instruction( + &[StepRecord::new_b_instruction( 12, Change::new(MOCK_PC_START, pc_after), insn_code, @@ -172,7 +172,7 @@ fn impl_bgeu_circuit(taken: bool, a: u32, b: u32) -> Result<(), ZKVMError> { &mut ShardContext::default(), circuit_builder.cs.num_witin as usize, circuit_builder.cs.num_structural_witin as usize, - vec![&StepRecord::new_b_instruction( + &[StepRecord::new_b_instruction( 12, Change::new(MOCK_PC_START, pc_after), insn_code, @@ -224,7 +224,7 @@ fn impl_blt_circuit(taken: bool, a: i32, b: i32) -> Result<() &mut ShardContext::default(), circuit_builder.cs.num_witin as usize, circuit_builder.cs.num_structural_witin as usize, - vec![&StepRecord::new_b_instruction( + &[StepRecord::new_b_instruction( 12, Change::new(MOCK_PC_START, pc_after), insn_code, @@ -276,7 +276,7 @@ fn impl_bge_circuit(taken: bool, a: i32, b: i32) -> Result<() &mut ShardContext::default(), circuit_builder.cs.num_witin as usize, circuit_builder.cs.num_structural_witin as usize, - vec![&StepRecord::new_b_instruction( + &[StepRecord::new_b_instruction( 12, Change::new(MOCK_PC_START, pc_after), insn_code, diff --git a/ceno_zkvm/src/instructions/riscv/div.rs b/ceno_zkvm/src/instructions/riscv/div.rs index dda09370a..85718ad24 100644 --- a/ceno_zkvm/src/instructions/riscv/div.rs +++ b/ceno_zkvm/src/instructions/riscv/div.rs @@ -64,7 +64,7 @@ mod test { scheme::mock_prover::{MOCK_PC_START, MockProver}, structs::ProgramParams, }; - use ceno_emul::{Change, InsnKind, StepRecord, encode_rv32}; + use ceno_emul::{Change, StepRecord, encode_rv32}; #[cfg(feature = "u16limb_circuit")] use ff_ext::BabyBearExt4 as BE; use ff_ext::{ExtensionField, GoldilocksExt2 as GE}; @@ -84,7 +84,6 @@ mod test { fn output(config: Self::InstructionConfig) -> UInt; // the correct/expected value for given parameters fn correct(dividend: Self::NumType, divisor: Self::NumType) -> Self::NumType; - const INSN_KIND: InsnKind; } impl TestInstance for DivInstruction { @@ -102,7 +101,6 @@ mod test { dividend.wrapping_div(divisor) } } - const INSN_KIND: InsnKind = InsnKind::DIV; } impl TestInstance for RemInstruction { @@ -120,7 +118,6 @@ mod test { dividend.wrapping_rem(divisor) } } - const INSN_KIND: InsnKind = InsnKind::REM; } impl TestInstance for DivuInstruction { @@ -138,7 +135,6 @@ mod test { dividend / divisor } } - const INSN_KIND: InsnKind = InsnKind::DIVU; } impl TestInstance for RemuInstruction { @@ -156,10 +152,12 @@ mod test { dividend % divisor } } - const INSN_KIND: InsnKind = InsnKind::REMU; } - fn verify + TestInstance>( + fn verify< + E: ExtensionField, + Insn: Instruction + TestInstance, + >( name: &str, dividend: >::NumType, divisor: >::NumType, @@ -176,14 +174,18 @@ mod test { .unwrap() .unwrap(); let outcome = Insn::correct(dividend, divisor); - let insn_code = encode_rv32(Insn::INSN_KIND, 2, 3, 4, 0); + let insn_kind = Insn::inst_kinds() + .first() + .copied() + .expect("instruction must declare at least one InsnKind"); + let insn_code = encode_rv32(insn_kind, 2, 3, 4, 0); // values assignment let ([raw_witin, _], lkm) = Insn::assign_instances( &config, &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, - vec![&StepRecord::new_r_instruction( + &[StepRecord::new_r_instruction( 3, MOCK_PC_START, insn_code, @@ -222,7 +224,10 @@ mod test { } // shortcut to verify given pair produces correct output - fn verify_positive + TestInstance>( + fn verify_positive< + E: ExtensionField, + Insn: Instruction + TestInstance, + >( name: &str, dividend: >::NumType, divisor: >::NumType, diff --git a/ceno_zkvm/src/instructions/riscv/div/div_circuit.rs b/ceno_zkvm/src/instructions/riscv/div/div_circuit.rs index 99a73a8a4..da754fd98 100644 --- a/ceno_zkvm/src/instructions/riscv/div/div_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/div/div_circuit.rs @@ -118,6 +118,11 @@ pub struct ArithInstruction(PhantomData<(E, I)>); impl Instruction for ArithInstruction { type InstructionConfig = DivRemConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[I::INST_KIND] + } fn name() -> String { format!("{:?}", I::INST_KIND) diff --git a/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs index f062ea949..eb1a5d0f9 100644 --- a/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs @@ -48,6 +48,11 @@ pub struct ArithInstruction(PhantomData<(E, I)>); impl Instruction for ArithInstruction { type InstructionConfig = DivRemConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[I::INST_KIND] + } fn name() -> String { format!("{:?}", I::INST_KIND) diff --git a/ceno_zkvm/src/instructions/riscv/dummy/dummy_circuit.rs b/ceno_zkvm/src/instructions/riscv/dummy/dummy_circuit.rs index 1df279dd9..9d7ad95a9 100644 --- a/ceno_zkvm/src/instructions/riscv/dummy/dummy_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/dummy/dummy_circuit.rs @@ -1,17 +1,13 @@ -use std::marker::PhantomData; - -use ceno_emul::{InsnCategory, InsnFormat, InsnKind, StepRecord}; +use ceno_emul::{InsnKind, StepRecord}; use ff_ext::ExtensionField; use super::super::{ - RIVInstruction, constants::UInt, insn_base::{ReadMEM, ReadRS1, ReadRS2, StateInOut, WriteMEM, WriteRD}, }; use crate::{ chip_handler::general::InstFetch, circuit_builder::CircuitBuilder, e2e::ShardContext, - error::ZKVMError, instructions::Instruction, structs::ProgramParams, tables::InsnRecord, - uint::Value, witness::LkMultiplicity, + error::ZKVMError, tables::InsnRecord, uint::Value, witness::LkMultiplicity, }; use ff_ext::FieldInto; use multilinear_extensions::{ToExpr, WitIn}; @@ -19,66 +15,6 @@ use multilinear_extensions::{ToExpr, WitIn}; use p3::field::FieldAlgebra; use witness::set_val; -/// DummyInstruction can handle any instruction and produce its side-effects. -pub struct DummyInstruction(PhantomData<(E, I)>); - -impl Instruction for DummyInstruction { - type InstructionConfig = DummyConfig; - - fn name() -> String { - format!("{:?}_DUMMY", I::INST_KIND) - } - - fn construct_circuit( - circuit_builder: &mut CircuitBuilder, - _params: &ProgramParams, - ) -> Result { - let kind = I::INST_KIND; - let format = InsnFormat::from(kind); - let category = InsnCategory::from(kind); - - // ECALL can do everything. - let is_ecall = matches!(kind, InsnKind::ECALL); - - // Regular instructions do what is implied by their format. - let (with_rs1, with_rs2, with_rd) = match format { - _ if is_ecall => (true, true, true), - InsnFormat::R => (true, true, true), - InsnFormat::I => (true, false, true), - InsnFormat::S => (true, true, false), - InsnFormat::B => (true, true, false), - InsnFormat::U => (false, false, true), - InsnFormat::J => (false, false, true), - }; - let with_mem_write = matches!(category, InsnCategory::Store) || is_ecall; - let with_mem_read = matches!(category, InsnCategory::Load); - let branching = matches!(category, InsnCategory::Branch) - || matches!(kind, InsnKind::JAL | InsnKind::JALR) - || is_ecall; - - DummyConfig::construct_circuit( - circuit_builder, - I::INST_KIND, - with_rs1, - with_rs2, - with_rd, - with_mem_write, - with_mem_read, - branching, - ) - } - - fn assign_instance( - config: &Self::InstructionConfig, - shard_ctx: &mut ShardContext, - instance: &mut [::BaseField], - lk_multiplicity: &mut LkMultiplicity, - step: &StepRecord, - ) -> Result<(), ZKVMError> { - config.assign_instance(instance, shard_ctx, lk_multiplicity, step) - } -} - #[derive(Debug)] pub struct MemAddrVal { mem_addr: WitIn, diff --git a/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs b/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs index 8c7a9852d..3ae516e9c 100644 --- a/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs +++ b/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs @@ -29,6 +29,11 @@ pub struct LargeEcallDummy(PhantomData<(E, S)>); impl Instruction for LargeEcallDummy { type InstructionConfig = LargeEcallConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[InsnKind::ECALL] + } fn name() -> String { S::NAME.to_owned() diff --git a/ceno_zkvm/src/instructions/riscv/dummy/mod.rs b/ceno_zkvm/src/instructions/riscv/dummy/mod.rs index c016ff643..3874bc7e1 100644 --- a/ceno_zkvm/src/instructions/riscv/dummy/mod.rs +++ b/ceno_zkvm/src/instructions/riscv/dummy/mod.rs @@ -1,16 +1,7 @@ -//! Dummy instruction circuits for testing. -//! Support instructions that don’t have a complete implementation yet. -//! It connects all the state together (register writes, etc), but does not verify the values. -//! -//! Usage: -//! Specify an instruction with `trait RIVInstruction` and define a `DummyInstruction` like so: -//! -//! use ceno_zkvm::instructions::riscv::{arith::AddOp, dummy::DummyInstruction}; -//! -//! type AddDummy = DummyInstruction; +//! Helper dummy circuits for testing and large ECALLs. mod dummy_circuit; -pub use dummy_circuit::{DummyConfig, DummyInstruction}; +pub use dummy_circuit::DummyConfig; mod dummy_ecall; pub use dummy_ecall::LargeEcallDummy; diff --git a/ceno_zkvm/src/instructions/riscv/dummy/test.rs b/ceno_zkvm/src/instructions/riscv/dummy/test.rs index 6e068a07f..8cccc8d49 100644 --- a/ceno_zkvm/src/instructions/riscv/dummy/test.rs +++ b/ceno_zkvm/src/instructions/riscv/dummy/test.rs @@ -1,66 +1,22 @@ -use ceno_emul::{Change, InsnKind, KeccakSpec, StepRecord, encode_rv32}; +use ceno_emul::KeccakSpec; use ff_ext::GoldilocksExt2; -use super::*; +use super::LargeEcallDummy; use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, e2e::ShardContext, - instructions::{ - Instruction, - riscv::{arith::AddOp, branch::BeqOp, ecall::EcallDummy}, - }, - scheme::mock_prover::{MOCK_PC_START, MockProver}, + instructions::Instruction, + scheme::mock_prover::MockProver, structs::ProgramParams, }; -type AddDummy = DummyInstruction; -type BeqDummy = DummyInstruction; - -#[test] -fn test_dummy_ecall() { - let mut cs = ConstraintSystem::::new(|| "riscv"); - let mut cb = CircuitBuilder::new(&mut cs); - let config = cb - .namespace( - || "ecall_dummy", - |cb| { - let config = EcallDummy::construct_circuit(cb, &ProgramParams::default()); - Ok(config) - }, - ) - .unwrap() - .unwrap(); - - let step = StepRecord::new_ecall_any(4, MOCK_PC_START); - let insn_code = step.insn(); - let (raw_witin, lkm) = EcallDummy::assign_instances( - &config, - &mut ShardContext::default(), - cb.cs.num_witin as usize, - cb.cs.num_structural_witin as usize, - vec![&step], - ) - .unwrap(); - - MockProver::assert_satisfied_raw(&cb, raw_witin, &[insn_code], None, Some(lkm)); -} - #[test] -fn test_dummy_keccak() { +fn test_large_ecall_dummy_keccak() { type KeccakDummy = LargeEcallDummy; let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); - let config = cb - .namespace( - || "keccak_dummy", - |cb| { - let config = KeccakDummy::construct_circuit(cb, &ProgramParams::default()); - Ok(config) - }, - ) - .unwrap() - .unwrap(); + let config = KeccakDummy::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); let (step, program) = ceno_emul::test_utils::keccak_step(); let (raw_witin, lkm) = KeccakDummy::assign_instances( @@ -68,80 +24,9 @@ fn test_dummy_keccak() { &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, - vec![&step], + &[step], ) .unwrap(); MockProver::assert_satisfied_raw(&cb, raw_witin, &program, None, Some(lkm)); } - -#[test] -fn test_dummy_r() { - let mut cs = ConstraintSystem::::new(|| "riscv"); - let mut cb = CircuitBuilder::new(&mut cs); - let config = cb - .namespace( - || "add_dummy", - |cb| { - let config = AddDummy::construct_circuit(cb, &ProgramParams::default()); - Ok(config) - }, - ) - .unwrap() - .unwrap(); - - let insn_code = encode_rv32(InsnKind::ADD, 2, 3, 4, 0); - let (raw_witin, lkm) = AddDummy::assign_instances( - &config, - &mut ShardContext::default(), - cb.cs.num_witin as usize, - cb.cs.num_structural_witin as usize, - vec![&StepRecord::new_r_instruction( - 3, - MOCK_PC_START, - insn_code, - 11, - 0xfffffffe, - Change::new(0, 11_u32.wrapping_add(0xfffffffe)), - 0, - )], - ) - .unwrap(); - - MockProver::assert_satisfied_raw(&cb, raw_witin, &[insn_code], None, Some(lkm)); -} - -#[test] -fn test_dummy_b() { - let mut cs = ConstraintSystem::::new(|| "riscv"); - let mut cb = CircuitBuilder::new(&mut cs); - let config = cb - .namespace( - || "beq_dummy", - |cb| { - let config = BeqDummy::construct_circuit(cb, &ProgramParams::default()); - Ok(config) - }, - ) - .unwrap() - .unwrap(); - - let insn_code = encode_rv32(InsnKind::BEQ, 2, 3, 0, 8); - let (raw_witin, lkm) = BeqDummy::assign_instances( - &config, - &mut ShardContext::default(), - cb.cs.num_witin as usize, - cb.cs.num_structural_witin as usize, - vec![&StepRecord::new_b_instruction( - 3, - Change::new(MOCK_PC_START, MOCK_PC_START + 8_usize), - insn_code, - 0xbead1010, - 0xbead1010, - 0, - )], - ) - .unwrap(); - - MockProver::assert_satisfied_raw(&cb, raw_witin, &[insn_code], None, Some(lkm)); -} diff --git a/ceno_zkvm/src/instructions/riscv/ecall.rs b/ceno_zkvm/src/instructions/riscv/ecall.rs index ce4427386..8f1300bad 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall.rs @@ -17,14 +17,4 @@ pub use weierstrass_add::WeierstrassAddAssignInstruction; pub use weierstrass_decompress::WeierstrassDecompressInstruction; pub use weierstrass_double::WeierstrassDoubleAssignInstruction; -use ceno_emul::InsnKind; pub use halt::HaltInstruction; - -use super::{RIVInstruction, dummy::DummyInstruction}; - -pub struct EcallOp; -impl RIVInstruction for EcallOp { - const INST_KIND: InsnKind = InsnKind::ECALL; -} -/// Unsafe. A dummy ecall circuit that ignores unimplemented functions. -pub type EcallDummy = DummyInstruction; diff --git a/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp.rs b/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp.rs index 66f47b59b..7eeba1ab0 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp.rs @@ -77,6 +77,11 @@ impl Instruction for FpAddInstruction { type InstructionConfig = EcallFpOpConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[InsnKind::ECALL] + } fn name() -> String { "Ecall_FpAdd".to_string() @@ -120,7 +125,7 @@ impl Instruction shard_ctx: &mut ShardContext, num_witin: usize, num_structural_witin: usize, - steps: Vec<&StepRecord>, + steps: &[StepRecord], ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { assign_fp_op_instances::( config, @@ -140,6 +145,11 @@ impl Instruction for FpMulInstruction { type InstructionConfig = EcallFpOpConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[InsnKind::ECALL] + } fn name() -> String { "Ecall_FpMul".to_string() @@ -183,7 +193,7 @@ impl Instruction shard_ctx: &mut ShardContext, num_witin: usize, num_structural_witin: usize, - steps: Vec<&StepRecord>, + steps: &[StepRecord], ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { assign_fp_op_instances::( config, @@ -300,7 +310,7 @@ fn assign_fp_op_instances( shard_ctx: &mut ShardContext, num_witin: usize, num_structural_witin: usize, - steps: Vec<&StepRecord>, + steps: &[StepRecord], syscall_code: u32, op: FieldOperation, ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { diff --git a/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_add.rs b/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_add.rs index c4a7ca0d2..6552b6241 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_add.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_add.rs @@ -68,6 +68,11 @@ impl Instruction for Fp2AddInstruction { type InstructionConfig = EcallFp2AddConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[InsnKind::ECALL] + } fn name() -> String { "Ecall_Fp2Add".to_string() @@ -111,7 +116,7 @@ impl Instruction shard_ctx: &mut ShardContext, num_witin: usize, num_structural_witin: usize, - steps: Vec<&StepRecord>, + steps: &[StepRecord], ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { assign_fp2_add_instances::(config, shard_ctx, num_witin, num_structural_witin, steps) } @@ -219,7 +224,7 @@ fn assign_fp2_add_instances, + steps: &[StepRecord], ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { let mut lk_multiplicity = LkMultiplicity::default(); if steps.is_empty() { diff --git a/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_mul.rs b/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_mul.rs index a3d7b63d5..709e9734d 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_mul.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_mul.rs @@ -67,6 +67,11 @@ impl Instruction for Fp2MulInstruction { type InstructionConfig = EcallFp2MulConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[InsnKind::ECALL] + } fn name() -> String { "Ecall_Fp2Mul".to_string() @@ -110,7 +115,7 @@ impl Instruction shard_ctx: &mut ShardContext, num_witin: usize, num_structural_witin: usize, - steps: Vec<&StepRecord>, + steps: &[StepRecord], ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { assign_fp2_mul_instances::(config, shard_ctx, num_witin, num_structural_witin, steps) } @@ -217,7 +222,7 @@ fn assign_fp2_mul_instances, + steps: &[StepRecord], ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { let mut lk_multiplicity = LkMultiplicity::default(); if steps.is_empty() { diff --git a/ceno_zkvm/src/instructions/riscv/ecall/halt.rs b/ceno_zkvm/src/instructions/riscv/ecall/halt.rs index d30d7a97c..5ba6df208 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/halt.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/halt.rs @@ -14,7 +14,7 @@ use crate::{ structs::{ProgramParams, RAMType}, witness::LkMultiplicity, }; -use ceno_emul::{FullTracer as Tracer, StepRecord}; +use ceno_emul::{FullTracer as Tracer, InsnKind, StepRecord}; use ff_ext::{ExtensionField, FieldInto}; use multilinear_extensions::{ToExpr, WitIn}; use p3::field::FieldAlgebra; @@ -31,6 +31,11 @@ pub struct HaltInstruction(PhantomData); impl Instruction for HaltInstruction { type InstructionConfig = HaltConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[InsnKind::ECALL] + } fn name() -> String { "ECALL_HALT".into() diff --git a/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs b/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs index b0fabff2b..51568b56a 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs @@ -54,6 +54,11 @@ pub struct KeccakInstruction(PhantomData); impl Instruction for KeccakInstruction { type InstructionConfig = EcallKeccakConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[InsnKind::ECALL] + } fn name() -> String { "Ecall_Keccak".to_string() @@ -169,7 +174,7 @@ impl Instruction for KeccakInstruction { shard_ctx: &mut ShardContext, num_witin: usize, num_structural_witin: usize, - steps: Vec<&StepRecord>, + steps: &[StepRecord], ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { let mut lk_multiplicity = LkMultiplicity::default(); if steps.is_empty() { diff --git a/ceno_zkvm/src/instructions/riscv/ecall/uint256.rs b/ceno_zkvm/src/instructions/riscv/ecall/uint256.rs index b305e7e58..8b38df7e1 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/uint256.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/uint256.rs @@ -68,6 +68,11 @@ pub struct Uint256MulInstruction(PhantomData); impl Instruction for Uint256MulInstruction { type InstructionConfig = EcallUint256MulConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[InsnKind::ECALL] + } fn name() -> String { "Ecall_Uint256Mul".to_string() @@ -221,7 +226,7 @@ impl Instruction for Uint256MulInstruction { shard_ctx: &mut ShardContext, num_witin: usize, num_structural_witin: usize, - steps: Vec<&StepRecord>, + steps: &[StepRecord], ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { let syscall_code = UINT256_MUL; @@ -417,6 +422,11 @@ pub struct EcallUint256InvConfig { impl Instruction for Uint256InvInstruction { type InstructionConfig = EcallUint256InvConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[InsnKind::ECALL] + } fn name() -> String { Spec::name() @@ -536,7 +546,7 @@ impl Instruction for Uint256InvInstr shard_ctx: &mut ShardContext, num_witin: usize, num_structural_witin: usize, - steps: Vec<&StepRecord>, + steps: &[StepRecord], ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { let syscall_code = Spec::syscall(); diff --git a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs index e960190f3..05a91cd97 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs @@ -59,6 +59,11 @@ impl Instruction for WeierstrassAddAssignInstruction { type InstructionConfig = EcallWeierstrassAddAssignConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[InsnKind::ECALL] + } fn name() -> String { "Ecall_WeierstrassAddAssign_".to_string() + format!("{:?}", EC::CURVE_TYPE).as_str() @@ -221,7 +226,7 @@ impl Instruction shard_ctx: &mut ShardContext, num_witin: usize, num_structural_witin: usize, - steps: Vec<&StepRecord>, + steps: &[StepRecord], ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { let syscall_code = match EC::CURVE_TYPE { CurveType::Secp256k1 => SECP256K1_ADD, diff --git a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_decompress.rs b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_decompress.rs index 67af36f4d..6d9a7470b 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_decompress.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_decompress.rs @@ -66,6 +66,11 @@ impl Instruction { type InstructionConfig = EcallWeierstrassDecompressConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[InsnKind::ECALL] + } fn name() -> String { "Ecall_WeierstrassDecompress_".to_string() + format!("{:?}", EC::CURVE_TYPE).as_str() @@ -222,7 +227,7 @@ impl Instruction, + steps: &[StepRecord], ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { let syscall_code = match EC::CURVE_TYPE { CurveType::Secp256k1 => SECP256K1_DECOMPRESS, diff --git a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_double.rs b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_double.rs index 19bb8cf69..4b9a2aeb6 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_double.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_double.rs @@ -61,6 +61,11 @@ impl Instruction { type InstructionConfig = EcallWeierstrassDoubleAssignConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[InsnKind::ECALL] + } fn name() -> String { "Ecall_WeierstrassDoubleAssign_".to_string() + format!("{:?}", EC::CURVE_TYPE).as_str() @@ -193,7 +198,7 @@ impl Instruction, + steps: &[StepRecord], ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { let syscall_code = match EC::CURVE_TYPE { CurveType::Secp256k1 => SECP256K1_DOUBLE, diff --git a/ceno_zkvm/src/instructions/riscv/jump/jal.rs b/ceno_zkvm/src/instructions/riscv/jump/jal.rs index c8abc77ac..14566b477 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/jal.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/jal.rs @@ -37,6 +37,11 @@ pub struct JalInstruction(PhantomData); /// of native WitIn values for address space arithmetic. impl Instruction for JalInstruction { type InstructionConfig = JalConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[InsnKind::JAL] + } fn name() -> String { format!("{:?}", InsnKind::JAL) diff --git a/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs b/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs index 545adf275..a766ea795 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs @@ -42,6 +42,11 @@ pub struct JalInstruction(PhantomData); /// of native WitIn values for address space arithmetic. impl Instruction for JalInstruction { type InstructionConfig = JalConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[InsnKind::JAL] + } fn name() -> String { format!("{:?}", InsnKind::JAL) diff --git a/ceno_zkvm/src/instructions/riscv/jump/jalr.rs b/ceno_zkvm/src/instructions/riscv/jump/jalr.rs index 77f6ad1f8..2331c2f82 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/jalr.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/jalr.rs @@ -37,6 +37,11 @@ pub struct JalrInstruction(PhantomData); /// the program table impl Instruction for JalrInstruction { type InstructionConfig = JalrConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[InsnKind::JALR] + } fn name() -> String { format!("{:?}", InsnKind::JALR) diff --git a/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs b/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs index 7f23ac9b6..7c51728ac 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs @@ -42,6 +42,11 @@ pub struct JalrInstruction(PhantomData); /// the program table impl Instruction for JalrInstruction { type InstructionConfig = JalrConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[InsnKind::JALR] + } fn name() -> String { format!("{:?}", InsnKind::JALR) diff --git a/ceno_zkvm/src/instructions/riscv/jump/test.rs b/ceno_zkvm/src/instructions/riscv/jump/test.rs index 51bc63cd8..355dad511 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/test.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/test.rs @@ -46,7 +46,7 @@ fn verify_test_opcode_jal(pc_offset: i32) { &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, - vec![&StepRecord::new_j_instruction( + &[StepRecord::new_j_instruction( 4, Change::new(MOCK_PC_START, new_pc), insn_code, @@ -122,7 +122,7 @@ fn verify_test_opcode_jalr(rs1_read: Word, imm: i32) { &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, - vec![&StepRecord::new_i_instruction( + &[StepRecord::new_i_instruction( 4, Change::new(MOCK_PC_START, new_pc), insn_code, diff --git a/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs b/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs index 5a2d8e404..4d2cf6db8 100644 --- a/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs @@ -29,6 +29,11 @@ pub struct LogicInstruction(PhantomData<(E, I)>); impl Instruction for LogicInstruction { type InstructionConfig = LogicConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[I::INST_KIND] + } fn name() -> String { format!("{:?}", I::INST_KIND) diff --git a/ceno_zkvm/src/instructions/riscv/logic/test.rs b/ceno_zkvm/src/instructions/riscv/logic/test.rs index 5fcb17a62..6bade9c0f 100644 --- a/ceno_zkvm/src/instructions/riscv/logic/test.rs +++ b/ceno_zkvm/src/instructions/riscv/logic/test.rs @@ -35,7 +35,7 @@ fn test_opcode_and() { &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, - vec![&StepRecord::new_r_instruction( + &[StepRecord::new_r_instruction( 3, MOCK_PC_START, insn_code, @@ -78,7 +78,7 @@ fn test_opcode_or() { &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, - vec![&StepRecord::new_r_instruction( + &[StepRecord::new_r_instruction( 3, MOCK_PC_START, insn_code, @@ -121,7 +121,7 @@ fn test_opcode_xor() { &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, - vec![&StepRecord::new_r_instruction( + &[StepRecord::new_r_instruction( 3, MOCK_PC_START, insn_code, diff --git a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit.rs b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit.rs index 3ab2a6df5..fea2b03df 100644 --- a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit.rs @@ -24,6 +24,11 @@ pub struct LogicInstruction(PhantomData<(E, I)>); impl Instruction for LogicInstruction { type InstructionConfig = LogicConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[I::INST_KIND] + } fn name() -> String { format!("{:?}", I::INST_KIND) @@ -232,7 +237,7 @@ mod test { &config, cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, - vec![&StepRecord::new_i_instruction( + &[StepRecord::new_i_instruction( 3, Change::new(MOCK_PC_START, MOCK_PC_START + PC_STEP_SIZE), insn_code, diff --git a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs index b48af7f5f..14c2adeb0 100644 --- a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs @@ -31,6 +31,11 @@ pub struct LogicInstruction(PhantomData<(E, I)>); impl Instruction for LogicInstruction { type InstructionConfig = LogicConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[I::INST_KIND] + } fn name() -> String { format!("{:?}", I::INST_KIND) diff --git a/ceno_zkvm/src/instructions/riscv/logic_imm/test.rs b/ceno_zkvm/src/instructions/riscv/logic_imm/test.rs index 70afdfbe2..3a003777b 100644 --- a/ceno_zkvm/src/instructions/riscv/logic_imm/test.rs +++ b/ceno_zkvm/src/instructions/riscv/logic_imm/test.rs @@ -74,7 +74,7 @@ fn verify(name: &'static str, rs1_read: u32, imm: u32, expected_rd_w &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, - vec![&StepRecord::new_i_instruction( + &[StepRecord::new_i_instruction( 3, Change::new(MOCK_PC_START, MOCK_PC_START + PC_STEP_SIZE), insn_code, diff --git a/ceno_zkvm/src/instructions/riscv/lui.rs b/ceno_zkvm/src/instructions/riscv/lui.rs index e863d8de0..93d24c4ef 100644 --- a/ceno_zkvm/src/instructions/riscv/lui.rs +++ b/ceno_zkvm/src/instructions/riscv/lui.rs @@ -34,6 +34,11 @@ pub struct LuiInstruction(PhantomData); impl Instruction for LuiInstruction { type InstructionConfig = LuiConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[InsnKind::LUI] + } fn name() -> String { format!("{:?}", InsnKind::LUI) @@ -159,7 +164,7 @@ mod tests { &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, - vec![&StepRecord::new_i_instruction( + &[StepRecord::new_i_instruction( 3, Change::new(MOCK_PC_START, MOCK_PC_START + PC_STEP_SIZE), insn_code, diff --git a/ceno_zkvm/src/instructions/riscv/memory/load.rs b/ceno_zkvm/src/instructions/riscv/memory/load.rs index 41fbf0059..818e8902a 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/load.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/load.rs @@ -38,6 +38,11 @@ pub struct LoadInstruction(PhantomData<(E, I)>); impl Instruction for LoadInstruction { type InstructionConfig = LoadConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[I::INST_KIND] + } fn name() -> String { format!("{:?}", I::INST_KIND) diff --git a/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs b/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs index 812e4020a..5a9ed40eb 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs @@ -42,6 +42,11 @@ pub struct LoadInstruction(PhantomData<(E, I)>); impl Instruction for LoadInstruction { type InstructionConfig = LoadConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[I::INST_KIND] + } fn name() -> String { format!("{:?}", I::INST_KIND) diff --git a/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs b/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs index cb512975b..a1bd7a812 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs @@ -42,6 +42,11 @@ impl Instruction for StoreInstruction { type InstructionConfig = StoreConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[I::INST_KIND] + } fn name() -> String { format!("{:?}", I::INST_KIND) diff --git a/ceno_zkvm/src/instructions/riscv/memory/test.rs b/ceno_zkvm/src/instructions/riscv/memory/test.rs index 3fb7692f8..f6b0fa153 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/test.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/test.rs @@ -106,7 +106,7 @@ fn impl_opcode_store { impl Instruction for MulhInstructionBase { type InstructionConfig = MulhConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[I::INST_KIND] + } fn name() -> String { format!("{:?}", I::INST_KIND) diff --git a/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs index a94f63e74..f3bddff1b 100644 --- a/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs @@ -38,6 +38,11 @@ pub struct MulhConfig { impl Instruction for MulhInstructionBase { type InstructionConfig = MulhConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[I::INST_KIND] + } fn name() -> String { format!("{:?}", I::INST_KIND) diff --git a/ceno_zkvm/src/instructions/riscv/rv32im.rs b/ceno_zkvm/src/instructions/riscv/rv32im.rs index 01dad68dc..60c27b984 100644 --- a/ceno_zkvm/src/instructions/riscv/rv32im.rs +++ b/ceno_zkvm/src/instructions/riscv/rv32im.rs @@ -9,7 +9,7 @@ use crate::instructions::riscv::lui::LuiInstruction; #[cfg(not(feature = "u16limb_circuit"))] use crate::tables::PowTableCircuit; use crate::{ - e2e::{ShardContext, StepCellExtractor}, + e2e::ShardContext, error::ZKVMError, instructions::{ Instruction, @@ -47,12 +47,12 @@ use ceno_emul::{ InsnKind::{self, *}, KeccakSpec, LogPcCycleSpec, Platform, Secp256k1AddSpec, Secp256k1DecompressSpec, Secp256k1DoubleSpec, Secp256k1ScalarInvertSpec, Secp256r1AddSpec, Secp256r1DoubleSpec, - Secp256r1ScalarInvertSpec, Sha256ExtendSpec, StepRecord, SyscallSpec, Uint256MulSpec, + Secp256r1ScalarInvertSpec, Sha256ExtendSpec, StepCellExtractor, StepRecord, SyscallSpec, + Uint256MulSpec, Word, }; use dummy::LargeEcallDummy; -use ecall::EcallDummy; use ff_ext::ExtensionField; -use itertools::{Itertools, izip}; +use itertools::Itertools; use mulh::{MulInstruction, MulhInstruction, MulhsuInstruction}; use shift::SraInstruction; use slt::{SltInstruction, SltuInstruction}; @@ -64,8 +64,9 @@ use sp1_curves::weierstrass::{ secp256r1::Secp256r1, }; use std::{ + any::{TypeId, type_name}, cmp::Reverse, - collections::{BTreeMap, BTreeSet, HashMap}, + collections::{BTreeMap, HashMap}, }; use strum::{EnumCount, IntoEnumIterator}; @@ -179,15 +180,78 @@ pub struct Rv32imConfig { pub ecall_cells_map: HashMap, } +#[derive(Clone)] +pub struct InstructionDispatchBuilder { + record_buffer_count: usize, + insn_to_record_buffer: Vec>, + type_to_record_buffer: HashMap, +} + +impl InstructionDispatchBuilder { + fn new() -> Self { + Self { + record_buffer_count: 0, + insn_to_record_buffer: vec![None; InsnKind::COUNT], + type_to_record_buffer: HashMap::new(), + } + } + + fn register_instruction_kinds + 'static>( + &mut self, + kinds: &[InsnKind], + ) { + assert!( + kinds.iter().all(|kind| *kind != InsnKind::ECALL), + "ecall dispatch via function code" + ); + let record_buffer_index = self.record_buffer_count; + self.record_buffer_count += 1; + for &kind in kinds { + if let Some(existing) = self.insn_to_record_buffer[kind as usize] { + panic!( + "Instruction kind {:?} registered multiple times: existing buffer {}, new buffer {} (instruction type: {})", + kind, + existing, + record_buffer_index, + type_name::() + ); + } + self.insn_to_record_buffer[kind as usize] = Some(record_buffer_index); + } + assert!( + self.type_to_record_buffer + .insert(TypeId::of::(), record_buffer_index) + .is_none(), + "Instruction circuit {} registered more than once", + type_name::() + ); + } + + pub fn to_dispatch_ctx(&self) -> InstructionDispatchCtx { + InstructionDispatchCtx::new( + self.record_buffer_count, + self.insn_to_record_buffer.clone(), + self.type_to_record_buffer.clone(), + ) + } +} + const KECCAK_CELL_BLOWUP_FACTOR: u64 = 2; impl Rv32imConfig { - pub fn construct_circuits(cs: &mut ZKVMConstraintSystem) -> Self { + pub fn construct_circuits( + cs: &mut ZKVMConstraintSystem, + ) -> (Self, InstructionDispatchBuilder) { let mut inst_cells_map = vec![0; InsnKind::COUNT]; let mut ecall_cells_map = HashMap::new(); + let mut inst_dispatch_builder = InstructionDispatchBuilder::new(); + macro_rules! register_opcode_circuit { ($insn_kind:ident, $instruction:ty, $inst_cells_map:ident) => {{ + inst_dispatch_builder.register_instruction_kinds::( + <$instruction as Instruction>::inst_kinds(), + ); let config = cs.register_opcode_circuit::<$instruction>(); // update estimated cell @@ -347,7 +411,7 @@ impl Rv32imConfig { #[cfg(not(feature = "u16limb_circuit"))] let pow_config = cs.register_table_circuit::>(); - Self { + let config = Self { // alu opcodes add_config, sub_config, @@ -428,7 +492,9 @@ impl Rv32imConfig { pow_config, inst_cells_map, ecall_cells_map, - } + }; + + (config, inst_dispatch_builder) } pub fn generate_fixed_traces( @@ -559,317 +625,207 @@ impl Rv32imConfig { fixed.register_table_circuit::>(cs, &self.pow_config, &()); } - pub fn assign_opcode_circuit<'a>( + pub fn assign_opcode_circuit( &self, cs: &ZKVMConstraintSystem, shard_ctx: &mut ShardContext, + instrunction_dispatch_ctx: &mut InstructionDispatchCtx, witness: &mut ZKVMWitnesses, - steps: &'a [StepRecord], - ) -> Result, ZKVMError> { - let mut all_records: BTreeMap> = InsnKind::iter() - .map(|insn_kind| (insn_kind, Vec::new())) - .collect(); - let mut halt_records = Vec::new(); - let mut keccak_records = Vec::new(); - let mut bn254_add_records = Vec::new(); - let mut bn254_double_records = Vec::new(); - let mut bn254_fp_add_records = Vec::new(); - let mut bn254_fp_mul_records = Vec::new(); - let mut bn254_fp2_add_records = Vec::new(); - let mut bn254_fp2_mul_records = Vec::new(); - let mut secp256k1_add_records = Vec::new(); - let mut secp256k1_double_records = Vec::new(); - let mut secp256k1_decompress_records = Vec::new(); - let mut uint256_mul_records = Vec::new(); - let mut secp256k1_scalar_invert_records = Vec::new(); - let mut secp256r1_add_records = Vec::new(); - let mut secp256r1_double_records = Vec::new(); - let mut secp256r1_scalar_invert_records = Vec::new(); - steps.iter().for_each(|record| { - let insn_kind = record.insn.kind; - match insn_kind { - // ecall / halt - InsnKind::ECALL if record.rs1().unwrap().value == Platform::ecall_halt() => { - halt_records.push(record); - } - InsnKind::ECALL if record.rs1().unwrap().value == KeccakSpec::CODE => { - keccak_records.push(record); - } - InsnKind::ECALL if record.rs1().unwrap().value == Bn254AddSpec::CODE => { - bn254_add_records.push(record); - } - InsnKind::ECALL if record.rs1().unwrap().value == Bn254DoubleSpec::CODE => { - bn254_double_records.push(record); - } - InsnKind::ECALL if record.rs1().unwrap().value == Bn254FpAddSpec::CODE => { - bn254_fp_add_records.push(record); - } - InsnKind::ECALL if record.rs1().unwrap().value == Bn254FpMulSpec::CODE => { - bn254_fp_mul_records.push(record); - } - InsnKind::ECALL if record.rs1().unwrap().value == Bn254Fp2AddSpec::CODE => { - bn254_fp2_add_records.push(record); - } - InsnKind::ECALL if record.rs1().unwrap().value == Bn254Fp2MulSpec::CODE => { - bn254_fp2_mul_records.push(record); - } - InsnKind::ECALL if record.rs1().unwrap().value == Secp256k1AddSpec::CODE => { - secp256k1_add_records.push(record); - } - InsnKind::ECALL if record.rs1().unwrap().value == Secp256k1DoubleSpec::CODE => { - secp256k1_double_records.push(record); - } - InsnKind::ECALL if record.rs1().unwrap().value == Secp256r1AddSpec::CODE => { - secp256r1_add_records.push(record); - } - InsnKind::ECALL if record.rs1().unwrap().value == Secp256r1DoubleSpec::CODE => { - secp256r1_double_records.push(record); - } - InsnKind::ECALL - if record.rs1().unwrap().value == Secp256k1ScalarInvertSpec::CODE => - { - secp256k1_scalar_invert_records.push(record); - } - InsnKind::ECALL - if record.rs1().unwrap().value == Secp256r1ScalarInvertSpec::CODE => - { - secp256r1_scalar_invert_records.push(record); - } - InsnKind::ECALL if record.rs1().unwrap().value == Secp256k1DecompressSpec::CODE => { - secp256k1_decompress_records.push(record); - } - InsnKind::ECALL if record.rs1().unwrap().value == Uint256MulSpec::CODE => { - uint256_mul_records.push(record); - } - // other type of ecalls are handled by dummy ecall instruction - _ => { - // it's safe to unwrap as all_records are initialized with Vec::new() - all_records.get_mut(&insn_kind).unwrap().push(record); - } - } - }); + ) -> Result<(), ZKVMError> { + instrunction_dispatch_ctx.trace_opcode_stats(); - for (insn_kind, (_, records)) in - izip!(InsnKind::iter(), &all_records).sorted_by_key(|(_, (_, a))| Reverse(a.len())) - { - tracing::debug!("tracer generated {:?} {} records", insn_kind, records.len()); + macro_rules! log_ecall { + ($desc:literal, $code:expr) => { + tracing::debug!( + "tracer generated {} {} records", + $desc, + instrunction_dispatch_ctx.count_ecall_code($code) + ); + }; } - tracing::debug!("tracer generated HALT {} records", halt_records.len()); - tracing::debug!("tracer generated KECCAK {} records", keccak_records.len()); - tracing::debug!( - "tracer generated bn254_add_records {} records", - bn254_add_records.len() - ); - tracing::debug!( - "tracer generated bn254_double_records {} records", - bn254_double_records.len() - ); - tracing::debug!( - "tracer generated bn254_fp_add_records {} records", - bn254_fp_add_records.len() - ); - tracing::debug!( - "tracer generated bn254_fp_mul_records {} records", - bn254_fp_mul_records.len() - ); - tracing::debug!( - "tracer generated bn254_fp2_add_records {} records", - bn254_fp2_add_records.len() - ); - tracing::debug!( - "tracer generated bn254_fp2_mul_records {} records", - bn254_fp2_mul_records.len() - ); - tracing::debug!( - "tracer generated secp256k1_add_records {} records", - secp256k1_add_records.len() - ); - tracing::debug!( - "tracer generated secp256k1_double_records {} records", - secp256k1_double_records.len() - ); - tracing::debug!( - "tracer generated secp256k1_scalar_invert_records {} records", - secp256k1_scalar_invert_records.len() + + log_ecall!("HALT", ECALL_HALT); + log_ecall!("KECCAK", KeccakSpec::CODE); + log_ecall!("bn254_add_records", Bn254AddSpec::CODE); + log_ecall!("bn254_double_records", Bn254DoubleSpec::CODE); + log_ecall!("bn254_fp_add_records", Bn254FpAddSpec::CODE); + log_ecall!("bn254_fp_mul_records", Bn254FpMulSpec::CODE); + log_ecall!("bn254_fp2_add_records", Bn254Fp2AddSpec::CODE); + log_ecall!("bn254_fp2_mul_records", Bn254Fp2MulSpec::CODE); + log_ecall!("secp256k1_add_records", Secp256k1AddSpec::CODE); + log_ecall!("secp256k1_double_records", Secp256k1DoubleSpec::CODE); + log_ecall!( + "secp256k1_scalar_invert_records", + Secp256k1ScalarInvertSpec::CODE ); - tracing::debug!( - "tracer generated secp256k1_decompress_records {} records", - secp256k1_decompress_records.len() + log_ecall!( + "secp256k1_decompress_records", + Secp256k1DecompressSpec::CODE ); - tracing::debug!( - "tracer generated uint256_mul_records {} records", - uint256_mul_records.len() + log_ecall!("secp256r1_add_records", Secp256r1AddSpec::CODE); + log_ecall!("secp256r1_double_records", Secp256r1DoubleSpec::CODE); + log_ecall!( + "secp256r1_scalar_invert_records", + Secp256r1ScalarInvertSpec::CODE ); + log_ecall!("uint256_mul_records", Uint256MulSpec::CODE); macro_rules! assign_opcode { - ($insn_kind:ident,$instruction:ty,$config:ident) => { + ($instruction:ty, $config:ident) => {{ + let records = instrunction_dispatch_ctx + .records_for_kinds::() + .unwrap_or(&[]); witness.assign_opcode_circuit::<$instruction>( cs, shard_ctx, &self.$config, - all_records.remove(&($insn_kind)).unwrap(), + records, )?; - }; + }}; } + + macro_rules! assign_ecall { + ($instruction:ty, $config:ident, $code:expr) => {{ + let records = instrunction_dispatch_ctx + .records_for_ecall_code($code) + .unwrap_or(&[]); + witness.assign_opcode_circuit::<$instruction>( + cs, + shard_ctx, + &self.$config, + records, + )?; + }}; + } + // alu - assign_opcode!(ADD, AddInstruction, add_config); - assign_opcode!(SUB, SubInstruction, sub_config); - assign_opcode!(AND, AndInstruction, and_config); - assign_opcode!(OR, OrInstruction, or_config); - assign_opcode!(XOR, XorInstruction, xor_config); - assign_opcode!(SLL, SllInstruction, sll_config); - assign_opcode!(SRL, SrlInstruction, srl_config); - assign_opcode!(SRA, SraInstruction, sra_config); - assign_opcode!(SLT, SltInstruction, slt_config); - assign_opcode!(SLTU, SltuInstruction, sltu_config); - assign_opcode!(MUL, MulInstruction, mul_config); - assign_opcode!(MULH, MulhInstruction, mulh_config); - assign_opcode!(MULHSU, MulhsuInstruction, mulhsu_config); - assign_opcode!(MULHU, MulhuInstruction, mulhu_config); - assign_opcode!(DIVU, DivuInstruction, divu_config); - assign_opcode!(REMU, RemuInstruction, remu_config); - assign_opcode!(DIV, DivInstruction, div_config); - assign_opcode!(REM, RemInstruction, rem_config); + assign_opcode!(AddInstruction, add_config); + assign_opcode!(SubInstruction, sub_config); + assign_opcode!(AndInstruction, and_config); + assign_opcode!(OrInstruction, or_config); + assign_opcode!(XorInstruction, xor_config); + assign_opcode!(SllInstruction, sll_config); + assign_opcode!(SrlInstruction, srl_config); + assign_opcode!(SraInstruction, sra_config); + assign_opcode!(SltInstruction, slt_config); + assign_opcode!(SltuInstruction, sltu_config); + assign_opcode!(MulInstruction, mul_config); + assign_opcode!(MulhInstruction, mulh_config); + assign_opcode!(MulhsuInstruction, mulhsu_config); + assign_opcode!(MulhuInstruction, mulhu_config); + assign_opcode!(DivuInstruction, divu_config); + assign_opcode!(RemuInstruction, remu_config); + assign_opcode!(DivInstruction, div_config); + assign_opcode!(RemInstruction, rem_config); // alu with imm - assign_opcode!(ADDI, AddiInstruction, addi_config); - assign_opcode!(ANDI, AndiInstruction, andi_config); - assign_opcode!(ORI, OriInstruction, ori_config); - assign_opcode!(XORI, XoriInstruction, xori_config); - assign_opcode!(SLLI, SlliInstruction, slli_config); - assign_opcode!(SRLI, SrliInstruction, srli_config); - assign_opcode!(SRAI, SraiInstruction, srai_config); - assign_opcode!(SLTI, SltiInstruction, slti_config); - assign_opcode!(SLTIU, SltiuInstruction, sltiu_config); + assign_opcode!(AddiInstruction, addi_config); + assign_opcode!(AndiInstruction, andi_config); + assign_opcode!(OriInstruction, ori_config); + assign_opcode!(XoriInstruction, xori_config); + assign_opcode!(SlliInstruction, slli_config); + assign_opcode!(SrliInstruction, srli_config); + assign_opcode!(SraiInstruction, srai_config); + assign_opcode!(SltiInstruction, slti_config); + assign_opcode!(SltiuInstruction, sltiu_config); #[cfg(feature = "u16limb_circuit")] - assign_opcode!(LUI, LuiInstruction, lui_config); + assign_opcode!(LuiInstruction, lui_config); #[cfg(feature = "u16limb_circuit")] - assign_opcode!(AUIPC, AuipcInstruction, auipc_config); + assign_opcode!(AuipcInstruction, auipc_config); // branching - assign_opcode!(BEQ, BeqInstruction, beq_config); - assign_opcode!(BNE, BneInstruction, bne_config); - assign_opcode!(BLT, BltInstruction, blt_config); - assign_opcode!(BLTU, BltuInstruction, bltu_config); - assign_opcode!(BGE, BgeInstruction, bge_config); - assign_opcode!(BGEU, BgeuInstruction, bgeu_config); + assign_opcode!(BeqInstruction, beq_config); + assign_opcode!(BneInstruction, bne_config); + assign_opcode!(BltInstruction, blt_config); + assign_opcode!(BltuInstruction, bltu_config); + assign_opcode!(BgeInstruction, bge_config); + assign_opcode!(BgeuInstruction, bgeu_config); // jump - assign_opcode!(JAL, JalInstruction, jal_config); - assign_opcode!(JALR, JalrInstruction, jalr_config); + assign_opcode!(JalInstruction, jal_config); + assign_opcode!(JalrInstruction, jalr_config); // memory - assign_opcode!(LW, LwInstruction, lw_config); - assign_opcode!(LB, LbInstruction, lb_config); - assign_opcode!(LBU, LbuInstruction, lbu_config); - assign_opcode!(LH, LhInstruction, lh_config); - assign_opcode!(LHU, LhuInstruction, lhu_config); - assign_opcode!(SW, SwInstruction, sw_config); - assign_opcode!(SH, ShInstruction, sh_config); - assign_opcode!(SB, SbInstruction, sb_config); + assign_opcode!(LwInstruction, lw_config); + assign_opcode!(LbInstruction, lb_config); + assign_opcode!(LbuInstruction, lbu_config); + assign_opcode!(LhInstruction, lh_config); + assign_opcode!(LhuInstruction, lhu_config); + assign_opcode!(SwInstruction, sw_config); + assign_opcode!(ShInstruction, sh_config); + assign_opcode!(SbInstruction, sb_config); // ecall / halt - witness.assign_opcode_circuit::>( - cs, - shard_ctx, - &self.halt_config, - halt_records, - )?; - witness.assign_opcode_circuit::>( - cs, - shard_ctx, - &self.keccak_config, - keccak_records, - )?; - witness.assign_opcode_circuit::>>( - cs, - shard_ctx, - &self.bn254_add_config, - bn254_add_records, - )?; - witness.assign_opcode_circuit::>>( - cs, - shard_ctx, - &self.bn254_double_config, - bn254_double_records, - )?; - witness.assign_opcode_circuit::>( - cs, - shard_ctx, - &self.bn254_fp_add_config, - bn254_fp_add_records, - )?; - witness.assign_opcode_circuit::>( - cs, - shard_ctx, - &self.bn254_fp_mul_config, - bn254_fp_mul_records, - )?; - witness.assign_opcode_circuit::>( - cs, - shard_ctx, - &self.bn254_fp2_add_config, - bn254_fp2_add_records, - )?; - witness.assign_opcode_circuit::>( - cs, - shard_ctx, - &self.bn254_fp2_mul_config, - bn254_fp2_mul_records, - )?; - witness.assign_opcode_circuit::>>( - cs, - shard_ctx, - &self.secp256k1_add_config, - secp256k1_add_records, - )?; - witness - .assign_opcode_circuit::>>( - cs, - shard_ctx, - &self.secp256k1_double_config, - secp256k1_double_records, - )?; - witness.assign_opcode_circuit::>( - cs, - shard_ctx, - &self.secp256k1_scalar_invert, - secp256k1_scalar_invert_records, - )?; - witness.assign_opcode_circuit::>>( - cs, - shard_ctx, - &self.secp256k1_decompress_config, - secp256k1_decompress_records, - )?; - witness.assign_opcode_circuit::>>( - cs, - shard_ctx, - &self.secp256r1_add_config, - secp256r1_add_records, - )?; - witness - .assign_opcode_circuit::>>( - cs, - shard_ctx, - &self.secp256r1_double_config, - secp256r1_double_records, - )?; - witness.assign_opcode_circuit::>( - cs, - shard_ctx, - &self.secp256r1_scalar_invert, - secp256r1_scalar_invert_records, - )?; - witness.assign_opcode_circuit::>( - cs, - shard_ctx, - &self.uint256_mul_config, - uint256_mul_records, - )?; - - assert_eq!( - all_records.keys().cloned().collect::>(), - // these are opcodes that haven't been implemented - [INVALID, ECALL].into_iter().collect::>(), + assign_ecall!(HaltInstruction, halt_config, ECALL_HALT); + assign_ecall!(KeccakInstruction, keccak_config, KeccakSpec::CODE); + assign_ecall!( + WeierstrassAddAssignInstruction>, + bn254_add_config, + Bn254AddSpec::CODE ); - Ok(GroupedSteps(all_records)) + assign_ecall!( + WeierstrassDoubleAssignInstruction>, + bn254_double_config, + Bn254DoubleSpec::CODE + ); + assign_ecall!( + FpAddInstruction, + bn254_fp_add_config, + Bn254FpAddSpec::CODE + ); + assign_ecall!( + FpMulInstruction, + bn254_fp_mul_config, + Bn254FpMulSpec::CODE + ); + assign_ecall!( + Fp2AddInstruction, + bn254_fp2_add_config, + Bn254Fp2AddSpec::CODE + ); + assign_ecall!( + Fp2MulInstruction, + bn254_fp2_mul_config, + Bn254Fp2MulSpec::CODE + ); + assign_ecall!( + WeierstrassAddAssignInstruction>, + secp256k1_add_config, + Secp256k1AddSpec::CODE + ); + assign_ecall!( + WeierstrassDoubleAssignInstruction>, + secp256k1_double_config, + Secp256k1DoubleSpec::CODE + ); + assign_ecall!( + Secp256k1InvInstruction, + secp256k1_scalar_invert, + Secp256k1ScalarInvertSpec::CODE + ); + assign_ecall!( + WeierstrassDecompressInstruction>, + secp256k1_decompress_config, + Secp256k1DecompressSpec::CODE + ); + assign_ecall!( + WeierstrassAddAssignInstruction>, + secp256r1_add_config, + Secp256r1AddSpec::CODE + ); + assign_ecall!( + WeierstrassDoubleAssignInstruction>, + secp256r1_double_config, + Secp256r1DoubleSpec::CODE + ); + assign_ecall!( + Secp256r1InvInstruction, + secp256r1_scalar_invert, + Secp256r1ScalarInvertSpec::CODE + ); + assign_ecall!( + Uint256MulInstruction, + uint256_mul_config, + Uint256MulSpec::CODE + ); + + Ok(()) } pub fn assign_table_circuit( @@ -898,29 +854,133 @@ impl Rv32imConfig { } } -/// Opaque type to pass unimplemented instructions from Rv32imConfig to DummyExtraConfig. -pub struct GroupedSteps<'a>(BTreeMap>); +pub struct InstructionDispatchCtx { + insn_to_record_buffer: Vec>, + type_to_record_buffer: HashMap, + insn_kinds: Vec, + circuit_record_buffers: Vec>, + fallback_record_buffers: Vec>, + ecall_record_buffers: BTreeMap>, +} + +impl InstructionDispatchCtx { + fn new( + record_buffer_count: usize, + insn_to_record_buffer: Vec>, + type_to_record_buffer: HashMap, + ) -> Self { + Self { + insn_to_record_buffer, + type_to_record_buffer, + insn_kinds: InsnKind::iter().collect(), + circuit_record_buffers: (0..record_buffer_count).map(|_| Vec::new()).collect(), + fallback_record_buffers: (0..InsnKind::COUNT).map(|_| Vec::new()).collect(), + ecall_record_buffers: BTreeMap::new(), + } + } + + pub fn begin_shard(&mut self) { + self.reset_record_buffers(); + } + + #[inline(always)] + pub fn ingest_step(&mut self, step: StepRecord) { + let kind = step.insn.kind; + if kind == InsnKind::ECALL { + let code = step + .rs1() + .expect("ecall requires rs1 to determine syscall code") + .value; + self.ecall_record_buffers + .entry(code) + .or_default() + .push(step); + } else if let Some(record_buffer_idx) = self.insn_to_record_buffer[kind as usize] { + self.circuit_record_buffers[record_buffer_idx].push(step); + } else { + self.fallback_record_buffers[kind as usize].push(step); + } + } + + fn reset_record_buffers(&mut self) { + for record_buffer in &mut self.circuit_record_buffers { + record_buffer.clear(); + } + for record_buffer in &mut self.fallback_record_buffers { + record_buffer.clear(); + } + for record_buffer in self.ecall_record_buffers.values_mut() { + record_buffer.clear(); + } + } + + fn trace_opcode_stats(&self) { + let mut counts = self + .insn_kinds + .iter() + .map(|kind| (*kind, self.count_kind(*kind))) + .collect_vec(); + counts.sort_by_key(|(_, count)| Reverse(*count)); + for (kind, count) in counts { + tracing::debug!("tracer generated {:?} {} records", kind, count); + } + } + + fn count_kind(&self, kind: InsnKind) -> usize { + if kind == InsnKind::ECALL { + return self + .ecall_record_buffers + .values() + .map(|record_buffer| record_buffer.len()) + .sum(); + } + if let Some(idx) = self.insn_to_record_buffer[kind as usize] { + self.circuit_record_buffers[idx].len() + } else { + self.fallback_record_buffers[kind as usize].len() + } + } + + fn count_ecall_code(&self, code: u32) -> usize { + self.ecall_record_buffers + .get(&code) + .map(|record_buffer| record_buffer.len()) + .unwrap_or_default() + } + + fn records_for_kinds + 'static>( + &self, + ) -> Option<&[StepRecord]> { + let record_buffer_id = self + .type_to_record_buffer + .get(&TypeId::of::()) + .expect("un-registered instruction circuit"); + self.circuit_record_buffers + .get(*record_buffer_id) + .map(|records| records.as_slice()) + } + fn records_for_ecall_code(&self, code: u32) -> Option<&[StepRecord]> { + self.ecall_record_buffers + .get(&code) + .map(|records| records.as_slice()) + } +} /// Fake version of what is missing in Rv32imConfig, for some tests. pub struct DummyExtraConfig { - ecall_config: as Instruction>::InstructionConfig, - sha256_extend_config: as Instruction>::InstructionConfig, - phantom_log_pc_cycle: as Instruction>::InstructionConfig, } impl DummyExtraConfig { pub fn construct_circuits(cs: &mut ZKVMConstraintSystem) -> Self { - let ecall_config = cs.register_opcode_circuit::>(); let sha256_extend_config = cs.register_opcode_circuit::>(); let phantom_log_pc_cycle = cs.register_opcode_circuit::>(); Self { - ecall_config, sha256_extend_config, phantom_log_pc_cycle, } @@ -931,7 +991,6 @@ impl DummyExtraConfig { cs: &ZKVMConstraintSystem, fixed: &mut ZKVMFixedTraces, ) { - fixed.register_opcode_circuit::>(cs, &self.ecall_config); fixed.register_opcode_circuit::>( cs, &self.sha256_extend_config, @@ -946,69 +1005,42 @@ impl DummyExtraConfig { &self, cs: &ZKVMConstraintSystem, shard_ctx: &mut ShardContext, + instrunction_dispatch_ctx: &InstructionDispatchCtx, witness: &mut ZKVMWitnesses, - steps: GroupedSteps, ) -> Result<(), ZKVMError> { - let mut steps = steps.0; - - let mut sha256_extend_steps = Vec::new(); - let mut bn254_fp_add_steps = Vec::new(); - let mut bn254_fp_mul_steps = Vec::new(); - let mut bn254_fp2_add_steps = Vec::new(); - let mut bn254_fp2_mul_steps = Vec::new(); - let mut phantom_log_pc_cycle_spec = Vec::new(); - let mut other_steps = Vec::new(); - - if let Some(ecall_steps) = steps.remove(&ECALL) { - for step in ecall_steps { - match step.rs1().unwrap().value { - Sha256ExtendSpec::CODE => sha256_extend_steps.push(step), - Bn254FpAddSpec::CODE => bn254_fp_add_steps.push(step), - Bn254FpMulSpec::CODE => bn254_fp_mul_steps.push(step), - Bn254Fp2AddSpec::CODE => bn254_fp2_add_steps.push(step), - Bn254Fp2MulSpec::CODE => bn254_fp2_mul_steps.push(step), - LogPcCycleSpec::CODE => phantom_log_pc_cycle_spec.push(step), - _ => other_steps.push(step), - } - } - } + let sha256_extend_records = instrunction_dispatch_ctx + .records_for_ecall_code(Sha256ExtendSpec::CODE) + .unwrap_or(&[]); + let phantom_log_pc_cycle_records = instrunction_dispatch_ctx + .records_for_ecall_code(LogPcCycleSpec::CODE) + .unwrap_or(&[]); witness.assign_opcode_circuit::>( cs, shard_ctx, &self.sha256_extend_config, - sha256_extend_steps, + sha256_extend_records, )?; witness.assign_opcode_circuit::>( cs, shard_ctx, &self.phantom_log_pc_cycle, - phantom_log_pc_cycle_spec, + phantom_log_pc_cycle_records, )?; - witness.assign_opcode_circuit::>( - cs, - shard_ctx, - &self.ecall_config, - other_steps, - )?; - - let _ = steps.remove(&INVALID); - let keys: Vec<&InsnKind> = steps.keys().collect::>(); - assert!(steps.is_empty(), "unimplemented opcodes: {:?}", keys); Ok(()) } } -impl StepCellExtractor for &Rv32imConfig { +impl Rv32imConfig { #[inline(always)] - fn extract_cells(&self, record: &StepRecord) -> u64 { - let insn_kind = record.insn.kind; - if !matches!(insn_kind, InsnKind::ECALL) { - // quick match for opcode and return - return self.inst_cells_map[insn_kind as usize]; + pub fn cells_for(&self, kind: InsnKind, rs1_value: Option) -> u64 { + if !matches!(kind, InsnKind::ECALL) { + return self.inst_cells_map[kind as usize]; } + // deal with ecall logic - match record.rs1().unwrap().value { + let code = rs1_value.unwrap_or_default(); + match code { // ecall / halt ECALL_HALT => *self .ecall_cells_map @@ -1077,8 +1109,21 @@ impl StepCellExtractor for &Rv32imConfig { // phantom LogPcCycleSpec::CODE => 0, ceno_emul::SHA_EXTEND => 0, - // other type of ecalls are handled by dummy ecall instruction - _ => unreachable!("unknow match record {:?}", record), + _ => panic!("unknown ecall code {code:#x}"), } } } + +impl StepCellExtractor for &Rv32imConfig { + #[inline(always)] + fn cells_for_kind(&self, kind: InsnKind, rs1_value: Option) -> u64 { + self.cells_for(kind, rs1_value) + } +} + +impl StepCellExtractor for Rv32imConfig { + #[inline(always)] + fn cells_for_kind(&self, kind: InsnKind, rs1_value: Option) -> u64 { + self.cells_for(kind, rs1_value) + } +} diff --git a/ceno_zkvm/src/instructions/riscv/shift.rs b/ceno_zkvm/src/instructions/riscv/shift.rs index 97665bbcf..ea082e3c6 100644 --- a/ceno_zkvm/src/instructions/riscv/shift.rs +++ b/ceno_zkvm/src/instructions/riscv/shift.rs @@ -177,7 +177,7 @@ mod tests { &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, - vec![&StepRecord::new_r_instruction( + &[StepRecord::new_r_instruction( 3, MOCK_PC_START, insn_code, diff --git a/ceno_zkvm/src/instructions/riscv/shift/shift_circuit.rs b/ceno_zkvm/src/instructions/riscv/shift/shift_circuit.rs index c1d83ce87..44ee44988 100644 --- a/ceno_zkvm/src/instructions/riscv/shift/shift_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/shift/shift_circuit.rs @@ -42,6 +42,11 @@ pub struct ShiftLogicalInstruction(PhantomData<(E, I)>); impl Instruction for ShiftLogicalInstruction { type InstructionConfig = ShiftConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[I::INST_KIND] + } fn name() -> String { format!("{:?}", I::INST_KIND) diff --git a/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs index fac05279e..310d17491 100644 --- a/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs @@ -276,6 +276,11 @@ pub struct ShiftLogicalInstruction(PhantomData<(E, I)>); impl Instruction for ShiftLogicalInstruction { type InstructionConfig = ShiftRTypeConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[I::INST_KIND] + } fn name() -> String { format!("{:?}", I::INST_KIND) @@ -372,6 +377,11 @@ pub struct ShiftImmInstruction(PhantomData<(E, I)>); impl Instruction for ShiftImmInstruction { type InstructionConfig = ShiftImmConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[I::INST_KIND] + } fn name() -> String { format!("{:?}", I::INST_KIND) diff --git a/ceno_zkvm/src/instructions/riscv/shift_imm.rs b/ceno_zkvm/src/instructions/riscv/shift_imm.rs index 998d2395e..d97a0b09e 100644 --- a/ceno_zkvm/src/instructions/riscv/shift_imm.rs +++ b/ceno_zkvm/src/instructions/riscv/shift_imm.rs @@ -174,7 +174,7 @@ mod test { &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, - vec![&StepRecord::new_i_instruction( + &[StepRecord::new_i_instruction( 3, Change::new(MOCK_PC_START, MOCK_PC_START + PC_STEP_SIZE), insn_code, diff --git a/ceno_zkvm/src/instructions/riscv/shift_imm/shift_imm_circuit.rs b/ceno_zkvm/src/instructions/riscv/shift_imm/shift_imm_circuit.rs index a2fa8d032..9442d9805 100644 --- a/ceno_zkvm/src/instructions/riscv/shift_imm/shift_imm_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/shift_imm/shift_imm_circuit.rs @@ -40,6 +40,11 @@ pub struct ShiftImmConfig { impl Instruction for ShiftImmInstruction { type InstructionConfig = ShiftImmConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[I::INST_KIND] + } fn name() -> String { format!("{:?}", I::INST_KIND) diff --git a/ceno_zkvm/src/instructions/riscv/slt.rs b/ceno_zkvm/src/instructions/riscv/slt.rs index 3707304e1..629354e41 100644 --- a/ceno_zkvm/src/instructions/riscv/slt.rs +++ b/ceno_zkvm/src/instructions/riscv/slt.rs @@ -76,7 +76,7 @@ mod test { &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, - vec![&StepRecord::new_r_instruction( + &[StepRecord::new_r_instruction( 3, MOCK_PC_START, insn_code, diff --git a/ceno_zkvm/src/instructions/riscv/slt/slt_circuit.rs b/ceno_zkvm/src/instructions/riscv/slt/slt_circuit.rs index b9b63acaf..ed49932d3 100644 --- a/ceno_zkvm/src/instructions/riscv/slt/slt_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/slt/slt_circuit.rs @@ -40,6 +40,11 @@ enum SetLessThanDependencies { impl Instruction for SetLessThanInstruction { type InstructionConfig = SetLessThanConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[I::INST_KIND] + } fn name() -> String { format!("{:?}", I::INST_KIND) diff --git a/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs index cd0b97ce4..d57aeb2cd 100644 --- a/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs @@ -30,6 +30,11 @@ pub struct SetLessThanConfig { } impl Instruction for SetLessThanInstruction { type InstructionConfig = SetLessThanConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[I::INST_KIND] + } fn name() -> String { format!("{:?}", I::INST_KIND) diff --git a/ceno_zkvm/src/instructions/riscv/slti.rs b/ceno_zkvm/src/instructions/riscv/slti.rs index 801d928d2..620d6ff3d 100644 --- a/ceno_zkvm/src/instructions/riscv/slti.rs +++ b/ceno_zkvm/src/instructions/riscv/slti.rs @@ -189,7 +189,7 @@ mod test { &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, - vec![&StepRecord::new_i_instruction( + &[StepRecord::new_i_instruction( 3, Change::new(MOCK_PC_START, MOCK_PC_START + PC_STEP_SIZE), insn_code, diff --git a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit.rs b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit.rs index 8b93f593c..e2df652b1 100644 --- a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit.rs @@ -41,6 +41,11 @@ pub struct SetLessThanImmInstruction(PhantomData<(E, I)>); impl Instruction for SetLessThanImmInstruction { type InstructionConfig = SetLessThanImmConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[I::INST_KIND] + } fn name() -> String { format!("{:?}", I::INST_KIND) diff --git a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs index 914424247..b2449614e 100644 --- a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs @@ -41,6 +41,11 @@ pub struct SetLessThanImmInstruction(PhantomData<(E, I)>); impl Instruction for SetLessThanImmInstruction { type InstructionConfig = SetLessThanImmConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[I::INST_KIND] + } fn name() -> String { format!("{:?}", I::INST_KIND) diff --git a/ceno_zkvm/src/scheme/tests.rs b/ceno_zkvm/src/scheme/tests.rs index 5c63c86fa..dfb6c35ef 100644 --- a/ceno_zkvm/src/scheme/tests.rs +++ b/ceno_zkvm/src/scheme/tests.rs @@ -16,7 +16,7 @@ use crate::{ witness::{LkMultiplicity, set_val}, }; use ceno_emul::{ - CENO_PLATFORM, + CENO_PLATFORM, InsnKind, InsnKind::{ADD, ECALL}, Platform, Program, StepRecord, VMState, encode_rv32, }; @@ -62,6 +62,11 @@ struct TestCircuit { impl Instruction for TestCircuit { type InstructionConfig = TestConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[InsnKind::INVALID] + } fn name() -> String { "TEST".into() @@ -141,12 +146,13 @@ fn test_rw_lk_expression_combination() { // generate mock witness let num_instances = 1 << 8; let mut zkvm_witness = ZKVMWitnesses::default(); + let steps = vec![StepRecord::default(); num_instances]; zkvm_witness .assign_opcode_circuit::>( &zkvm_cs, &mut shard_ctx, &config, - vec![&StepRecord::default(); num_instances], + &steps, ) .unwrap(); @@ -329,7 +335,7 @@ fn test_single_add_instance_e2e() { .collect::>(); let mut add_records = vec![]; let mut halt_records = vec![]; - all_records.iter().for_each(|record| { + all_records.into_iter().for_each(|record| { let kind = record.insn().kind; match kind { ADD => add_records.push(record), @@ -357,7 +363,7 @@ fn test_single_add_instance_e2e() { &zkvm_cs, &mut shard_ctx, &add_config, - add_records, + &add_records, ) .unwrap(); zkvm_witness @@ -365,7 +371,7 @@ fn test_single_add_instance_e2e() { &zkvm_cs, &mut shard_ctx, &halt_config, - halt_records, + &halt_records, ) .unwrap(); zkvm_witness.finalize_lk_multiplicities(); diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index 5ab4f9b61..76a2d9334 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -368,7 +368,7 @@ impl ZKVMWitnesses { cs: &ZKVMConstraintSystem, shard_ctx: &mut ShardContext, config: &OC::InstructionConfig, - records: Vec<&StepRecord>, + records: &[StepRecord], ) -> Result<(), ZKVMError> { assert!(self.combined_lk_mlt.is_none());