From e559fbbf285d29a919ee6d66f76df4b0f67a51f5 Mon Sep 17 00:00:00 2001 From: Guillermo Valin Date: Fri, 13 Feb 2026 09:52:26 -0300 Subject: [PATCH 1/6] Enhance serialization support for proof types - Added `serde::Serialize` and `serde::Deserialize` derives to various structs including `Commitment`, `ShoutAddrPreProof`, and `BatchedTimeProof` for improved serialization capabilities. - Introduced `serde_helpers` module with custom serialization functions for handling non-trivially-serializable fields, specifically for `Vec<&'static [u8]>`. - Implemented verification methods in `FoldingSession` to allow proof verification using externally supplied step instance bundles without executing the program. This update enhances the flexibility and usability of proof serialization across the codebase. --- crates/neo-ajtai/src/types.rs | 2 +- crates/neo-fold/src/lib.rs | 3 + crates/neo-fold/src/riscv_shard.rs | 184 +++++++++++ crates/neo-fold/src/serde_helpers.rs | 61 ++++ crates/neo-fold/src/session.rs | 303 ++++++++++++++++++ crates/neo-fold/src/shard_proof_types.rs | 27 +- crates/neo-reductions/Cargo.toml | 6 +- .../src/engines/optimized_engine/common.rs | 2 +- .../src/engines/optimized_engine/mod.rs | 4 +- 9 files changed, 576 insertions(+), 16 deletions(-) create mode 100644 crates/neo-fold/src/serde_helpers.rs diff --git a/crates/neo-ajtai/src/types.rs b/crates/neo-ajtai/src/types.rs index 9f0b91e3..8fbaa37c 100644 --- a/crates/neo-ajtai/src/types.rs +++ b/crates/neo-ajtai/src/types.rs @@ -13,7 +13,7 @@ pub struct PP { } /// Commitment c ∈ F_q^{d×κ}, stored as column-major flat matrix (κ columns, each length d). -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)] pub struct Commitment { pub d: usize, pub kappa: usize, diff --git a/crates/neo-fold/src/lib.rs b/crates/neo-fold/src/lib.rs index dae0b088..aee318a1 100644 --- a/crates/neo-fold/src/lib.rs +++ b/crates/neo-fold/src/lib.rs @@ -33,3 +33,6 @@ pub mod riscv_shard; pub mod output_binding; mod shard_proof_types; + +/// Serde helpers for proof serialization (e.g. `&'static [u8]` labels). +pub mod serde_helpers; diff --git a/crates/neo-fold/src/riscv_shard.rs b/crates/neo-fold/src/riscv_shard.rs index 25c14283..47c4160b 100644 --- a/crates/neo-fold/src/riscv_shard.rs +++ b/crates/neo-fold/src/riscv_shard.rs @@ -373,6 +373,147 @@ impl Rv32B1 { self } + /// Build only the verification context (CCS + session) without executing the program or proving. + /// + /// This performs the same validation, program decoding, memory layout setup, CCS construction + /// (including shared-bus wiring), and session creation that `prove()` does -- but stops before + /// any RISC-V execution or folding. + /// + /// The returned [`Rv32B1Verifier`] can verify a `ShardProof` given the public MCS instances + /// (`mcss_public`) that were produced by the prover. + /// + /// **Cost:** circuit synthesis only (~ms). No RISC-V execution, no folding. + pub fn build_verifier(self) -> Result { + // --- Input validation (same as prove) --- + if self.xlen != 32 { + return Err(PiCcsError::InvalidInput(format!( + "RV32 B1 MVP requires xlen == 32 (got {})", + self.xlen + ))); + } + if self.program_bytes.is_empty() { + return Err(PiCcsError::InvalidInput("program_bytes must be non-empty".into())); + } + if self.chunk_size == 0 { + return Err(PiCcsError::InvalidInput("chunk_size must be non-zero".into())); + } + if self.ram_bytes == 0 { + return Err(PiCcsError::InvalidInput("ram_bytes must be non-zero".into())); + } + if self.program_base != 0 { + return Err(PiCcsError::InvalidInput( + "RV32 B1 MVP requires program_base == 0 (addresses are indices into PROG/RAM layouts)".into(), + )); + } + if self.program_bytes.len() % 4 != 0 { + return Err(PiCcsError::InvalidInput( + "program_bytes must be 4-byte aligned (RV32 B1 runner does not support RVC)".into(), + )); + } + for (i, chunk) in self.program_bytes.chunks_exact(4).enumerate() { + let first_half = u16::from_le_bytes([chunk[0], chunk[1]]); + if (first_half & 0b11) != 0b11 { + return Err(PiCcsError::InvalidInput(format!( + "RV32 B1 runner does not support compressed instructions (RVC): found compressed encoding at word index {i}" + ))); + } + } + + // --- Program decoding + memory layouts (same as prove, minus VM/Twist init) --- + let program = decode_program(&self.program_bytes) + .map_err(|e| PiCcsError::InvalidInput(format!("decode_program failed: {e}")))?; + + let (prog_layout, initial_mem) = neo_memory::riscv::rom_init::prog_rom_layout_and_init_words( + PROG_ID, + /*base_addr=*/ 0, + &self.program_bytes, + ) + .map_err(|e| PiCcsError::InvalidInput(format!("prog_rom_layout_and_init_words failed: {e}")))?; + let mut initial_mem = initial_mem; + for (&addr, &value) in &self.ram_init { + let value = value as u32 as u64; + initial_mem.insert((neo_memory::riscv::lookups::RAM_ID.0, addr), F::from_u64(value)); + } + + let (k_ram, d_ram) = pow2_ceil_k(self.ram_bytes.max(4)); + let mem_layouts = HashMap::from([ + ( + neo_memory::riscv::lookups::RAM_ID.0, + PlainMemLayout { + k: k_ram, + d: d_ram, + n_side: 2, + lanes: 1, + }, + ), + (PROG_ID.0, prog_layout), + ]); + + // --- Shout tables (same as prove) --- + let mut shout_ops = match &self.shout_ops { + Some(ops) => ops.clone(), + None if self.shout_auto_minimal => infer_required_shout_opcodes(&program), + None => all_shout_opcodes(), + }; + shout_ops.insert(RiscvOpcode::Add); + + let shout = RiscvShoutTables::new(self.xlen); + let mut table_specs: HashMap = HashMap::new(); + for op in shout_ops { + let table_id = shout.opcode_to_id(op).0; + table_specs.insert( + table_id, + LutTableSpec::RiscvOpcode { + opcode: op, + xlen: self.xlen, + }, + ); + } + let mut shout_table_ids: Vec = table_specs.keys().copied().collect(); + shout_table_ids.sort_unstable(); + + // --- CCS + Session (same as prove) --- + let (ccs_base, layout) = build_rv32_b1_step_ccs(&mem_layouts, &shout_table_ids, self.chunk_size) + .map_err(|e| PiCcsError::InvalidInput(format!("build_rv32_b1_step_ccs failed: {e}")))?; + + let mut session = FoldingSession::::new_ajtai(self.mode.clone(), &ccs_base)?; + let params = session.params().clone(); + let committer = session.committer().clone(); + + let empty_tables: HashMap> = HashMap::new(); + + // Build R1csCpu for CCS with shared-bus wiring (same as prove; no execution). + let mut cpu = R1csCpu::new( + ccs_base, + params, + committer, + layout.m_in, + &empty_tables, + &table_specs, + rv32_b1_chunk_to_witness(layout.clone()), + ); + cpu = cpu + .with_shared_cpu_bus( + rv32_b1_shared_cpu_bus_config(&layout, &shout_table_ids, mem_layouts.clone(), initial_mem.clone()) + .map_err(|e| PiCcsError::InvalidInput(format!("rv32_b1_shared_cpu_bus_config failed: {e}")))?, + self.chunk_size, + ) + .map_err(|e| PiCcsError::InvalidInput(format!("shared bus inject failed: {e}")))?; + + session.set_step_linking(rv32_b1_step_linking_config(&layout)); + + let ccs = cpu.ccs.clone(); + + // No execution, no fold_and_prove -- just the verification context. + Ok(Rv32B1Verifier { + session, + ccs, + _layout: layout, + ram_num_bits: d_ram, + output_claims: self.output_claims, + }) + } + pub fn prove(self) -> Result { if self.xlen != 32 { return Err(PiCcsError::InvalidInput(format!( @@ -569,6 +710,40 @@ impl Rv32B1 { } } +/// Verification context for RV32 B1 proofs. +/// +/// Created by [`Rv32B1::build_verifier`]. Contains the CCS structure and folding session +/// needed to verify a `ShardProof` without executing the RISC-V program. +pub struct Rv32B1Verifier { + session: FoldingSession, + ccs: CcsStructure, + _layout: Rv32B1Layout, + ram_num_bits: usize, + output_claims: ProgramIO, +} + +impl Rv32B1Verifier { + /// Verify a `ShardProof` using the provided public step instance bundles. + /// + /// `steps_public` must be the step instance bundles produced by the prover (via + /// `Rv32B1Run::steps_public()`). The verifier checks the folding proof + /// against these instances and the CCS structure -- no RISC-V execution + /// is performed. + pub fn verify( + &self, + proof: &ShardProof, + steps_public: &[StepInstanceBundle], + ) -> Result { + if self.output_claims.is_empty() { + self.session.verify_with_external_steps(&self.ccs, steps_public, proof) + } else { + let ob_cfg = OutputBindingConfig::new(self.ram_num_bits, self.output_claims.clone()); + self.session + .verify_with_external_steps_and_output_binding(&self.ccs, steps_public, proof, &ob_cfg) + } + } +} + pub struct Rv32B1Run { session: FoldingSession, proof: ShardProof, @@ -623,6 +798,15 @@ impl Rv32B1Run { self.session.steps_public() } + /// Return the public MCS instances for inclusion in a proof package. + /// + /// These instances must be transmitted alongside the `ShardProof` so that + /// a standalone verifier (via [`Rv32B1Verifier::verify`]) can check the proof + /// without re-executing the RISC-V program. + pub fn mcss_public(&self) -> Vec> { + self.session.mcss_public() + } + pub fn final_boundary_state(&self) -> Result { let steps_public = self.steps_public(); let last = steps_public diff --git a/crates/neo-fold/src/serde_helpers.rs b/crates/neo-fold/src/serde_helpers.rs new file mode 100644 index 00000000..4ad22f16 --- /dev/null +++ b/crates/neo-fold/src/serde_helpers.rs @@ -0,0 +1,61 @@ +//! Serde helpers for proof types that contain non-trivially-serializable fields. +//! +//! In particular, `BatchedTimeProof` stores `Vec<&'static [u8]>` labels which +//! require custom serialization. We serialize them as `Vec>` and leak +//! the deserialized allocations so the `&'static` lifetime is valid. + +use serde::de::{Deserializer, SeqAccess, Visitor}; +use serde::ser::{SerializeSeq, Serializer}; +use std::fmt; + +/// Serialize `Vec<&'static [u8]>` as `Vec>`. +pub fn serialize_static_byte_slices( + labels: &[&'static [u8]], + serializer: S, +) -> Result +where + S: Serializer, +{ + let mut seq = serializer.serialize_seq(Some(labels.len()))?; + for label in labels { + seq.serialize_element(&label.to_vec())?; + } + seq.end() +} + +/// Deserialize `Vec>` into `Vec<&'static [u8]>` by leaking allocations. +/// +/// This is safe for proof deserialization where the proof lives for the duration +/// of the verification process. The leaked memory is small (label strings are +/// typically short domain-separation tags). +pub fn deserialize_static_byte_slices<'de, D>( + deserializer: D, +) -> Result, D::Error> +where + D: Deserializer<'de>, +{ + struct StaticByteSlicesVisitor; + + impl<'de> Visitor<'de> for StaticByteSlicesVisitor { + type Value = Vec<&'static [u8]>; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a sequence of byte arrays") + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: SeqAccess<'de>, + { + let mut result = Vec::with_capacity(seq.size_hint().unwrap_or(0)); + while let Some(bytes) = seq.next_element::>()? { + // Leak the allocation to get a &'static [u8]. + let leaked: &'static [u8] = Box::leak(bytes.into_boxed_slice()); + result.push(leaked); + } + Ok(result) + } + } + + deserializer.deserialize_seq(StaticByteSlicesVisitor) +} diff --git a/crates/neo-fold/src/session.rs b/crates/neo-fold/src/session.rs index 83c86ae8..ae723da0 100644 --- a/crates/neo-fold/src/session.rs +++ b/crates/neo-fold/src/session.rs @@ -1680,6 +1680,309 @@ where Ok(true) } + /// Verify a proof using externally supplied step instance bundles. + /// + /// Unlike [`verify`] and [`verify_collected`] which use the session's internal + /// collected steps (populated by execution), this method takes the step bundles + /// directly. This enables **verify-only** flows where the verifier never + /// executes the program -- the steps come from the proof package. + pub fn verify_with_external_steps( + &self, + s: &CcsStructure, + steps_public: &[StepInstanceBundle], + run: &FoldRun, + ) -> Result { + let mut tr = Poseidon2Transcript::new(b"neo.fold/session"); + self.verify_with_external_steps_transcript(&mut tr, s, steps_public, run) + } + + /// Like [`verify_with_external_steps`] but with output binding. + pub fn verify_with_external_steps_and_output_binding( + &self, + s: &CcsStructure, + steps_public: &[StepInstanceBundle], + run: &FoldRun, + ob_cfg: &crate::output_binding::OutputBindingConfig, + ) -> Result { + let mut tr = Poseidon2Transcript::new(b"neo.fold/session"); + self.verify_with_external_steps_and_output_binding_transcript(&mut tr, s, steps_public, run, ob_cfg) + } + + /// Internal: verify with external steps + transcript. + fn verify_with_external_steps_transcript( + &self, + tr: &mut Poseidon2Transcript, + s: &CcsStructure, + steps_public: &[StepInstanceBundle], + run: &FoldRun, + ) -> Result { + let src_ptr = (s as *const CcsStructure) as usize; + let verifier_cache = self + .verifier_ctx + .as_ref() + .filter(|cache| cache.src_ptr == src_ptr); + let verifier_ctx = verifier_cache.map(|cache| &cache.ctx); + + let m_in_steps = steps_public.first().map(|inst| inst.mcs_inst.m_in).unwrap_or(0); + let s_prepared = self.prepared_ccs_for_accumulator(s)?; + + let seed_me: &[MeInstance] = match &self.acc0 { + Some(acc) => { + acc.check(&self.params, s_prepared)?; + let acc_m_in = acc.me.first().map(|m| m.m_in).unwrap_or(m_in_steps); + if acc_m_in != m_in_steps { + return Err(PiCcsError::InvalidInput( + "initial Accumulator.m_in must match steps' m_in".into(), + )); + } + &acc.me + } + None => &[], + }; + + let step_linking = self + .step_linking + .as_ref() + .filter(|cfg| !cfg.prev_next_equalities.is_empty()); + + let outputs = if steps_public.len() > 1 { + match step_linking { + Some(cfg) => match verifier_ctx { + Some(ctx) => shard::fold_shard_verify_with_step_linking_with_context( + self.mode.clone(), + tr, + &self.params, + s, + steps_public, + seed_me, + run, + self.mixers, + cfg, + ctx, + )?, + None => shard::fold_shard_verify_with_step_linking( + self.mode.clone(), + tr, + &self.params, + s, + steps_public, + seed_me, + run, + self.mixers, + cfg, + )?, + }, + None if self.allow_unlinked_steps => match verifier_ctx { + Some(ctx) => shard::fold_shard_verify_with_context( + self.mode.clone(), + tr, + &self.params, + s, + steps_public, + seed_me, + run, + self.mixers, + ctx, + )?, + None => shard::fold_shard_verify( + self.mode.clone(), + tr, + &self.params, + s, + steps_public, + seed_me, + run, + self.mixers, + )?, + }, + None => { + let mut msg = + "multi-step verification requires step linking; call FoldingSession::set_step_linking(...)" + .to_string(); + if let Some(diag) = &self.auto_step_linking_error { + msg.push_str(&format!(" (auto step-linking from StepSpec failed: {diag})")); + } + return Err(PiCcsError::InvalidInput(msg)); + } + } + } else { + match verifier_ctx { + Some(ctx) => shard::fold_shard_verify_with_context( + self.mode.clone(), + tr, + &self.params, + s, + steps_public, + seed_me, + run, + self.mixers, + ctx, + )?, + None => shard::fold_shard_verify( + self.mode.clone(), + tr, + &self.params, + s, + steps_public, + seed_me, + run, + self.mixers, + )?, + } + }; + + // Detect twist/shout from the provided steps (self.steps is empty for verify-only). + let has_twist_or_shout = steps_public.iter().any(|s| !s.mem_insts.is_empty()) + || steps_public.iter().any(|s| !s.lut_insts.is_empty()); + if !has_twist_or_shout && !outputs.obligations.val.is_empty() { + return Err(PiCcsError::ProtocolError( + "CCS-only session verification produced unexpected val-lane obligations".into(), + )); + } + Ok(true) + } + + /// Internal: verify with external steps + output binding + transcript. + fn verify_with_external_steps_and_output_binding_transcript( + &self, + tr: &mut Poseidon2Transcript, + s: &CcsStructure, + steps_public: &[StepInstanceBundle], + run: &FoldRun, + ob_cfg: &crate::output_binding::OutputBindingConfig, + ) -> Result { + let src_ptr = (s as *const CcsStructure) as usize; + let verifier_cache = self + .verifier_ctx + .as_ref() + .filter(|cache| cache.src_ptr == src_ptr); + let verifier_ctx = verifier_cache.map(|cache| &cache.ctx); + + let m_in_steps = steps_public.first().map(|inst| inst.mcs_inst.m_in).unwrap_or(0); + let s_prepared = self.prepared_ccs_for_accumulator(s)?; + + let seed_me: &[MeInstance] = match &self.acc0 { + Some(acc) => { + acc.check(&self.params, s_prepared)?; + let acc_m_in = acc.me.first().map(|m| m.m_in).unwrap_or(m_in_steps); + if acc_m_in != m_in_steps { + return Err(PiCcsError::InvalidInput( + "initial Accumulator.m_in must match steps' m_in".into(), + )); + } + &acc.me + } + None => &[], + }; + + let step_linking = self + .step_linking + .as_ref() + .filter(|cfg| !cfg.prev_next_equalities.is_empty()); + + let outputs = if steps_public.len() > 1 { + match step_linking { + Some(cfg) => match verifier_ctx { + Some(ctx) => shard::fold_shard_verify_with_output_binding_and_step_linking_with_context( + self.mode.clone(), + tr, + &self.params, + s, + steps_public, + seed_me, + run, + self.mixers, + ob_cfg, + cfg, + ctx, + )?, + None => shard::fold_shard_verify_with_output_binding_and_step_linking( + self.mode.clone(), + tr, + &self.params, + s, + steps_public, + seed_me, + run, + self.mixers, + ob_cfg, + cfg, + )?, + }, + None if self.allow_unlinked_steps => match verifier_ctx { + Some(ctx) => shard::fold_shard_verify_with_output_binding_with_context( + self.mode.clone(), + tr, + &self.params, + s, + steps_public, + seed_me, + run, + self.mixers, + ob_cfg, + ctx, + )?, + None => shard::fold_shard_verify_with_output_binding( + self.mode.clone(), + tr, + &self.params, + s, + steps_public, + seed_me, + run, + self.mixers, + ob_cfg, + )?, + }, + None => { + let mut msg = + "multi-step verification with output binding requires step linking" + .to_string(); + if let Some(diag) = &self.auto_step_linking_error { + msg.push_str(&format!(" (auto step-linking from StepSpec failed: {diag})")); + } + return Err(PiCcsError::InvalidInput(msg)); + } + } + } else { + match verifier_ctx { + Some(ctx) => shard::fold_shard_verify_with_output_binding_with_context( + self.mode.clone(), + tr, + &self.params, + s, + steps_public, + seed_me, + run, + self.mixers, + ob_cfg, + ctx, + )?, + None => shard::fold_shard_verify_with_output_binding( + self.mode.clone(), + tr, + &self.params, + s, + steps_public, + seed_me, + run, + self.mixers, + ob_cfg, + )?, + } + }; + + let has_twist_or_shout = !steps_public.is_empty() + && (steps_public.iter().any(|s| !s.lut_insts.is_empty()) + || steps_public.iter().any(|s| !s.mem_insts.is_empty())); + if !has_twist_or_shout && !outputs.obligations.val.is_empty() { + return Err(PiCcsError::ProtocolError( + "CCS-only session verification produced unexpected val-lane obligations".into(), + )); + } + Ok(true) + } + /// Verify with output binding, managing the transcript internally. pub fn verify_with_output_binding_simple( &self, diff --git a/crates/neo-fold/src/shard_proof_types.rs b/crates/neo-fold/src/shard_proof_types.rs index 51534325..b19949d2 100644 --- a/crates/neo-fold/src/shard_proof_types.rs +++ b/crates/neo-fold/src/shard_proof_types.rs @@ -19,7 +19,7 @@ pub type ShoutProofK = neo_memory::shout::ShoutProof; /// Within each group, when a Shout lane is provably inactive for a step (no lookups), we can /// skip its address-domain sumcheck entirely. We still bind all `claimed_sums` to the transcript, /// but we include sumcheck rounds only for the active subset. -#[derive(Clone, Debug)] +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] pub struct ShoutAddrPreProof { /// Claimed sums per Shout lane. /// @@ -32,7 +32,7 @@ pub struct ShoutAddrPreProof { pub groups: Vec>, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] pub struct ShoutAddrPreGroupProof { /// Address-bit width (sumcheck round count) for this group. pub ell_addr: u32, @@ -58,7 +58,7 @@ impl Default for ShoutAddrPreProof { } /// One fold step’s artifacts (Π_CCS → Π_RLC → Π_DEC). -#[derive(Clone, Debug)] +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] pub struct FoldStep { /// Π_CCS outputs (k ME(b,L) instances) pub ccs_out: Vec>, @@ -125,13 +125,17 @@ pub struct ShardFoldWitnesses { pub val_lane_wits: Vec>, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] pub enum MemOrLutProof { Twist(TwistProofK), Shout(ShoutProofK), } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] +#[serde(bound( + serialize = "C: serde::Serialize, FF: serde::Serialize, KK: serde::Serialize", + deserialize = "C: serde::de::DeserializeOwned, FF: serde::de::DeserializeOwned, KK: serde::de::DeserializeOwned + Default" +))] pub struct MemSidecarProof { /// CPU ME claims evaluated at `r_val` (Twist val-eval terminal point). /// @@ -147,20 +151,25 @@ pub struct MemSidecarProof { /// /// This batches CCS (row/time rounds) with Twist/Shout time-domain oracles so all /// protocols share the same transcript-derived `r` (enabling Π_RLC folding). -#[derive(Clone, Debug)] +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] pub struct BatchedTimeProof { /// Claimed sums per participating oracle (in the same order as `round_polys`). pub claimed_sums: Vec, /// Degree bounds per participating oracle. pub degree_bounds: Vec, /// Domain-separation labels per participating oracle. + /// Serialized as `Vec>`; the `&'static` lifetime is restored via `labels_static()`. + #[serde( + serialize_with = "crate::serde_helpers::serialize_static_byte_slices", + deserialize_with = "crate::serde_helpers::deserialize_static_byte_slices" + )] pub labels: Vec<&'static [u8]>, /// Per-claim sum-check messages: `round_polys[claim][round] = coeffs`. pub round_polys: Vec>>, } /// Proof data for a standalone Π_RLC → Π_DEC lane (no Π_CCS). -#[derive(Clone, Debug)] +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] pub struct RlcDecProof { /// RLC mixing matrices ρ_i ∈ S ⊆ F^{D×D} pub rlc_rhos: Vec>, @@ -170,7 +179,7 @@ pub struct RlcDecProof { pub dec_children: Vec>, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] pub struct StepProof { pub fold: FoldStep, pub mem: MemSidecarProof, @@ -179,7 +188,7 @@ pub struct StepProof { pub val_fold: Option, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] pub struct ShardProof { pub steps: Vec, /// Optional output binding proof (proves final memory matches claimed outputs). diff --git a/crates/neo-reductions/Cargo.toml b/crates/neo-reductions/Cargo.toml index 4f2636e7..2c7691a0 100644 --- a/crates/neo-reductions/Cargo.toml +++ b/crates/neo-reductions/Cargo.toml @@ -21,7 +21,7 @@ fs-guard = ["neo-transcript/fs-guard"] # Enable verbose logging for debugging (no-op in this crate; used for cfg gating) neo-logs = [] # Enable constraint diagnostic system -prove-diagnostics = ["serde", "serde_json", "flate2", "hex", "chrono"] +prove-diagnostics = ["serde_json", "flate2", "hex", "chrono"] # Enable CBOR format for diagnostics (smaller but less readable) prove-diagnostics-cbor = ["prove-diagnostics", "serde_cbor"] wasm-threads = ["neo-ajtai/wasm-threads", "neo-ccs/wasm-threads"] @@ -46,8 +46,8 @@ thiserror = { workspace = true } bincode = { workspace = true } blake3 = "1.5" -# Diagnostic system dependencies (feature-gated) -serde = { workspace = true, optional = true } +# Serde for proof serialization (always available; also used by diagnostics) +serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true, optional = true } flate2 = { version = "1.0", optional = true } hex = { version = "0.4", optional = true } diff --git a/crates/neo-reductions/src/engines/optimized_engine/common.rs b/crates/neo-reductions/src/engines/optimized_engine/common.rs index 19c81021..a14ea5bb 100644 --- a/crates/neo-reductions/src/engines/optimized_engine/common.rs +++ b/crates/neo-reductions/src/engines/optimized_engine/common.rs @@ -17,7 +17,7 @@ use p3_field::{Field, PrimeCharacteristicRing}; use rayon::prelude::*; /// Challenges sampled in Step 1 of the protocol -#[derive(Debug, Clone)] +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] pub struct Challenges { /// α ∈ K^{log d} - for Ajtai dimension pub alpha: Vec, diff --git a/crates/neo-reductions/src/engines/optimized_engine/mod.rs b/crates/neo-reductions/src/engines/optimized_engine/mod.rs index 4cf317b3..b16c56f7 100644 --- a/crates/neo-reductions/src/engines/optimized_engine/mod.rs +++ b/crates/neo-reductions/src/engines/optimized_engine/mod.rs @@ -20,7 +20,7 @@ pub mod verify; pub use common::Challenges; /// Proof format variant for Π_CCS. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)] pub enum PiCcsProofVariant { /// Split-NC proof with two sumchecks: FE-only + NC-only. SplitNcV1, @@ -55,7 +55,7 @@ pub use common::{ }; /// Proof structure for the Π_CCS protocol -#[derive(Debug, Clone)] +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] pub struct PiCcsProof { /// Proof format variant. pub variant: PiCcsProofVariant, From c8096d0328589c710539bee299404e01644b10e8 Mon Sep 17 00:00:00 2001 From: Guillermo Valin Date: Fri, 13 Feb 2026 10:43:04 -0300 Subject: [PATCH 2/6] Enforce chunk0 memory integrity --- crates/neo-fold/src/lib.rs | 3 - .../src/memory_sidecar/route_a_time.rs | 4 +- crates/neo-fold/src/riscv_shard.rs | 13 +++- crates/neo-fold/src/serde_helpers.rs | 61 ------------------ crates/neo-fold/src/shard.rs | 8 ++- crates/neo-fold/src/shard_proof_types.rs | 7 +- .../riscv_build_verifier_statement_memory.rs | 64 +++++++++++++++++++ .../neo-fold/tests/twist_shout_soundness.rs | 2 +- 8 files changed, 87 insertions(+), 75 deletions(-) delete mode 100644 crates/neo-fold/src/serde_helpers.rs create mode 100644 crates/neo-fold/tests/riscv_build_verifier_statement_memory.rs diff --git a/crates/neo-fold/src/lib.rs b/crates/neo-fold/src/lib.rs index aee318a1..dae0b088 100644 --- a/crates/neo-fold/src/lib.rs +++ b/crates/neo-fold/src/lib.rs @@ -33,6 +33,3 @@ pub mod riscv_shard; pub mod output_binding; mod shard_proof_types; - -/// Serde helpers for proof serialization (e.g. `&'static [u8]` labels). -pub mod serde_helpers; diff --git a/crates/neo-fold/src/memory_sidecar/route_a_time.rs b/crates/neo-fold/src/memory_sidecar/route_a_time.rs index ee39a129..c0e49ff5 100644 --- a/crates/neo-fold/src/memory_sidecar/route_a_time.rs +++ b/crates/neo-fold/src/memory_sidecar/route_a_time.rs @@ -137,7 +137,7 @@ pub fn prove_route_a_batched_time( let proof = BatchedTimeProof { claimed_sums: claimed_sums.clone(), degree_bounds: degree_bounds.clone(), - labels: labels.clone(), + labels: labels.iter().map(|label| label.to_vec()).collect(), round_polys: per_claim_results .iter() .map(|r| r.round_polys.clone()) @@ -223,7 +223,7 @@ pub fn verify_route_a_batched_time( ))); } for (i, (got, exp)) in proof.labels.iter().zip(expected_labels.iter()).enumerate() { - if (*got as &[u8]) != *exp { + if got.as_slice() != *exp { return Err(PiCcsError::ProtocolError(format!( "step {}: batched_time label mismatch at claim {}", step_idx, i diff --git a/crates/neo-fold/src/riscv_shard.rs b/crates/neo-fold/src/riscv_shard.rs index 47c4160b..40e0029d 100644 --- a/crates/neo-fold/src/riscv_shard.rs +++ b/crates/neo-fold/src/riscv_shard.rs @@ -509,6 +509,8 @@ impl Rv32B1 { session, ccs, _layout: layout, + mem_layouts, + statement_initial_mem: initial_mem, ram_num_bits: d_ram, output_claims: self.output_claims, }) @@ -718,6 +720,8 @@ pub struct Rv32B1Verifier { session: FoldingSession, ccs: CcsStructure, _layout: Rv32B1Layout, + mem_layouts: HashMap, + statement_initial_mem: HashMap<(u32, u64), F>, ram_num_bits: usize, output_claims: ProgramIO, } @@ -734,8 +738,15 @@ impl Rv32B1Verifier { proof: &ShardProof, steps_public: &[StepInstanceBundle], ) -> Result { + rv32_b1_enforce_chunk0_mem_init_matches_statement( + &self.mem_layouts, + &self.statement_initial_mem, + steps_public, + )?; + if self.output_claims.is_empty() { - self.session.verify_with_external_steps(&self.ccs, steps_public, proof) + self.session + .verify_with_external_steps(&self.ccs, steps_public, proof) } else { let ob_cfg = OutputBindingConfig::new(self.ram_num_bits, self.output_claims.clone()); self.session diff --git a/crates/neo-fold/src/serde_helpers.rs b/crates/neo-fold/src/serde_helpers.rs deleted file mode 100644 index 4ad22f16..00000000 --- a/crates/neo-fold/src/serde_helpers.rs +++ /dev/null @@ -1,61 +0,0 @@ -//! Serde helpers for proof types that contain non-trivially-serializable fields. -//! -//! In particular, `BatchedTimeProof` stores `Vec<&'static [u8]>` labels which -//! require custom serialization. We serialize them as `Vec>` and leak -//! the deserialized allocations so the `&'static` lifetime is valid. - -use serde::de::{Deserializer, SeqAccess, Visitor}; -use serde::ser::{SerializeSeq, Serializer}; -use std::fmt; - -/// Serialize `Vec<&'static [u8]>` as `Vec>`. -pub fn serialize_static_byte_slices( - labels: &[&'static [u8]], - serializer: S, -) -> Result -where - S: Serializer, -{ - let mut seq = serializer.serialize_seq(Some(labels.len()))?; - for label in labels { - seq.serialize_element(&label.to_vec())?; - } - seq.end() -} - -/// Deserialize `Vec>` into `Vec<&'static [u8]>` by leaking allocations. -/// -/// This is safe for proof deserialization where the proof lives for the duration -/// of the verification process. The leaked memory is small (label strings are -/// typically short domain-separation tags). -pub fn deserialize_static_byte_slices<'de, D>( - deserializer: D, -) -> Result, D::Error> -where - D: Deserializer<'de>, -{ - struct StaticByteSlicesVisitor; - - impl<'de> Visitor<'de> for StaticByteSlicesVisitor { - type Value = Vec<&'static [u8]>; - - fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - formatter.write_str("a sequence of byte arrays") - } - - fn visit_seq(self, mut seq: A) -> Result - where - A: SeqAccess<'de>, - { - let mut result = Vec::with_capacity(seq.size_hint().unwrap_or(0)); - while let Some(bytes) = seq.next_element::>()? { - // Leak the allocation to get a &'static [u8]. - let leaked: &'static [u8] = Box::leak(bytes.into_boxed_slice()); - result.push(leaked); - } - Ok(result) - } - } - - deserializer.deserialize_seq(StaticByteSlicesVisitor) -} diff --git a/crates/neo-fold/src/shard.rs b/crates/neo-fold/src/shard.rs index 035bd1bb..b1100389 100644 --- a/crates/neo-fold/src/shard.rs +++ b/crates/neo-fold/src/shard.rs @@ -3527,7 +3527,13 @@ where .len() .checked_sub(1) .ok_or_else(|| PiCcsError::ProtocolError("missing inc_total claim".into()))?; - if step_proof.batched_time.labels.get(inc_idx).copied() != Some(crate::output_binding::OB_INC_TOTAL_LABEL) { + if step_proof + .batched_time + .labels + .get(inc_idx) + .map(|label| label.as_slice()) + != Some(crate::output_binding::OB_INC_TOTAL_LABEL) + { return Err(PiCcsError::ProtocolError("output binding claim not last".into())); } diff --git a/crates/neo-fold/src/shard_proof_types.rs b/crates/neo-fold/src/shard_proof_types.rs index b19949d2..c11a9490 100644 --- a/crates/neo-fold/src/shard_proof_types.rs +++ b/crates/neo-fold/src/shard_proof_types.rs @@ -158,12 +158,7 @@ pub struct BatchedTimeProof { /// Degree bounds per participating oracle. pub degree_bounds: Vec, /// Domain-separation labels per participating oracle. - /// Serialized as `Vec>`; the `&'static` lifetime is restored via `labels_static()`. - #[serde( - serialize_with = "crate::serde_helpers::serialize_static_byte_slices", - deserialize_with = "crate::serde_helpers::deserialize_static_byte_slices" - )] - pub labels: Vec<&'static [u8]>, + pub labels: Vec>, /// Per-claim sum-check messages: `round_polys[claim][round] = coeffs`. pub round_polys: Vec>>, } diff --git a/crates/neo-fold/tests/riscv_build_verifier_statement_memory.rs b/crates/neo-fold/tests/riscv_build_verifier_statement_memory.rs new file mode 100644 index 00000000..fa5189c4 --- /dev/null +++ b/crates/neo-fold/tests/riscv_build_verifier_statement_memory.rs @@ -0,0 +1,64 @@ +use neo_fold::riscv_shard::Rv32B1; +use neo_fold::PiCcsError; +use neo_memory::riscv::lookups::{encode_program, RiscvInstruction, RiscvOpcode}; + +fn program_bytes_with_seed(seed: i32) -> Vec { + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: seed, + }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 2, + rs1: 1, + imm: 1, + }, + RiscvInstruction::Halt, + ]; + encode_program(&program) +} + +#[test] +fn rv32_b1_build_verifier_binds_statement_memory_to_chunk0_mem_init() { + let program_a = program_bytes_with_seed(7); + let program_b = program_bytes_with_seed(9); + + let mut run_a = Rv32B1::from_rom(/*program_base=*/ 0, &program_a) + .chunk_size(1) + .prove() + .expect("prove statement A"); + run_a.verify().expect("self-verify statement A"); + + let proof_a = run_a.proof().clone(); + let steps_public_a = run_a.steps_public(); + + let verifier_a = Rv32B1::from_rom(/*program_base=*/ 0, &program_a) + .chunk_size(1) + .build_verifier() + .expect("build verifier for statement A"); + let ok = verifier_a + .verify(&proof_a, &steps_public_a) + .expect("matching statement should verify"); + assert!(ok, "matching statement should verify"); + + let verifier_b = Rv32B1::from_rom(/*program_base=*/ 0, &program_b) + .chunk_size(1) + .build_verifier() + .expect("build verifier for statement B"); + let err = verifier_b + .verify(&proof_a, &steps_public_a) + .expect_err("different statement memory must be rejected"); + + match err { + PiCcsError::InvalidInput(msg) => { + assert!( + msg.contains("chunk0 MemInstance.init mismatch"), + "unexpected invalid-input error: {msg}" + ); + } + other => panic!("unexpected error kind: {other:?}"), + } +} diff --git a/crates/neo-fold/tests/twist_shout_soundness.rs b/crates/neo-fold/tests/twist_shout_soundness.rs index b7f35988..58e2f8b8 100644 --- a/crates/neo-fold/tests/twist_shout_soundness.rs +++ b/crates/neo-fold/tests/twist_shout_soundness.rs @@ -338,7 +338,7 @@ fn tamper_batched_time_label_fails() { ) .expect("prove should succeed"); - proof.steps[0].batched_time.labels[0] = b"bad/label".as_slice(); + proof.steps[0].batched_time.labels[0] = b"bad/label".to_vec(); let mut tr_verify = Poseidon2Transcript::new(b"soundness/tamper-batched-time-label"); let steps_public = [StepInstanceBundle::from(&step_bundle)]; From 1aa6fe14dfe12eefe99f58b39f11bb8be76bb333 Mon Sep 17 00:00:00 2001 From: Guillermo Valin Date: Fri, 13 Feb 2026 11:42:06 -0300 Subject: [PATCH 3/6] Validate Commitment invariants --- crates/neo-ajtai/src/types.rs | 48 ++++++++++++++++++- .../commitment_deserialize_invariants.rs | 41 ++++++++++++++++ 2 files changed, 87 insertions(+), 2 deletions(-) create mode 100644 crates/neo-ajtai/tests/commitment_deserialize_invariants.rs diff --git a/crates/neo-ajtai/src/types.rs b/crates/neo-ajtai/src/types.rs index 8fbaa37c..07527f3e 100644 --- a/crates/neo-ajtai/src/types.rs +++ b/crates/neo-ajtai/src/types.rs @@ -1,6 +1,7 @@ use p3_field::PrimeCharacteristicRing; use p3_goldilocks::Goldilocks as Fq; -use serde::{Deserialize, Serialize}; +use serde::de::Error as _; +use serde::{Deserialize, Deserializer, Serialize}; /// Public parameters for Ajtai: M ∈ R_q^{κ×m}, stored row-major. #[derive(Clone, Debug, Serialize, Deserialize)] @@ -13,7 +14,7 @@ pub struct PP { } /// Commitment c ∈ F_q^{d×κ}, stored as column-major flat matrix (κ columns, each length d). -#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize)] pub struct Commitment { pub d: usize, pub kappa: usize, @@ -22,6 +23,26 @@ pub struct Commitment { } impl Commitment { + #[inline] + fn validate_shape(d: usize, kappa: usize, data_len: usize) -> Result<(), String> { + let expected_d = neo_math::ring::D; + if d != expected_d { + return Err(format!("invalid Commitment.d: expected {expected_d}, got {d}")); + } + + let expected_len = d + .checked_mul(kappa) + .ok_or_else(|| format!("invalid Commitment shape: d*kappa overflow (d={d}, kappa={kappa})"))?; + if data_len != expected_len { + return Err(format!( + "invalid Commitment shape: data.len()={} but d*kappa={expected_len}", + data_len + )); + } + + Ok(()) + } + pub fn zeros(d: usize, kappa: usize) -> Self { Self { d, @@ -48,3 +69,26 @@ impl Commitment { } } } + +impl<'de> Deserialize<'de> for Commitment { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + #[derive(Deserialize)] + struct CommitmentWire { + d: usize, + kappa: usize, + data: Vec, + } + + let wire = CommitmentWire::deserialize(deserializer)?; + Commitment::validate_shape(wire.d, wire.kappa, wire.data.len()).map_err(D::Error::custom)?; + + Ok(Self { + d: wire.d, + kappa: wire.kappa, + data: wire.data, + }) + } +} diff --git a/crates/neo-ajtai/tests/commitment_deserialize_invariants.rs b/crates/neo-ajtai/tests/commitment_deserialize_invariants.rs new file mode 100644 index 00000000..9b743ec2 --- /dev/null +++ b/crates/neo-ajtai/tests/commitment_deserialize_invariants.rs @@ -0,0 +1,41 @@ +use neo_ajtai::Commitment; +use neo_math::ring::D; +use serde_json::Value; + +fn valid_commitment_json(kappa: usize) -> Value { + let c = Commitment::zeros(D, kappa); + serde_json::to_value(&c).expect("serialize valid commitment") +} + +#[test] +fn commitment_deserialize_rejects_wrong_d() { + let mut value = valid_commitment_json(2); + value["d"] = serde_json::json!(D - 1); + + let err = serde_json::from_value::(value).expect_err("wrong d must be rejected"); + let msg = err.to_string(); + assert!(msg.contains("invalid Commitment.d"), "unexpected error message: {msg}"); +} + +#[test] +fn commitment_deserialize_rejects_data_len_mismatch() { + let mut value = valid_commitment_json(2); + let data = value + .get_mut("data") + .and_then(Value::as_array_mut) + .expect("commitment.data array"); + data.pop().expect("non-empty data"); + + let err = serde_json::from_value::(value).expect_err("shape mismatch must be rejected"); + let msg = err.to_string(); + assert!(msg.contains("data.len()"), "unexpected error message: {msg}"); +} + +#[test] +fn commitment_deserialize_accepts_valid_shape() { + let value = valid_commitment_json(3); + let c: Commitment = serde_json::from_value(value).expect("valid shape should deserialize"); + assert_eq!(c.d, D); + assert_eq!(c.kappa, 3); + assert_eq!(c.data.len(), D * 3); +} From a6b22554a69106a1119f3d326a42bd9c234c741f Mon Sep 17 00:00:00 2001 From: Guillermo Valin Date: Fri, 13 Feb 2026 11:59:27 -0300 Subject: [PATCH 4/6] Enforce chunk0 memory match --- crates/neo-fold/src/session.rs | 22 +++++- .../riscv_build_verifier_statement_memory.rs | 73 +++++++++++++++++++ 2 files changed, 93 insertions(+), 2 deletions(-) diff --git a/crates/neo-fold/src/session.rs b/crates/neo-fold/src/session.rs index ae723da0..0dece63c 100644 --- a/crates/neo-fold/src/session.rs +++ b/crates/neo-fold/src/session.rs @@ -1723,7 +1723,16 @@ where .filter(|cache| cache.src_ptr == src_ptr); let verifier_ctx = verifier_cache.map(|cache| &cache.ctx); - let m_in_steps = steps_public.first().map(|inst| inst.mcs_inst.m_in).unwrap_or(0); + let m_in_steps = steps_public + .first() + .map(|inst| inst.mcs_inst.m_in) + .unwrap_or(0); + if !steps_public + .iter() + .all(|inst| inst.mcs_inst.m_in == m_in_steps) + { + return Err(PiCcsError::InvalidInput("all steps must share the same m_in".into())); + } let s_prepared = self.prepared_ccs_for_accumulator(s)?; let seed_me: &[MeInstance] = match &self.acc0 { @@ -1858,7 +1867,16 @@ where .filter(|cache| cache.src_ptr == src_ptr); let verifier_ctx = verifier_cache.map(|cache| &cache.ctx); - let m_in_steps = steps_public.first().map(|inst| inst.mcs_inst.m_in).unwrap_or(0); + let m_in_steps = steps_public + .first() + .map(|inst| inst.mcs_inst.m_in) + .unwrap_or(0); + if !steps_public + .iter() + .all(|inst| inst.mcs_inst.m_in == m_in_steps) + { + return Err(PiCcsError::InvalidInput("all steps must share the same m_in".into())); + } let s_prepared = self.prepared_ccs_for_accumulator(s)?; let seed_me: &[MeInstance] = match &self.acc0 { diff --git a/crates/neo-fold/tests/riscv_build_verifier_statement_memory.rs b/crates/neo-fold/tests/riscv_build_verifier_statement_memory.rs index fa5189c4..78605df7 100644 --- a/crates/neo-fold/tests/riscv_build_verifier_statement_memory.rs +++ b/crates/neo-fold/tests/riscv_build_verifier_statement_memory.rs @@ -1,6 +1,8 @@ use neo_fold::riscv_shard::Rv32B1; use neo_fold::PiCcsError; +use neo_math::F; use neo_memory::riscv::lookups::{encode_program, RiscvInstruction, RiscvOpcode}; +use p3_field::PrimeCharacteristicRing; fn program_bytes_with_seed(seed: i32) -> Vec { let program = vec![ @@ -21,6 +23,18 @@ fn program_bytes_with_seed(seed: i32) -> Vec { encode_program(&program) } +fn assert_m_in_mismatch_rejected(err: PiCcsError) { + match err { + PiCcsError::InvalidInput(msg) => { + assert!( + msg.contains("all steps must share the same m_in"), + "unexpected invalid-input error: {msg}" + ); + } + other => panic!("unexpected error kind: {other:?}"), + } +} + #[test] fn rv32_b1_build_verifier_binds_statement_memory_to_chunk0_mem_init() { let program_a = program_bytes_with_seed(7); @@ -62,3 +76,62 @@ fn rv32_b1_build_verifier_binds_statement_memory_to_chunk0_mem_init() { other => panic!("unexpected error kind: {other:?}"), } } + +#[test] +fn rv32_b1_build_verifier_rejects_external_steps_with_non_uniform_m_in() { + let program = program_bytes_with_seed(11); + + let mut run = Rv32B1::from_rom(/*program_base=*/ 0, &program) + .chunk_size(1) + .prove() + .expect("prove statement"); + run.verify().expect("self-verify statement"); + + let proof = run.proof().clone(); + let mut steps_public = run.steps_public(); + assert!( + steps_public.len() > 1, + "test needs at least two steps to create m_in mismatch" + ); + steps_public[1].mcs_inst.m_in += 1; + + let verifier = Rv32B1::from_rom(/*program_base=*/ 0, &program) + .chunk_size(1) + .build_verifier() + .expect("build verifier"); + let err = verifier + .verify(&proof, &steps_public) + .expect_err("non-uniform m_in must be rejected"); + assert_m_in_mismatch_rejected(err); +} + +#[test] +fn rv32_b1_build_verifier_output_binding_rejects_external_steps_with_non_uniform_m_in() { + let program = program_bytes_with_seed(13); + + let mut run = Rv32B1::from_rom(/*program_base=*/ 0, &program) + .chunk_size(1) + .output(/*output_addr=*/ 0, /*expected_output=*/ F::ZERO) + .prove() + .expect("prove statement with output binding"); + run.verify() + .expect("self-verify statement with output binding"); + + let proof = run.proof().clone(); + let mut steps_public = run.steps_public(); + assert!( + steps_public.len() > 1, + "test needs at least two steps to create m_in mismatch" + ); + steps_public[1].mcs_inst.m_in += 1; + + let verifier = Rv32B1::from_rom(/*program_base=*/ 0, &program) + .chunk_size(1) + .output(/*output_addr=*/ 0, /*expected_output=*/ F::ZERO) + .build_verifier() + .expect("build verifier with output binding"); + let err = verifier + .verify(&proof, &steps_public) + .expect_err("non-uniform m_in must be rejected"); + assert_m_in_mismatch_rejected(err); +} From be7944c8ffb969a0649012cf5835696790238f3e Mon Sep 17 00:00:00 2001 From: Guillermo Valin Date: Tue, 17 Feb 2026 09:49:30 -0300 Subject: [PATCH 5/6] Add Poseidon2-Goldilocks hash support via ECALL precompile - Introduced Poseidon2 compute and read ECALLs for hashing Goldilocks field elements. - Implemented guest-side API for computing Poseidon2 hashes, allowing for efficient input handling and digest retrieval. - Enhanced RISC-V CPU to manage Poseidon2 ECALLs, including internal state for pending digest words. - Updated CCS layout and constraints to accommodate new ECALLs and ensure proper integration with existing functionality. --- crates/neo-fold/src/riscv_shard.rs | 205 +++++++++++++++++- .../riscv_poseidon2_ecall_prove_verify.rs | 93 ++++++++ crates/neo-memory/src/riscv/ccs.rs | 120 +++++++++- crates/neo-memory/src/riscv/ccs/layout.rs | 41 ++++ crates/neo-memory/src/riscv/ccs/witness.rs | 40 +++- crates/neo-memory/src/riscv/lookups/cpu.rs | 80 ++++++- crates/neo-memory/src/riscv/lookups/mod.rs | 16 ++ crates/neo-memory/tests/riscv_ccs_tests.rs | 94 +++++++- crates/neo-vm-trace/src/lib.rs | 14 ++ crates/nightstream-sdk/src/goldilocks.rs | 116 ++++++++++ crates/nightstream-sdk/src/lib.rs | 2 + crates/nightstream-sdk/src/poseidon2.rs | 111 ++++++++++ 12 files changed, 915 insertions(+), 17 deletions(-) create mode 100644 crates/neo-fold/tests/riscv_poseidon2_ecall_prove_verify.rs create mode 100644 crates/nightstream-sdk/src/goldilocks.rs create mode 100644 crates/nightstream-sdk/src/poseidon2.rs diff --git a/crates/neo-fold/src/riscv_shard.rs b/crates/neo-fold/src/riscv_shard.rs index 40e0029d..a5d1c932 100644 --- a/crates/neo-fold/src/riscv_shard.rs +++ b/crates/neo-fold/src/riscv_shard.rs @@ -7,11 +7,13 @@ use std::collections::HashMap; use std::collections::HashSet; +use std::sync::Arc; use std::time::Duration; use crate::output_binding::{simple_output_config, OutputBindingConfig}; use crate::pi_ccs::FoldingMode; use crate::session::FoldingSession; +use neo_reductions::engines::optimized_engine::oracle::SparseCache; use crate::shard::{ fold_shard_verify_with_output_binding_and_step_linking, fold_shard_verify_with_step_linking, CommitMixers, ShardFoldOutputs, ShardProof, StepLinkingConfig, @@ -73,6 +75,21 @@ fn elapsed_duration(start: TimePoint) -> Duration { } } +/// Per-phase timing breakdown for `Rv32B1::prove()`. +/// +/// Captures wall-clock time for each major phase so callers can identify bottlenecks. +#[derive(Clone, Debug, Default)] +pub struct ProveTimings { + /// Program decode, memory layout setup, and Shout table inference. + pub decode_and_setup: Duration, + /// CCS construction (base + shared-bus wiring) and session/committer creation. + pub ccs_and_shared_bus: Duration, + /// RISC-V VM execution and shard/witness collection. + pub vm_execution: Duration, + /// Fold-and-prove (sumcheck + Ajtai commitment). + pub fold_and_prove: Duration, +} + pub fn rv32_b1_step_linking_config(layout: &Rv32B1Layout) -> StepLinkingConfig { StepLinkingConfig::new(rv32_b1_step_linking_pairs(layout)) } @@ -267,6 +284,30 @@ fn all_shout_opcodes() -> HashSet { /// High-level “few lines” builder for proving/verifying an RV32 program using the B1 shared-bus step circuit. /// +/// Pre-computed SparseCache + matrix digest for RV32 B1 circuit preprocessing. +/// +/// The `SparseCache::build()` and matrix-digest computation are the most expensive parts +/// of both proving and verification. This struct captures those artefacts so they can be +/// computed once and reused across many `Rv32B1::prove()` / `build_verifier()` calls with +/// the same ROM, `ram_bytes`, and `chunk_size`. +/// +/// Build with [`Rv32B1::build_ccs_cache`], then inject into subsequent builders via +/// [`Rv32B1::with_ccs_cache`]. +/// +/// **Important**: the CCS structure itself is still built from scratch each time (it is fast). +/// Only the SparseCache (CSC decomposition + matrix digest) is cached. +pub struct Rv32B1CcsCache { + pub sparse: Arc>, +} + +impl std::fmt::Debug for Rv32B1CcsCache { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Rv32B1CcsCache") + .field("sparse_len", &self.sparse.len()) + .finish() + } +} + /// This: /// - chooses parameters + Ajtai committer automatically, /// - infers the minimal Shout table set from the program (unless overridden), @@ -285,6 +326,7 @@ pub struct Rv32B1 { shout_ops: Option>, output_claims: ProgramIO, ram_init: HashMap, + ccs_cache: Option>, } /// Default instruction cap for RV32B1 runs when `max_steps` is not specified. @@ -308,6 +350,7 @@ impl Rv32B1 { shout_ops: None, output_claims: ProgramIO::new(), ram_init: HashMap::new(), + ccs_cache: None, } } @@ -373,6 +416,114 @@ impl Rv32B1 { self } + /// Attach a pre-built CCS cache to skip CCS synthesis in [`prove`] and [`build_verifier`]. + /// + /// The cache **must** have been built from a builder with the same `program_bytes`, + /// `ram_bytes`, `chunk_size`, and Shout configuration. No runtime validation is performed; + /// mismatched caches produce undefined behaviour. + pub fn with_ccs_cache(mut self, cache: Arc) -> Self { + self.ccs_cache = Some(cache); + self + } + + /// Build the CCS preprocessing cache from the current builder configuration. + /// + /// This performs program decoding, memory-layout setup, Shout-table inference, + /// CCS construction, and `SparseCache` synthesis -- exactly the same work that + /// [`prove`] and [`build_verifier`] would do on first call -- and packages the + /// result so it can be shared across many runs. + /// + /// ```ignore + /// let cache = Rv32B1::from_rom(0, &rom).ram_bytes(0x40000).chunk_size(1024) + /// .shout_auto_minimal().build_ccs_cache()?; + /// let cache = Arc::new(cache); + /// + /// // Subsequent prove/verify calls skip CCS synthesis: + /// let run = Rv32B1::from_rom(0, &rom).ram_bytes(0x40000).chunk_size(1024) + /// .shout_auto_minimal().with_ccs_cache(cache.clone()) + /// .ram_init_u32(0x104, 42).prove()?; + /// ``` + pub fn build_ccs_cache(&self) -> Result { + let program = decode_program(&self.program_bytes) + .map_err(|e| PiCcsError::InvalidInput(format!("decode_program failed: {e}")))?; + + let (prog_layout, initial_mem) = neo_memory::riscv::rom_init::prog_rom_layout_and_init_words::( + PROG_ID, + /*base_addr=*/ 0, + &self.program_bytes, + ) + .map_err(|e| PiCcsError::InvalidInput(format!("prog_rom_layout_and_init_words failed: {e}")))?; + + let (k_ram, d_ram) = pow2_ceil_k(self.ram_bytes.max(4)); + let _ = d_ram; + let mem_layouts = HashMap::from([ + ( + neo_memory::riscv::lookups::RAM_ID.0, + PlainMemLayout { + k: k_ram, + d: pow2_ceil_k(self.ram_bytes.max(4)).1, + n_side: 2, + lanes: 1, + }, + ), + (PROG_ID.0, prog_layout), + ]); + + let shout = RiscvShoutTables::new(self.xlen); + let mut shout_ops = match &self.shout_ops { + Some(ops) => ops.clone(), + None if self.shout_auto_minimal => infer_required_shout_opcodes(&program), + None => all_shout_opcodes(), + }; + shout_ops.insert(RiscvOpcode::Add); + + let mut table_specs: HashMap = HashMap::new(); + for op in &shout_ops { + let table_id = shout.opcode_to_id(*op).0; + table_specs.insert( + table_id, + LutTableSpec::RiscvOpcode { + opcode: *op, + xlen: self.xlen, + }, + ); + } + let mut shout_table_ids: Vec = table_specs.keys().copied().collect(); + shout_table_ids.sort_unstable(); + + let (ccs_base, layout) = build_rv32_b1_step_ccs(&mem_layouts, &shout_table_ids, self.chunk_size) + .map_err(|e| PiCcsError::InvalidInput(format!("build_rv32_b1_step_ccs failed: {e}")))?; + + // Apply shared-bus wiring (adds Twist/Shout constraint matrices to the CCS). + // This uses ROM-only initial_mem since the CCS structure doesn't depend on + // witness-specific ram_init values. + let session = FoldingSession::::new_ajtai(self.mode.clone(), &ccs_base)?; + let params = session.params().clone(); + let committer = session.committer().clone(); + let empty_tables: HashMap> = HashMap::new(); + + let mut cpu = R1csCpu::new( + ccs_base, + params, + committer, + layout.m_in, + &empty_tables, + &table_specs, + rv32_b1_chunk_to_witness(layout.clone()), + ); + cpu = cpu + .with_shared_cpu_bus( + rv32_b1_shared_cpu_bus_config(&layout, &shout_table_ids, mem_layouts, initial_mem) + .map_err(|e| PiCcsError::InvalidInput(format!("rv32_b1_shared_cpu_bus_config failed: {e}")))?, + self.chunk_size, + ) + .map_err(|e| PiCcsError::InvalidInput(format!("shared bus inject failed: {e}")))?; + + let sparse = Arc::new(SparseCache::build(&cpu.ccs)); + + Ok(Rv32B1CcsCache { sparse }) + } + /// Build only the verification context (CCS + session) without executing the program or proving. /// /// This performs the same validation, program decoding, memory layout setup, CCS construction @@ -505,6 +656,8 @@ impl Rv32B1 { let ccs = cpu.ccs.clone(); // No execution, no fold_and_prove -- just the verification context. + // NOTE: verifier SparseCache preloading is deferred to `preload_sparse_cache()` + // because the CCS pointer changes when moved into the Rv32B1Verifier struct. Ok(Rv32B1Verifier { session, ccs, @@ -551,6 +704,9 @@ impl Rv32B1 { } } + // === Phase 1: Decode + setup === + let phase_start = time_now(); + let program = decode_program(&self.program_bytes) .map_err(|e| PiCcsError::InvalidInput(format!("decode_program failed: {e}")))?; let using_default_max_steps = self.max_steps.is_none(); @@ -621,6 +777,13 @@ impl Rv32B1 { let mut shout_table_ids: Vec = table_specs.keys().copied().collect(); shout_table_ids.sort_unstable(); + let t_decode_and_setup = elapsed_duration(phase_start); + + // === Phase 2: CCS construction + shared-bus wiring === + let phase_start = time_now(); + + let prebuilt_sparse = self.ccs_cache.as_ref().map(|c| c.sparse.clone()); + let (ccs_base, layout) = build_rv32_b1_step_ccs(&mem_layouts, &shout_table_ids, self.chunk_size) .map_err(|e| PiCcsError::InvalidInput(format!("build_rv32_b1_step_ccs failed: {e}")))?; @@ -656,6 +819,11 @@ impl Rv32B1 { // Always enforce step-to-step chunk chaining for RV32 B1. session.set_step_linking(rv32_b1_step_linking_config(&layout)); + let t_ccs_and_shared_bus = elapsed_duration(phase_start); + + // === Phase 3: VM execution + shard collection === + let phase_start = time_now(); + // Execute + collect step bundles (and aux for output binding). session.execute_shard_shared_cpu_bus( vm, @@ -687,15 +855,25 @@ impl Rv32B1 { let ccs = cpu.ccs.clone(); - // Prove phase (timed) - let prove_start = time_now(); + // Preload the SparseCache using the *final* CCS reference that fold_and_prove will use. + // (Must happen after cpu.ccs.clone() because the pointer-keyed cache checks identity.) + if let Some(ref sparse) = prebuilt_sparse { + session.preload_ccs_sparse_cache(&ccs, sparse.clone())?; + } + + let t_vm_execution = elapsed_duration(phase_start); + + // === Phase 4: Fold-and-prove === + let phase_start = time_now(); let proof = if self.output_claims.is_empty() { session.fold_and_prove(&ccs)? } else { let ob_cfg = OutputBindingConfig::new(d_ram, self.output_claims.clone()); session.fold_and_prove_with_output_binding_auto_simple(&ccs, &ob_cfg)? }; - let prove_duration = elapsed_duration(prove_start); + let t_fold_and_prove = elapsed_duration(phase_start); + + let prove_duration = t_decode_and_setup + t_ccs_and_shared_bus + t_vm_execution + t_fold_and_prove; Ok(Rv32B1Run { session, @@ -707,6 +885,12 @@ impl Rv32B1 { ram_num_bits: d_ram, output_claims: self.output_claims, prove_duration, + prove_timings: ProveTimings { + decode_and_setup: t_decode_and_setup, + ccs_and_shared_bus: t_ccs_and_shared_bus, + vm_execution: t_vm_execution, + fold_and_prove: t_fold_and_prove, + }, verify_duration: None, }) } @@ -727,6 +911,15 @@ pub struct Rv32B1Verifier { } impl Rv32B1Verifier { + /// Preload the verifier SparseCache using `&self.ccs` (final pointer). + /// + /// Call this once after `build_verifier()` returns, before the first `verify()`. + /// This ensures the pointer-keyed cache hits inside the session's verify path. + pub fn preload_sparse_cache(&mut self, sparse: Arc>) -> Result<(), PiCcsError> { + self.session + .preload_verifier_ccs_sparse_cache(&self.ccs, sparse) + } + /// Verify a `ShardProof` using the provided public step instance bundles. /// /// `steps_public` must be the step instance bundles produced by the prover (via @@ -765,6 +958,7 @@ pub struct Rv32B1Run { ram_num_bits: usize, output_claims: ProgramIO, prove_duration: Duration, + prove_timings: ProveTimings, verify_duration: Option, } @@ -930,6 +1124,11 @@ impl Rv32B1Run { self.prove_duration } + /// Per-phase timing breakdown for the prove call. + pub fn prove_timings(&self) -> &ProveTimings { + &self.prove_timings + } + pub fn verify_duration(&self) -> Option { self.verify_duration } diff --git a/crates/neo-fold/tests/riscv_poseidon2_ecall_prove_verify.rs b/crates/neo-fold/tests/riscv_poseidon2_ecall_prove_verify.rs new file mode 100644 index 00000000..4256241a --- /dev/null +++ b/crates/neo-fold/tests/riscv_poseidon2_ecall_prove_verify.rs @@ -0,0 +1,93 @@ +use neo_fold::riscv_shard::Rv32B1; +use neo_memory::riscv::lookups::{ + encode_program, RiscvInstruction, RiscvOpcode, POSEIDON2_ECALL_NUM, POSEIDON2_READ_ECALL_NUM, +}; + +fn load_u32_imm(rd: u8, value: u32) -> Vec { + let upper = ((value as i64 + 0x800) >> 12) as i32; + let lower = (value as i32) - (upper << 12); + vec![ + RiscvInstruction::Lui { rd, imm: upper }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd, + rs1: rd, + imm: lower, + }, + ] +} + +fn poseidon2_ecall_program() -> Vec { + let mut program = Vec::new(); + + // a1 = 0 (n_elements = 0 for empty-input Poseidon2 hash). + program.push(RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 11, + rs1: 0, + imm: 0, + }); + + // a0 = POSEIDON2_ECALL_NUM -> compute ECALL. + program.extend(load_u32_imm(10, POSEIDON2_ECALL_NUM)); + program.push(RiscvInstruction::Halt); + + // Read all 8 digest words via read ECALLs. + for _ in 0..8 { + program.extend(load_u32_imm(10, POSEIDON2_READ_ECALL_NUM)); + program.push(RiscvInstruction::Halt); + } + + // Clear a0 -> final halt. + program.push(RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 10, + rs1: 0, + imm: 0, + }); + program.push(RiscvInstruction::Halt); + + program +} + +#[test] +fn rv32_b1_prove_verify_poseidon2_ecall_chunk1() { + let program = poseidon2_ecall_program(); + let program_bytes = encode_program(&program); + + let mut run = Rv32B1::from_rom(0, &program_bytes) + .chunk_size(1) + .max_steps(program.len() + 64) + .prove() + .expect("prove should succeed"); + + run.verify().expect("verify should succeed for Poseidon2 ECALL with chunk_size=1"); +} + +#[test] +fn rv32_b1_prove_verify_poseidon2_ecall_chunk4() { + let program = poseidon2_ecall_program(); + let program_bytes = encode_program(&program); + + let mut run = Rv32B1::from_rom(0, &program_bytes) + .chunk_size(4) + .max_steps(program.len() + 64) + .prove() + .expect("prove should succeed"); + + run.verify().expect("verify should succeed for Poseidon2 ECALL with chunk_size=4"); +} + +#[test] +fn rv32_b1_prove_verify_poseidon2_ecall_chunk32() { + let program = poseidon2_ecall_program(); + let program_bytes = encode_program(&program); + + let mut run = Rv32B1::from_rom(0, &program_bytes) + .chunk_size(32) + .max_steps(program.len() + 64) + .prove() + .expect("prove should succeed"); + + run.verify().expect("verify should succeed for Poseidon2 ECALL with chunk_size=32"); +} diff --git a/crates/neo-memory/src/riscv/ccs.rs b/crates/neo-memory/src/riscv/ccs.rs index 81e8a52e..2b8d860a 100644 --- a/crates/neo-memory/src/riscv/ccs.rs +++ b/crates/neo-memory/src/riscv/ccs.rs @@ -42,7 +42,7 @@ use p3_field::PrimeCharacteristicRing; use p3_goldilocks::Goldilocks as F; use crate::plain::PlainMemLayout; -use crate::riscv::lookups::{JOLT_CYCLE_TRACK_ECALL_NUM, JOLT_PRINT_ECALL_NUM, PROG_ID, RAM_ID}; +use crate::riscv::lookups::{JOLT_CYCLE_TRACK_ECALL_NUM, JOLT_PRINT_ECALL_NUM, POSEIDON2_ECALL_NUM, PROG_ID, RAM_ID}; mod bus_bindings; mod config; @@ -1036,10 +1036,14 @@ fn semantic_constraints( )); } - // ECALL helpers (Jolt marker/print IDs). + // ECALL helpers (Jolt marker/print IDs + Poseidon2 precompile). let a0 = layout.reg_in(10, j); let ecall_is_cycle = layout.ecall_is_cycle(j); let ecall_is_print = layout.ecall_is_print(j); + let ecall_is_poseidon2 = layout.ecall_is_poseidon2(j); + let ecall_is_poseidon2_read = layout.ecall_is_poseidon2_read(j); + let halt_ecall_poseidon2_read = layout.halt_ecall_poseidon2_read(j); + let ecall_rd_val = layout.ecall_rd_val(j); let ecall_halts = layout.ecall_halts(j); let halt_effective = layout.halt_effective(j); @@ -1166,6 +1170,86 @@ fn semantic_constraints( } constraints.push(Constraint::mul(ecall_is_print, ecall_is_print, ecall_is_print)); + // ecall_is_poseidon2 = (a0 == POSEIDON2_ECALL_NUM). + let poseidon2_const = POSEIDON2_ECALL_NUM as u32; + { + let bit0 = layout.ecall_a0_bit(0, j); + let prefix0 = layout.ecall_poseidon2_prefix(0, j); + if (poseidon2_const & 1) == 1 { + constraints.push(Constraint::terms(one, false, vec![(prefix0, F::ONE), (bit0, -F::ONE)])); + } else { + constraints.push(Constraint::terms( + one, + false, + vec![(prefix0, F::ONE), (bit0, F::ONE), (one, -F::ONE)], + )); + } + } + for k in 1..31 { + let prev = layout.ecall_poseidon2_prefix(k - 1, j); + let next = layout.ecall_poseidon2_prefix(k, j); + let bit_col = layout.ecall_a0_bit(k, j); + let bit_is_one = ((poseidon2_const >> k) & 1) == 1; + let b_terms = if bit_is_one { + vec![(bit_col, F::ONE)] + } else { + vec![(one, F::ONE), (bit_col, -F::ONE)] + }; + constraints.push(Constraint { + condition_col: prev, + negate_condition: false, + additional_condition_cols: Vec::new(), + b_terms, + c_terms: vec![(next, F::ONE)], + }); + } + { + let prev = layout.ecall_poseidon2_prefix(30, j); + let bit_col = layout.ecall_a0_bit(31, j); + let bit_is_one = ((poseidon2_const >> 31) & 1) == 1; + let b_terms = if bit_is_one { + vec![(bit_col, F::ONE)] + } else { + vec![(one, F::ONE), (bit_col, -F::ONE)] + }; + constraints.push(Constraint { + condition_col: prev, + negate_condition: false, + additional_condition_cols: Vec::new(), + b_terms, + c_terms: vec![(ecall_is_poseidon2, F::ONE)], + }); + } + constraints.push(Constraint::mul(ecall_is_poseidon2, ecall_is_poseidon2, ecall_is_poseidon2)); + + // ecall_is_poseidon2_read = prefix_30 * bit_31. + // POSEIDON2_READ_ECALL_NUM (0x80504F53) shares bits 0-30 with POSEIDON2_ECALL_NUM + // (0x504F53) and has bit 31 set. The prefix chain already matches bits 0-30, + // so we branch on bit 31: compute = prefix_30 * (1-bit_31), read = prefix_30 * bit_31. + { + let prefix_30 = layout.ecall_poseidon2_prefix(30, j); + let bit_31 = layout.ecall_a0_bit(31, j); + constraints.push(Constraint::mul(prefix_30, bit_31, ecall_is_poseidon2_read)); + } + constraints.push(Constraint::mul(ecall_is_poseidon2_read, ecall_is_poseidon2_read, ecall_is_poseidon2_read)); + + // halt_ecall_poseidon2_read = is_halt * ecall_is_poseidon2_read. + // Gates the raw prefix-derived flag by is_halt so that it only activates + // on actual ECALL instructions, not on arbitrary instructions where a0 + // happens to contain POSEIDON2_READ_ECALL_NUM. + constraints.push(Constraint::mul(layout.is_halt(j), ecall_is_poseidon2_read, halt_ecall_poseidon2_read)); + constraints.push(Constraint::mul(halt_ecall_poseidon2_read, halt_ecall_poseidon2_read, halt_ecall_poseidon2_read)); + + // halt_ecall_poseidon2_read * (ecall_rd_val - rd_write_val) = 0 + // Links ecall_rd_val to rd_write_val for poseidon2_read steps, reusing the + // existing enforce_u32_bits range check on rd_write_val. + constraints.push(Constraint::terms( + halt_ecall_poseidon2_read, + false, + vec![(ecall_rd_val, F::ONE), (layout.rd_write_val(j), -F::ONE)], + )); + + // ecall_halts = 1 - ecall_is_cycle - ecall_is_print - ecall_is_poseidon2 - ecall_is_poseidon2_read. constraints.push(Constraint::terms( one, false, @@ -1173,6 +1257,8 @@ fn semantic_constraints( (ecall_halts, F::ONE), (ecall_is_cycle, F::ONE), (ecall_is_print, F::ONE), + (ecall_is_poseidon2, F::ONE), + (ecall_is_poseidon2_read, F::ONE), (one, -F::ONE), ], )); @@ -1692,17 +1778,37 @@ fn semantic_constraints( // Register update pattern for r=1..31: // - if rd_sel[r]=1 then reg_out[r] = rd_write_val // - else reg_out[r] = reg_in[r] + // Special case for r=10 (a0): the Poseidon2 read ECALL writes ecall_rd_val + // to reg_out[10] without going through rd_sel/rd_write_val. for r in 1..32 { constraints.push(Constraint::terms( layout.rd_sel(r, j), false, vec![(layout.reg_out(r, j), F::ONE), (layout.rd_write_val(j), -F::ONE)], )); - constraints.push(Constraint::terms( - layout.rd_sel(r, j), - true, - vec![(layout.reg_out(r, j), F::ONE), (layout.reg_in(r, j), -F::ONE)], - )); + if r == 10 { + // (1 - rd_sel[10] - halt_ecall_poseidon2_read) * (reg_out[10] - reg_in[10]) = 0 + // Uses halt-gated flag so this only activates on actual ECALL instructions. + constraints.push(Constraint { + condition_col: layout.rd_sel(r, j), + negate_condition: true, + additional_condition_cols: vec![halt_ecall_poseidon2_read], + b_terms: vec![(layout.reg_out(r, j), F::ONE), (layout.reg_in(r, j), -F::ONE)], + c_terms: Vec::new(), + }); + // halt_ecall_poseidon2_read * (reg_out[10] - ecall_rd_val) = 0 + constraints.push(Constraint::terms( + halt_ecall_poseidon2_read, + false, + vec![(layout.reg_out(r, j), F::ONE), (ecall_rd_val, -F::ONE)], + )); + } else { + constraints.push(Constraint::terms( + layout.rd_sel(r, j), + true, + vec![(layout.reg_out(r, j), F::ONE), (layout.reg_in(r, j), -F::ONE)], + )); + } } // RAM effective address is computed via the ADD Shout lookup (mod 2^32 semantics). diff --git a/crates/neo-memory/src/riscv/ccs/layout.rs b/crates/neo-memory/src/riscv/ccs/layout.rs index 3d4fca2b..fb113df5 100644 --- a/crates/neo-memory/src/riscv/ccs/layout.rs +++ b/crates/neo-memory/src/riscv/ccs/layout.rs @@ -188,6 +188,11 @@ pub struct Rv32B1Layout { pub ecall_is_cycle: usize, pub ecall_print_prefix_start: usize, // 31 pub ecall_is_print: usize, + pub ecall_poseidon2_prefix_start: usize, // 31 + pub ecall_is_poseidon2: usize, + pub ecall_is_poseidon2_read: usize, + pub halt_ecall_poseidon2_read: usize, + pub ecall_rd_val: usize, pub ecall_halts: usize, pub halt_effective: usize, @@ -549,6 +554,32 @@ impl Rv32B1Layout { self.cpu_cell(self.ecall_is_print, j) } + #[inline] + pub fn ecall_poseidon2_prefix(&self, k: usize, j: usize) -> usize { + debug_assert!(k < 31, "ecall_poseidon2_prefix k out of range"); + self.ecall_poseidon2_prefix_start + k * self.chunk_size + j + } + + #[inline] + pub fn ecall_is_poseidon2(&self, j: usize) -> usize { + self.cpu_cell(self.ecall_is_poseidon2, j) + } + + #[inline] + pub fn ecall_is_poseidon2_read(&self, j: usize) -> usize { + self.cpu_cell(self.ecall_is_poseidon2_read, j) + } + + #[inline] + pub fn halt_ecall_poseidon2_read(&self, j: usize) -> usize { + self.cpu_cell(self.halt_ecall_poseidon2_read, j) + } + + #[inline] + pub fn ecall_rd_val(&self, j: usize) -> usize { + self.cpu_cell(self.ecall_rd_val, j) + } + #[inline] pub fn ecall_halts(&self, j: usize) -> usize { self.cpu_cell(self.ecall_halts, j) @@ -1134,6 +1165,11 @@ pub(super) fn build_layout_with_m( let ecall_is_cycle = alloc_scalar(&mut col); let ecall_print_prefix_start = alloc_array(&mut col, 31); let ecall_is_print = alloc_scalar(&mut col); + let ecall_poseidon2_prefix_start = alloc_array(&mut col, 31); + let ecall_is_poseidon2 = alloc_scalar(&mut col); + let ecall_is_poseidon2_read = alloc_scalar(&mut col); + let halt_ecall_poseidon2_read = alloc_scalar(&mut col); + let ecall_rd_val = alloc_scalar(&mut col); let ecall_halts = alloc_scalar(&mut col); let halt_effective = alloc_scalar(&mut col); @@ -1315,6 +1351,11 @@ pub(super) fn build_layout_with_m( ecall_is_cycle, ecall_print_prefix_start, ecall_is_print, + ecall_poseidon2_prefix_start, + ecall_is_poseidon2, + ecall_is_poseidon2_read, + halt_ecall_poseidon2_read, + ecall_rd_val, ecall_halts, halt_effective, bus, diff --git a/crates/neo-memory/src/riscv/ccs/witness.rs b/crates/neo-memory/src/riscv/ccs/witness.rs index a4f673a7..5fcdb14d 100644 --- a/crates/neo-memory/src/riscv/ccs/witness.rs +++ b/crates/neo-memory/src/riscv/ccs/witness.rs @@ -5,7 +5,7 @@ use neo_vm_trace::{StepTrace, TwistOpKind}; use crate::riscv::lookups::{ decode_instruction, BranchCondition, RiscvInstruction, RiscvMemOp, RiscvOpcode, JOLT_CYCLE_TRACK_ECALL_NUM, - JOLT_PRINT_ECALL_NUM, PROG_ID, RAM_ID, + JOLT_PRINT_ECALL_NUM, POSEIDON2_ECALL_NUM, POSEIDON2_READ_ECALL_NUM, PROG_ID, RAM_ID, }; use super::constants::{ @@ -78,7 +78,30 @@ fn set_ecall_helpers(z: &mut [F], layout: &Rv32B1Layout, j: usize, a0_u64: u64, let is_print = print_prefix == 1 && print_match31; z[layout.ecall_is_print(j)] = if is_print { F::ONE } else { F::ZERO }; - let ecall_halts = !(is_cycle || is_print); + let poseidon2_const = POSEIDON2_ECALL_NUM; + let mut poseidon2_prefix = if ((a0_u32 ^ poseidon2_const) & 1) == 0 { 1u32 } else { 0u32 }; + z[layout.ecall_poseidon2_prefix(0, j)] = if poseidon2_prefix == 1 { F::ONE } else { F::ZERO }; + for k in 1..31 { + let bit_match = ((a0_u32 >> k) ^ (poseidon2_const >> k)) & 1; + poseidon2_prefix &= 1u32 ^ bit_match; + z[layout.ecall_poseidon2_prefix(k, j)] = if poseidon2_prefix == 1 { F::ONE } else { F::ZERO }; + } + let poseidon2_match31 = (((a0_u32 >> 31) ^ (poseidon2_const >> 31)) & 1) == 0; + let is_poseidon2 = poseidon2_prefix == 1 && poseidon2_match31; + z[layout.ecall_is_poseidon2(j)] = if is_poseidon2 { F::ONE } else { F::ZERO }; + + // ecall_is_poseidon2_read = prefix_30 * bit_31 (bit 31 set = read ECALL). + let bit_31_set = (a0_u32 >> 31) & 1 == 1; + let is_poseidon2_read = poseidon2_prefix == 1 && bit_31_set; + z[layout.ecall_is_poseidon2_read(j)] = if is_poseidon2_read { F::ONE } else { F::ZERO }; + + // halt_ecall_poseidon2_read = is_halt * ecall_is_poseidon2_read. + // Gates the raw prefix flag by is_halt so register update constraints only + // activate on actual ECALL instructions. + let halt_poseidon2_read = is_halt && is_poseidon2_read; + z[layout.halt_ecall_poseidon2_read(j)] = if halt_poseidon2_read { F::ONE } else { F::ZERO }; + + let ecall_halts = !(is_cycle || is_print || is_poseidon2 || is_poseidon2_read); z[layout.ecall_halts(j)] = if ecall_halts { F::ONE } else { F::ZERO }; z[layout.halt_effective(j)] = if is_halt && ecall_halts { F::ONE } else { F::ZERO }; @@ -644,6 +667,13 @@ fn rv32_b1_chunk_to_witness_internal( z[layout.is_halt(j)] = if is_halt { F::ONE } else { F::ZERO }; set_ecall_helpers(&mut z, layout, j, step.regs_before[10], is_halt)?; + // Detect Poseidon2 read ECALL and set ecall_rd_val. + let is_poseidon2_read_ecall = is_halt && step.regs_before[10] as u32 == POSEIDON2_READ_ECALL_NUM; + if is_poseidon2_read_ecall { + let rd_val = step.regs_after[10]; + z[layout.ecall_rd_val(j)] = F::from_u64(rd_val); + } + // One-hot register selectors. let rs1_idx = rs1 as usize; let rs2_idx = rs2 as usize; @@ -1152,6 +1182,12 @@ fn rv32_b1_chunk_to_witness_internal( z[layout.br_not_taken(j)] = F::ONE - z[layout.br_taken(j)]; } + // Poseidon2 read ECALL: set rd_write_val = ecall_rd_val so the existing + // enforce_u32_bits range check covers it. + if is_poseidon2_read_ecall { + z[layout.rd_write_val(j)] = z[layout.ecall_rd_val(j)]; + } + let mul_carry = if is_mulh { let rhs = (mul_hi as i128) - (rs1_sign as i128) * (rs2_u64 as i128) - (rs2_sign as i128) * (rs1_u64 as i128) diff --git a/crates/neo-memory/src/riscv/lookups/cpu.rs b/crates/neo-memory/src/riscv/lookups/cpu.rs index 4c3094bb..f8fbb5b3 100644 --- a/crates/neo-memory/src/riscv/lookups/cpu.rs +++ b/crates/neo-memory/src/riscv/lookups/cpu.rs @@ -1,11 +1,13 @@ use neo_vm_trace::{Shout, Twist}; +use p3_field::{PrimeCharacteristicRing, PrimeField64}; +use p3_goldilocks::Goldilocks; use super::bits::interleave_bits; use super::decode::decode_instruction; use super::encode::encode_instruction; use super::isa::{BranchCondition, RiscvInstruction, RiscvMemOp, RiscvOpcode}; use super::tables::RiscvShoutTables; -use super::{JOLT_CYCLE_TRACK_ECALL_NUM, JOLT_PRINT_ECALL_NUM}; +use super::{JOLT_CYCLE_TRACK_ECALL_NUM, JOLT_PRINT_ECALL_NUM, POSEIDON2_ECALL_NUM, POSEIDON2_READ_ECALL_NUM}; /// A RISC-V CPU that can be traced using Neo's VmCpu trait. /// @@ -25,6 +27,10 @@ pub struct RiscvCpu { program: Vec, /// Base address of the program. program_base: u64, + /// Pending Poseidon2 digest words (8 × u32, set by the compute ECALL). + poseidon2_pending: Option<[u32; 8]>, + /// Index into `poseidon2_pending` for the next read ECALL. + poseidon2_read_idx: usize, } impl RiscvCpu { @@ -38,6 +44,8 @@ impl RiscvCpu { halted: false, program: Vec::new(), program_base: 0, + poseidon2_pending: None, + poseidon2_read_idx: 0, } } @@ -86,10 +94,64 @@ impl RiscvCpu { fn handle_ecall(&mut self) { let call_id = self.get_reg(10) as u32; // a0 - if call_id != JOLT_CYCLE_TRACK_ECALL_NUM && call_id != JOLT_PRINT_ECALL_NUM { + if call_id != JOLT_CYCLE_TRACK_ECALL_NUM + && call_id != JOLT_PRINT_ECALL_NUM + && call_id != POSEIDON2_ECALL_NUM + && call_id != POSEIDON2_READ_ECALL_NUM + { self.halted = true; } } + + /// Execute the Poseidon2 hash compute precompile. + /// + /// Reads `n_elements` Goldilocks field elements from RAM at `input_addr` + /// using `load_untraced` (no Twist bus events), computes the Poseidon2- + /// Goldilocks sponge hash, and stores the 4-element digest as 8 u32 words + /// in internal CPU state (`poseidon2_pending`). The guest retrieves output + /// words one at a time via `POSEIDON2_READ_ECALL_NUM`. + fn handle_poseidon2_ecall>(&mut self, twist: &mut T) { + use neo_ccs::crypto::poseidon2_goldilocks::poseidon2_hash; + + let ram = super::RAM_ID; + let n_elements = self.get_reg(11) as u32; // a1 + let input_addr = self.get_reg(12); // a2 + + let mut inputs = Vec::with_capacity(n_elements as usize); + for i in 0..n_elements { + let addr = input_addr + (i as u64) * 8; + let lo = twist.load_untraced(ram, addr) & 0xFFFF_FFFF; + let hi = twist.load_untraced(ram, addr + 4) & 0xFFFF_FFFF; + let val = lo | (hi << 32); + inputs.push(Goldilocks::from_u64(val)); + } + + let digest = poseidon2_hash(&inputs); + + let mut words = [0u32; 8]; + for (i, &elem) in digest.iter().enumerate() { + let val = elem.as_canonical_u64(); + words[i * 2] = val as u32; + words[i * 2 + 1] = (val >> 32) as u32; + } + self.poseidon2_pending = Some(words); + self.poseidon2_read_idx = 0; + } + + /// Read one u32 word of the pending Poseidon2 digest into register a0. + /// + /// Called 8 times by the guest to retrieve the full 4-element digest. + /// Each call returns the next word and advances the internal read index. + fn handle_poseidon2_read_ecall(&mut self) { + let words = self + .poseidon2_pending + .expect("poseidon2 read ECALL called without a pending digest"); + let idx = self.poseidon2_read_idx; + assert!(idx < 8, "poseidon2 read ECALL: all 8 words already consumed"); + let word = words[idx] as u64; + self.set_reg(10, word); // a0 + self.poseidon2_read_idx = idx + 1; + } } impl neo_vm_trace::VmCpu for RiscvCpu { @@ -402,7 +464,12 @@ impl neo_vm_trace::VmCpu for RiscvCpu { } RiscvInstruction::Halt => { - // ECALL trap semantics (Jolt-style): no architectural effects, halt unless it's a known marker/print call. + let call_id = self.get_reg(10) as u32; + if call_id == POSEIDON2_ECALL_NUM { + self.handle_poseidon2_ecall(twist); + } else if call_id == POSEIDON2_READ_ECALL_NUM { + self.handle_poseidon2_read_ecall(); + } self.handle_ecall(); } @@ -558,7 +625,12 @@ impl neo_vm_trace::VmCpu for RiscvCpu { // === System Instructions === RiscvInstruction::Ecall => { - // ECALL - environment call (syscall). + let call_id = self.get_reg(10) as u32; + if call_id == POSEIDON2_ECALL_NUM { + self.handle_poseidon2_ecall(twist); + } else if call_id == POSEIDON2_READ_ECALL_NUM { + self.handle_poseidon2_read_ecall(); + } self.handle_ecall(); } diff --git a/crates/neo-memory/src/riscv/lookups/mod.rs b/crates/neo-memory/src/riscv/lookups/mod.rs index 4753eaec..a11bdb60 100644 --- a/crates/neo-memory/src/riscv/lookups/mod.rs +++ b/crates/neo-memory/src/riscv/lookups/mod.rs @@ -105,6 +105,22 @@ pub const PROG_ID: TwistId = TwistId(1); pub const JOLT_CYCLE_TRACK_ECALL_NUM: u32 = 0xC7C1E; pub const JOLT_PRINT_ECALL_NUM: u32 = 0x505249; +/// Poseidon2-Goldilocks hash compute ECALL identifier. +/// +/// ABI: a0 = POSEIDON2_ECALL_NUM, a1 = input element count, +/// a2 = input RAM address (elements as 2×u32 LE). +/// The host reads inputs via untraced loads, computes the hash, and stores the +/// 4-element digest in CPU-internal state. Use POSEIDON2_READ_ECALL_NUM to +/// retrieve output words one at a time via register a0. +pub const POSEIDON2_ECALL_NUM: u32 = 0x504F53; + +/// Poseidon2-Goldilocks digest read ECALL identifier (bit 31 set). +/// +/// ABI: a0 = POSEIDON2_READ_ECALL_NUM. Returns the next u32 word of the +/// pending Poseidon2 digest in register a0. Call 8 times (4 elements × 2 words) +/// to retrieve the full digest. +pub const POSEIDON2_READ_ECALL_NUM: u32 = 0x80504F53; + pub use alu::{compute_op, lookup_entry}; pub use bits::{interleave_bits, uninterleave_bits}; pub use cpu::RiscvCpu; diff --git a/crates/neo-memory/tests/riscv_ccs_tests.rs b/crates/neo-memory/tests/riscv_ccs_tests.rs index 43e11ded..b9b9e819 100644 --- a/crates/neo-memory/tests/riscv_ccs_tests.rs +++ b/crates/neo-memory/tests/riscv_ccs_tests.rs @@ -12,7 +12,8 @@ use neo_memory::riscv::ccs::{ }; use neo_memory::riscv::lookups::{ decode_instruction, encode_program, BranchCondition, RiscvCpu, RiscvInstruction, RiscvMemOp, RiscvMemory, - RiscvOpcode, RiscvShoutTables, JOLT_CYCLE_TRACK_ECALL_NUM, JOLT_PRINT_ECALL_NUM, PROG_ID, RAM_ID, + RiscvOpcode, RiscvShoutTables, JOLT_CYCLE_TRACK_ECALL_NUM, JOLT_PRINT_ECALL_NUM, POSEIDON2_ECALL_NUM, + POSEIDON2_READ_ECALL_NUM, PROG_ID, RAM_ID, }; use neo_memory::riscv::rom_init::prog_init_words; use neo_memory::witness::LutTableSpec; @@ -427,6 +428,97 @@ fn rv32_b1_ccs_happy_path_rv32i_ecall_markers_program() { } } +#[test] +fn rv32_b1_ccs_happy_path_poseidon2_ecall() { + let xlen = 32usize; + let mut program = Vec::new(); + + // a1 = 0 (n_elements = 0 for empty-input Poseidon2 hash). + program.push(RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 11, + rs1: 0, + imm: 0, + }); + // a0 = POSEIDON2_ECALL_NUM → compute ECALL. + program.extend(load_u32_imm(10, POSEIDON2_ECALL_NUM)); + program.push(RiscvInstruction::Halt); + + // Read all 8 digest words via read ECALLs. + for _ in 0..8 { + program.extend(load_u32_imm(10, POSEIDON2_READ_ECALL_NUM)); + program.push(RiscvInstruction::Halt); + } + + // Clear a0 → final halt. + program.push(RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 10, + rs1: 0, + imm: 0, + }); + program.push(RiscvInstruction::Halt); + + let program_bytes = encode_program(&program); + let mut cpu_vm = RiscvCpu::new(xlen); + cpu_vm.load_program(0, program.clone()); + let memory = RiscvMemory::with_program_in_twist(xlen, PROG_ID, 0, &program_bytes); + let shout = RiscvShoutTables::new(xlen); + let trace = trace_program(cpu_vm, memory, shout, 256).expect("trace"); + assert!(trace.did_halt(), "expected Halt"); + + let (k_prog, d_prog) = pow2_ceil_k(program_bytes.len()); + let (k_ram, d_ram) = pow2_ceil_k(0x40); + let mem_layouts = HashMap::from([ + ( + 0u32, + PlainMemLayout { + k: k_ram, + d: d_ram, + n_side: 2, + lanes: 1, + }, + ), + ( + 1u32, + PlainMemLayout { + k: k_prog, + d: d_prog, + n_side: 2, + lanes: 1, + }, + ), + ]); + + let initial_mem = prog_init_words(PROG_ID, 0, &program_bytes); + + let shout_table_ids = RV32I_SHOUT_TABLE_IDS; + let (ccs, layout) = build_rv32_b1_step_ccs(&mem_layouts, &shout_table_ids, 1).expect("ccs"); + let params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n).expect("params"); + + let table_specs = rv32i_table_specs(xlen); + + let cpu = R1csCpu::new( + ccs, + params, + NoopCommit::default(), + layout.m_in, + &HashMap::new(), + &table_specs, + rv32_b1_chunk_to_witness(layout.clone()), + ) + .with_shared_cpu_bus( + rv32_b1_shared_cpu_bus_config(&layout, &shout_table_ids, mem_layouts, initial_mem).expect("cfg"), + 1, + ) + .expect("shared bus"); + + let steps = CpuArithmetization::build_ccs_steps(&cpu, &trace).expect("build steps"); + for (mcs_inst, mcs_wit) in steps { + check_ccs_rowwise_zero(&cpu.ccs, &mcs_inst.x, &mcs_wit.w).expect("CCS satisfied"); + } +} + #[test] fn rv32_b1_ccs_happy_path_rv32m_program() { let xlen = 32usize; diff --git a/crates/neo-vm-trace/src/lib.rs b/crates/neo-vm-trace/src/lib.rs index f75f963d..e7750d77 100644 --- a/crates/neo-vm-trace/src/lib.rs +++ b/crates/neo-vm-trace/src/lib.rs @@ -340,6 +340,16 @@ pub trait Twist { self.store_lane(twist_id, addr, value, lane); } } + + /// Load a value from memory without recording a Twist event. + /// + /// Used by ECALL precompiles that need to read guest RAM without generating + /// per-step Twist bus entries (the ECALL computation is trusted host code). + /// The default implementation delegates to `load`. + #[inline] + fn load_untraced(&mut self, twist_id: TwistId, addr: Addr) -> Word { + self.load(twist_id, addr) + } } /// A tracing wrapper around any `Twist` implementation. @@ -431,6 +441,10 @@ where lane: Some(lane), }); } + + fn load_untraced(&mut self, twist_id: TwistId, addr: Addr) -> Word { + self.inner.load(twist_id, addr) + } } // ============================================================================ diff --git a/crates/nightstream-sdk/src/goldilocks.rs b/crates/nightstream-sdk/src/goldilocks.rs new file mode 100644 index 00000000..a25193c6 --- /dev/null +++ b/crates/nightstream-sdk/src/goldilocks.rs @@ -0,0 +1,116 @@ +//! Software Goldilocks field arithmetic for RISC-V guests. +//! +//! Field: p = 2^64 - 2^32 + 1 = 0xFFFF_FFFF_0000_0001 +//! +//! All values are canonical: in [0, p). + +#![allow(dead_code)] + +/// Goldilocks prime: p = 2^64 - 2^32 + 1. +pub const GL_P: u64 = 0xFFFF_FFFF_0000_0001; + +/// A Goldilocks field element digest: 4 elements (32 bytes). +pub type GlDigest = [u64; 4]; + +/// Reduce a u128 value mod p using the identity 2^64 ≡ 2^32 - 1 (mod p). +/// Avoids 128-bit division. +#[inline] +fn reduce128(x: u128) -> u64 { + let lo = x as u64; + let hi = (x >> 64) as u64; + + // x ≡ lo + hi * (2^32 - 1) (mod p) + let hi_times_correction = (hi as u128) * 0xFFFF_FFFFu128; + let sum = lo as u128 + hi_times_correction; + + let lo2 = sum as u64; + let hi2 = (sum >> 64) as u64; + + // hi2 is 0 or 1; do one more pass + let (mut result, overflow) = lo2.overflowing_add(hi2.wrapping_mul(0xFFFF_FFFF)); + if overflow || result >= GL_P { + result = result.wrapping_sub(GL_P); + } + result +} + +/// Additive identity. +pub const GL_ZERO: u64 = 0; + +/// Multiplicative identity. +pub const GL_ONE: u64 = 1; + +/// Field addition: (a + b) mod p. +#[inline] +pub fn gl_add(a: u64, b: u64) -> u64 { + let sum = a as u128 + b as u128; + let s = sum as u64; + let carry = (sum >> 64) as u64; + // carry is 0 or 1 + let (mut r, overflow) = s.overflowing_add(carry.wrapping_mul(0xFFFF_FFFF)); + if overflow || r >= GL_P { + r = r.wrapping_sub(GL_P); + } + r +} + +/// Field subtraction: (a - b) mod p. +#[inline] +pub fn gl_sub(a: u64, b: u64) -> u64 { + if a >= b { + let diff = a - b; + if diff >= GL_P { diff - GL_P } else { diff } + } else { + // a < b, so result = p - (b - a) + GL_P - (b - a) + } +} + +/// Field negation: (-a) mod p. +#[inline] +pub fn gl_neg(a: u64) -> u64 { + if a == 0 { 0 } else { GL_P - a } +} + +/// Field multiplication: (a * b) mod p. +#[inline] +pub fn gl_mul(a: u64, b: u64) -> u64 { + let prod = (a as u128) * (b as u128); + reduce128(prod) +} + +/// Field squaring: (a * a) mod p. +#[inline] +pub fn gl_sqr(a: u64) -> u64 { + gl_mul(a, a) +} + +/// Field exponentiation: a^exp mod p via square-and-multiply. +pub fn gl_pow(mut base: u64, mut exp: u64) -> u64 { + let mut result = GL_ONE; + while exp > 0 { + if exp & 1 == 1 { + result = gl_mul(result, base); + } + base = gl_sqr(base); + exp >>= 1; + } + result +} + +/// Field inverse: a^(p-2) mod p (Fermat's little theorem). +/// +/// Panics (halts) if a == 0. +pub fn gl_inv(a: u64) -> u64 { + debug_assert!(a != 0, "inverse of zero"); + // p - 2 = 0xFFFF_FFFE_FFFF_FFFF + gl_pow(a, GL_P - 2) +} + +/// Check equality of two digests. +pub fn digest_eq(a: &GlDigest, b: &GlDigest) -> bool { + a[0] == b[0] && a[1] == b[1] && a[2] == b[2] && a[3] == b[3] +} + +/// Zero digest. +pub const ZERO_DIGEST: GlDigest = [0; 4]; diff --git a/crates/nightstream-sdk/src/lib.rs b/crates/nightstream-sdk/src/lib.rs index f0a043c8..4cd5b989 100644 --- a/crates/nightstream-sdk/src/lib.rs +++ b/crates/nightstream-sdk/src/lib.rs @@ -3,6 +3,8 @@ pub use nightstream_sdk_macros::{entry, provable, NeoAbi}; pub mod abi; +pub mod goldilocks; +pub mod poseidon2; #[cfg(any(target_arch = "riscv32", target_arch = "riscv64"))] mod riscv; diff --git a/crates/nightstream-sdk/src/poseidon2.rs b/crates/nightstream-sdk/src/poseidon2.rs new file mode 100644 index 00000000..1cf9b90d --- /dev/null +++ b/crates/nightstream-sdk/src/poseidon2.rs @@ -0,0 +1,111 @@ +//! Poseidon2-Goldilocks hash via ECALL precompile. +//! +//! Provides the guest-side API for computing Poseidon2 hashes over Goldilocks +//! field elements. On RISC-V targets, this issues ECALLs to the host: +//! 1. A "compute" ECALL that reads inputs (via untraced loads) and computes the +//! Poseidon2 hash, storing the digest in host-side CPU state. +//! 2. Eight "read" ECALLs that each return one u32 word of the digest in +//! register a0. +//! +//! On non-RISC-V targets, a stub panics. + +use crate::goldilocks::GlDigest; + +/// Poseidon2 compute ECALL number (must match neo-memory's POSEIDON2_ECALL_NUM). +#[cfg(any(target_arch = "riscv32", target_arch = "riscv64"))] +const POSEIDON2_ECALL_NUM: u32 = 0x504F53; + +/// Poseidon2 read ECALL number (must match neo-memory's POSEIDON2_READ_ECALL_NUM). +#[cfg(any(target_arch = "riscv32", target_arch = "riscv64"))] +const POSEIDON2_READ_ECALL_NUM: u32 = 0x80504F53; + +/// Digest length in Goldilocks elements (matches neo-params DIGEST_LEN). +const DIGEST_LEN: usize = 4; + +/// Scratch buffer size: supports hashing up to 64 Goldilocks elements per call. +/// Each element occupies 2 u32 words (8 bytes). +const MAX_INPUT_ELEMENTS: usize = 64; + +/// Hash an arbitrary-length slice of Goldilocks field elements. +/// +/// Packs elements to a stack-allocated scratch buffer, issues the Poseidon2 +/// compute ECALL, then retrieves the 4-element digest via 8 read ECALLs +/// (each returning one u32 word in register a0). +/// +/// # Panics +/// +/// Panics if `input.len() > MAX_INPUT_ELEMENTS` (64). +pub fn poseidon2_hash(input: &[u64]) -> GlDigest { + assert!( + input.len() <= MAX_INPUT_ELEMENTS, + "poseidon2_hash: too many input elements" + ); + + // Scratch buffer for input (as u32 words). + let mut input_buf: [u32; MAX_INPUT_ELEMENTS * 2] = [0; MAX_INPUT_ELEMENTS * 2]; + + // Pack Goldilocks elements as pairs of u32 words (little-endian). + for (i, &elem) in input.iter().enumerate() { + input_buf[i * 2] = elem as u32; + input_buf[i * 2 + 1] = (elem >> 32) as u32; + } + + // Compute ECALL: host reads inputs via untraced loads and stores digest in CPU state. + poseidon2_ecall_compute(input.len() as u32, input_buf.as_ptr() as u32); + + // Read 8 u32 words of the digest via register a0. + let d: [u32; 8] = core::array::from_fn(|_| poseidon2_ecall_read()); + + // Unpack into 4 Goldilocks elements. + let mut digest = [0u64; DIGEST_LEN]; + for i in 0..DIGEST_LEN { + digest[i] = (d[i * 2] as u64) | ((d[i * 2 + 1] as u64) << 32); + } + digest +} + +/// Issue the Poseidon2 compute ECALL. +/// +/// Registers: a0 = ECALL ID, a1 = element count, a2 = input addr. +/// No output via registers; digest is stored in host CPU state. +#[cfg(any(target_arch = "riscv32", target_arch = "riscv64"))] +fn poseidon2_ecall_compute(n_elements: u32, input_addr: u32) { + unsafe { + core::arch::asm!( + "ecall", + in("a0") POSEIDON2_ECALL_NUM, + in("a1") n_elements, + in("a2") input_addr, + options(nostack), + ); + } +} + +/// Issue the Poseidon2 read ECALL and return the next digest word. +/// +/// The ECALL number is passed in a0, and the host returns the next u32 word +/// of the pending digest in a0. +#[cfg(any(target_arch = "riscv32", target_arch = "riscv64"))] +fn poseidon2_ecall_read() -> u32 { + let result: u32; + unsafe { + core::arch::asm!( + "ecall", + inout("a0") POSEIDON2_READ_ECALL_NUM => result, + options(nostack), + ); + } + result +} + +/// Stub for non-RISC-V targets (used in native tests via the reference impl). +#[cfg(not(any(target_arch = "riscv32", target_arch = "riscv64")))] +fn poseidon2_ecall_compute(_n_elements: u32, _input_addr: u32) { + unimplemented!("poseidon2_ecall_compute is only available on RISC-V targets") +} + +/// Stub for non-RISC-V targets. +#[cfg(not(any(target_arch = "riscv32", target_arch = "riscv64")))] +fn poseidon2_ecall_read() -> u32 { + unimplemented!("poseidon2_ecall_read is only available on RISC-V targets") +} From 10ef34c4765035d98bfaaf323b13716223969bd4 Mon Sep 17 00:00:00 2001 From: Guillermo Valin Date: Tue, 17 Feb 2026 11:18:19 -0300 Subject: [PATCH 6/6] Add detailed timing metrics for RISC-V VM and CPU witness phases - Introduced new timing fields in `ProveTimings` to track durations for VM trace collection and CPU witness building. - Updated `ShardWitnessAux` to include wall-clock time for VM trace and CPU witness phases. - Enhanced `R1csCpu` to log detailed timing for witness building, bus filling, matrix decomposition, and commitment processes. - Added methods in `Session` for preloading CCS preprocessing with matrix digests to optimize performance during proving and verification. --- crates/neo-fold/src/riscv_shard.rs | 247 ++++++++++++++++------ crates/neo-fold/src/session.rs | 49 ++++- crates/neo-memory/src/builder.rs | 10 + crates/neo-memory/src/cpu/r1cs_adapter.rs | 54 +++++ 4 files changed, 299 insertions(+), 61 deletions(-) diff --git a/crates/neo-fold/src/riscv_shard.rs b/crates/neo-fold/src/riscv_shard.rs index a5d1c932..6b30e799 100644 --- a/crates/neo-fold/src/riscv_shard.rs +++ b/crates/neo-fold/src/riscv_shard.rs @@ -88,6 +88,10 @@ pub struct ProveTimings { pub vm_execution: Duration, /// Fold-and-prove (sumcheck + Ajtai commitment). pub fold_and_prove: Duration, + /// Sub-phase: RISC-V VM trace collection (within vm_execution). + pub vm_trace: Duration, + /// Sub-phase: CPU witness building including decomp + commit (within vm_execution). + pub cpu_witness: Duration, } pub fn rv32_b1_step_linking_config(layout: &Rv32B1Layout) -> StepLinkingConfig { @@ -294,16 +298,33 @@ fn all_shout_opcodes() -> HashSet { /// Build with [`Rv32B1::build_ccs_cache`], then inject into subsequent builders via /// [`Rv32B1::with_ccs_cache`]. /// -/// **Important**: the CCS structure itself is still built from scratch each time (it is fast). -/// Only the SparseCache (CSC decomposition + matrix digest) is cached. +/// When `full_ccs` is present, `prove()` and `build_verifier()` skip `build_rv32_b1_step_ccs` +/// and `with_shared_cpu_bus`, reusing the cached CCS directly. This saves ~1s per call. +/// +/// When `params` and `committer` are also present, `prove()` and `build_verifier()` additionally +/// skip `FoldingSession::new_ajtai()` (which generates a fresh random kappa x m Ajtai matrix), +/// using `FoldingSession::new()` with the cached values instead. pub struct Rv32B1CcsCache { pub sparse: Arc>, + /// Full post-shared-bus CCS. When present, prove/verify skip CCS construction. + pub full_ccs: Option>, + /// Layout corresponding to the cached CCS. + pub layout: Option, + /// Cached NeoParams from the original session creation. + pub params: Option, + /// Cached Ajtai committer (wraps `Arc`). Avoids re-generating the random matrix. + pub committer: Option, + /// Precomputed CCS matrix digest. Avoids re-hashing all matrices (~1.5s) on every prove/verify. + pub ccs_mat_digest: Option>, } impl std::fmt::Debug for Rv32B1CcsCache { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Rv32B1CcsCache") .field("sparse_len", &self.sparse.len()) + .field("has_full_ccs", &self.full_ccs.is_some()) + .field("has_params", &self.params.is_some()) + .field("has_committer", &self.committer.is_some()) .finish() } } @@ -519,9 +540,21 @@ impl Rv32B1 { ) .map_err(|e| PiCcsError::InvalidInput(format!("shared bus inject failed: {e}")))?; - let sparse = Arc::new(SparseCache::build(&cpu.ccs)); + let cached_params = session.params().clone(); + let cached_committer = session.committer().clone(); + + let full_ccs = cpu.ccs.clone(); + let sparse = Arc::new(SparseCache::build(&full_ccs)); + let ccs_mat_digest = neo_reductions::engines::utils::digest_ccs_matrices(&full_ccs); - Ok(Rv32B1CcsCache { sparse }) + Ok(Rv32B1CcsCache { + sparse, + full_ccs: Some(full_ccs), + layout: Some(layout), + params: Some(cached_params), + committer: Some(cached_committer), + ccs_mat_digest: Some(ccs_mat_digest), + }) } /// Build only the verification context (CCS + session) without executing the program or proving. @@ -623,38 +656,62 @@ impl Rv32B1 { let mut shout_table_ids: Vec = table_specs.keys().copied().collect(); shout_table_ids.sort_unstable(); - // --- CCS + Session (same as prove) --- - let (ccs_base, layout) = build_rv32_b1_step_ccs(&mem_layouts, &shout_table_ids, self.chunk_size) - .map_err(|e| PiCcsError::InvalidInput(format!("build_rv32_b1_step_ccs failed: {e}")))?; - - let mut session = FoldingSession::::new_ajtai(self.mode.clone(), &ccs_base)?; - let params = session.params().clone(); - let committer = session.committer().clone(); + // --- CCS + Session --- + let has_full_ccs = self + .ccs_cache + .as_ref() + .map(|c| c.full_ccs.is_some() && c.layout.is_some()) + .unwrap_or(false); let empty_tables: HashMap> = HashMap::new(); - - // Build R1csCpu for CCS with shared-bus wiring (same as prove; no execution). - let mut cpu = R1csCpu::new( - ccs_base, - params, - committer, - layout.m_in, - &empty_tables, - &table_specs, - rv32_b1_chunk_to_witness(layout.clone()), - ); - cpu = cpu - .with_shared_cpu_bus( - rv32_b1_shared_cpu_bus_config(&layout, &shout_table_ids, mem_layouts.clone(), initial_mem.clone()) - .map_err(|e| PiCcsError::InvalidInput(format!("rv32_b1_shared_cpu_bus_config failed: {e}")))?, - self.chunk_size, - ) - .map_err(|e| PiCcsError::InvalidInput(format!("shared bus inject failed: {e}")))?; + let (mut session, ccs, layout); + + if has_full_ccs { + let cache = self.ccs_cache.as_ref().unwrap(); + let cached_ccs = cache.full_ccs.as_ref().unwrap().clone(); + layout = cache.layout.as_ref().unwrap().clone(); + + // Use cached params+committer to skip the expensive setup_par() random matrix generation. + let has_cached_session = cache.params.is_some() && cache.committer.is_some(); + if has_cached_session { + let cached_params = cache.params.as_ref().unwrap().clone(); + let cached_committer = cache.committer.as_ref().unwrap().clone(); + session = FoldingSession::new(self.mode.clone(), cached_params, cached_committer); + session.commit_m = Some(cached_ccs.m); + } else { + session = FoldingSession::::new_ajtai(self.mode.clone(), &cached_ccs)?; + } + ccs = cached_ccs; + } else { + let (ccs_base, layout_built) = build_rv32_b1_step_ccs(&mem_layouts, &shout_table_ids, self.chunk_size) + .map_err(|e| PiCcsError::InvalidInput(format!("build_rv32_b1_step_ccs failed: {e}")))?; + layout = layout_built; + + session = FoldingSession::::new_ajtai(self.mode.clone(), &ccs_base)?; + let params = session.params().clone(); + let committer = session.committer().clone(); + + let mut cpu = R1csCpu::new( + ccs_base, + params, + committer, + layout.m_in, + &empty_tables, + &table_specs, + rv32_b1_chunk_to_witness(layout.clone()), + ); + cpu = cpu + .with_shared_cpu_bus( + rv32_b1_shared_cpu_bus_config(&layout, &shout_table_ids, mem_layouts.clone(), initial_mem.clone()) + .map_err(|e| PiCcsError::InvalidInput(format!("rv32_b1_shared_cpu_bus_config failed: {e}")))?, + self.chunk_size, + ) + .map_err(|e| PiCcsError::InvalidInput(format!("shared bus inject failed: {e}")))?; + ccs = cpu.ccs.clone(); + } session.set_step_linking(rv32_b1_step_linking_config(&layout)); - let ccs = cpu.ccs.clone(); - // No execution, no fold_and_prove -- just the verification context. // NOTE: verifier SparseCache preloading is deferred to `preload_sparse_cache()` // because the CCS pointer changes when moved into the Rv32B1Verifier struct. @@ -784,37 +841,82 @@ impl Rv32B1 { let prebuilt_sparse = self.ccs_cache.as_ref().map(|c| c.sparse.clone()); - let (ccs_base, layout) = build_rv32_b1_step_ccs(&mem_layouts, &shout_table_ids, self.chunk_size) - .map_err(|e| PiCcsError::InvalidInput(format!("build_rv32_b1_step_ccs failed: {e}")))?; - - // Session + Ajtai committer + params (auto-picked for this CCS). - let mut session = FoldingSession::::new_ajtai(self.mode.clone(), &ccs_base)?; - let params = session.params().clone(); - let committer = session.committer().clone(); + // Check if the cache contains a full pre-wired CCS we can reuse. + let has_full_ccs = self + .ccs_cache + .as_ref() + .map(|c| c.full_ccs.is_some() && c.layout.is_some()) + .unwrap_or(false); let mut vm = RiscvCpu::new(self.xlen); vm.load_program(/*base=*/ 0, program); - let empty_tables: HashMap> = HashMap::new(); let lut_lanes: HashMap = HashMap::new(); - // CPU arithmetization (builds chunk witnesses and commits them). - let mut cpu = R1csCpu::new( - ccs_base, - params, - committer, - layout.m_in, - &empty_tables, - &table_specs, - rv32_b1_chunk_to_witness(layout.clone()), - ); - cpu = cpu - .with_shared_cpu_bus( - rv32_b1_shared_cpu_bus_config(&layout, &shout_table_ids, mem_layouts.clone(), initial_mem.clone()) - .map_err(|e| PiCcsError::InvalidInput(format!("rv32_b1_shared_cpu_bus_config failed: {e}")))?, - self.chunk_size, - ) - .map_err(|e| PiCcsError::InvalidInput(format!("shared bus inject failed: {e}")))?; + let (mut session, cpu, layout); + + if has_full_ccs { + let cache = self.ccs_cache.as_ref().unwrap(); + let cached_ccs = cache.full_ccs.as_ref().unwrap().clone(); + layout = cache.layout.as_ref().unwrap().clone(); + + // Use cached params+committer to skip the expensive setup_par() random matrix generation. + let has_cached_session = cache.params.is_some() && cache.committer.is_some(); + if has_cached_session { + let cached_params = cache.params.as_ref().unwrap().clone(); + let cached_committer = cache.committer.as_ref().unwrap().clone(); + session = FoldingSession::new(self.mode.clone(), cached_params, cached_committer); + session.commit_m = Some(cached_ccs.m); + } else { + session = FoldingSession::::new_ajtai(self.mode.clone(), &cached_ccs)?; + } + let params = session.params().clone(); + let committer = session.committer().clone(); + + let mut cpu_inner = R1csCpu::new( + cached_ccs, + params, + committer, + layout.m_in, + &empty_tables, + &table_specs, + rv32_b1_chunk_to_witness(layout.clone()), + ); + cpu_inner = cpu_inner + .with_shared_cpu_bus_witness_only( + rv32_b1_shared_cpu_bus_config(&layout, &shout_table_ids, mem_layouts.clone(), initial_mem.clone()) + .map_err(|e| PiCcsError::InvalidInput(format!("rv32_b1_shared_cpu_bus_config failed: {e}")))?, + self.chunk_size, + ) + .map_err(|e| PiCcsError::InvalidInput(format!("shared bus witness-only setup failed: {e}")))?; + cpu = cpu_inner; + } else { + let (ccs_base, layout_built) = build_rv32_b1_step_ccs(&mem_layouts, &shout_table_ids, self.chunk_size) + .map_err(|e| PiCcsError::InvalidInput(format!("build_rv32_b1_step_ccs failed: {e}")))?; + layout = layout_built; + + session = FoldingSession::::new_ajtai(self.mode.clone(), &ccs_base)?; + let params = session.params().clone(); + let committer = session.committer().clone(); + + let mut cpu_inner = R1csCpu::new( + ccs_base, + params, + committer, + layout.m_in, + &empty_tables, + &table_specs, + rv32_b1_chunk_to_witness(layout.clone()), + ); + cpu_inner = cpu_inner + .with_shared_cpu_bus( + rv32_b1_shared_cpu_bus_config(&layout, &shout_table_ids, mem_layouts.clone(), initial_mem.clone()) + .map_err(|e| PiCcsError::InvalidInput(format!("rv32_b1_shared_cpu_bus_config failed: {e}")))?, + self.chunk_size, + ) + .map_err(|e| PiCcsError::InvalidInput(format!("shared bus inject failed: {e}")))?; + cpu = cpu_inner; + } // Always enforce step-to-step chunk chaining for RV32 B1. session.set_step_linking(rv32_b1_step_linking_config(&layout)); @@ -849,16 +951,28 @@ impl Rv32B1 { } } + // Extract sub-phase timings from the shared-bus aux. + let (t_vm_trace, t_cpu_witness) = session + .shared_bus_aux() + .map(|aux| (aux.vm_trace_duration, aux.cpu_witness_duration)) + .unwrap_or((Duration::ZERO, Duration::ZERO)); + // Enforce that the *statement* initial memory matches chunk 0's public MemInit. let steps_public = session.steps_public(); rv32_b1_enforce_chunk0_mem_init_matches_statement(&mem_layouts, &initial_mem, &steps_public)?; - let ccs = cpu.ccs.clone(); + // Move the CCS out of cpu (no longer needed) to avoid an expensive clone. + let ccs = cpu.ccs; - // Preload the SparseCache using the *final* CCS reference that fold_and_prove will use. - // (Must happen after cpu.ccs.clone() because the pointer-keyed cache checks identity.) + // Preload the SparseCache + matrix digest. When the digest is cached, skip the + // expensive digest_ccs_matrices recomputation (~1.5s for large circuits). if let Some(ref sparse) = prebuilt_sparse { - session.preload_ccs_sparse_cache(&ccs, sparse.clone())?; + let cached_digest = self.ccs_cache.as_ref().and_then(|c| c.ccs_mat_digest.clone()); + if let Some(digest) = cached_digest { + session.preload_ccs_sparse_cache_with_digest(&ccs, sparse.clone(), digest)?; + } else { + session.preload_ccs_sparse_cache(&ccs, sparse.clone())?; + } } let t_vm_execution = elapsed_duration(phase_start); @@ -890,6 +1004,8 @@ impl Rv32B1 { ccs_and_shared_bus: t_ccs_and_shared_bus, vm_execution: t_vm_execution, fold_and_prove: t_fold_and_prove, + vm_trace: t_vm_trace, + cpu_witness: t_cpu_witness, }, verify_duration: None, }) @@ -920,6 +1036,17 @@ impl Rv32B1Verifier { .preload_verifier_ccs_sparse_cache(&self.ccs, sparse) } + /// Preload verifier SparseCache with a precomputed matrix digest, skipping the expensive + /// digest recomputation. + pub fn preload_sparse_cache_with_digest( + &mut self, + sparse: Arc>, + ccs_mat_digest: Vec, + ) -> Result<(), PiCcsError> { + self.session + .preload_verifier_ccs_sparse_cache_with_digest(&self.ccs, sparse, ccs_mat_digest) + } + /// Verify a `ShardProof` using the provided public step instance bundles. /// /// `steps_public` must be the step instance bundles produced by the prover (via diff --git a/crates/neo-fold/src/session.rs b/crates/neo-fold/src/session.rs index 0dece63c..d12fd552 100644 --- a/crates/neo-fold/src/session.rs +++ b/crates/neo-fold/src/session.rs @@ -592,7 +592,7 @@ where mode: FoldingMode, params: NeoParams, l: L, - commit_m: Option, + pub(crate) commit_m: Option, mixers: CommitMixers], &[Cmt]) -> Cmt, fn(&[Cmt], u32) -> Cmt>, // Cached CCS preprocessing for proving (best-effort reuse). @@ -1193,6 +1193,30 @@ where Ok(()) } + /// Preload prover-side CCS preprocessing with a precomputed matrix digest, skipping the + /// expensive `digest_ccs_matrices` call (~1.5s for large circuits). + pub fn preload_ccs_sparse_cache_with_digest( + &mut self, + s: &CcsStructure, + ccs_sparse_cache: Arc>, + ccs_mat_digest: Vec, + ) -> Result<(), PiCcsError> { + let src_ptr = (s as *const CcsStructure) as usize; + if ccs_sparse_cache.len() != s.t() { + return Err(PiCcsError::InvalidInput(format!( + "SparseCache matrix count mismatch: cache has {}, CCS has {}", + ccs_sparse_cache.len(), + s.t() + ))); + } + let ctx = ShardProverContext { + ccs_mat_digest, + ccs_sparse_cache: Some(ccs_sparse_cache), + }; + self.prover_ctx = Some(SessionCcsCache { src_ptr, ctx }); + Ok(()) + } + /// Preload verifier-side CCS preprocessing (sparse-cache + matrix-digest). /// /// This does **not** affect proving. It exists so benchmarks can model a verifier that has @@ -1206,6 +1230,29 @@ where Ok(()) } + /// Preload verifier-side CCS preprocessing with a precomputed matrix digest. + pub fn preload_verifier_ccs_sparse_cache_with_digest( + &mut self, + s: &CcsStructure, + ccs_sparse_cache: Arc>, + ccs_mat_digest: Vec, + ) -> Result<(), PiCcsError> { + let src_ptr = (s as *const CcsStructure) as usize; + if ccs_sparse_cache.len() != s.t() { + return Err(PiCcsError::InvalidInput(format!( + "SparseCache matrix count mismatch: cache has {}, CCS has {}", + ccs_sparse_cache.len(), + s.t() + ))); + } + let ctx = ShardProverContext { + ccs_mat_digest, + ccs_sparse_cache: Some(ccs_sparse_cache), + }; + self.verifier_ctx = Some(SessionCcsCache { src_ptr, ctx }); + Ok(()) + } + /// Fold and prove: run folding over all collected steps and return a `FoldRun`. /// This is where the actual cryptographic work happens (Π_CCS → RLC → DEC for each step). /// This method manages the transcript internally for ease of use. diff --git a/crates/neo-memory/src/builder.rs b/crates/neo-memory/src/builder.rs index ed35c991..f80db964 100644 --- a/crates/neo-memory/src/builder.rs +++ b/crates/neo-memory/src/builder.rs @@ -51,6 +51,10 @@ pub struct ShardWitnessAux { pub mem_ids: Vec, /// Final sparse memory states at the end of the shard: mem_id -> (addr -> value), with zero cells omitted. pub final_mem_states: HashMap>, + /// Wall-clock time for VM trace collection (`trace_program`). + pub vm_trace_duration: std::time::Duration, + /// Wall-clock time for CPU witness building (`build_ccs_chunks`), includes decomp + commit. + pub cpu_witness_duration: std::time::Duration, } fn ell_from_pow2_n_side(n_side: usize) -> Result { @@ -138,8 +142,10 @@ where // // This keeps the proof size proportional to the executed trace length instead of the caller's // safety bound. + let t_vm_start = std::time::Instant::now(); let trace = neo_vm_trace::trace_program(vm, twist, shout, max_steps) .map_err(|e| ShardBuildError::VmError(e.to_string()))?; + let vm_trace_duration = t_vm_start.elapsed(); let original_len = trace.steps.len(); let did_halt = trace.did_halt(); debug_assert!( @@ -193,9 +199,11 @@ where table_ids.dedup(); // 3) CPU arithmetization chunks. + let t_witness_start = std::time::Instant::now(); let mcss = cpu_arith .build_ccs_chunks(&trace, chunk_size) .map_err(|e| ShardBuildError::CcsError(e.to_string()))?; + let cpu_witness_duration = t_witness_start.elapsed(); if mcss.len() != chunks_len { return Err(ShardBuildError::CcsError(format!( "cpu arithmetization returned {} chunks, expected {} (steps={}, chunk_size={})", @@ -368,6 +376,8 @@ where chunk_size, mem_ids, final_mem_states: mem_states, + vm_trace_duration, + cpu_witness_duration, }; Ok((step_bundles, aux)) } diff --git a/crates/neo-memory/src/cpu/r1cs_adapter.rs b/crates/neo-memory/src/cpu/r1cs_adapter.rs index bf321513..5aa253d3 100644 --- a/crates/neo-memory/src/cpu/r1cs_adapter.rs +++ b/crates/neo-memory/src/cpu/r1cs_adapter.rs @@ -419,6 +419,38 @@ where }); Ok(self) } + + /// Like [`with_shared_cpu_bus`], but assumes the CCS already contains bus constraints. + /// + /// Use this when the `ccs` passed to [`R1csCpu::new`] is a pre-wired CCS from a cache + /// (e.g. from a previous `with_shared_cpu_bus` call). This sets up the internal + /// `SharedCpuBusState` needed for witness generation in [`build_ccs_chunks`] without + /// re-injecting constraint matrices into the CCS. + pub fn with_shared_cpu_bus_witness_only( + mut self, + cfg: SharedCpuBusConfig, + chunk_size: usize, + ) -> Result { + if chunk_size == 0 { + return Err("shared_cpu_bus: chunk_size must be >= 1".into()); + } + if cfg.const_one_col >= self.m_in { + return Err(format!( + "shared_cpu_bus: const_one_col={} must be < m_in={}", + cfg.const_one_col, self.m_in + )); + } + + let (table_ids, mem_ids, layout) = self.shared_bus_schema(&cfg, chunk_size)?; + + self.shared_cpu_bus = Some(SharedCpuBusState { + cfg, + table_ids, + mem_ids, + layout, + }); + Ok(self) + } } // R1csCpu implementation specifically for Goldilocks field because neo_ajtai::decomp_b uses Goldilocks @@ -451,6 +483,10 @@ where } let mut mcss = Vec::with_capacity(trace.steps.len().div_ceil(chunk_size)); + let mut t_witness_build = std::time::Duration::ZERO; + let mut t_bus_fill = std::time::Duration::ZERO; + let mut t_decomp_b = std::time::Duration::ZERO; + let mut t_commit = std::time::Duration::ZERO; // Track sparse memory state across the full trace to compute inc_at_write_addr. let mut mem_state: HashMap> = HashMap::new(); @@ -480,6 +516,7 @@ where let chunk = &trace.steps[chunk_start..chunk_end]; // 1) Build witness z for this chunk. + let t0 = std::time::Instant::now(); let mut z_vec = (self.chunk_to_witness)(chunk); // Allow witness builders to omit trailing dummy variables (including the shared-bus tail). @@ -512,8 +549,10 @@ where } z_vec[shared.cfg.const_one_col] = Goldilocks::ONE; } + t_witness_build += t0.elapsed(); // 2) Overwrite the shared bus tail from the trace events. + let t0 = std::time::Instant::now(); if let Some(shared) = shared { let bus_base = shared.layout.bus_base; let bus_region_len = shared.layout.bus_region_len(); @@ -778,7 +817,10 @@ where } } + t_bus_fill += t0.elapsed(); + // 3) Decompose z -> Z matrix + let t0 = std::time::Instant::now(); let d = self.params.d as usize; let m = z_vec.len(); // == ccs.m after padding @@ -799,9 +841,12 @@ where } } let z_mat = Mat::from_row_major(d, m, mat_data); + t_decomp_b += t0.elapsed(); // 4) Commit to Z + let t0 = std::time::Instant::now(); let c = self.committer.commit(&z_mat); + t_commit += t0.elapsed(); // 5) Build Instance/Witness let x = z_vec[..m_in].to_vec(); @@ -812,6 +857,15 @@ where chunk_start = chunk_end; } + eprintln!( + "[build_ccs_chunks] witness_build={}ms, bus_fill={}ms, decomp_b={}ms, commit={}ms, total={}ms", + t_witness_build.as_millis(), + t_bus_fill.as_millis(), + t_decomp_b.as_millis(), + t_commit.as_millis(), + (t_witness_build + t_bus_fill + t_decomp_b + t_commit).as_millis(), + ); + Ok(mcss) }