From e559fbbf285d29a919ee6d66f76df4b0f67a51f5 Mon Sep 17 00:00:00 2001 From: Guillermo Valin Date: Fri, 13 Feb 2026 09:52:26 -0300 Subject: [PATCH 01/13] Enhance serialization support for proof types - Added `serde::Serialize` and `serde::Deserialize` derives to various structs including `Commitment`, `ShoutAddrPreProof`, and `BatchedTimeProof` for improved serialization capabilities. - Introduced `serde_helpers` module with custom serialization functions for handling non-trivially-serializable fields, specifically for `Vec<&'static [u8]>`. - Implemented verification methods in `FoldingSession` to allow proof verification using externally supplied step instance bundles without executing the program. This update enhances the flexibility and usability of proof serialization across the codebase. --- crates/neo-ajtai/src/types.rs | 2 +- crates/neo-fold/src/lib.rs | 3 + crates/neo-fold/src/riscv_shard.rs | 184 +++++++++++ crates/neo-fold/src/serde_helpers.rs | 61 ++++ crates/neo-fold/src/session.rs | 303 ++++++++++++++++++ crates/neo-fold/src/shard_proof_types.rs | 27 +- crates/neo-reductions/Cargo.toml | 6 +- .../src/engines/optimized_engine/common.rs | 2 +- .../src/engines/optimized_engine/mod.rs | 4 +- 9 files changed, 576 insertions(+), 16 deletions(-) create mode 100644 crates/neo-fold/src/serde_helpers.rs diff --git a/crates/neo-ajtai/src/types.rs b/crates/neo-ajtai/src/types.rs index 9f0b91e3..8fbaa37c 100644 --- a/crates/neo-ajtai/src/types.rs +++ b/crates/neo-ajtai/src/types.rs @@ -13,7 +13,7 @@ pub struct PP { } /// Commitment c ∈ F_q^{d×κ}, stored as column-major flat matrix (κ columns, each length d). -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)] pub struct Commitment { pub d: usize, pub kappa: usize, diff --git a/crates/neo-fold/src/lib.rs b/crates/neo-fold/src/lib.rs index dae0b088..aee318a1 100644 --- a/crates/neo-fold/src/lib.rs +++ b/crates/neo-fold/src/lib.rs @@ -33,3 +33,6 @@ pub mod riscv_shard; pub mod output_binding; mod shard_proof_types; + +/// Serde helpers for proof serialization (e.g. `&'static [u8]` labels). +pub mod serde_helpers; diff --git a/crates/neo-fold/src/riscv_shard.rs b/crates/neo-fold/src/riscv_shard.rs index 25c14283..47c4160b 100644 --- a/crates/neo-fold/src/riscv_shard.rs +++ b/crates/neo-fold/src/riscv_shard.rs @@ -373,6 +373,147 @@ impl Rv32B1 { self } + /// Build only the verification context (CCS + session) without executing the program or proving. + /// + /// This performs the same validation, program decoding, memory layout setup, CCS construction + /// (including shared-bus wiring), and session creation that `prove()` does -- but stops before + /// any RISC-V execution or folding. + /// + /// The returned [`Rv32B1Verifier`] can verify a `ShardProof` given the public MCS instances + /// (`mcss_public`) that were produced by the prover. + /// + /// **Cost:** circuit synthesis only (~ms). No RISC-V execution, no folding. + pub fn build_verifier(self) -> Result { + // --- Input validation (same as prove) --- + if self.xlen != 32 { + return Err(PiCcsError::InvalidInput(format!( + "RV32 B1 MVP requires xlen == 32 (got {})", + self.xlen + ))); + } + if self.program_bytes.is_empty() { + return Err(PiCcsError::InvalidInput("program_bytes must be non-empty".into())); + } + if self.chunk_size == 0 { + return Err(PiCcsError::InvalidInput("chunk_size must be non-zero".into())); + } + if self.ram_bytes == 0 { + return Err(PiCcsError::InvalidInput("ram_bytes must be non-zero".into())); + } + if self.program_base != 0 { + return Err(PiCcsError::InvalidInput( + "RV32 B1 MVP requires program_base == 0 (addresses are indices into PROG/RAM layouts)".into(), + )); + } + if self.program_bytes.len() % 4 != 0 { + return Err(PiCcsError::InvalidInput( + "program_bytes must be 4-byte aligned (RV32 B1 runner does not support RVC)".into(), + )); + } + for (i, chunk) in self.program_bytes.chunks_exact(4).enumerate() { + let first_half = u16::from_le_bytes([chunk[0], chunk[1]]); + if (first_half & 0b11) != 0b11 { + return Err(PiCcsError::InvalidInput(format!( + "RV32 B1 runner does not support compressed instructions (RVC): found compressed encoding at word index {i}" + ))); + } + } + + // --- Program decoding + memory layouts (same as prove, minus VM/Twist init) --- + let program = decode_program(&self.program_bytes) + .map_err(|e| PiCcsError::InvalidInput(format!("decode_program failed: {e}")))?; + + let (prog_layout, initial_mem) = neo_memory::riscv::rom_init::prog_rom_layout_and_init_words( + PROG_ID, + /*base_addr=*/ 0, + &self.program_bytes, + ) + .map_err(|e| PiCcsError::InvalidInput(format!("prog_rom_layout_and_init_words failed: {e}")))?; + let mut initial_mem = initial_mem; + for (&addr, &value) in &self.ram_init { + let value = value as u32 as u64; + initial_mem.insert((neo_memory::riscv::lookups::RAM_ID.0, addr), F::from_u64(value)); + } + + let (k_ram, d_ram) = pow2_ceil_k(self.ram_bytes.max(4)); + let mem_layouts = HashMap::from([ + ( + neo_memory::riscv::lookups::RAM_ID.0, + PlainMemLayout { + k: k_ram, + d: d_ram, + n_side: 2, + lanes: 1, + }, + ), + (PROG_ID.0, prog_layout), + ]); + + // --- Shout tables (same as prove) --- + let mut shout_ops = match &self.shout_ops { + Some(ops) => ops.clone(), + None if self.shout_auto_minimal => infer_required_shout_opcodes(&program), + None => all_shout_opcodes(), + }; + shout_ops.insert(RiscvOpcode::Add); + + let shout = RiscvShoutTables::new(self.xlen); + let mut table_specs: HashMap = HashMap::new(); + for op in shout_ops { + let table_id = shout.opcode_to_id(op).0; + table_specs.insert( + table_id, + LutTableSpec::RiscvOpcode { + opcode: op, + xlen: self.xlen, + }, + ); + } + let mut shout_table_ids: Vec = table_specs.keys().copied().collect(); + shout_table_ids.sort_unstable(); + + // --- CCS + Session (same as prove) --- + let (ccs_base, layout) = build_rv32_b1_step_ccs(&mem_layouts, &shout_table_ids, self.chunk_size) + .map_err(|e| PiCcsError::InvalidInput(format!("build_rv32_b1_step_ccs failed: {e}")))?; + + let mut session = FoldingSession::::new_ajtai(self.mode.clone(), &ccs_base)?; + let params = session.params().clone(); + let committer = session.committer().clone(); + + let empty_tables: HashMap> = HashMap::new(); + + // Build R1csCpu for CCS with shared-bus wiring (same as prove; no execution). + let mut cpu = R1csCpu::new( + ccs_base, + params, + committer, + layout.m_in, + &empty_tables, + &table_specs, + rv32_b1_chunk_to_witness(layout.clone()), + ); + cpu = cpu + .with_shared_cpu_bus( + rv32_b1_shared_cpu_bus_config(&layout, &shout_table_ids, mem_layouts.clone(), initial_mem.clone()) + .map_err(|e| PiCcsError::InvalidInput(format!("rv32_b1_shared_cpu_bus_config failed: {e}")))?, + self.chunk_size, + ) + .map_err(|e| PiCcsError::InvalidInput(format!("shared bus inject failed: {e}")))?; + + session.set_step_linking(rv32_b1_step_linking_config(&layout)); + + let ccs = cpu.ccs.clone(); + + // No execution, no fold_and_prove -- just the verification context. + Ok(Rv32B1Verifier { + session, + ccs, + _layout: layout, + ram_num_bits: d_ram, + output_claims: self.output_claims, + }) + } + pub fn prove(self) -> Result { if self.xlen != 32 { return Err(PiCcsError::InvalidInput(format!( @@ -569,6 +710,40 @@ impl Rv32B1 { } } +/// Verification context for RV32 B1 proofs. +/// +/// Created by [`Rv32B1::build_verifier`]. Contains the CCS structure and folding session +/// needed to verify a `ShardProof` without executing the RISC-V program. +pub struct Rv32B1Verifier { + session: FoldingSession, + ccs: CcsStructure, + _layout: Rv32B1Layout, + ram_num_bits: usize, + output_claims: ProgramIO, +} + +impl Rv32B1Verifier { + /// Verify a `ShardProof` using the provided public step instance bundles. + /// + /// `steps_public` must be the step instance bundles produced by the prover (via + /// `Rv32B1Run::steps_public()`). The verifier checks the folding proof + /// against these instances and the CCS structure -- no RISC-V execution + /// is performed. + pub fn verify( + &self, + proof: &ShardProof, + steps_public: &[StepInstanceBundle], + ) -> Result { + if self.output_claims.is_empty() { + self.session.verify_with_external_steps(&self.ccs, steps_public, proof) + } else { + let ob_cfg = OutputBindingConfig::new(self.ram_num_bits, self.output_claims.clone()); + self.session + .verify_with_external_steps_and_output_binding(&self.ccs, steps_public, proof, &ob_cfg) + } + } +} + pub struct Rv32B1Run { session: FoldingSession, proof: ShardProof, @@ -623,6 +798,15 @@ impl Rv32B1Run { self.session.steps_public() } + /// Return the public MCS instances for inclusion in a proof package. + /// + /// These instances must be transmitted alongside the `ShardProof` so that + /// a standalone verifier (via [`Rv32B1Verifier::verify`]) can check the proof + /// without re-executing the RISC-V program. + pub fn mcss_public(&self) -> Vec> { + self.session.mcss_public() + } + pub fn final_boundary_state(&self) -> Result { let steps_public = self.steps_public(); let last = steps_public diff --git a/crates/neo-fold/src/serde_helpers.rs b/crates/neo-fold/src/serde_helpers.rs new file mode 100644 index 00000000..4ad22f16 --- /dev/null +++ b/crates/neo-fold/src/serde_helpers.rs @@ -0,0 +1,61 @@ +//! Serde helpers for proof types that contain non-trivially-serializable fields. +//! +//! In particular, `BatchedTimeProof` stores `Vec<&'static [u8]>` labels which +//! require custom serialization. We serialize them as `Vec>` and leak +//! the deserialized allocations so the `&'static` lifetime is valid. + +use serde::de::{Deserializer, SeqAccess, Visitor}; +use serde::ser::{SerializeSeq, Serializer}; +use std::fmt; + +/// Serialize `Vec<&'static [u8]>` as `Vec>`. +pub fn serialize_static_byte_slices( + labels: &[&'static [u8]], + serializer: S, +) -> Result +where + S: Serializer, +{ + let mut seq = serializer.serialize_seq(Some(labels.len()))?; + for label in labels { + seq.serialize_element(&label.to_vec())?; + } + seq.end() +} + +/// Deserialize `Vec>` into `Vec<&'static [u8]>` by leaking allocations. +/// +/// This is safe for proof deserialization where the proof lives for the duration +/// of the verification process. The leaked memory is small (label strings are +/// typically short domain-separation tags). +pub fn deserialize_static_byte_slices<'de, D>( + deserializer: D, +) -> Result, D::Error> +where + D: Deserializer<'de>, +{ + struct StaticByteSlicesVisitor; + + impl<'de> Visitor<'de> for StaticByteSlicesVisitor { + type Value = Vec<&'static [u8]>; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a sequence of byte arrays") + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: SeqAccess<'de>, + { + let mut result = Vec::with_capacity(seq.size_hint().unwrap_or(0)); + while let Some(bytes) = seq.next_element::>()? { + // Leak the allocation to get a &'static [u8]. + let leaked: &'static [u8] = Box::leak(bytes.into_boxed_slice()); + result.push(leaked); + } + Ok(result) + } + } + + deserializer.deserialize_seq(StaticByteSlicesVisitor) +} diff --git a/crates/neo-fold/src/session.rs b/crates/neo-fold/src/session.rs index 83c86ae8..ae723da0 100644 --- a/crates/neo-fold/src/session.rs +++ b/crates/neo-fold/src/session.rs @@ -1680,6 +1680,309 @@ where Ok(true) } + /// Verify a proof using externally supplied step instance bundles. + /// + /// Unlike [`verify`] and [`verify_collected`] which use the session's internal + /// collected steps (populated by execution), this method takes the step bundles + /// directly. This enables **verify-only** flows where the verifier never + /// executes the program -- the steps come from the proof package. + pub fn verify_with_external_steps( + &self, + s: &CcsStructure, + steps_public: &[StepInstanceBundle], + run: &FoldRun, + ) -> Result { + let mut tr = Poseidon2Transcript::new(b"neo.fold/session"); + self.verify_with_external_steps_transcript(&mut tr, s, steps_public, run) + } + + /// Like [`verify_with_external_steps`] but with output binding. + pub fn verify_with_external_steps_and_output_binding( + &self, + s: &CcsStructure, + steps_public: &[StepInstanceBundle], + run: &FoldRun, + ob_cfg: &crate::output_binding::OutputBindingConfig, + ) -> Result { + let mut tr = Poseidon2Transcript::new(b"neo.fold/session"); + self.verify_with_external_steps_and_output_binding_transcript(&mut tr, s, steps_public, run, ob_cfg) + } + + /// Internal: verify with external steps + transcript. + fn verify_with_external_steps_transcript( + &self, + tr: &mut Poseidon2Transcript, + s: &CcsStructure, + steps_public: &[StepInstanceBundle], + run: &FoldRun, + ) -> Result { + let src_ptr = (s as *const CcsStructure) as usize; + let verifier_cache = self + .verifier_ctx + .as_ref() + .filter(|cache| cache.src_ptr == src_ptr); + let verifier_ctx = verifier_cache.map(|cache| &cache.ctx); + + let m_in_steps = steps_public.first().map(|inst| inst.mcs_inst.m_in).unwrap_or(0); + let s_prepared = self.prepared_ccs_for_accumulator(s)?; + + let seed_me: &[MeInstance] = match &self.acc0 { + Some(acc) => { + acc.check(&self.params, s_prepared)?; + let acc_m_in = acc.me.first().map(|m| m.m_in).unwrap_or(m_in_steps); + if acc_m_in != m_in_steps { + return Err(PiCcsError::InvalidInput( + "initial Accumulator.m_in must match steps' m_in".into(), + )); + } + &acc.me + } + None => &[], + }; + + let step_linking = self + .step_linking + .as_ref() + .filter(|cfg| !cfg.prev_next_equalities.is_empty()); + + let outputs = if steps_public.len() > 1 { + match step_linking { + Some(cfg) => match verifier_ctx { + Some(ctx) => shard::fold_shard_verify_with_step_linking_with_context( + self.mode.clone(), + tr, + &self.params, + s, + steps_public, + seed_me, + run, + self.mixers, + cfg, + ctx, + )?, + None => shard::fold_shard_verify_with_step_linking( + self.mode.clone(), + tr, + &self.params, + s, + steps_public, + seed_me, + run, + self.mixers, + cfg, + )?, + }, + None if self.allow_unlinked_steps => match verifier_ctx { + Some(ctx) => shard::fold_shard_verify_with_context( + self.mode.clone(), + tr, + &self.params, + s, + steps_public, + seed_me, + run, + self.mixers, + ctx, + )?, + None => shard::fold_shard_verify( + self.mode.clone(), + tr, + &self.params, + s, + steps_public, + seed_me, + run, + self.mixers, + )?, + }, + None => { + let mut msg = + "multi-step verification requires step linking; call FoldingSession::set_step_linking(...)" + .to_string(); + if let Some(diag) = &self.auto_step_linking_error { + msg.push_str(&format!(" (auto step-linking from StepSpec failed: {diag})")); + } + return Err(PiCcsError::InvalidInput(msg)); + } + } + } else { + match verifier_ctx { + Some(ctx) => shard::fold_shard_verify_with_context( + self.mode.clone(), + tr, + &self.params, + s, + steps_public, + seed_me, + run, + self.mixers, + ctx, + )?, + None => shard::fold_shard_verify( + self.mode.clone(), + tr, + &self.params, + s, + steps_public, + seed_me, + run, + self.mixers, + )?, + } + }; + + // Detect twist/shout from the provided steps (self.steps is empty for verify-only). + let has_twist_or_shout = steps_public.iter().any(|s| !s.mem_insts.is_empty()) + || steps_public.iter().any(|s| !s.lut_insts.is_empty()); + if !has_twist_or_shout && !outputs.obligations.val.is_empty() { + return Err(PiCcsError::ProtocolError( + "CCS-only session verification produced unexpected val-lane obligations".into(), + )); + } + Ok(true) + } + + /// Internal: verify with external steps + output binding + transcript. + fn verify_with_external_steps_and_output_binding_transcript( + &self, + tr: &mut Poseidon2Transcript, + s: &CcsStructure, + steps_public: &[StepInstanceBundle], + run: &FoldRun, + ob_cfg: &crate::output_binding::OutputBindingConfig, + ) -> Result { + let src_ptr = (s as *const CcsStructure) as usize; + let verifier_cache = self + .verifier_ctx + .as_ref() + .filter(|cache| cache.src_ptr == src_ptr); + let verifier_ctx = verifier_cache.map(|cache| &cache.ctx); + + let m_in_steps = steps_public.first().map(|inst| inst.mcs_inst.m_in).unwrap_or(0); + let s_prepared = self.prepared_ccs_for_accumulator(s)?; + + let seed_me: &[MeInstance] = match &self.acc0 { + Some(acc) => { + acc.check(&self.params, s_prepared)?; + let acc_m_in = acc.me.first().map(|m| m.m_in).unwrap_or(m_in_steps); + if acc_m_in != m_in_steps { + return Err(PiCcsError::InvalidInput( + "initial Accumulator.m_in must match steps' m_in".into(), + )); + } + &acc.me + } + None => &[], + }; + + let step_linking = self + .step_linking + .as_ref() + .filter(|cfg| !cfg.prev_next_equalities.is_empty()); + + let outputs = if steps_public.len() > 1 { + match step_linking { + Some(cfg) => match verifier_ctx { + Some(ctx) => shard::fold_shard_verify_with_output_binding_and_step_linking_with_context( + self.mode.clone(), + tr, + &self.params, + s, + steps_public, + seed_me, + run, + self.mixers, + ob_cfg, + cfg, + ctx, + )?, + None => shard::fold_shard_verify_with_output_binding_and_step_linking( + self.mode.clone(), + tr, + &self.params, + s, + steps_public, + seed_me, + run, + self.mixers, + ob_cfg, + cfg, + )?, + }, + None if self.allow_unlinked_steps => match verifier_ctx { + Some(ctx) => shard::fold_shard_verify_with_output_binding_with_context( + self.mode.clone(), + tr, + &self.params, + s, + steps_public, + seed_me, + run, + self.mixers, + ob_cfg, + ctx, + )?, + None => shard::fold_shard_verify_with_output_binding( + self.mode.clone(), + tr, + &self.params, + s, + steps_public, + seed_me, + run, + self.mixers, + ob_cfg, + )?, + }, + None => { + let mut msg = + "multi-step verification with output binding requires step linking" + .to_string(); + if let Some(diag) = &self.auto_step_linking_error { + msg.push_str(&format!(" (auto step-linking from StepSpec failed: {diag})")); + } + return Err(PiCcsError::InvalidInput(msg)); + } + } + } else { + match verifier_ctx { + Some(ctx) => shard::fold_shard_verify_with_output_binding_with_context( + self.mode.clone(), + tr, + &self.params, + s, + steps_public, + seed_me, + run, + self.mixers, + ob_cfg, + ctx, + )?, + None => shard::fold_shard_verify_with_output_binding( + self.mode.clone(), + tr, + &self.params, + s, + steps_public, + seed_me, + run, + self.mixers, + ob_cfg, + )?, + } + }; + + let has_twist_or_shout = !steps_public.is_empty() + && (steps_public.iter().any(|s| !s.lut_insts.is_empty()) + || steps_public.iter().any(|s| !s.mem_insts.is_empty())); + if !has_twist_or_shout && !outputs.obligations.val.is_empty() { + return Err(PiCcsError::ProtocolError( + "CCS-only session verification produced unexpected val-lane obligations".into(), + )); + } + Ok(true) + } + /// Verify with output binding, managing the transcript internally. pub fn verify_with_output_binding_simple( &self, diff --git a/crates/neo-fold/src/shard_proof_types.rs b/crates/neo-fold/src/shard_proof_types.rs index 51534325..b19949d2 100644 --- a/crates/neo-fold/src/shard_proof_types.rs +++ b/crates/neo-fold/src/shard_proof_types.rs @@ -19,7 +19,7 @@ pub type ShoutProofK = neo_memory::shout::ShoutProof; /// Within each group, when a Shout lane is provably inactive for a step (no lookups), we can /// skip its address-domain sumcheck entirely. We still bind all `claimed_sums` to the transcript, /// but we include sumcheck rounds only for the active subset. -#[derive(Clone, Debug)] +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] pub struct ShoutAddrPreProof { /// Claimed sums per Shout lane. /// @@ -32,7 +32,7 @@ pub struct ShoutAddrPreProof { pub groups: Vec>, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] pub struct ShoutAddrPreGroupProof { /// Address-bit width (sumcheck round count) for this group. pub ell_addr: u32, @@ -58,7 +58,7 @@ impl Default for ShoutAddrPreProof { } /// One fold step’s artifacts (Π_CCS → Π_RLC → Π_DEC). -#[derive(Clone, Debug)] +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] pub struct FoldStep { /// Π_CCS outputs (k ME(b,L) instances) pub ccs_out: Vec>, @@ -125,13 +125,17 @@ pub struct ShardFoldWitnesses { pub val_lane_wits: Vec>, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] pub enum MemOrLutProof { Twist(TwistProofK), Shout(ShoutProofK), } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] +#[serde(bound( + serialize = "C: serde::Serialize, FF: serde::Serialize, KK: serde::Serialize", + deserialize = "C: serde::de::DeserializeOwned, FF: serde::de::DeserializeOwned, KK: serde::de::DeserializeOwned + Default" +))] pub struct MemSidecarProof { /// CPU ME claims evaluated at `r_val` (Twist val-eval terminal point). /// @@ -147,20 +151,25 @@ pub struct MemSidecarProof { /// /// This batches CCS (row/time rounds) with Twist/Shout time-domain oracles so all /// protocols share the same transcript-derived `r` (enabling Π_RLC folding). -#[derive(Clone, Debug)] +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] pub struct BatchedTimeProof { /// Claimed sums per participating oracle (in the same order as `round_polys`). pub claimed_sums: Vec, /// Degree bounds per participating oracle. pub degree_bounds: Vec, /// Domain-separation labels per participating oracle. + /// Serialized as `Vec>`; the `&'static` lifetime is restored via `labels_static()`. + #[serde( + serialize_with = "crate::serde_helpers::serialize_static_byte_slices", + deserialize_with = "crate::serde_helpers::deserialize_static_byte_slices" + )] pub labels: Vec<&'static [u8]>, /// Per-claim sum-check messages: `round_polys[claim][round] = coeffs`. pub round_polys: Vec>>, } /// Proof data for a standalone Π_RLC → Π_DEC lane (no Π_CCS). -#[derive(Clone, Debug)] +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] pub struct RlcDecProof { /// RLC mixing matrices ρ_i ∈ S ⊆ F^{D×D} pub rlc_rhos: Vec>, @@ -170,7 +179,7 @@ pub struct RlcDecProof { pub dec_children: Vec>, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] pub struct StepProof { pub fold: FoldStep, pub mem: MemSidecarProof, @@ -179,7 +188,7 @@ pub struct StepProof { pub val_fold: Option, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] pub struct ShardProof { pub steps: Vec, /// Optional output binding proof (proves final memory matches claimed outputs). diff --git a/crates/neo-reductions/Cargo.toml b/crates/neo-reductions/Cargo.toml index 4f2636e7..2c7691a0 100644 --- a/crates/neo-reductions/Cargo.toml +++ b/crates/neo-reductions/Cargo.toml @@ -21,7 +21,7 @@ fs-guard = ["neo-transcript/fs-guard"] # Enable verbose logging for debugging (no-op in this crate; used for cfg gating) neo-logs = [] # Enable constraint diagnostic system -prove-diagnostics = ["serde", "serde_json", "flate2", "hex", "chrono"] +prove-diagnostics = ["serde_json", "flate2", "hex", "chrono"] # Enable CBOR format for diagnostics (smaller but less readable) prove-diagnostics-cbor = ["prove-diagnostics", "serde_cbor"] wasm-threads = ["neo-ajtai/wasm-threads", "neo-ccs/wasm-threads"] @@ -46,8 +46,8 @@ thiserror = { workspace = true } bincode = { workspace = true } blake3 = "1.5" -# Diagnostic system dependencies (feature-gated) -serde = { workspace = true, optional = true } +# Serde for proof serialization (always available; also used by diagnostics) +serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true, optional = true } flate2 = { version = "1.0", optional = true } hex = { version = "0.4", optional = true } diff --git a/crates/neo-reductions/src/engines/optimized_engine/common.rs b/crates/neo-reductions/src/engines/optimized_engine/common.rs index 19c81021..a14ea5bb 100644 --- a/crates/neo-reductions/src/engines/optimized_engine/common.rs +++ b/crates/neo-reductions/src/engines/optimized_engine/common.rs @@ -17,7 +17,7 @@ use p3_field::{Field, PrimeCharacteristicRing}; use rayon::prelude::*; /// Challenges sampled in Step 1 of the protocol -#[derive(Debug, Clone)] +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] pub struct Challenges { /// α ∈ K^{log d} - for Ajtai dimension pub alpha: Vec, diff --git a/crates/neo-reductions/src/engines/optimized_engine/mod.rs b/crates/neo-reductions/src/engines/optimized_engine/mod.rs index 4cf317b3..b16c56f7 100644 --- a/crates/neo-reductions/src/engines/optimized_engine/mod.rs +++ b/crates/neo-reductions/src/engines/optimized_engine/mod.rs @@ -20,7 +20,7 @@ pub mod verify; pub use common::Challenges; /// Proof format variant for Π_CCS. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)] pub enum PiCcsProofVariant { /// Split-NC proof with two sumchecks: FE-only + NC-only. SplitNcV1, @@ -55,7 +55,7 @@ pub use common::{ }; /// Proof structure for the Π_CCS protocol -#[derive(Debug, Clone)] +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] pub struct PiCcsProof { /// Proof format variant. pub variant: PiCcsProofVariant, From c8096d0328589c710539bee299404e01644b10e8 Mon Sep 17 00:00:00 2001 From: Guillermo Valin Date: Fri, 13 Feb 2026 10:43:04 -0300 Subject: [PATCH 02/13] Enforce chunk0 memory integrity --- crates/neo-fold/src/lib.rs | 3 - .../src/memory_sidecar/route_a_time.rs | 4 +- crates/neo-fold/src/riscv_shard.rs | 13 +++- crates/neo-fold/src/serde_helpers.rs | 61 ------------------ crates/neo-fold/src/shard.rs | 8 ++- crates/neo-fold/src/shard_proof_types.rs | 7 +- .../riscv_build_verifier_statement_memory.rs | 64 +++++++++++++++++++ .../neo-fold/tests/twist_shout_soundness.rs | 2 +- 8 files changed, 87 insertions(+), 75 deletions(-) delete mode 100644 crates/neo-fold/src/serde_helpers.rs create mode 100644 crates/neo-fold/tests/riscv_build_verifier_statement_memory.rs diff --git a/crates/neo-fold/src/lib.rs b/crates/neo-fold/src/lib.rs index aee318a1..dae0b088 100644 --- a/crates/neo-fold/src/lib.rs +++ b/crates/neo-fold/src/lib.rs @@ -33,6 +33,3 @@ pub mod riscv_shard; pub mod output_binding; mod shard_proof_types; - -/// Serde helpers for proof serialization (e.g. `&'static [u8]` labels). -pub mod serde_helpers; diff --git a/crates/neo-fold/src/memory_sidecar/route_a_time.rs b/crates/neo-fold/src/memory_sidecar/route_a_time.rs index ee39a129..c0e49ff5 100644 --- a/crates/neo-fold/src/memory_sidecar/route_a_time.rs +++ b/crates/neo-fold/src/memory_sidecar/route_a_time.rs @@ -137,7 +137,7 @@ pub fn prove_route_a_batched_time( let proof = BatchedTimeProof { claimed_sums: claimed_sums.clone(), degree_bounds: degree_bounds.clone(), - labels: labels.clone(), + labels: labels.iter().map(|label| label.to_vec()).collect(), round_polys: per_claim_results .iter() .map(|r| r.round_polys.clone()) @@ -223,7 +223,7 @@ pub fn verify_route_a_batched_time( ))); } for (i, (got, exp)) in proof.labels.iter().zip(expected_labels.iter()).enumerate() { - if (*got as &[u8]) != *exp { + if got.as_slice() != *exp { return Err(PiCcsError::ProtocolError(format!( "step {}: batched_time label mismatch at claim {}", step_idx, i diff --git a/crates/neo-fold/src/riscv_shard.rs b/crates/neo-fold/src/riscv_shard.rs index 47c4160b..40e0029d 100644 --- a/crates/neo-fold/src/riscv_shard.rs +++ b/crates/neo-fold/src/riscv_shard.rs @@ -509,6 +509,8 @@ impl Rv32B1 { session, ccs, _layout: layout, + mem_layouts, + statement_initial_mem: initial_mem, ram_num_bits: d_ram, output_claims: self.output_claims, }) @@ -718,6 +720,8 @@ pub struct Rv32B1Verifier { session: FoldingSession, ccs: CcsStructure, _layout: Rv32B1Layout, + mem_layouts: HashMap, + statement_initial_mem: HashMap<(u32, u64), F>, ram_num_bits: usize, output_claims: ProgramIO, } @@ -734,8 +738,15 @@ impl Rv32B1Verifier { proof: &ShardProof, steps_public: &[StepInstanceBundle], ) -> Result { + rv32_b1_enforce_chunk0_mem_init_matches_statement( + &self.mem_layouts, + &self.statement_initial_mem, + steps_public, + )?; + if self.output_claims.is_empty() { - self.session.verify_with_external_steps(&self.ccs, steps_public, proof) + self.session + .verify_with_external_steps(&self.ccs, steps_public, proof) } else { let ob_cfg = OutputBindingConfig::new(self.ram_num_bits, self.output_claims.clone()); self.session diff --git a/crates/neo-fold/src/serde_helpers.rs b/crates/neo-fold/src/serde_helpers.rs deleted file mode 100644 index 4ad22f16..00000000 --- a/crates/neo-fold/src/serde_helpers.rs +++ /dev/null @@ -1,61 +0,0 @@ -//! Serde helpers for proof types that contain non-trivially-serializable fields. -//! -//! In particular, `BatchedTimeProof` stores `Vec<&'static [u8]>` labels which -//! require custom serialization. We serialize them as `Vec>` and leak -//! the deserialized allocations so the `&'static` lifetime is valid. - -use serde::de::{Deserializer, SeqAccess, Visitor}; -use serde::ser::{SerializeSeq, Serializer}; -use std::fmt; - -/// Serialize `Vec<&'static [u8]>` as `Vec>`. -pub fn serialize_static_byte_slices( - labels: &[&'static [u8]], - serializer: S, -) -> Result -where - S: Serializer, -{ - let mut seq = serializer.serialize_seq(Some(labels.len()))?; - for label in labels { - seq.serialize_element(&label.to_vec())?; - } - seq.end() -} - -/// Deserialize `Vec>` into `Vec<&'static [u8]>` by leaking allocations. -/// -/// This is safe for proof deserialization where the proof lives for the duration -/// of the verification process. The leaked memory is small (label strings are -/// typically short domain-separation tags). -pub fn deserialize_static_byte_slices<'de, D>( - deserializer: D, -) -> Result, D::Error> -where - D: Deserializer<'de>, -{ - struct StaticByteSlicesVisitor; - - impl<'de> Visitor<'de> for StaticByteSlicesVisitor { - type Value = Vec<&'static [u8]>; - - fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - formatter.write_str("a sequence of byte arrays") - } - - fn visit_seq(self, mut seq: A) -> Result - where - A: SeqAccess<'de>, - { - let mut result = Vec::with_capacity(seq.size_hint().unwrap_or(0)); - while let Some(bytes) = seq.next_element::>()? { - // Leak the allocation to get a &'static [u8]. - let leaked: &'static [u8] = Box::leak(bytes.into_boxed_slice()); - result.push(leaked); - } - Ok(result) - } - } - - deserializer.deserialize_seq(StaticByteSlicesVisitor) -} diff --git a/crates/neo-fold/src/shard.rs b/crates/neo-fold/src/shard.rs index 035bd1bb..b1100389 100644 --- a/crates/neo-fold/src/shard.rs +++ b/crates/neo-fold/src/shard.rs @@ -3527,7 +3527,13 @@ where .len() .checked_sub(1) .ok_or_else(|| PiCcsError::ProtocolError("missing inc_total claim".into()))?; - if step_proof.batched_time.labels.get(inc_idx).copied() != Some(crate::output_binding::OB_INC_TOTAL_LABEL) { + if step_proof + .batched_time + .labels + .get(inc_idx) + .map(|label| label.as_slice()) + != Some(crate::output_binding::OB_INC_TOTAL_LABEL) + { return Err(PiCcsError::ProtocolError("output binding claim not last".into())); } diff --git a/crates/neo-fold/src/shard_proof_types.rs b/crates/neo-fold/src/shard_proof_types.rs index b19949d2..c11a9490 100644 --- a/crates/neo-fold/src/shard_proof_types.rs +++ b/crates/neo-fold/src/shard_proof_types.rs @@ -158,12 +158,7 @@ pub struct BatchedTimeProof { /// Degree bounds per participating oracle. pub degree_bounds: Vec, /// Domain-separation labels per participating oracle. - /// Serialized as `Vec>`; the `&'static` lifetime is restored via `labels_static()`. - #[serde( - serialize_with = "crate::serde_helpers::serialize_static_byte_slices", - deserialize_with = "crate::serde_helpers::deserialize_static_byte_slices" - )] - pub labels: Vec<&'static [u8]>, + pub labels: Vec>, /// Per-claim sum-check messages: `round_polys[claim][round] = coeffs`. pub round_polys: Vec>>, } diff --git a/crates/neo-fold/tests/riscv_build_verifier_statement_memory.rs b/crates/neo-fold/tests/riscv_build_verifier_statement_memory.rs new file mode 100644 index 00000000..fa5189c4 --- /dev/null +++ b/crates/neo-fold/tests/riscv_build_verifier_statement_memory.rs @@ -0,0 +1,64 @@ +use neo_fold::riscv_shard::Rv32B1; +use neo_fold::PiCcsError; +use neo_memory::riscv::lookups::{encode_program, RiscvInstruction, RiscvOpcode}; + +fn program_bytes_with_seed(seed: i32) -> Vec { + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: seed, + }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 2, + rs1: 1, + imm: 1, + }, + RiscvInstruction::Halt, + ]; + encode_program(&program) +} + +#[test] +fn rv32_b1_build_verifier_binds_statement_memory_to_chunk0_mem_init() { + let program_a = program_bytes_with_seed(7); + let program_b = program_bytes_with_seed(9); + + let mut run_a = Rv32B1::from_rom(/*program_base=*/ 0, &program_a) + .chunk_size(1) + .prove() + .expect("prove statement A"); + run_a.verify().expect("self-verify statement A"); + + let proof_a = run_a.proof().clone(); + let steps_public_a = run_a.steps_public(); + + let verifier_a = Rv32B1::from_rom(/*program_base=*/ 0, &program_a) + .chunk_size(1) + .build_verifier() + .expect("build verifier for statement A"); + let ok = verifier_a + .verify(&proof_a, &steps_public_a) + .expect("matching statement should verify"); + assert!(ok, "matching statement should verify"); + + let verifier_b = Rv32B1::from_rom(/*program_base=*/ 0, &program_b) + .chunk_size(1) + .build_verifier() + .expect("build verifier for statement B"); + let err = verifier_b + .verify(&proof_a, &steps_public_a) + .expect_err("different statement memory must be rejected"); + + match err { + PiCcsError::InvalidInput(msg) => { + assert!( + msg.contains("chunk0 MemInstance.init mismatch"), + "unexpected invalid-input error: {msg}" + ); + } + other => panic!("unexpected error kind: {other:?}"), + } +} diff --git a/crates/neo-fold/tests/twist_shout_soundness.rs b/crates/neo-fold/tests/twist_shout_soundness.rs index b7f35988..58e2f8b8 100644 --- a/crates/neo-fold/tests/twist_shout_soundness.rs +++ b/crates/neo-fold/tests/twist_shout_soundness.rs @@ -338,7 +338,7 @@ fn tamper_batched_time_label_fails() { ) .expect("prove should succeed"); - proof.steps[0].batched_time.labels[0] = b"bad/label".as_slice(); + proof.steps[0].batched_time.labels[0] = b"bad/label".to_vec(); let mut tr_verify = Poseidon2Transcript::new(b"soundness/tamper-batched-time-label"); let steps_public = [StepInstanceBundle::from(&step_bundle)]; From 1aa6fe14dfe12eefe99f58b39f11bb8be76bb333 Mon Sep 17 00:00:00 2001 From: Guillermo Valin Date: Fri, 13 Feb 2026 11:42:06 -0300 Subject: [PATCH 03/13] Validate Commitment invariants --- crates/neo-ajtai/src/types.rs | 48 ++++++++++++++++++- .../commitment_deserialize_invariants.rs | 41 ++++++++++++++++ 2 files changed, 87 insertions(+), 2 deletions(-) create mode 100644 crates/neo-ajtai/tests/commitment_deserialize_invariants.rs diff --git a/crates/neo-ajtai/src/types.rs b/crates/neo-ajtai/src/types.rs index 8fbaa37c..07527f3e 100644 --- a/crates/neo-ajtai/src/types.rs +++ b/crates/neo-ajtai/src/types.rs @@ -1,6 +1,7 @@ use p3_field::PrimeCharacteristicRing; use p3_goldilocks::Goldilocks as Fq; -use serde::{Deserialize, Serialize}; +use serde::de::Error as _; +use serde::{Deserialize, Deserializer, Serialize}; /// Public parameters for Ajtai: M ∈ R_q^{κ×m}, stored row-major. #[derive(Clone, Debug, Serialize, Deserialize)] @@ -13,7 +14,7 @@ pub struct PP { } /// Commitment c ∈ F_q^{d×κ}, stored as column-major flat matrix (κ columns, each length d). -#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize)] pub struct Commitment { pub d: usize, pub kappa: usize, @@ -22,6 +23,26 @@ pub struct Commitment { } impl Commitment { + #[inline] + fn validate_shape(d: usize, kappa: usize, data_len: usize) -> Result<(), String> { + let expected_d = neo_math::ring::D; + if d != expected_d { + return Err(format!("invalid Commitment.d: expected {expected_d}, got {d}")); + } + + let expected_len = d + .checked_mul(kappa) + .ok_or_else(|| format!("invalid Commitment shape: d*kappa overflow (d={d}, kappa={kappa})"))?; + if data_len != expected_len { + return Err(format!( + "invalid Commitment shape: data.len()={} but d*kappa={expected_len}", + data_len + )); + } + + Ok(()) + } + pub fn zeros(d: usize, kappa: usize) -> Self { Self { d, @@ -48,3 +69,26 @@ impl Commitment { } } } + +impl<'de> Deserialize<'de> for Commitment { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + #[derive(Deserialize)] + struct CommitmentWire { + d: usize, + kappa: usize, + data: Vec, + } + + let wire = CommitmentWire::deserialize(deserializer)?; + Commitment::validate_shape(wire.d, wire.kappa, wire.data.len()).map_err(D::Error::custom)?; + + Ok(Self { + d: wire.d, + kappa: wire.kappa, + data: wire.data, + }) + } +} diff --git a/crates/neo-ajtai/tests/commitment_deserialize_invariants.rs b/crates/neo-ajtai/tests/commitment_deserialize_invariants.rs new file mode 100644 index 00000000..9b743ec2 --- /dev/null +++ b/crates/neo-ajtai/tests/commitment_deserialize_invariants.rs @@ -0,0 +1,41 @@ +use neo_ajtai::Commitment; +use neo_math::ring::D; +use serde_json::Value; + +fn valid_commitment_json(kappa: usize) -> Value { + let c = Commitment::zeros(D, kappa); + serde_json::to_value(&c).expect("serialize valid commitment") +} + +#[test] +fn commitment_deserialize_rejects_wrong_d() { + let mut value = valid_commitment_json(2); + value["d"] = serde_json::json!(D - 1); + + let err = serde_json::from_value::(value).expect_err("wrong d must be rejected"); + let msg = err.to_string(); + assert!(msg.contains("invalid Commitment.d"), "unexpected error message: {msg}"); +} + +#[test] +fn commitment_deserialize_rejects_data_len_mismatch() { + let mut value = valid_commitment_json(2); + let data = value + .get_mut("data") + .and_then(Value::as_array_mut) + .expect("commitment.data array"); + data.pop().expect("non-empty data"); + + let err = serde_json::from_value::(value).expect_err("shape mismatch must be rejected"); + let msg = err.to_string(); + assert!(msg.contains("data.len()"), "unexpected error message: {msg}"); +} + +#[test] +fn commitment_deserialize_accepts_valid_shape() { + let value = valid_commitment_json(3); + let c: Commitment = serde_json::from_value(value).expect("valid shape should deserialize"); + assert_eq!(c.d, D); + assert_eq!(c.kappa, 3); + assert_eq!(c.data.len(), D * 3); +} From a6b22554a69106a1119f3d326a42bd9c234c741f Mon Sep 17 00:00:00 2001 From: Guillermo Valin Date: Fri, 13 Feb 2026 11:59:27 -0300 Subject: [PATCH 04/13] Enforce chunk0 memory match --- crates/neo-fold/src/session.rs | 22 +++++- .../riscv_build_verifier_statement_memory.rs | 73 +++++++++++++++++++ 2 files changed, 93 insertions(+), 2 deletions(-) diff --git a/crates/neo-fold/src/session.rs b/crates/neo-fold/src/session.rs index ae723da0..0dece63c 100644 --- a/crates/neo-fold/src/session.rs +++ b/crates/neo-fold/src/session.rs @@ -1723,7 +1723,16 @@ where .filter(|cache| cache.src_ptr == src_ptr); let verifier_ctx = verifier_cache.map(|cache| &cache.ctx); - let m_in_steps = steps_public.first().map(|inst| inst.mcs_inst.m_in).unwrap_or(0); + let m_in_steps = steps_public + .first() + .map(|inst| inst.mcs_inst.m_in) + .unwrap_or(0); + if !steps_public + .iter() + .all(|inst| inst.mcs_inst.m_in == m_in_steps) + { + return Err(PiCcsError::InvalidInput("all steps must share the same m_in".into())); + } let s_prepared = self.prepared_ccs_for_accumulator(s)?; let seed_me: &[MeInstance] = match &self.acc0 { @@ -1858,7 +1867,16 @@ where .filter(|cache| cache.src_ptr == src_ptr); let verifier_ctx = verifier_cache.map(|cache| &cache.ctx); - let m_in_steps = steps_public.first().map(|inst| inst.mcs_inst.m_in).unwrap_or(0); + let m_in_steps = steps_public + .first() + .map(|inst| inst.mcs_inst.m_in) + .unwrap_or(0); + if !steps_public + .iter() + .all(|inst| inst.mcs_inst.m_in == m_in_steps) + { + return Err(PiCcsError::InvalidInput("all steps must share the same m_in".into())); + } let s_prepared = self.prepared_ccs_for_accumulator(s)?; let seed_me: &[MeInstance] = match &self.acc0 { diff --git a/crates/neo-fold/tests/riscv_build_verifier_statement_memory.rs b/crates/neo-fold/tests/riscv_build_verifier_statement_memory.rs index fa5189c4..78605df7 100644 --- a/crates/neo-fold/tests/riscv_build_verifier_statement_memory.rs +++ b/crates/neo-fold/tests/riscv_build_verifier_statement_memory.rs @@ -1,6 +1,8 @@ use neo_fold::riscv_shard::Rv32B1; use neo_fold::PiCcsError; +use neo_math::F; use neo_memory::riscv::lookups::{encode_program, RiscvInstruction, RiscvOpcode}; +use p3_field::PrimeCharacteristicRing; fn program_bytes_with_seed(seed: i32) -> Vec { let program = vec![ @@ -21,6 +23,18 @@ fn program_bytes_with_seed(seed: i32) -> Vec { encode_program(&program) } +fn assert_m_in_mismatch_rejected(err: PiCcsError) { + match err { + PiCcsError::InvalidInput(msg) => { + assert!( + msg.contains("all steps must share the same m_in"), + "unexpected invalid-input error: {msg}" + ); + } + other => panic!("unexpected error kind: {other:?}"), + } +} + #[test] fn rv32_b1_build_verifier_binds_statement_memory_to_chunk0_mem_init() { let program_a = program_bytes_with_seed(7); @@ -62,3 +76,62 @@ fn rv32_b1_build_verifier_binds_statement_memory_to_chunk0_mem_init() { other => panic!("unexpected error kind: {other:?}"), } } + +#[test] +fn rv32_b1_build_verifier_rejects_external_steps_with_non_uniform_m_in() { + let program = program_bytes_with_seed(11); + + let mut run = Rv32B1::from_rom(/*program_base=*/ 0, &program) + .chunk_size(1) + .prove() + .expect("prove statement"); + run.verify().expect("self-verify statement"); + + let proof = run.proof().clone(); + let mut steps_public = run.steps_public(); + assert!( + steps_public.len() > 1, + "test needs at least two steps to create m_in mismatch" + ); + steps_public[1].mcs_inst.m_in += 1; + + let verifier = Rv32B1::from_rom(/*program_base=*/ 0, &program) + .chunk_size(1) + .build_verifier() + .expect("build verifier"); + let err = verifier + .verify(&proof, &steps_public) + .expect_err("non-uniform m_in must be rejected"); + assert_m_in_mismatch_rejected(err); +} + +#[test] +fn rv32_b1_build_verifier_output_binding_rejects_external_steps_with_non_uniform_m_in() { + let program = program_bytes_with_seed(13); + + let mut run = Rv32B1::from_rom(/*program_base=*/ 0, &program) + .chunk_size(1) + .output(/*output_addr=*/ 0, /*expected_output=*/ F::ZERO) + .prove() + .expect("prove statement with output binding"); + run.verify() + .expect("self-verify statement with output binding"); + + let proof = run.proof().clone(); + let mut steps_public = run.steps_public(); + assert!( + steps_public.len() > 1, + "test needs at least two steps to create m_in mismatch" + ); + steps_public[1].mcs_inst.m_in += 1; + + let verifier = Rv32B1::from_rom(/*program_base=*/ 0, &program) + .chunk_size(1) + .output(/*output_addr=*/ 0, /*expected_output=*/ F::ZERO) + .build_verifier() + .expect("build verifier with output binding"); + let err = verifier + .verify(&proof, &steps_public) + .expect_err("non-uniform m_in must be rejected"); + assert_m_in_mismatch_rejected(err); +} From be7944c8ffb969a0649012cf5835696790238f3e Mon Sep 17 00:00:00 2001 From: Guillermo Valin Date: Tue, 17 Feb 2026 09:49:30 -0300 Subject: [PATCH 05/13] Add Poseidon2-Goldilocks hash support via ECALL precompile - Introduced Poseidon2 compute and read ECALLs for hashing Goldilocks field elements. - Implemented guest-side API for computing Poseidon2 hashes, allowing for efficient input handling and digest retrieval. - Enhanced RISC-V CPU to manage Poseidon2 ECALLs, including internal state for pending digest words. - Updated CCS layout and constraints to accommodate new ECALLs and ensure proper integration with existing functionality. --- crates/neo-fold/src/riscv_shard.rs | 205 +++++++++++++++++- .../riscv_poseidon2_ecall_prove_verify.rs | 93 ++++++++ crates/neo-memory/src/riscv/ccs.rs | 120 +++++++++- crates/neo-memory/src/riscv/ccs/layout.rs | 41 ++++ crates/neo-memory/src/riscv/ccs/witness.rs | 40 +++- crates/neo-memory/src/riscv/lookups/cpu.rs | 80 ++++++- crates/neo-memory/src/riscv/lookups/mod.rs | 16 ++ crates/neo-memory/tests/riscv_ccs_tests.rs | 94 +++++++- crates/neo-vm-trace/src/lib.rs | 14 ++ crates/nightstream-sdk/src/goldilocks.rs | 116 ++++++++++ crates/nightstream-sdk/src/lib.rs | 2 + crates/nightstream-sdk/src/poseidon2.rs | 111 ++++++++++ 12 files changed, 915 insertions(+), 17 deletions(-) create mode 100644 crates/neo-fold/tests/riscv_poseidon2_ecall_prove_verify.rs create mode 100644 crates/nightstream-sdk/src/goldilocks.rs create mode 100644 crates/nightstream-sdk/src/poseidon2.rs diff --git a/crates/neo-fold/src/riscv_shard.rs b/crates/neo-fold/src/riscv_shard.rs index 40e0029d..a5d1c932 100644 --- a/crates/neo-fold/src/riscv_shard.rs +++ b/crates/neo-fold/src/riscv_shard.rs @@ -7,11 +7,13 @@ use std::collections::HashMap; use std::collections::HashSet; +use std::sync::Arc; use std::time::Duration; use crate::output_binding::{simple_output_config, OutputBindingConfig}; use crate::pi_ccs::FoldingMode; use crate::session::FoldingSession; +use neo_reductions::engines::optimized_engine::oracle::SparseCache; use crate::shard::{ fold_shard_verify_with_output_binding_and_step_linking, fold_shard_verify_with_step_linking, CommitMixers, ShardFoldOutputs, ShardProof, StepLinkingConfig, @@ -73,6 +75,21 @@ fn elapsed_duration(start: TimePoint) -> Duration { } } +/// Per-phase timing breakdown for `Rv32B1::prove()`. +/// +/// Captures wall-clock time for each major phase so callers can identify bottlenecks. +#[derive(Clone, Debug, Default)] +pub struct ProveTimings { + /// Program decode, memory layout setup, and Shout table inference. + pub decode_and_setup: Duration, + /// CCS construction (base + shared-bus wiring) and session/committer creation. + pub ccs_and_shared_bus: Duration, + /// RISC-V VM execution and shard/witness collection. + pub vm_execution: Duration, + /// Fold-and-prove (sumcheck + Ajtai commitment). + pub fold_and_prove: Duration, +} + pub fn rv32_b1_step_linking_config(layout: &Rv32B1Layout) -> StepLinkingConfig { StepLinkingConfig::new(rv32_b1_step_linking_pairs(layout)) } @@ -267,6 +284,30 @@ fn all_shout_opcodes() -> HashSet { /// High-level “few lines” builder for proving/verifying an RV32 program using the B1 shared-bus step circuit. /// +/// Pre-computed SparseCache + matrix digest for RV32 B1 circuit preprocessing. +/// +/// The `SparseCache::build()` and matrix-digest computation are the most expensive parts +/// of both proving and verification. This struct captures those artefacts so they can be +/// computed once and reused across many `Rv32B1::prove()` / `build_verifier()` calls with +/// the same ROM, `ram_bytes`, and `chunk_size`. +/// +/// Build with [`Rv32B1::build_ccs_cache`], then inject into subsequent builders via +/// [`Rv32B1::with_ccs_cache`]. +/// +/// **Important**: the CCS structure itself is still built from scratch each time (it is fast). +/// Only the SparseCache (CSC decomposition + matrix digest) is cached. +pub struct Rv32B1CcsCache { + pub sparse: Arc>, +} + +impl std::fmt::Debug for Rv32B1CcsCache { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Rv32B1CcsCache") + .field("sparse_len", &self.sparse.len()) + .finish() + } +} + /// This: /// - chooses parameters + Ajtai committer automatically, /// - infers the minimal Shout table set from the program (unless overridden), @@ -285,6 +326,7 @@ pub struct Rv32B1 { shout_ops: Option>, output_claims: ProgramIO, ram_init: HashMap, + ccs_cache: Option>, } /// Default instruction cap for RV32B1 runs when `max_steps` is not specified. @@ -308,6 +350,7 @@ impl Rv32B1 { shout_ops: None, output_claims: ProgramIO::new(), ram_init: HashMap::new(), + ccs_cache: None, } } @@ -373,6 +416,114 @@ impl Rv32B1 { self } + /// Attach a pre-built CCS cache to skip CCS synthesis in [`prove`] and [`build_verifier`]. + /// + /// The cache **must** have been built from a builder with the same `program_bytes`, + /// `ram_bytes`, `chunk_size`, and Shout configuration. No runtime validation is performed; + /// mismatched caches produce undefined behaviour. + pub fn with_ccs_cache(mut self, cache: Arc) -> Self { + self.ccs_cache = Some(cache); + self + } + + /// Build the CCS preprocessing cache from the current builder configuration. + /// + /// This performs program decoding, memory-layout setup, Shout-table inference, + /// CCS construction, and `SparseCache` synthesis -- exactly the same work that + /// [`prove`] and [`build_verifier`] would do on first call -- and packages the + /// result so it can be shared across many runs. + /// + /// ```ignore + /// let cache = Rv32B1::from_rom(0, &rom).ram_bytes(0x40000).chunk_size(1024) + /// .shout_auto_minimal().build_ccs_cache()?; + /// let cache = Arc::new(cache); + /// + /// // Subsequent prove/verify calls skip CCS synthesis: + /// let run = Rv32B1::from_rom(0, &rom).ram_bytes(0x40000).chunk_size(1024) + /// .shout_auto_minimal().with_ccs_cache(cache.clone()) + /// .ram_init_u32(0x104, 42).prove()?; + /// ``` + pub fn build_ccs_cache(&self) -> Result { + let program = decode_program(&self.program_bytes) + .map_err(|e| PiCcsError::InvalidInput(format!("decode_program failed: {e}")))?; + + let (prog_layout, initial_mem) = neo_memory::riscv::rom_init::prog_rom_layout_and_init_words::( + PROG_ID, + /*base_addr=*/ 0, + &self.program_bytes, + ) + .map_err(|e| PiCcsError::InvalidInput(format!("prog_rom_layout_and_init_words failed: {e}")))?; + + let (k_ram, d_ram) = pow2_ceil_k(self.ram_bytes.max(4)); + let _ = d_ram; + let mem_layouts = HashMap::from([ + ( + neo_memory::riscv::lookups::RAM_ID.0, + PlainMemLayout { + k: k_ram, + d: pow2_ceil_k(self.ram_bytes.max(4)).1, + n_side: 2, + lanes: 1, + }, + ), + (PROG_ID.0, prog_layout), + ]); + + let shout = RiscvShoutTables::new(self.xlen); + let mut shout_ops = match &self.shout_ops { + Some(ops) => ops.clone(), + None if self.shout_auto_minimal => infer_required_shout_opcodes(&program), + None => all_shout_opcodes(), + }; + shout_ops.insert(RiscvOpcode::Add); + + let mut table_specs: HashMap = HashMap::new(); + for op in &shout_ops { + let table_id = shout.opcode_to_id(*op).0; + table_specs.insert( + table_id, + LutTableSpec::RiscvOpcode { + opcode: *op, + xlen: self.xlen, + }, + ); + } + let mut shout_table_ids: Vec = table_specs.keys().copied().collect(); + shout_table_ids.sort_unstable(); + + let (ccs_base, layout) = build_rv32_b1_step_ccs(&mem_layouts, &shout_table_ids, self.chunk_size) + .map_err(|e| PiCcsError::InvalidInput(format!("build_rv32_b1_step_ccs failed: {e}")))?; + + // Apply shared-bus wiring (adds Twist/Shout constraint matrices to the CCS). + // This uses ROM-only initial_mem since the CCS structure doesn't depend on + // witness-specific ram_init values. + let session = FoldingSession::::new_ajtai(self.mode.clone(), &ccs_base)?; + let params = session.params().clone(); + let committer = session.committer().clone(); + let empty_tables: HashMap> = HashMap::new(); + + let mut cpu = R1csCpu::new( + ccs_base, + params, + committer, + layout.m_in, + &empty_tables, + &table_specs, + rv32_b1_chunk_to_witness(layout.clone()), + ); + cpu = cpu + .with_shared_cpu_bus( + rv32_b1_shared_cpu_bus_config(&layout, &shout_table_ids, mem_layouts, initial_mem) + .map_err(|e| PiCcsError::InvalidInput(format!("rv32_b1_shared_cpu_bus_config failed: {e}")))?, + self.chunk_size, + ) + .map_err(|e| PiCcsError::InvalidInput(format!("shared bus inject failed: {e}")))?; + + let sparse = Arc::new(SparseCache::build(&cpu.ccs)); + + Ok(Rv32B1CcsCache { sparse }) + } + /// Build only the verification context (CCS + session) without executing the program or proving. /// /// This performs the same validation, program decoding, memory layout setup, CCS construction @@ -505,6 +656,8 @@ impl Rv32B1 { let ccs = cpu.ccs.clone(); // No execution, no fold_and_prove -- just the verification context. + // NOTE: verifier SparseCache preloading is deferred to `preload_sparse_cache()` + // because the CCS pointer changes when moved into the Rv32B1Verifier struct. Ok(Rv32B1Verifier { session, ccs, @@ -551,6 +704,9 @@ impl Rv32B1 { } } + // === Phase 1: Decode + setup === + let phase_start = time_now(); + let program = decode_program(&self.program_bytes) .map_err(|e| PiCcsError::InvalidInput(format!("decode_program failed: {e}")))?; let using_default_max_steps = self.max_steps.is_none(); @@ -621,6 +777,13 @@ impl Rv32B1 { let mut shout_table_ids: Vec = table_specs.keys().copied().collect(); shout_table_ids.sort_unstable(); + let t_decode_and_setup = elapsed_duration(phase_start); + + // === Phase 2: CCS construction + shared-bus wiring === + let phase_start = time_now(); + + let prebuilt_sparse = self.ccs_cache.as_ref().map(|c| c.sparse.clone()); + let (ccs_base, layout) = build_rv32_b1_step_ccs(&mem_layouts, &shout_table_ids, self.chunk_size) .map_err(|e| PiCcsError::InvalidInput(format!("build_rv32_b1_step_ccs failed: {e}")))?; @@ -656,6 +819,11 @@ impl Rv32B1 { // Always enforce step-to-step chunk chaining for RV32 B1. session.set_step_linking(rv32_b1_step_linking_config(&layout)); + let t_ccs_and_shared_bus = elapsed_duration(phase_start); + + // === Phase 3: VM execution + shard collection === + let phase_start = time_now(); + // Execute + collect step bundles (and aux for output binding). session.execute_shard_shared_cpu_bus( vm, @@ -687,15 +855,25 @@ impl Rv32B1 { let ccs = cpu.ccs.clone(); - // Prove phase (timed) - let prove_start = time_now(); + // Preload the SparseCache using the *final* CCS reference that fold_and_prove will use. + // (Must happen after cpu.ccs.clone() because the pointer-keyed cache checks identity.) + if let Some(ref sparse) = prebuilt_sparse { + session.preload_ccs_sparse_cache(&ccs, sparse.clone())?; + } + + let t_vm_execution = elapsed_duration(phase_start); + + // === Phase 4: Fold-and-prove === + let phase_start = time_now(); let proof = if self.output_claims.is_empty() { session.fold_and_prove(&ccs)? } else { let ob_cfg = OutputBindingConfig::new(d_ram, self.output_claims.clone()); session.fold_and_prove_with_output_binding_auto_simple(&ccs, &ob_cfg)? }; - let prove_duration = elapsed_duration(prove_start); + let t_fold_and_prove = elapsed_duration(phase_start); + + let prove_duration = t_decode_and_setup + t_ccs_and_shared_bus + t_vm_execution + t_fold_and_prove; Ok(Rv32B1Run { session, @@ -707,6 +885,12 @@ impl Rv32B1 { ram_num_bits: d_ram, output_claims: self.output_claims, prove_duration, + prove_timings: ProveTimings { + decode_and_setup: t_decode_and_setup, + ccs_and_shared_bus: t_ccs_and_shared_bus, + vm_execution: t_vm_execution, + fold_and_prove: t_fold_and_prove, + }, verify_duration: None, }) } @@ -727,6 +911,15 @@ pub struct Rv32B1Verifier { } impl Rv32B1Verifier { + /// Preload the verifier SparseCache using `&self.ccs` (final pointer). + /// + /// Call this once after `build_verifier()` returns, before the first `verify()`. + /// This ensures the pointer-keyed cache hits inside the session's verify path. + pub fn preload_sparse_cache(&mut self, sparse: Arc>) -> Result<(), PiCcsError> { + self.session + .preload_verifier_ccs_sparse_cache(&self.ccs, sparse) + } + /// Verify a `ShardProof` using the provided public step instance bundles. /// /// `steps_public` must be the step instance bundles produced by the prover (via @@ -765,6 +958,7 @@ pub struct Rv32B1Run { ram_num_bits: usize, output_claims: ProgramIO, prove_duration: Duration, + prove_timings: ProveTimings, verify_duration: Option, } @@ -930,6 +1124,11 @@ impl Rv32B1Run { self.prove_duration } + /// Per-phase timing breakdown for the prove call. + pub fn prove_timings(&self) -> &ProveTimings { + &self.prove_timings + } + pub fn verify_duration(&self) -> Option { self.verify_duration } diff --git a/crates/neo-fold/tests/riscv_poseidon2_ecall_prove_verify.rs b/crates/neo-fold/tests/riscv_poseidon2_ecall_prove_verify.rs new file mode 100644 index 00000000..4256241a --- /dev/null +++ b/crates/neo-fold/tests/riscv_poseidon2_ecall_prove_verify.rs @@ -0,0 +1,93 @@ +use neo_fold::riscv_shard::Rv32B1; +use neo_memory::riscv::lookups::{ + encode_program, RiscvInstruction, RiscvOpcode, POSEIDON2_ECALL_NUM, POSEIDON2_READ_ECALL_NUM, +}; + +fn load_u32_imm(rd: u8, value: u32) -> Vec { + let upper = ((value as i64 + 0x800) >> 12) as i32; + let lower = (value as i32) - (upper << 12); + vec![ + RiscvInstruction::Lui { rd, imm: upper }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd, + rs1: rd, + imm: lower, + }, + ] +} + +fn poseidon2_ecall_program() -> Vec { + let mut program = Vec::new(); + + // a1 = 0 (n_elements = 0 for empty-input Poseidon2 hash). + program.push(RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 11, + rs1: 0, + imm: 0, + }); + + // a0 = POSEIDON2_ECALL_NUM -> compute ECALL. + program.extend(load_u32_imm(10, POSEIDON2_ECALL_NUM)); + program.push(RiscvInstruction::Halt); + + // Read all 8 digest words via read ECALLs. + for _ in 0..8 { + program.extend(load_u32_imm(10, POSEIDON2_READ_ECALL_NUM)); + program.push(RiscvInstruction::Halt); + } + + // Clear a0 -> final halt. + program.push(RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 10, + rs1: 0, + imm: 0, + }); + program.push(RiscvInstruction::Halt); + + program +} + +#[test] +fn rv32_b1_prove_verify_poseidon2_ecall_chunk1() { + let program = poseidon2_ecall_program(); + let program_bytes = encode_program(&program); + + let mut run = Rv32B1::from_rom(0, &program_bytes) + .chunk_size(1) + .max_steps(program.len() + 64) + .prove() + .expect("prove should succeed"); + + run.verify().expect("verify should succeed for Poseidon2 ECALL with chunk_size=1"); +} + +#[test] +fn rv32_b1_prove_verify_poseidon2_ecall_chunk4() { + let program = poseidon2_ecall_program(); + let program_bytes = encode_program(&program); + + let mut run = Rv32B1::from_rom(0, &program_bytes) + .chunk_size(4) + .max_steps(program.len() + 64) + .prove() + .expect("prove should succeed"); + + run.verify().expect("verify should succeed for Poseidon2 ECALL with chunk_size=4"); +} + +#[test] +fn rv32_b1_prove_verify_poseidon2_ecall_chunk32() { + let program = poseidon2_ecall_program(); + let program_bytes = encode_program(&program); + + let mut run = Rv32B1::from_rom(0, &program_bytes) + .chunk_size(32) + .max_steps(program.len() + 64) + .prove() + .expect("prove should succeed"); + + run.verify().expect("verify should succeed for Poseidon2 ECALL with chunk_size=32"); +} diff --git a/crates/neo-memory/src/riscv/ccs.rs b/crates/neo-memory/src/riscv/ccs.rs index 81e8a52e..2b8d860a 100644 --- a/crates/neo-memory/src/riscv/ccs.rs +++ b/crates/neo-memory/src/riscv/ccs.rs @@ -42,7 +42,7 @@ use p3_field::PrimeCharacteristicRing; use p3_goldilocks::Goldilocks as F; use crate::plain::PlainMemLayout; -use crate::riscv::lookups::{JOLT_CYCLE_TRACK_ECALL_NUM, JOLT_PRINT_ECALL_NUM, PROG_ID, RAM_ID}; +use crate::riscv::lookups::{JOLT_CYCLE_TRACK_ECALL_NUM, JOLT_PRINT_ECALL_NUM, POSEIDON2_ECALL_NUM, PROG_ID, RAM_ID}; mod bus_bindings; mod config; @@ -1036,10 +1036,14 @@ fn semantic_constraints( )); } - // ECALL helpers (Jolt marker/print IDs). + // ECALL helpers (Jolt marker/print IDs + Poseidon2 precompile). let a0 = layout.reg_in(10, j); let ecall_is_cycle = layout.ecall_is_cycle(j); let ecall_is_print = layout.ecall_is_print(j); + let ecall_is_poseidon2 = layout.ecall_is_poseidon2(j); + let ecall_is_poseidon2_read = layout.ecall_is_poseidon2_read(j); + let halt_ecall_poseidon2_read = layout.halt_ecall_poseidon2_read(j); + let ecall_rd_val = layout.ecall_rd_val(j); let ecall_halts = layout.ecall_halts(j); let halt_effective = layout.halt_effective(j); @@ -1166,6 +1170,86 @@ fn semantic_constraints( } constraints.push(Constraint::mul(ecall_is_print, ecall_is_print, ecall_is_print)); + // ecall_is_poseidon2 = (a0 == POSEIDON2_ECALL_NUM). + let poseidon2_const = POSEIDON2_ECALL_NUM as u32; + { + let bit0 = layout.ecall_a0_bit(0, j); + let prefix0 = layout.ecall_poseidon2_prefix(0, j); + if (poseidon2_const & 1) == 1 { + constraints.push(Constraint::terms(one, false, vec![(prefix0, F::ONE), (bit0, -F::ONE)])); + } else { + constraints.push(Constraint::terms( + one, + false, + vec![(prefix0, F::ONE), (bit0, F::ONE), (one, -F::ONE)], + )); + } + } + for k in 1..31 { + let prev = layout.ecall_poseidon2_prefix(k - 1, j); + let next = layout.ecall_poseidon2_prefix(k, j); + let bit_col = layout.ecall_a0_bit(k, j); + let bit_is_one = ((poseidon2_const >> k) & 1) == 1; + let b_terms = if bit_is_one { + vec![(bit_col, F::ONE)] + } else { + vec![(one, F::ONE), (bit_col, -F::ONE)] + }; + constraints.push(Constraint { + condition_col: prev, + negate_condition: false, + additional_condition_cols: Vec::new(), + b_terms, + c_terms: vec![(next, F::ONE)], + }); + } + { + let prev = layout.ecall_poseidon2_prefix(30, j); + let bit_col = layout.ecall_a0_bit(31, j); + let bit_is_one = ((poseidon2_const >> 31) & 1) == 1; + let b_terms = if bit_is_one { + vec![(bit_col, F::ONE)] + } else { + vec![(one, F::ONE), (bit_col, -F::ONE)] + }; + constraints.push(Constraint { + condition_col: prev, + negate_condition: false, + additional_condition_cols: Vec::new(), + b_terms, + c_terms: vec![(ecall_is_poseidon2, F::ONE)], + }); + } + constraints.push(Constraint::mul(ecall_is_poseidon2, ecall_is_poseidon2, ecall_is_poseidon2)); + + // ecall_is_poseidon2_read = prefix_30 * bit_31. + // POSEIDON2_READ_ECALL_NUM (0x80504F53) shares bits 0-30 with POSEIDON2_ECALL_NUM + // (0x504F53) and has bit 31 set. The prefix chain already matches bits 0-30, + // so we branch on bit 31: compute = prefix_30 * (1-bit_31), read = prefix_30 * bit_31. + { + let prefix_30 = layout.ecall_poseidon2_prefix(30, j); + let bit_31 = layout.ecall_a0_bit(31, j); + constraints.push(Constraint::mul(prefix_30, bit_31, ecall_is_poseidon2_read)); + } + constraints.push(Constraint::mul(ecall_is_poseidon2_read, ecall_is_poseidon2_read, ecall_is_poseidon2_read)); + + // halt_ecall_poseidon2_read = is_halt * ecall_is_poseidon2_read. + // Gates the raw prefix-derived flag by is_halt so that it only activates + // on actual ECALL instructions, not on arbitrary instructions where a0 + // happens to contain POSEIDON2_READ_ECALL_NUM. + constraints.push(Constraint::mul(layout.is_halt(j), ecall_is_poseidon2_read, halt_ecall_poseidon2_read)); + constraints.push(Constraint::mul(halt_ecall_poseidon2_read, halt_ecall_poseidon2_read, halt_ecall_poseidon2_read)); + + // halt_ecall_poseidon2_read * (ecall_rd_val - rd_write_val) = 0 + // Links ecall_rd_val to rd_write_val for poseidon2_read steps, reusing the + // existing enforce_u32_bits range check on rd_write_val. + constraints.push(Constraint::terms( + halt_ecall_poseidon2_read, + false, + vec![(ecall_rd_val, F::ONE), (layout.rd_write_val(j), -F::ONE)], + )); + + // ecall_halts = 1 - ecall_is_cycle - ecall_is_print - ecall_is_poseidon2 - ecall_is_poseidon2_read. constraints.push(Constraint::terms( one, false, @@ -1173,6 +1257,8 @@ fn semantic_constraints( (ecall_halts, F::ONE), (ecall_is_cycle, F::ONE), (ecall_is_print, F::ONE), + (ecall_is_poseidon2, F::ONE), + (ecall_is_poseidon2_read, F::ONE), (one, -F::ONE), ], )); @@ -1692,17 +1778,37 @@ fn semantic_constraints( // Register update pattern for r=1..31: // - if rd_sel[r]=1 then reg_out[r] = rd_write_val // - else reg_out[r] = reg_in[r] + // Special case for r=10 (a0): the Poseidon2 read ECALL writes ecall_rd_val + // to reg_out[10] without going through rd_sel/rd_write_val. for r in 1..32 { constraints.push(Constraint::terms( layout.rd_sel(r, j), false, vec![(layout.reg_out(r, j), F::ONE), (layout.rd_write_val(j), -F::ONE)], )); - constraints.push(Constraint::terms( - layout.rd_sel(r, j), - true, - vec![(layout.reg_out(r, j), F::ONE), (layout.reg_in(r, j), -F::ONE)], - )); + if r == 10 { + // (1 - rd_sel[10] - halt_ecall_poseidon2_read) * (reg_out[10] - reg_in[10]) = 0 + // Uses halt-gated flag so this only activates on actual ECALL instructions. + constraints.push(Constraint { + condition_col: layout.rd_sel(r, j), + negate_condition: true, + additional_condition_cols: vec![halt_ecall_poseidon2_read], + b_terms: vec![(layout.reg_out(r, j), F::ONE), (layout.reg_in(r, j), -F::ONE)], + c_terms: Vec::new(), + }); + // halt_ecall_poseidon2_read * (reg_out[10] - ecall_rd_val) = 0 + constraints.push(Constraint::terms( + halt_ecall_poseidon2_read, + false, + vec![(layout.reg_out(r, j), F::ONE), (ecall_rd_val, -F::ONE)], + )); + } else { + constraints.push(Constraint::terms( + layout.rd_sel(r, j), + true, + vec![(layout.reg_out(r, j), F::ONE), (layout.reg_in(r, j), -F::ONE)], + )); + } } // RAM effective address is computed via the ADD Shout lookup (mod 2^32 semantics). diff --git a/crates/neo-memory/src/riscv/ccs/layout.rs b/crates/neo-memory/src/riscv/ccs/layout.rs index 3d4fca2b..fb113df5 100644 --- a/crates/neo-memory/src/riscv/ccs/layout.rs +++ b/crates/neo-memory/src/riscv/ccs/layout.rs @@ -188,6 +188,11 @@ pub struct Rv32B1Layout { pub ecall_is_cycle: usize, pub ecall_print_prefix_start: usize, // 31 pub ecall_is_print: usize, + pub ecall_poseidon2_prefix_start: usize, // 31 + pub ecall_is_poseidon2: usize, + pub ecall_is_poseidon2_read: usize, + pub halt_ecall_poseidon2_read: usize, + pub ecall_rd_val: usize, pub ecall_halts: usize, pub halt_effective: usize, @@ -549,6 +554,32 @@ impl Rv32B1Layout { self.cpu_cell(self.ecall_is_print, j) } + #[inline] + pub fn ecall_poseidon2_prefix(&self, k: usize, j: usize) -> usize { + debug_assert!(k < 31, "ecall_poseidon2_prefix k out of range"); + self.ecall_poseidon2_prefix_start + k * self.chunk_size + j + } + + #[inline] + pub fn ecall_is_poseidon2(&self, j: usize) -> usize { + self.cpu_cell(self.ecall_is_poseidon2, j) + } + + #[inline] + pub fn ecall_is_poseidon2_read(&self, j: usize) -> usize { + self.cpu_cell(self.ecall_is_poseidon2_read, j) + } + + #[inline] + pub fn halt_ecall_poseidon2_read(&self, j: usize) -> usize { + self.cpu_cell(self.halt_ecall_poseidon2_read, j) + } + + #[inline] + pub fn ecall_rd_val(&self, j: usize) -> usize { + self.cpu_cell(self.ecall_rd_val, j) + } + #[inline] pub fn ecall_halts(&self, j: usize) -> usize { self.cpu_cell(self.ecall_halts, j) @@ -1134,6 +1165,11 @@ pub(super) fn build_layout_with_m( let ecall_is_cycle = alloc_scalar(&mut col); let ecall_print_prefix_start = alloc_array(&mut col, 31); let ecall_is_print = alloc_scalar(&mut col); + let ecall_poseidon2_prefix_start = alloc_array(&mut col, 31); + let ecall_is_poseidon2 = alloc_scalar(&mut col); + let ecall_is_poseidon2_read = alloc_scalar(&mut col); + let halt_ecall_poseidon2_read = alloc_scalar(&mut col); + let ecall_rd_val = alloc_scalar(&mut col); let ecall_halts = alloc_scalar(&mut col); let halt_effective = alloc_scalar(&mut col); @@ -1315,6 +1351,11 @@ pub(super) fn build_layout_with_m( ecall_is_cycle, ecall_print_prefix_start, ecall_is_print, + ecall_poseidon2_prefix_start, + ecall_is_poseidon2, + ecall_is_poseidon2_read, + halt_ecall_poseidon2_read, + ecall_rd_val, ecall_halts, halt_effective, bus, diff --git a/crates/neo-memory/src/riscv/ccs/witness.rs b/crates/neo-memory/src/riscv/ccs/witness.rs index a4f673a7..5fcdb14d 100644 --- a/crates/neo-memory/src/riscv/ccs/witness.rs +++ b/crates/neo-memory/src/riscv/ccs/witness.rs @@ -5,7 +5,7 @@ use neo_vm_trace::{StepTrace, TwistOpKind}; use crate::riscv::lookups::{ decode_instruction, BranchCondition, RiscvInstruction, RiscvMemOp, RiscvOpcode, JOLT_CYCLE_TRACK_ECALL_NUM, - JOLT_PRINT_ECALL_NUM, PROG_ID, RAM_ID, + JOLT_PRINT_ECALL_NUM, POSEIDON2_ECALL_NUM, POSEIDON2_READ_ECALL_NUM, PROG_ID, RAM_ID, }; use super::constants::{ @@ -78,7 +78,30 @@ fn set_ecall_helpers(z: &mut [F], layout: &Rv32B1Layout, j: usize, a0_u64: u64, let is_print = print_prefix == 1 && print_match31; z[layout.ecall_is_print(j)] = if is_print { F::ONE } else { F::ZERO }; - let ecall_halts = !(is_cycle || is_print); + let poseidon2_const = POSEIDON2_ECALL_NUM; + let mut poseidon2_prefix = if ((a0_u32 ^ poseidon2_const) & 1) == 0 { 1u32 } else { 0u32 }; + z[layout.ecall_poseidon2_prefix(0, j)] = if poseidon2_prefix == 1 { F::ONE } else { F::ZERO }; + for k in 1..31 { + let bit_match = ((a0_u32 >> k) ^ (poseidon2_const >> k)) & 1; + poseidon2_prefix &= 1u32 ^ bit_match; + z[layout.ecall_poseidon2_prefix(k, j)] = if poseidon2_prefix == 1 { F::ONE } else { F::ZERO }; + } + let poseidon2_match31 = (((a0_u32 >> 31) ^ (poseidon2_const >> 31)) & 1) == 0; + let is_poseidon2 = poseidon2_prefix == 1 && poseidon2_match31; + z[layout.ecall_is_poseidon2(j)] = if is_poseidon2 { F::ONE } else { F::ZERO }; + + // ecall_is_poseidon2_read = prefix_30 * bit_31 (bit 31 set = read ECALL). + let bit_31_set = (a0_u32 >> 31) & 1 == 1; + let is_poseidon2_read = poseidon2_prefix == 1 && bit_31_set; + z[layout.ecall_is_poseidon2_read(j)] = if is_poseidon2_read { F::ONE } else { F::ZERO }; + + // halt_ecall_poseidon2_read = is_halt * ecall_is_poseidon2_read. + // Gates the raw prefix flag by is_halt so register update constraints only + // activate on actual ECALL instructions. + let halt_poseidon2_read = is_halt && is_poseidon2_read; + z[layout.halt_ecall_poseidon2_read(j)] = if halt_poseidon2_read { F::ONE } else { F::ZERO }; + + let ecall_halts = !(is_cycle || is_print || is_poseidon2 || is_poseidon2_read); z[layout.ecall_halts(j)] = if ecall_halts { F::ONE } else { F::ZERO }; z[layout.halt_effective(j)] = if is_halt && ecall_halts { F::ONE } else { F::ZERO }; @@ -644,6 +667,13 @@ fn rv32_b1_chunk_to_witness_internal( z[layout.is_halt(j)] = if is_halt { F::ONE } else { F::ZERO }; set_ecall_helpers(&mut z, layout, j, step.regs_before[10], is_halt)?; + // Detect Poseidon2 read ECALL and set ecall_rd_val. + let is_poseidon2_read_ecall = is_halt && step.regs_before[10] as u32 == POSEIDON2_READ_ECALL_NUM; + if is_poseidon2_read_ecall { + let rd_val = step.regs_after[10]; + z[layout.ecall_rd_val(j)] = F::from_u64(rd_val); + } + // One-hot register selectors. let rs1_idx = rs1 as usize; let rs2_idx = rs2 as usize; @@ -1152,6 +1182,12 @@ fn rv32_b1_chunk_to_witness_internal( z[layout.br_not_taken(j)] = F::ONE - z[layout.br_taken(j)]; } + // Poseidon2 read ECALL: set rd_write_val = ecall_rd_val so the existing + // enforce_u32_bits range check covers it. + if is_poseidon2_read_ecall { + z[layout.rd_write_val(j)] = z[layout.ecall_rd_val(j)]; + } + let mul_carry = if is_mulh { let rhs = (mul_hi as i128) - (rs1_sign as i128) * (rs2_u64 as i128) - (rs2_sign as i128) * (rs1_u64 as i128) diff --git a/crates/neo-memory/src/riscv/lookups/cpu.rs b/crates/neo-memory/src/riscv/lookups/cpu.rs index 4c3094bb..f8fbb5b3 100644 --- a/crates/neo-memory/src/riscv/lookups/cpu.rs +++ b/crates/neo-memory/src/riscv/lookups/cpu.rs @@ -1,11 +1,13 @@ use neo_vm_trace::{Shout, Twist}; +use p3_field::{PrimeCharacteristicRing, PrimeField64}; +use p3_goldilocks::Goldilocks; use super::bits::interleave_bits; use super::decode::decode_instruction; use super::encode::encode_instruction; use super::isa::{BranchCondition, RiscvInstruction, RiscvMemOp, RiscvOpcode}; use super::tables::RiscvShoutTables; -use super::{JOLT_CYCLE_TRACK_ECALL_NUM, JOLT_PRINT_ECALL_NUM}; +use super::{JOLT_CYCLE_TRACK_ECALL_NUM, JOLT_PRINT_ECALL_NUM, POSEIDON2_ECALL_NUM, POSEIDON2_READ_ECALL_NUM}; /// A RISC-V CPU that can be traced using Neo's VmCpu trait. /// @@ -25,6 +27,10 @@ pub struct RiscvCpu { program: Vec, /// Base address of the program. program_base: u64, + /// Pending Poseidon2 digest words (8 × u32, set by the compute ECALL). + poseidon2_pending: Option<[u32; 8]>, + /// Index into `poseidon2_pending` for the next read ECALL. + poseidon2_read_idx: usize, } impl RiscvCpu { @@ -38,6 +44,8 @@ impl RiscvCpu { halted: false, program: Vec::new(), program_base: 0, + poseidon2_pending: None, + poseidon2_read_idx: 0, } } @@ -86,10 +94,64 @@ impl RiscvCpu { fn handle_ecall(&mut self) { let call_id = self.get_reg(10) as u32; // a0 - if call_id != JOLT_CYCLE_TRACK_ECALL_NUM && call_id != JOLT_PRINT_ECALL_NUM { + if call_id != JOLT_CYCLE_TRACK_ECALL_NUM + && call_id != JOLT_PRINT_ECALL_NUM + && call_id != POSEIDON2_ECALL_NUM + && call_id != POSEIDON2_READ_ECALL_NUM + { self.halted = true; } } + + /// Execute the Poseidon2 hash compute precompile. + /// + /// Reads `n_elements` Goldilocks field elements from RAM at `input_addr` + /// using `load_untraced` (no Twist bus events), computes the Poseidon2- + /// Goldilocks sponge hash, and stores the 4-element digest as 8 u32 words + /// in internal CPU state (`poseidon2_pending`). The guest retrieves output + /// words one at a time via `POSEIDON2_READ_ECALL_NUM`. + fn handle_poseidon2_ecall>(&mut self, twist: &mut T) { + use neo_ccs::crypto::poseidon2_goldilocks::poseidon2_hash; + + let ram = super::RAM_ID; + let n_elements = self.get_reg(11) as u32; // a1 + let input_addr = self.get_reg(12); // a2 + + let mut inputs = Vec::with_capacity(n_elements as usize); + for i in 0..n_elements { + let addr = input_addr + (i as u64) * 8; + let lo = twist.load_untraced(ram, addr) & 0xFFFF_FFFF; + let hi = twist.load_untraced(ram, addr + 4) & 0xFFFF_FFFF; + let val = lo | (hi << 32); + inputs.push(Goldilocks::from_u64(val)); + } + + let digest = poseidon2_hash(&inputs); + + let mut words = [0u32; 8]; + for (i, &elem) in digest.iter().enumerate() { + let val = elem.as_canonical_u64(); + words[i * 2] = val as u32; + words[i * 2 + 1] = (val >> 32) as u32; + } + self.poseidon2_pending = Some(words); + self.poseidon2_read_idx = 0; + } + + /// Read one u32 word of the pending Poseidon2 digest into register a0. + /// + /// Called 8 times by the guest to retrieve the full 4-element digest. + /// Each call returns the next word and advances the internal read index. + fn handle_poseidon2_read_ecall(&mut self) { + let words = self + .poseidon2_pending + .expect("poseidon2 read ECALL called without a pending digest"); + let idx = self.poseidon2_read_idx; + assert!(idx < 8, "poseidon2 read ECALL: all 8 words already consumed"); + let word = words[idx] as u64; + self.set_reg(10, word); // a0 + self.poseidon2_read_idx = idx + 1; + } } impl neo_vm_trace::VmCpu for RiscvCpu { @@ -402,7 +464,12 @@ impl neo_vm_trace::VmCpu for RiscvCpu { } RiscvInstruction::Halt => { - // ECALL trap semantics (Jolt-style): no architectural effects, halt unless it's a known marker/print call. + let call_id = self.get_reg(10) as u32; + if call_id == POSEIDON2_ECALL_NUM { + self.handle_poseidon2_ecall(twist); + } else if call_id == POSEIDON2_READ_ECALL_NUM { + self.handle_poseidon2_read_ecall(); + } self.handle_ecall(); } @@ -558,7 +625,12 @@ impl neo_vm_trace::VmCpu for RiscvCpu { // === System Instructions === RiscvInstruction::Ecall => { - // ECALL - environment call (syscall). + let call_id = self.get_reg(10) as u32; + if call_id == POSEIDON2_ECALL_NUM { + self.handle_poseidon2_ecall(twist); + } else if call_id == POSEIDON2_READ_ECALL_NUM { + self.handle_poseidon2_read_ecall(); + } self.handle_ecall(); } diff --git a/crates/neo-memory/src/riscv/lookups/mod.rs b/crates/neo-memory/src/riscv/lookups/mod.rs index 4753eaec..a11bdb60 100644 --- a/crates/neo-memory/src/riscv/lookups/mod.rs +++ b/crates/neo-memory/src/riscv/lookups/mod.rs @@ -105,6 +105,22 @@ pub const PROG_ID: TwistId = TwistId(1); pub const JOLT_CYCLE_TRACK_ECALL_NUM: u32 = 0xC7C1E; pub const JOLT_PRINT_ECALL_NUM: u32 = 0x505249; +/// Poseidon2-Goldilocks hash compute ECALL identifier. +/// +/// ABI: a0 = POSEIDON2_ECALL_NUM, a1 = input element count, +/// a2 = input RAM address (elements as 2×u32 LE). +/// The host reads inputs via untraced loads, computes the hash, and stores the +/// 4-element digest in CPU-internal state. Use POSEIDON2_READ_ECALL_NUM to +/// retrieve output words one at a time via register a0. +pub const POSEIDON2_ECALL_NUM: u32 = 0x504F53; + +/// Poseidon2-Goldilocks digest read ECALL identifier (bit 31 set). +/// +/// ABI: a0 = POSEIDON2_READ_ECALL_NUM. Returns the next u32 word of the +/// pending Poseidon2 digest in register a0. Call 8 times (4 elements × 2 words) +/// to retrieve the full digest. +pub const POSEIDON2_READ_ECALL_NUM: u32 = 0x80504F53; + pub use alu::{compute_op, lookup_entry}; pub use bits::{interleave_bits, uninterleave_bits}; pub use cpu::RiscvCpu; diff --git a/crates/neo-memory/tests/riscv_ccs_tests.rs b/crates/neo-memory/tests/riscv_ccs_tests.rs index 43e11ded..b9b9e819 100644 --- a/crates/neo-memory/tests/riscv_ccs_tests.rs +++ b/crates/neo-memory/tests/riscv_ccs_tests.rs @@ -12,7 +12,8 @@ use neo_memory::riscv::ccs::{ }; use neo_memory::riscv::lookups::{ decode_instruction, encode_program, BranchCondition, RiscvCpu, RiscvInstruction, RiscvMemOp, RiscvMemory, - RiscvOpcode, RiscvShoutTables, JOLT_CYCLE_TRACK_ECALL_NUM, JOLT_PRINT_ECALL_NUM, PROG_ID, RAM_ID, + RiscvOpcode, RiscvShoutTables, JOLT_CYCLE_TRACK_ECALL_NUM, JOLT_PRINT_ECALL_NUM, POSEIDON2_ECALL_NUM, + POSEIDON2_READ_ECALL_NUM, PROG_ID, RAM_ID, }; use neo_memory::riscv::rom_init::prog_init_words; use neo_memory::witness::LutTableSpec; @@ -427,6 +428,97 @@ fn rv32_b1_ccs_happy_path_rv32i_ecall_markers_program() { } } +#[test] +fn rv32_b1_ccs_happy_path_poseidon2_ecall() { + let xlen = 32usize; + let mut program = Vec::new(); + + // a1 = 0 (n_elements = 0 for empty-input Poseidon2 hash). + program.push(RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 11, + rs1: 0, + imm: 0, + }); + // a0 = POSEIDON2_ECALL_NUM → compute ECALL. + program.extend(load_u32_imm(10, POSEIDON2_ECALL_NUM)); + program.push(RiscvInstruction::Halt); + + // Read all 8 digest words via read ECALLs. + for _ in 0..8 { + program.extend(load_u32_imm(10, POSEIDON2_READ_ECALL_NUM)); + program.push(RiscvInstruction::Halt); + } + + // Clear a0 → final halt. + program.push(RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 10, + rs1: 0, + imm: 0, + }); + program.push(RiscvInstruction::Halt); + + let program_bytes = encode_program(&program); + let mut cpu_vm = RiscvCpu::new(xlen); + cpu_vm.load_program(0, program.clone()); + let memory = RiscvMemory::with_program_in_twist(xlen, PROG_ID, 0, &program_bytes); + let shout = RiscvShoutTables::new(xlen); + let trace = trace_program(cpu_vm, memory, shout, 256).expect("trace"); + assert!(trace.did_halt(), "expected Halt"); + + let (k_prog, d_prog) = pow2_ceil_k(program_bytes.len()); + let (k_ram, d_ram) = pow2_ceil_k(0x40); + let mem_layouts = HashMap::from([ + ( + 0u32, + PlainMemLayout { + k: k_ram, + d: d_ram, + n_side: 2, + lanes: 1, + }, + ), + ( + 1u32, + PlainMemLayout { + k: k_prog, + d: d_prog, + n_side: 2, + lanes: 1, + }, + ), + ]); + + let initial_mem = prog_init_words(PROG_ID, 0, &program_bytes); + + let shout_table_ids = RV32I_SHOUT_TABLE_IDS; + let (ccs, layout) = build_rv32_b1_step_ccs(&mem_layouts, &shout_table_ids, 1).expect("ccs"); + let params = NeoParams::goldilocks_auto_r1cs_ccs(ccs.n).expect("params"); + + let table_specs = rv32i_table_specs(xlen); + + let cpu = R1csCpu::new( + ccs, + params, + NoopCommit::default(), + layout.m_in, + &HashMap::new(), + &table_specs, + rv32_b1_chunk_to_witness(layout.clone()), + ) + .with_shared_cpu_bus( + rv32_b1_shared_cpu_bus_config(&layout, &shout_table_ids, mem_layouts, initial_mem).expect("cfg"), + 1, + ) + .expect("shared bus"); + + let steps = CpuArithmetization::build_ccs_steps(&cpu, &trace).expect("build steps"); + for (mcs_inst, mcs_wit) in steps { + check_ccs_rowwise_zero(&cpu.ccs, &mcs_inst.x, &mcs_wit.w).expect("CCS satisfied"); + } +} + #[test] fn rv32_b1_ccs_happy_path_rv32m_program() { let xlen = 32usize; diff --git a/crates/neo-vm-trace/src/lib.rs b/crates/neo-vm-trace/src/lib.rs index f75f963d..e7750d77 100644 --- a/crates/neo-vm-trace/src/lib.rs +++ b/crates/neo-vm-trace/src/lib.rs @@ -340,6 +340,16 @@ pub trait Twist { self.store_lane(twist_id, addr, value, lane); } } + + /// Load a value from memory without recording a Twist event. + /// + /// Used by ECALL precompiles that need to read guest RAM without generating + /// per-step Twist bus entries (the ECALL computation is trusted host code). + /// The default implementation delegates to `load`. + #[inline] + fn load_untraced(&mut self, twist_id: TwistId, addr: Addr) -> Word { + self.load(twist_id, addr) + } } /// A tracing wrapper around any `Twist` implementation. @@ -431,6 +441,10 @@ where lane: Some(lane), }); } + + fn load_untraced(&mut self, twist_id: TwistId, addr: Addr) -> Word { + self.inner.load(twist_id, addr) + } } // ============================================================================ diff --git a/crates/nightstream-sdk/src/goldilocks.rs b/crates/nightstream-sdk/src/goldilocks.rs new file mode 100644 index 00000000..a25193c6 --- /dev/null +++ b/crates/nightstream-sdk/src/goldilocks.rs @@ -0,0 +1,116 @@ +//! Software Goldilocks field arithmetic for RISC-V guests. +//! +//! Field: p = 2^64 - 2^32 + 1 = 0xFFFF_FFFF_0000_0001 +//! +//! All values are canonical: in [0, p). + +#![allow(dead_code)] + +/// Goldilocks prime: p = 2^64 - 2^32 + 1. +pub const GL_P: u64 = 0xFFFF_FFFF_0000_0001; + +/// A Goldilocks field element digest: 4 elements (32 bytes). +pub type GlDigest = [u64; 4]; + +/// Reduce a u128 value mod p using the identity 2^64 ≡ 2^32 - 1 (mod p). +/// Avoids 128-bit division. +#[inline] +fn reduce128(x: u128) -> u64 { + let lo = x as u64; + let hi = (x >> 64) as u64; + + // x ≡ lo + hi * (2^32 - 1) (mod p) + let hi_times_correction = (hi as u128) * 0xFFFF_FFFFu128; + let sum = lo as u128 + hi_times_correction; + + let lo2 = sum as u64; + let hi2 = (sum >> 64) as u64; + + // hi2 is 0 or 1; do one more pass + let (mut result, overflow) = lo2.overflowing_add(hi2.wrapping_mul(0xFFFF_FFFF)); + if overflow || result >= GL_P { + result = result.wrapping_sub(GL_P); + } + result +} + +/// Additive identity. +pub const GL_ZERO: u64 = 0; + +/// Multiplicative identity. +pub const GL_ONE: u64 = 1; + +/// Field addition: (a + b) mod p. +#[inline] +pub fn gl_add(a: u64, b: u64) -> u64 { + let sum = a as u128 + b as u128; + let s = sum as u64; + let carry = (sum >> 64) as u64; + // carry is 0 or 1 + let (mut r, overflow) = s.overflowing_add(carry.wrapping_mul(0xFFFF_FFFF)); + if overflow || r >= GL_P { + r = r.wrapping_sub(GL_P); + } + r +} + +/// Field subtraction: (a - b) mod p. +#[inline] +pub fn gl_sub(a: u64, b: u64) -> u64 { + if a >= b { + let diff = a - b; + if diff >= GL_P { diff - GL_P } else { diff } + } else { + // a < b, so result = p - (b - a) + GL_P - (b - a) + } +} + +/// Field negation: (-a) mod p. +#[inline] +pub fn gl_neg(a: u64) -> u64 { + if a == 0 { 0 } else { GL_P - a } +} + +/// Field multiplication: (a * b) mod p. +#[inline] +pub fn gl_mul(a: u64, b: u64) -> u64 { + let prod = (a as u128) * (b as u128); + reduce128(prod) +} + +/// Field squaring: (a * a) mod p. +#[inline] +pub fn gl_sqr(a: u64) -> u64 { + gl_mul(a, a) +} + +/// Field exponentiation: a^exp mod p via square-and-multiply. +pub fn gl_pow(mut base: u64, mut exp: u64) -> u64 { + let mut result = GL_ONE; + while exp > 0 { + if exp & 1 == 1 { + result = gl_mul(result, base); + } + base = gl_sqr(base); + exp >>= 1; + } + result +} + +/// Field inverse: a^(p-2) mod p (Fermat's little theorem). +/// +/// Panics (halts) if a == 0. +pub fn gl_inv(a: u64) -> u64 { + debug_assert!(a != 0, "inverse of zero"); + // p - 2 = 0xFFFF_FFFE_FFFF_FFFF + gl_pow(a, GL_P - 2) +} + +/// Check equality of two digests. +pub fn digest_eq(a: &GlDigest, b: &GlDigest) -> bool { + a[0] == b[0] && a[1] == b[1] && a[2] == b[2] && a[3] == b[3] +} + +/// Zero digest. +pub const ZERO_DIGEST: GlDigest = [0; 4]; diff --git a/crates/nightstream-sdk/src/lib.rs b/crates/nightstream-sdk/src/lib.rs index f0a043c8..4cd5b989 100644 --- a/crates/nightstream-sdk/src/lib.rs +++ b/crates/nightstream-sdk/src/lib.rs @@ -3,6 +3,8 @@ pub use nightstream_sdk_macros::{entry, provable, NeoAbi}; pub mod abi; +pub mod goldilocks; +pub mod poseidon2; #[cfg(any(target_arch = "riscv32", target_arch = "riscv64"))] mod riscv; diff --git a/crates/nightstream-sdk/src/poseidon2.rs b/crates/nightstream-sdk/src/poseidon2.rs new file mode 100644 index 00000000..1cf9b90d --- /dev/null +++ b/crates/nightstream-sdk/src/poseidon2.rs @@ -0,0 +1,111 @@ +//! Poseidon2-Goldilocks hash via ECALL precompile. +//! +//! Provides the guest-side API for computing Poseidon2 hashes over Goldilocks +//! field elements. On RISC-V targets, this issues ECALLs to the host: +//! 1. A "compute" ECALL that reads inputs (via untraced loads) and computes the +//! Poseidon2 hash, storing the digest in host-side CPU state. +//! 2. Eight "read" ECALLs that each return one u32 word of the digest in +//! register a0. +//! +//! On non-RISC-V targets, a stub panics. + +use crate::goldilocks::GlDigest; + +/// Poseidon2 compute ECALL number (must match neo-memory's POSEIDON2_ECALL_NUM). +#[cfg(any(target_arch = "riscv32", target_arch = "riscv64"))] +const POSEIDON2_ECALL_NUM: u32 = 0x504F53; + +/// Poseidon2 read ECALL number (must match neo-memory's POSEIDON2_READ_ECALL_NUM). +#[cfg(any(target_arch = "riscv32", target_arch = "riscv64"))] +const POSEIDON2_READ_ECALL_NUM: u32 = 0x80504F53; + +/// Digest length in Goldilocks elements (matches neo-params DIGEST_LEN). +const DIGEST_LEN: usize = 4; + +/// Scratch buffer size: supports hashing up to 64 Goldilocks elements per call. +/// Each element occupies 2 u32 words (8 bytes). +const MAX_INPUT_ELEMENTS: usize = 64; + +/// Hash an arbitrary-length slice of Goldilocks field elements. +/// +/// Packs elements to a stack-allocated scratch buffer, issues the Poseidon2 +/// compute ECALL, then retrieves the 4-element digest via 8 read ECALLs +/// (each returning one u32 word in register a0). +/// +/// # Panics +/// +/// Panics if `input.len() > MAX_INPUT_ELEMENTS` (64). +pub fn poseidon2_hash(input: &[u64]) -> GlDigest { + assert!( + input.len() <= MAX_INPUT_ELEMENTS, + "poseidon2_hash: too many input elements" + ); + + // Scratch buffer for input (as u32 words). + let mut input_buf: [u32; MAX_INPUT_ELEMENTS * 2] = [0; MAX_INPUT_ELEMENTS * 2]; + + // Pack Goldilocks elements as pairs of u32 words (little-endian). + for (i, &elem) in input.iter().enumerate() { + input_buf[i * 2] = elem as u32; + input_buf[i * 2 + 1] = (elem >> 32) as u32; + } + + // Compute ECALL: host reads inputs via untraced loads and stores digest in CPU state. + poseidon2_ecall_compute(input.len() as u32, input_buf.as_ptr() as u32); + + // Read 8 u32 words of the digest via register a0. + let d: [u32; 8] = core::array::from_fn(|_| poseidon2_ecall_read()); + + // Unpack into 4 Goldilocks elements. + let mut digest = [0u64; DIGEST_LEN]; + for i in 0..DIGEST_LEN { + digest[i] = (d[i * 2] as u64) | ((d[i * 2 + 1] as u64) << 32); + } + digest +} + +/// Issue the Poseidon2 compute ECALL. +/// +/// Registers: a0 = ECALL ID, a1 = element count, a2 = input addr. +/// No output via registers; digest is stored in host CPU state. +#[cfg(any(target_arch = "riscv32", target_arch = "riscv64"))] +fn poseidon2_ecall_compute(n_elements: u32, input_addr: u32) { + unsafe { + core::arch::asm!( + "ecall", + in("a0") POSEIDON2_ECALL_NUM, + in("a1") n_elements, + in("a2") input_addr, + options(nostack), + ); + } +} + +/// Issue the Poseidon2 read ECALL and return the next digest word. +/// +/// The ECALL number is passed in a0, and the host returns the next u32 word +/// of the pending digest in a0. +#[cfg(any(target_arch = "riscv32", target_arch = "riscv64"))] +fn poseidon2_ecall_read() -> u32 { + let result: u32; + unsafe { + core::arch::asm!( + "ecall", + inout("a0") POSEIDON2_READ_ECALL_NUM => result, + options(nostack), + ); + } + result +} + +/// Stub for non-RISC-V targets (used in native tests via the reference impl). +#[cfg(not(any(target_arch = "riscv32", target_arch = "riscv64")))] +fn poseidon2_ecall_compute(_n_elements: u32, _input_addr: u32) { + unimplemented!("poseidon2_ecall_compute is only available on RISC-V targets") +} + +/// Stub for non-RISC-V targets. +#[cfg(not(any(target_arch = "riscv32", target_arch = "riscv64")))] +fn poseidon2_ecall_read() -> u32 { + unimplemented!("poseidon2_ecall_read is only available on RISC-V targets") +} From 10ef34c4765035d98bfaaf323b13716223969bd4 Mon Sep 17 00:00:00 2001 From: Guillermo Valin Date: Tue, 17 Feb 2026 11:18:19 -0300 Subject: [PATCH 06/13] Add detailed timing metrics for RISC-V VM and CPU witness phases - Introduced new timing fields in `ProveTimings` to track durations for VM trace collection and CPU witness building. - Updated `ShardWitnessAux` to include wall-clock time for VM trace and CPU witness phases. - Enhanced `R1csCpu` to log detailed timing for witness building, bus filling, matrix decomposition, and commitment processes. - Added methods in `Session` for preloading CCS preprocessing with matrix digests to optimize performance during proving and verification. --- crates/neo-fold/src/riscv_shard.rs | 247 ++++++++++++++++------ crates/neo-fold/src/session.rs | 49 ++++- crates/neo-memory/src/builder.rs | 10 + crates/neo-memory/src/cpu/r1cs_adapter.rs | 54 +++++ 4 files changed, 299 insertions(+), 61 deletions(-) diff --git a/crates/neo-fold/src/riscv_shard.rs b/crates/neo-fold/src/riscv_shard.rs index a5d1c932..6b30e799 100644 --- a/crates/neo-fold/src/riscv_shard.rs +++ b/crates/neo-fold/src/riscv_shard.rs @@ -88,6 +88,10 @@ pub struct ProveTimings { pub vm_execution: Duration, /// Fold-and-prove (sumcheck + Ajtai commitment). pub fold_and_prove: Duration, + /// Sub-phase: RISC-V VM trace collection (within vm_execution). + pub vm_trace: Duration, + /// Sub-phase: CPU witness building including decomp + commit (within vm_execution). + pub cpu_witness: Duration, } pub fn rv32_b1_step_linking_config(layout: &Rv32B1Layout) -> StepLinkingConfig { @@ -294,16 +298,33 @@ fn all_shout_opcodes() -> HashSet { /// Build with [`Rv32B1::build_ccs_cache`], then inject into subsequent builders via /// [`Rv32B1::with_ccs_cache`]. /// -/// **Important**: the CCS structure itself is still built from scratch each time (it is fast). -/// Only the SparseCache (CSC decomposition + matrix digest) is cached. +/// When `full_ccs` is present, `prove()` and `build_verifier()` skip `build_rv32_b1_step_ccs` +/// and `with_shared_cpu_bus`, reusing the cached CCS directly. This saves ~1s per call. +/// +/// When `params` and `committer` are also present, `prove()` and `build_verifier()` additionally +/// skip `FoldingSession::new_ajtai()` (which generates a fresh random kappa x m Ajtai matrix), +/// using `FoldingSession::new()` with the cached values instead. pub struct Rv32B1CcsCache { pub sparse: Arc>, + /// Full post-shared-bus CCS. When present, prove/verify skip CCS construction. + pub full_ccs: Option>, + /// Layout corresponding to the cached CCS. + pub layout: Option, + /// Cached NeoParams from the original session creation. + pub params: Option, + /// Cached Ajtai committer (wraps `Arc`). Avoids re-generating the random matrix. + pub committer: Option, + /// Precomputed CCS matrix digest. Avoids re-hashing all matrices (~1.5s) on every prove/verify. + pub ccs_mat_digest: Option>, } impl std::fmt::Debug for Rv32B1CcsCache { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Rv32B1CcsCache") .field("sparse_len", &self.sparse.len()) + .field("has_full_ccs", &self.full_ccs.is_some()) + .field("has_params", &self.params.is_some()) + .field("has_committer", &self.committer.is_some()) .finish() } } @@ -519,9 +540,21 @@ impl Rv32B1 { ) .map_err(|e| PiCcsError::InvalidInput(format!("shared bus inject failed: {e}")))?; - let sparse = Arc::new(SparseCache::build(&cpu.ccs)); + let cached_params = session.params().clone(); + let cached_committer = session.committer().clone(); + + let full_ccs = cpu.ccs.clone(); + let sparse = Arc::new(SparseCache::build(&full_ccs)); + let ccs_mat_digest = neo_reductions::engines::utils::digest_ccs_matrices(&full_ccs); - Ok(Rv32B1CcsCache { sparse }) + Ok(Rv32B1CcsCache { + sparse, + full_ccs: Some(full_ccs), + layout: Some(layout), + params: Some(cached_params), + committer: Some(cached_committer), + ccs_mat_digest: Some(ccs_mat_digest), + }) } /// Build only the verification context (CCS + session) without executing the program or proving. @@ -623,38 +656,62 @@ impl Rv32B1 { let mut shout_table_ids: Vec = table_specs.keys().copied().collect(); shout_table_ids.sort_unstable(); - // --- CCS + Session (same as prove) --- - let (ccs_base, layout) = build_rv32_b1_step_ccs(&mem_layouts, &shout_table_ids, self.chunk_size) - .map_err(|e| PiCcsError::InvalidInput(format!("build_rv32_b1_step_ccs failed: {e}")))?; - - let mut session = FoldingSession::::new_ajtai(self.mode.clone(), &ccs_base)?; - let params = session.params().clone(); - let committer = session.committer().clone(); + // --- CCS + Session --- + let has_full_ccs = self + .ccs_cache + .as_ref() + .map(|c| c.full_ccs.is_some() && c.layout.is_some()) + .unwrap_or(false); let empty_tables: HashMap> = HashMap::new(); - - // Build R1csCpu for CCS with shared-bus wiring (same as prove; no execution). - let mut cpu = R1csCpu::new( - ccs_base, - params, - committer, - layout.m_in, - &empty_tables, - &table_specs, - rv32_b1_chunk_to_witness(layout.clone()), - ); - cpu = cpu - .with_shared_cpu_bus( - rv32_b1_shared_cpu_bus_config(&layout, &shout_table_ids, mem_layouts.clone(), initial_mem.clone()) - .map_err(|e| PiCcsError::InvalidInput(format!("rv32_b1_shared_cpu_bus_config failed: {e}")))?, - self.chunk_size, - ) - .map_err(|e| PiCcsError::InvalidInput(format!("shared bus inject failed: {e}")))?; + let (mut session, ccs, layout); + + if has_full_ccs { + let cache = self.ccs_cache.as_ref().unwrap(); + let cached_ccs = cache.full_ccs.as_ref().unwrap().clone(); + layout = cache.layout.as_ref().unwrap().clone(); + + // Use cached params+committer to skip the expensive setup_par() random matrix generation. + let has_cached_session = cache.params.is_some() && cache.committer.is_some(); + if has_cached_session { + let cached_params = cache.params.as_ref().unwrap().clone(); + let cached_committer = cache.committer.as_ref().unwrap().clone(); + session = FoldingSession::new(self.mode.clone(), cached_params, cached_committer); + session.commit_m = Some(cached_ccs.m); + } else { + session = FoldingSession::::new_ajtai(self.mode.clone(), &cached_ccs)?; + } + ccs = cached_ccs; + } else { + let (ccs_base, layout_built) = build_rv32_b1_step_ccs(&mem_layouts, &shout_table_ids, self.chunk_size) + .map_err(|e| PiCcsError::InvalidInput(format!("build_rv32_b1_step_ccs failed: {e}")))?; + layout = layout_built; + + session = FoldingSession::::new_ajtai(self.mode.clone(), &ccs_base)?; + let params = session.params().clone(); + let committer = session.committer().clone(); + + let mut cpu = R1csCpu::new( + ccs_base, + params, + committer, + layout.m_in, + &empty_tables, + &table_specs, + rv32_b1_chunk_to_witness(layout.clone()), + ); + cpu = cpu + .with_shared_cpu_bus( + rv32_b1_shared_cpu_bus_config(&layout, &shout_table_ids, mem_layouts.clone(), initial_mem.clone()) + .map_err(|e| PiCcsError::InvalidInput(format!("rv32_b1_shared_cpu_bus_config failed: {e}")))?, + self.chunk_size, + ) + .map_err(|e| PiCcsError::InvalidInput(format!("shared bus inject failed: {e}")))?; + ccs = cpu.ccs.clone(); + } session.set_step_linking(rv32_b1_step_linking_config(&layout)); - let ccs = cpu.ccs.clone(); - // No execution, no fold_and_prove -- just the verification context. // NOTE: verifier SparseCache preloading is deferred to `preload_sparse_cache()` // because the CCS pointer changes when moved into the Rv32B1Verifier struct. @@ -784,37 +841,82 @@ impl Rv32B1 { let prebuilt_sparse = self.ccs_cache.as_ref().map(|c| c.sparse.clone()); - let (ccs_base, layout) = build_rv32_b1_step_ccs(&mem_layouts, &shout_table_ids, self.chunk_size) - .map_err(|e| PiCcsError::InvalidInput(format!("build_rv32_b1_step_ccs failed: {e}")))?; - - // Session + Ajtai committer + params (auto-picked for this CCS). - let mut session = FoldingSession::::new_ajtai(self.mode.clone(), &ccs_base)?; - let params = session.params().clone(); - let committer = session.committer().clone(); + // Check if the cache contains a full pre-wired CCS we can reuse. + let has_full_ccs = self + .ccs_cache + .as_ref() + .map(|c| c.full_ccs.is_some() && c.layout.is_some()) + .unwrap_or(false); let mut vm = RiscvCpu::new(self.xlen); vm.load_program(/*base=*/ 0, program); - let empty_tables: HashMap> = HashMap::new(); let lut_lanes: HashMap = HashMap::new(); - // CPU arithmetization (builds chunk witnesses and commits them). - let mut cpu = R1csCpu::new( - ccs_base, - params, - committer, - layout.m_in, - &empty_tables, - &table_specs, - rv32_b1_chunk_to_witness(layout.clone()), - ); - cpu = cpu - .with_shared_cpu_bus( - rv32_b1_shared_cpu_bus_config(&layout, &shout_table_ids, mem_layouts.clone(), initial_mem.clone()) - .map_err(|e| PiCcsError::InvalidInput(format!("rv32_b1_shared_cpu_bus_config failed: {e}")))?, - self.chunk_size, - ) - .map_err(|e| PiCcsError::InvalidInput(format!("shared bus inject failed: {e}")))?; + let (mut session, cpu, layout); + + if has_full_ccs { + let cache = self.ccs_cache.as_ref().unwrap(); + let cached_ccs = cache.full_ccs.as_ref().unwrap().clone(); + layout = cache.layout.as_ref().unwrap().clone(); + + // Use cached params+committer to skip the expensive setup_par() random matrix generation. + let has_cached_session = cache.params.is_some() && cache.committer.is_some(); + if has_cached_session { + let cached_params = cache.params.as_ref().unwrap().clone(); + let cached_committer = cache.committer.as_ref().unwrap().clone(); + session = FoldingSession::new(self.mode.clone(), cached_params, cached_committer); + session.commit_m = Some(cached_ccs.m); + } else { + session = FoldingSession::::new_ajtai(self.mode.clone(), &cached_ccs)?; + } + let params = session.params().clone(); + let committer = session.committer().clone(); + + let mut cpu_inner = R1csCpu::new( + cached_ccs, + params, + committer, + layout.m_in, + &empty_tables, + &table_specs, + rv32_b1_chunk_to_witness(layout.clone()), + ); + cpu_inner = cpu_inner + .with_shared_cpu_bus_witness_only( + rv32_b1_shared_cpu_bus_config(&layout, &shout_table_ids, mem_layouts.clone(), initial_mem.clone()) + .map_err(|e| PiCcsError::InvalidInput(format!("rv32_b1_shared_cpu_bus_config failed: {e}")))?, + self.chunk_size, + ) + .map_err(|e| PiCcsError::InvalidInput(format!("shared bus witness-only setup failed: {e}")))?; + cpu = cpu_inner; + } else { + let (ccs_base, layout_built) = build_rv32_b1_step_ccs(&mem_layouts, &shout_table_ids, self.chunk_size) + .map_err(|e| PiCcsError::InvalidInput(format!("build_rv32_b1_step_ccs failed: {e}")))?; + layout = layout_built; + + session = FoldingSession::::new_ajtai(self.mode.clone(), &ccs_base)?; + let params = session.params().clone(); + let committer = session.committer().clone(); + + let mut cpu_inner = R1csCpu::new( + ccs_base, + params, + committer, + layout.m_in, + &empty_tables, + &table_specs, + rv32_b1_chunk_to_witness(layout.clone()), + ); + cpu_inner = cpu_inner + .with_shared_cpu_bus( + rv32_b1_shared_cpu_bus_config(&layout, &shout_table_ids, mem_layouts.clone(), initial_mem.clone()) + .map_err(|e| PiCcsError::InvalidInput(format!("rv32_b1_shared_cpu_bus_config failed: {e}")))?, + self.chunk_size, + ) + .map_err(|e| PiCcsError::InvalidInput(format!("shared bus inject failed: {e}")))?; + cpu = cpu_inner; + } // Always enforce step-to-step chunk chaining for RV32 B1. session.set_step_linking(rv32_b1_step_linking_config(&layout)); @@ -849,16 +951,28 @@ impl Rv32B1 { } } + // Extract sub-phase timings from the shared-bus aux. + let (t_vm_trace, t_cpu_witness) = session + .shared_bus_aux() + .map(|aux| (aux.vm_trace_duration, aux.cpu_witness_duration)) + .unwrap_or((Duration::ZERO, Duration::ZERO)); + // Enforce that the *statement* initial memory matches chunk 0's public MemInit. let steps_public = session.steps_public(); rv32_b1_enforce_chunk0_mem_init_matches_statement(&mem_layouts, &initial_mem, &steps_public)?; - let ccs = cpu.ccs.clone(); + // Move the CCS out of cpu (no longer needed) to avoid an expensive clone. + let ccs = cpu.ccs; - // Preload the SparseCache using the *final* CCS reference that fold_and_prove will use. - // (Must happen after cpu.ccs.clone() because the pointer-keyed cache checks identity.) + // Preload the SparseCache + matrix digest. When the digest is cached, skip the + // expensive digest_ccs_matrices recomputation (~1.5s for large circuits). if let Some(ref sparse) = prebuilt_sparse { - session.preload_ccs_sparse_cache(&ccs, sparse.clone())?; + let cached_digest = self.ccs_cache.as_ref().and_then(|c| c.ccs_mat_digest.clone()); + if let Some(digest) = cached_digest { + session.preload_ccs_sparse_cache_with_digest(&ccs, sparse.clone(), digest)?; + } else { + session.preload_ccs_sparse_cache(&ccs, sparse.clone())?; + } } let t_vm_execution = elapsed_duration(phase_start); @@ -890,6 +1004,8 @@ impl Rv32B1 { ccs_and_shared_bus: t_ccs_and_shared_bus, vm_execution: t_vm_execution, fold_and_prove: t_fold_and_prove, + vm_trace: t_vm_trace, + cpu_witness: t_cpu_witness, }, verify_duration: None, }) @@ -920,6 +1036,17 @@ impl Rv32B1Verifier { .preload_verifier_ccs_sparse_cache(&self.ccs, sparse) } + /// Preload verifier SparseCache with a precomputed matrix digest, skipping the expensive + /// digest recomputation. + pub fn preload_sparse_cache_with_digest( + &mut self, + sparse: Arc>, + ccs_mat_digest: Vec, + ) -> Result<(), PiCcsError> { + self.session + .preload_verifier_ccs_sparse_cache_with_digest(&self.ccs, sparse, ccs_mat_digest) + } + /// Verify a `ShardProof` using the provided public step instance bundles. /// /// `steps_public` must be the step instance bundles produced by the prover (via diff --git a/crates/neo-fold/src/session.rs b/crates/neo-fold/src/session.rs index 0dece63c..d12fd552 100644 --- a/crates/neo-fold/src/session.rs +++ b/crates/neo-fold/src/session.rs @@ -592,7 +592,7 @@ where mode: FoldingMode, params: NeoParams, l: L, - commit_m: Option, + pub(crate) commit_m: Option, mixers: CommitMixers], &[Cmt]) -> Cmt, fn(&[Cmt], u32) -> Cmt>, // Cached CCS preprocessing for proving (best-effort reuse). @@ -1193,6 +1193,30 @@ where Ok(()) } + /// Preload prover-side CCS preprocessing with a precomputed matrix digest, skipping the + /// expensive `digest_ccs_matrices` call (~1.5s for large circuits). + pub fn preload_ccs_sparse_cache_with_digest( + &mut self, + s: &CcsStructure, + ccs_sparse_cache: Arc>, + ccs_mat_digest: Vec, + ) -> Result<(), PiCcsError> { + let src_ptr = (s as *const CcsStructure) as usize; + if ccs_sparse_cache.len() != s.t() { + return Err(PiCcsError::InvalidInput(format!( + "SparseCache matrix count mismatch: cache has {}, CCS has {}", + ccs_sparse_cache.len(), + s.t() + ))); + } + let ctx = ShardProverContext { + ccs_mat_digest, + ccs_sparse_cache: Some(ccs_sparse_cache), + }; + self.prover_ctx = Some(SessionCcsCache { src_ptr, ctx }); + Ok(()) + } + /// Preload verifier-side CCS preprocessing (sparse-cache + matrix-digest). /// /// This does **not** affect proving. It exists so benchmarks can model a verifier that has @@ -1206,6 +1230,29 @@ where Ok(()) } + /// Preload verifier-side CCS preprocessing with a precomputed matrix digest. + pub fn preload_verifier_ccs_sparse_cache_with_digest( + &mut self, + s: &CcsStructure, + ccs_sparse_cache: Arc>, + ccs_mat_digest: Vec, + ) -> Result<(), PiCcsError> { + let src_ptr = (s as *const CcsStructure) as usize; + if ccs_sparse_cache.len() != s.t() { + return Err(PiCcsError::InvalidInput(format!( + "SparseCache matrix count mismatch: cache has {}, CCS has {}", + ccs_sparse_cache.len(), + s.t() + ))); + } + let ctx = ShardProverContext { + ccs_mat_digest, + ccs_sparse_cache: Some(ccs_sparse_cache), + }; + self.verifier_ctx = Some(SessionCcsCache { src_ptr, ctx }); + Ok(()) + } + /// Fold and prove: run folding over all collected steps and return a `FoldRun`. /// This is where the actual cryptographic work happens (Π_CCS → RLC → DEC for each step). /// This method manages the transcript internally for ease of use. diff --git a/crates/neo-memory/src/builder.rs b/crates/neo-memory/src/builder.rs index ed35c991..f80db964 100644 --- a/crates/neo-memory/src/builder.rs +++ b/crates/neo-memory/src/builder.rs @@ -51,6 +51,10 @@ pub struct ShardWitnessAux { pub mem_ids: Vec, /// Final sparse memory states at the end of the shard: mem_id -> (addr -> value), with zero cells omitted. pub final_mem_states: HashMap>, + /// Wall-clock time for VM trace collection (`trace_program`). + pub vm_trace_duration: std::time::Duration, + /// Wall-clock time for CPU witness building (`build_ccs_chunks`), includes decomp + commit. + pub cpu_witness_duration: std::time::Duration, } fn ell_from_pow2_n_side(n_side: usize) -> Result { @@ -138,8 +142,10 @@ where // // This keeps the proof size proportional to the executed trace length instead of the caller's // safety bound. + let t_vm_start = std::time::Instant::now(); let trace = neo_vm_trace::trace_program(vm, twist, shout, max_steps) .map_err(|e| ShardBuildError::VmError(e.to_string()))?; + let vm_trace_duration = t_vm_start.elapsed(); let original_len = trace.steps.len(); let did_halt = trace.did_halt(); debug_assert!( @@ -193,9 +199,11 @@ where table_ids.dedup(); // 3) CPU arithmetization chunks. + let t_witness_start = std::time::Instant::now(); let mcss = cpu_arith .build_ccs_chunks(&trace, chunk_size) .map_err(|e| ShardBuildError::CcsError(e.to_string()))?; + let cpu_witness_duration = t_witness_start.elapsed(); if mcss.len() != chunks_len { return Err(ShardBuildError::CcsError(format!( "cpu arithmetization returned {} chunks, expected {} (steps={}, chunk_size={})", @@ -368,6 +376,8 @@ where chunk_size, mem_ids, final_mem_states: mem_states, + vm_trace_duration, + cpu_witness_duration, }; Ok((step_bundles, aux)) } diff --git a/crates/neo-memory/src/cpu/r1cs_adapter.rs b/crates/neo-memory/src/cpu/r1cs_adapter.rs index bf321513..5aa253d3 100644 --- a/crates/neo-memory/src/cpu/r1cs_adapter.rs +++ b/crates/neo-memory/src/cpu/r1cs_adapter.rs @@ -419,6 +419,38 @@ where }); Ok(self) } + + /// Like [`with_shared_cpu_bus`], but assumes the CCS already contains bus constraints. + /// + /// Use this when the `ccs` passed to [`R1csCpu::new`] is a pre-wired CCS from a cache + /// (e.g. from a previous `with_shared_cpu_bus` call). This sets up the internal + /// `SharedCpuBusState` needed for witness generation in [`build_ccs_chunks`] without + /// re-injecting constraint matrices into the CCS. + pub fn with_shared_cpu_bus_witness_only( + mut self, + cfg: SharedCpuBusConfig, + chunk_size: usize, + ) -> Result { + if chunk_size == 0 { + return Err("shared_cpu_bus: chunk_size must be >= 1".into()); + } + if cfg.const_one_col >= self.m_in { + return Err(format!( + "shared_cpu_bus: const_one_col={} must be < m_in={}", + cfg.const_one_col, self.m_in + )); + } + + let (table_ids, mem_ids, layout) = self.shared_bus_schema(&cfg, chunk_size)?; + + self.shared_cpu_bus = Some(SharedCpuBusState { + cfg, + table_ids, + mem_ids, + layout, + }); + Ok(self) + } } // R1csCpu implementation specifically for Goldilocks field because neo_ajtai::decomp_b uses Goldilocks @@ -451,6 +483,10 @@ where } let mut mcss = Vec::with_capacity(trace.steps.len().div_ceil(chunk_size)); + let mut t_witness_build = std::time::Duration::ZERO; + let mut t_bus_fill = std::time::Duration::ZERO; + let mut t_decomp_b = std::time::Duration::ZERO; + let mut t_commit = std::time::Duration::ZERO; // Track sparse memory state across the full trace to compute inc_at_write_addr. let mut mem_state: HashMap> = HashMap::new(); @@ -480,6 +516,7 @@ where let chunk = &trace.steps[chunk_start..chunk_end]; // 1) Build witness z for this chunk. + let t0 = std::time::Instant::now(); let mut z_vec = (self.chunk_to_witness)(chunk); // Allow witness builders to omit trailing dummy variables (including the shared-bus tail). @@ -512,8 +549,10 @@ where } z_vec[shared.cfg.const_one_col] = Goldilocks::ONE; } + t_witness_build += t0.elapsed(); // 2) Overwrite the shared bus tail from the trace events. + let t0 = std::time::Instant::now(); if let Some(shared) = shared { let bus_base = shared.layout.bus_base; let bus_region_len = shared.layout.bus_region_len(); @@ -778,7 +817,10 @@ where } } + t_bus_fill += t0.elapsed(); + // 3) Decompose z -> Z matrix + let t0 = std::time::Instant::now(); let d = self.params.d as usize; let m = z_vec.len(); // == ccs.m after padding @@ -799,9 +841,12 @@ where } } let z_mat = Mat::from_row_major(d, m, mat_data); + t_decomp_b += t0.elapsed(); // 4) Commit to Z + let t0 = std::time::Instant::now(); let c = self.committer.commit(&z_mat); + t_commit += t0.elapsed(); // 5) Build Instance/Witness let x = z_vec[..m_in].to_vec(); @@ -812,6 +857,15 @@ where chunk_start = chunk_end; } + eprintln!( + "[build_ccs_chunks] witness_build={}ms, bus_fill={}ms, decomp_b={}ms, commit={}ms, total={}ms", + t_witness_build.as_millis(), + t_bus_fill.as_millis(), + t_decomp_b.as_millis(), + t_commit.as_millis(), + (t_witness_build + t_bus_fill + t_decomp_b + t_commit).as_millis(), + ); + Ok(mcss) } From 89ccc93120e50b7c2409676be2f79c8fde7caf81 Mon Sep 17 00:00:00 2001 From: Guillermo Valin Date: Tue, 17 Feb 2026 13:32:07 -0300 Subject: [PATCH 07/13] Fix compile breakages after sovereign merge --- crates/neo-fold/src/riscv_shard.rs | 12 ++++-------- crates/neo-memory/src/builder.rs | 2 +- crates/neo-memory/src/riscv/ccs/witness.rs | 6 ------ crates/neo-memory/src/riscv/lookups/cpu.rs | 1 + 4 files changed, 6 insertions(+), 15 deletions(-) diff --git a/crates/neo-fold/src/riscv_shard.rs b/crates/neo-fold/src/riscv_shard.rs index 5e23e3a3..ecd1389c 100644 --- a/crates/neo-fold/src/riscv_shard.rs +++ b/crates/neo-fold/src/riscv_shard.rs @@ -18,6 +18,7 @@ use crate::PiCcsError; use neo_ajtai::{AjtaiSModule, Commitment as Cmt}; use neo_ccs::{CcsStructure, Mat, MeInstance}; use neo_math::{F, K}; +use neo_reductions::engines::optimized_engine::oracle::SparseCache; use neo_memory::mem_init_from_initial_mem; use neo_memory::output_check::ProgramIO; use neo_memory::plain::LutTable; @@ -778,10 +779,10 @@ impl Rv32B1 { // Always enforce step-to-step chunk chaining for RV32 B1. session.set_step_linking(rv32_b1_step_linking_config(&layout)); - let t_ccs_and_shared_bus = elapsed_duration(phase_start); + let _t_ccs_and_shared_bus = elapsed_duration(phase_start); // === Phase 3: VM execution + shard collection === - let phase_start = time_now(); + let _phase_start = time_now(); // Execute + collect step bundles (and aux for output binding). let build_start = time_now(); @@ -811,7 +812,7 @@ impl Rv32B1 { } // Extract sub-phase timings from the shared-bus aux. - let (t_vm_trace, t_cpu_witness) = session + let (_t_vm_trace, _t_cpu_witness) = session .shared_bus_aux() .map(|aux| (aux.vm_trace_duration, aux.cpu_witness_duration)) .unwrap_or((Duration::ZERO, Duration::ZERO)); @@ -1594,11 +1595,6 @@ impl Rv32B1Run { self.prove_duration } - /// Per-phase timing breakdown for the prove call. - pub fn prove_timings(&self) -> &ProveTimings { - &self.prove_timings - } - pub fn verify_duration(&self) -> Option { self.verify_duration } diff --git a/crates/neo-memory/src/builder.rs b/crates/neo-memory/src/builder.rs index 3f9df1a2..d2dbdea3 100644 --- a/crates/neo-memory/src/builder.rs +++ b/crates/neo-memory/src/builder.rs @@ -404,7 +404,7 @@ where chunk_size, mem_ids, final_mem_states: mem_states, - vm_trace_duration, + vm_trace_duration: std::time::Duration::ZERO, cpu_witness_duration, }; Ok((step_bundles, aux)) diff --git a/crates/neo-memory/src/riscv/ccs/witness.rs b/crates/neo-memory/src/riscv/ccs/witness.rs index 1d7e2437..3c5890ab 100644 --- a/crates/neo-memory/src/riscv/ccs/witness.rs +++ b/crates/neo-memory/src/riscv/ccs/witness.rs @@ -1317,12 +1317,6 @@ fn rv32_b1_chunk_to_witness_internal( z[layout.br_not_taken(j)] = F::ONE - taken; } - // Poseidon2 read ECALL: set rd_write_val = ecall_rd_val so the existing - // enforce_u32_bits range check covers it. - if is_poseidon2_read_ecall { - z[layout.rd_write_val(j)] = z[layout.ecall_rd_val(j)]; - } - let mul_carry = if is_mulh { let rhs = (mul_hi as i128) - (rs1_sign as i128) * (rs2_u64 as i128) - (rs2_sign as i128) * (rs1_u64 as i128) diff --git a/crates/neo-memory/src/riscv/lookups/cpu.rs b/crates/neo-memory/src/riscv/lookups/cpu.rs index de9f76b8..624c63b1 100644 --- a/crates/neo-memory/src/riscv/lookups/cpu.rs +++ b/crates/neo-memory/src/riscv/lookups/cpu.rs @@ -7,6 +7,7 @@ use super::decode::decode_instruction; use super::encode::encode_instruction; use super::isa::{BranchCondition, RiscvInstruction, RiscvMemOp, RiscvOpcode}; use super::tables::RiscvShoutTables; +use super::{POSEIDON2_ECALL_NUM, POSEIDON2_READ_ECALL_NUM}; /// A RISC-V CPU that can be traced using Neo's VmCpu trait. /// From d78286d5c18a0935f09d90b2b23b5abd5ecb9c83 Mon Sep 17 00:00:00 2001 From: Guillermo Valin Date: Tue, 17 Feb 2026 13:39:41 -0300 Subject: [PATCH 08/13] Restore sovereign RV32 builder/verifier compatibility APIs --- crates/neo-fold/src/riscv_shard.rs | 479 ++++++++++++++++++++++++++++- 1 file changed, 478 insertions(+), 1 deletion(-) diff --git a/crates/neo-fold/src/riscv_shard.rs b/crates/neo-fold/src/riscv_shard.rs index ecd1389c..446133a5 100644 --- a/crates/neo-fold/src/riscv_shard.rs +++ b/crates/neo-fold/src/riscv_shard.rs @@ -371,6 +371,7 @@ pub struct Rv32B1 { output_target: OutputTarget, ram_init: HashMap, reg_init: HashMap, + ccs_cache: Option>, } /// Default instruction cap for RV32B1 runs when `max_steps` is not specified. @@ -416,6 +417,7 @@ impl Rv32B1 { output_target: OutputTarget::Ram, ram_init: HashMap::new(), reg_init: HashMap::new(), + ccs_cache: None, } } @@ -532,6 +534,406 @@ impl Rv32B1 { self } + /// Attach a pre-built CCS cache to skip CCS synthesis in [`prove`] and [`build_verifier`]. + pub fn with_ccs_cache(mut self, cache: Arc) -> Self { + self.ccs_cache = Some(cache); + self + } + + /// Build the CCS preprocessing cache from the current builder configuration. + pub fn build_ccs_cache(&self) -> Result { + if self.xlen != 32 { + return Err(PiCcsError::InvalidInput(format!( + "RV32 B1 MVP requires xlen == 32 (got {})", + self.xlen + ))); + } + if self.program_bytes.is_empty() { + return Err(PiCcsError::InvalidInput("program_bytes must be non-empty".into())); + } + if !self.chunk_size_auto && self.chunk_size == 0 { + return Err(PiCcsError::InvalidInput("chunk_size must be non-zero".into())); + } + if self.ram_bytes == 0 { + return Err(PiCcsError::InvalidInput("ram_bytes must be non-zero".into())); + } + if self.program_base != 0 { + return Err(PiCcsError::InvalidInput( + "RV32 B1 MVP requires program_base == 0 (addresses are indices into PROG/RAM layouts)".into(), + )); + } + + let program = decode_program(&self.program_bytes) + .map_err(|e| PiCcsError::InvalidInput(format!("decode_program failed: {e}")))?; + + let estimated_steps = match self.max_steps { + Some(n) => { + if n == 0 { + return Err(PiCcsError::InvalidInput("max_steps must be non-zero".into())); + } + n + } + None => program.len().max(1), + }; + + let (prog_layout, initial_mem) = neo_memory::riscv::rom_init::prog_rom_layout_and_init_words::( + PROG_ID, + /*base_addr=*/ 0, + &self.program_bytes, + ) + .map_err(|e| PiCcsError::InvalidInput(format!("prog_rom_layout_and_init_words failed: {e}")))?; + let mut initial_mem = initial_mem; + for (&addr, &value) in &self.ram_init { + let value = value as u32 as u64; + initial_mem.insert((neo_memory::riscv::lookups::RAM_ID.0, addr), F::from_u64(value)); + } + for (®, &value) in &self.reg_init { + let value = value as u32 as u64; + initial_mem.insert((neo_memory::riscv::lookups::REG_ID.0, reg), F::from_u64(value)); + } + + let (k_ram, d_ram) = pow2_ceil_k(self.ram_bytes.max(4)); + let mem_layouts = HashMap::from([ + ( + neo_memory::riscv::lookups::RAM_ID.0, + PlainMemLayout { + k: k_ram, + d: d_ram, + n_side: 2, + lanes: 1, + }, + ), + ( + neo_memory::riscv::lookups::REG_ID.0, + PlainMemLayout { + k: 32, + d: 5, + n_side: 2, + lanes: 2, + }, + ), + (PROG_ID.0, prog_layout), + ]); + + let shout = RiscvShoutTables::new(self.xlen); + let inferred_shout_ops = infer_required_shout_opcodes(&program); + let mut shout_ops = match &self.shout_ops { + Some(ops) => { + let missing: HashSet = inferred_shout_ops.difference(ops).copied().collect(); + if !missing.is_empty() { + let mut missing_names: Vec = missing.into_iter().map(|op| format!("{op:?}")).collect(); + missing_names.sort_unstable(); + return Err(PiCcsError::InvalidInput(format!( + "shout_ops override must be a superset of required opcodes; missing [{}]", + missing_names.join(", ") + ))); + } + ops.clone() + } + None if self.shout_auto_minimal => inferred_shout_ops, + None => all_shout_opcodes(), + }; + shout_ops.insert(RiscvOpcode::Add); + + let mut table_specs: HashMap = HashMap::new(); + for op in shout_ops { + let table_id = shout.opcode_to_id(op).0; + table_specs.insert( + table_id, + LutTableSpec::RiscvOpcode { + opcode: op, + xlen: self.xlen, + }, + ); + } + let mut shout_table_ids: Vec = table_specs.keys().copied().collect(); + shout_table_ids.sort_unstable(); + + let chunk_size = if self.chunk_size_auto { + choose_rv32_b1_chunk_size(&mem_layouts, &shout_table_ids, estimated_steps) + .map_err(|e| PiCcsError::InvalidInput(format!("auto chunk_size failed: {e}")))? + } else { + self.chunk_size + }; + if chunk_size == 0 { + return Err(PiCcsError::InvalidInput("chunk_size must be non-zero".into())); + } + + let (ccs_base, layout) = build_rv32_b1_step_ccs(&mem_layouts, &shout_table_ids, chunk_size) + .map_err(|e| PiCcsError::InvalidInput(format!("build_rv32_b1_step_ccs failed: {e}")))?; + + let session = FoldingSession::::new_ajtai(self.mode.clone(), &ccs_base)?; + let params = session.params().clone(); + let committer = session.committer().clone(); + let empty_tables: HashMap> = HashMap::new(); + + let mut cpu = R1csCpu::new( + ccs_base, + params, + committer, + layout.m_in, + &empty_tables, + &table_specs, + rv32_b1_chunk_to_witness(layout.clone()), + ) + .map_err(|e| PiCcsError::InvalidInput(format!("R1csCpu::new failed: {e}")))?; + cpu = cpu + .with_shared_cpu_bus( + rv32_b1_shared_cpu_bus_config(&layout, &shout_table_ids, mem_layouts, initial_mem) + .map_err(|e| PiCcsError::InvalidInput(format!("rv32_b1_shared_cpu_bus_config failed: {e}")))?, + chunk_size, + ) + .map_err(|e| PiCcsError::InvalidInput(format!("shared bus inject failed: {e}")))?; + + let full_ccs = cpu.ccs.clone(); + let sparse = Arc::new(SparseCache::build(&full_ccs)); + let ccs_mat_digest = neo_reductions::engines::utils::digest_ccs_matrices(&full_ccs); + + Ok(Rv32B1CcsCache { + sparse, + full_ccs: Some(full_ccs), + layout: Some(layout), + params: Some(session.params().clone()), + committer: Some(session.committer().clone()), + ccs_mat_digest: Some(ccs_mat_digest), + }) + } + + /// Build only the verification context (CCS + session) without executing the program or proving. + pub fn build_verifier(self) -> Result { + if self.xlen != 32 { + return Err(PiCcsError::InvalidInput(format!( + "RV32 B1 MVP requires xlen == 32 (got {})", + self.xlen + ))); + } + if self.program_bytes.is_empty() { + return Err(PiCcsError::InvalidInput("program_bytes must be non-empty".into())); + } + if !self.chunk_size_auto && self.chunk_size == 0 { + return Err(PiCcsError::InvalidInput("chunk_size must be non-zero".into())); + } + if self.ram_bytes == 0 { + return Err(PiCcsError::InvalidInput("ram_bytes must be non-zero".into())); + } + if self.program_base != 0 { + return Err(PiCcsError::InvalidInput( + "RV32 B1 MVP requires program_base == 0 (addresses are indices into PROG/RAM layouts)".into(), + )); + } + if self.program_bytes.len() % 4 != 0 { + return Err(PiCcsError::InvalidInput( + "program_bytes must be 4-byte aligned (RV32 B1 runner does not support RVC)".into(), + )); + } + for (i, chunk) in self.program_bytes.chunks_exact(4).enumerate() { + let first_half = u16::from_le_bytes([chunk[0], chunk[1]]); + if (first_half & 0b11) != 0b11 { + return Err(PiCcsError::InvalidInput(format!( + "RV32 B1 runner does not support compressed instructions (RVC): found compressed encoding at word index {i}" + ))); + } + } + + let program = decode_program(&self.program_bytes) + .map_err(|e| PiCcsError::InvalidInput(format!("decode_program failed: {e}")))?; + + let estimated_steps = match self.max_steps { + Some(n) => { + if n == 0 { + return Err(PiCcsError::InvalidInput("max_steps must be non-zero".into())); + } + n + } + None => program.len().max(1), + }; + + let (prog_layout, initial_mem) = neo_memory::riscv::rom_init::prog_rom_layout_and_init_words( + PROG_ID, + /*base_addr=*/ 0, + &self.program_bytes, + ) + .map_err(|e| PiCcsError::InvalidInput(format!("prog_rom_layout_and_init_words failed: {e}")))?; + let mut initial_mem = initial_mem; + for (&addr, &value) in &self.ram_init { + let value = value as u32 as u64; + initial_mem.insert((neo_memory::riscv::lookups::RAM_ID.0, addr), F::from_u64(value)); + } + for (®, &value) in &self.reg_init { + let value = value as u32 as u64; + initial_mem.insert((neo_memory::riscv::lookups::REG_ID.0, reg), F::from_u64(value)); + } + + let (k_ram, d_ram) = pow2_ceil_k(self.ram_bytes.max(4)); + let mem_layouts = HashMap::from([ + ( + neo_memory::riscv::lookups::RAM_ID.0, + PlainMemLayout { + k: k_ram, + d: d_ram, + n_side: 2, + lanes: 1, + }, + ), + ( + neo_memory::riscv::lookups::REG_ID.0, + PlainMemLayout { + k: 32, + d: 5, + n_side: 2, + lanes: 2, + }, + ), + (PROG_ID.0, prog_layout), + ]); + + let shout = RiscvShoutTables::new(self.xlen); + let inferred_shout_ops = infer_required_shout_opcodes(&program); + let mut shout_ops = match &self.shout_ops { + Some(ops) => { + let missing: HashSet = inferred_shout_ops.difference(ops).copied().collect(); + if !missing.is_empty() { + let mut missing_names: Vec = missing.into_iter().map(|op| format!("{op:?}")).collect(); + missing_names.sort_unstable(); + return Err(PiCcsError::InvalidInput(format!( + "shout_ops override must be a superset of required opcodes; missing [{}]", + missing_names.join(", ") + ))); + } + ops.clone() + } + None if self.shout_auto_minimal => inferred_shout_ops, + None => all_shout_opcodes(), + }; + shout_ops.insert(RiscvOpcode::Add); + + let mut table_specs: HashMap = HashMap::new(); + for op in shout_ops { + let table_id = shout.opcode_to_id(op).0; + table_specs.insert( + table_id, + LutTableSpec::RiscvOpcode { + opcode: op, + xlen: self.xlen, + }, + ); + } + let mut shout_table_ids: Vec = table_specs.keys().copied().collect(); + shout_table_ids.sort_unstable(); + + let chunk_size = if self.chunk_size_auto { + choose_rv32_b1_chunk_size(&mem_layouts, &shout_table_ids, estimated_steps) + .map_err(|e| PiCcsError::InvalidInput(format!("auto chunk_size failed: {e}")))? + } else { + self.chunk_size + }; + if chunk_size == 0 { + return Err(PiCcsError::InvalidInput("chunk_size must be non-zero".into())); + } + + let empty_tables: HashMap> = HashMap::new(); + let has_full_ccs = self + .ccs_cache + .as_ref() + .map(|c| c.full_ccs.is_some() && c.layout.is_some()) + .unwrap_or(false); + + let mut session: FoldingSession; + let ccs: CcsStructure; + let layout: Rv32B1Layout; + + if has_full_ccs { + let cache = self.ccs_cache.as_ref().expect("ccs_cache checked above"); + let cached_ccs = cache.full_ccs.as_ref().expect("cache has full_ccs").clone(); + layout = cache.layout.as_ref().expect("cache has layout").clone(); + if let (Some(cached_params), Some(cached_committer)) = (&cache.params, &cache.committer) { + session = FoldingSession::new(self.mode.clone(), cached_params.clone(), cached_committer.clone()); + session.commit_m = Some(cached_ccs.m); + } else { + session = FoldingSession::::new_ajtai(self.mode.clone(), &cached_ccs)?; + } + ccs = cached_ccs; + } else { + let (ccs_base, layout_built) = build_rv32_b1_step_ccs(&mem_layouts, &shout_table_ids, chunk_size) + .map_err(|e| PiCcsError::InvalidInput(format!("build_rv32_b1_step_ccs failed: {e}")))?; + layout = layout_built; + + session = FoldingSession::::new_ajtai(self.mode.clone(), &ccs_base)?; + let params = session.params().clone(); + let committer = session.committer().clone(); + + let mut cpu = R1csCpu::new( + ccs_base, + params, + committer, + layout.m_in, + &empty_tables, + &table_specs, + rv32_b1_chunk_to_witness(layout.clone()), + ) + .map_err(|e| PiCcsError::InvalidInput(format!("R1csCpu::new failed: {e}")))?; + cpu = cpu + .with_shared_cpu_bus( + rv32_b1_shared_cpu_bus_config(&layout, &shout_table_ids, mem_layouts.clone(), initial_mem.clone()) + .map_err(|e| PiCcsError::InvalidInput(format!("rv32_b1_shared_cpu_bus_config failed: {e}")))?, + chunk_size, + ) + .map_err(|e| PiCcsError::InvalidInput(format!("shared bus inject failed: {e}")))?; + ccs = cpu.ccs.clone(); + } + + session.set_step_linking(rv32_b1_step_linking_config(&layout)); + + let output_binding_cfg = if self.output_claims.is_empty() { + None + } else { + let out_mem_id = match self.output_target { + OutputTarget::Ram => neo_memory::riscv::lookups::RAM_ID.0, + OutputTarget::Reg => neo_memory::riscv::lookups::REG_ID.0, + }; + let out_layout = mem_layouts.get(&out_mem_id).ok_or_else(|| { + PiCcsError::InvalidInput(format!( + "output binding: missing PlainMemLayout for mem_id={out_mem_id}" + )) + })?; + let expected_k = 1usize + .checked_shl(out_layout.d as u32) + .ok_or_else(|| PiCcsError::InvalidInput("output binding: 2^d overflow".into()))?; + if out_layout.k != expected_k { + return Err(PiCcsError::InvalidInput(format!( + "output binding: mem_id={out_mem_id} has k={}, but expected 2^d={} (d={})", + out_layout.k, expected_k, out_layout.d + ))); + } + let mut mem_ids: Vec = mem_layouts.keys().copied().collect(); + mem_ids.sort_unstable(); + let mem_idx = mem_ids + .iter() + .position(|&id| id == out_mem_id) + .ok_or_else(|| PiCcsError::InvalidInput("output binding: mem_id not in mem_layouts".into()))?; + Some(OutputBindingConfig::new(out_layout.d, self.output_claims.clone()).with_mem_idx(mem_idx)) + }; + + let mut verifier = Rv32B1Verifier { + session, + ccs, + _layout: layout, + mem_layouts, + statement_initial_mem: initial_mem, + output_binding_cfg, + }; + + if let Some(cache) = self.ccs_cache { + if let Some(digest) = cache.ccs_mat_digest.clone() { + verifier.preload_sparse_cache_with_digest(cache.sparse.clone(), digest)?; + } else { + verifier.preload_sparse_cache(cache.sparse.clone())?; + } + } + + Ok(verifier) + } + /// Prove/verify only the Tier 2.1 trace-wiring CCS (time-in-rows). /// /// `chunk_size`, `chunk_size_auto`, `ram_bytes`, and Shout-table selection knobs are ignored @@ -1054,6 +1456,58 @@ impl Rv32B1 { } } +/// Verification context for RV32 B1 proofs. +/// +/// Created by [`Rv32B1::build_verifier`]. Contains CCS/session state required to +/// verify proofs without re-running the guest. +pub struct Rv32B1Verifier { + session: FoldingSession, + ccs: CcsStructure, + _layout: Rv32B1Layout, + mem_layouts: HashMap, + statement_initial_mem: HashMap<(u32, u64), F>, + output_binding_cfg: Option, +} + +impl Rv32B1Verifier { + /// Preload the verifier SparseCache using `&self.ccs` (final pointer). + pub fn preload_sparse_cache(&mut self, sparse: Arc>) -> Result<(), PiCcsError> { + self.session + .preload_verifier_ccs_sparse_cache(&self.ccs, sparse) + } + + /// Preload verifier SparseCache with a precomputed matrix digest. + pub fn preload_sparse_cache_with_digest( + &mut self, + sparse: Arc>, + ccs_mat_digest: Vec, + ) -> Result<(), PiCcsError> { + self.session + .preload_verifier_ccs_sparse_cache_with_digest(&self.ccs, sparse, ccs_mat_digest) + } + + /// Verify a main `ShardProof` against externally supplied public steps. + pub fn verify( + &self, + proof: &ShardProof, + steps_public: &[StepInstanceBundle], + ) -> Result { + rv32_b1_enforce_chunk0_mem_init_matches_statement( + &self.mem_layouts, + &self.statement_initial_mem, + steps_public, + )?; + + if let Some(ob_cfg) = &self.output_binding_cfg { + self.session + .verify_with_external_steps_and_output_binding(&self.ccs, steps_public, proof, ob_cfg) + } else { + self.session + .verify_with_external_steps(&self.ccs, steps_public, proof) + } + } +} + #[derive(Clone, Debug)] pub struct PiCcsProofBundle { pub num_steps: usize, @@ -1407,7 +1861,13 @@ impl Rv32B1Run { Ok(()) } - pub fn proof(&self) -> &Rv32B1ProofBundle { + /// Main folding proof (without sidecar proofs). + pub fn proof(&self) -> &ShardProof { + &self.proof_bundle.main + } + + /// Full RV32 B1 proof bundle (main + sidecars). + pub fn proof_bundle(&self) -> &Rv32B1ProofBundle { &self.proof_bundle } @@ -1595,6 +2055,23 @@ impl Rv32B1Run { self.prove_duration } + /// Compatibility timing breakdown used by downstream adapters. + pub fn prove_timings(&self) -> ProveTimings { + let (vm_trace, cpu_witness) = self + .session + .shared_bus_aux() + .map(|aux| (aux.vm_trace_duration, aux.cpu_witness_duration)) + .unwrap_or((Duration::ZERO, Duration::ZERO)); + ProveTimings { + decode_and_setup: self.prove_phase_durations.setup, + ccs_and_shared_bus: Duration::ZERO, + vm_execution: self.prove_phase_durations.build_commit, + fold_and_prove: self.prove_phase_durations.fold_and_prove, + vm_trace, + cpu_witness, + } + } + pub fn verify_duration(&self) -> Option { self.verify_duration } From cca6616a2702091e67e271aca2b8dc21bb4b798b Mon Sep 17 00:00:00 2001 From: Guillermo Valin Date: Tue, 17 Feb 2026 15:45:45 -0300 Subject: [PATCH 09/13] Refactor proof handling in RISC-V tests and memory sidecar - Updated proof retrieval in RISC-V tests to use `proof_bundle()` instead of `proof()`, ensuring consistency across test cases. - Modified memory sidecar functions to utilize `rv32_trace_lookup_addr_group_for_table_shape` for address group lookups, enhancing shape awareness in shared-bus scenarios. - Introduced new helper functions for checking memory ID instances and shout specifications in the context of RV32 trace, improving clarity and maintainability of the code. --- ...cv_fibonacci_compiled_full_prove_verify.rs | 2 +- .../test_riscv_program_full_prove_verify.rs | 4 +- .../neo-fold/src/memory_sidecar/claim_plan.rs | 4 +- crates/neo-fold/src/memory_sidecar/cpu_bus.rs | 8 ++- crates/neo-fold/src/memory_sidecar/memory.rs | 68 ++++++++++++++++++- .../perf/single_addi_metrics_nightstream.rs | 14 ++-- .../suites/redteam/riscv_verifier_gaps.rs | 8 +-- .../riscv_bus_binding_redteam.rs | 9 ++- .../redteam_riscv/riscv_main_proof_redteam.rs | 12 ++-- .../riscv_twist_shout_redteam.rs | 18 ++--- .../rv32m/rv32m_sidecar_sparse_steps.rs | 6 +- crates/neo-memory/src/cpu/constraints.rs | 4 +- crates/neo-memory/src/cpu/r1cs_adapter.rs | 4 +- .../neo-memory/src/riscv/ccs/bus_bindings.rs | 13 ++-- crates/neo-memory/src/riscv/ccs/layout.rs | 17 ++++- crates/neo-memory/src/riscv/ccs/witness.rs | 10 +-- crates/neo-memory/src/riscv/trace/mod.rs | 14 ++++ crates/neo-memory/tests/riscv_ccs_tests.rs | 7 +- 18 files changed, 160 insertions(+), 62 deletions(-) diff --git a/crates/neo-fold/riscv-tests/test_riscv_fibonacci_compiled_full_prove_verify.rs b/crates/neo-fold/riscv-tests/test_riscv_fibonacci_compiled_full_prove_verify.rs index ccd0c73d..b5ec38d0 100644 --- a/crates/neo-fold/riscv-tests/test_riscv_fibonacci_compiled_full_prove_verify.rs +++ b/crates/neo-fold/riscv-tests/test_riscv_fibonacci_compiled_full_prove_verify.rs @@ -133,7 +133,7 @@ fn test_riscv_fibonacci_compiled_full_prove_verify() { // Print proof size estimate { - let proof = run.proof(); + let proof = run.proof_bundle(); let num_steps = proof.main.steps.len(); // Each MeInstance has exactly one commitment let num_commitments: usize = proof diff --git a/crates/neo-fold/riscv-tests/test_riscv_program_full_prove_verify.rs b/crates/neo-fold/riscv-tests/test_riscv_program_full_prove_verify.rs index 444c2449..06f46295 100644 --- a/crates/neo-fold/riscv-tests/test_riscv_program_full_prove_verify.rs +++ b/crates/neo-fold/riscv-tests/test_riscv_program_full_prove_verify.rs @@ -70,7 +70,7 @@ fn test_riscv_program_full_prove_verify() { // Ensure the Shout addr-pre proof skips inactive tables. // This program uses only the ADD lookup; LUI and HALT use no Shout lookups and should skip entirely. - let proof = run.proof(); + let proof = run.proof_bundle(); let mut saw_skipped = false; let mut saw_add_only = false; for step in &proof.main.steps { @@ -420,7 +420,7 @@ fn test_riscv_program_rv32m_full_prove_verify() { assert_eq!(rv32m_chunks, vec![2, 3], "expected RV32M rows on the MUL/DIV chunks"); let rv32m = run - .proof() + .proof_bundle() .rv32m .as_ref() .expect("expected RV32M sidecar proofs"); diff --git a/crates/neo-fold/src/memory_sidecar/claim_plan.rs b/crates/neo-fold/src/memory_sidecar/claim_plan.rs index 58188ab6..41ed7528 100644 --- a/crates/neo-fold/src/memory_sidecar/claim_plan.rs +++ b/crates/neo-fold/src/memory_sidecar/claim_plan.rs @@ -1,7 +1,7 @@ use neo_ajtai::Commitment as Cmt; use neo_math::{F, K}; use neo_memory::riscv::lookups::RiscvOpcode; -use neo_memory::riscv::trace::rv32_trace_lookup_addr_group_for_table_id; +use neo_memory::riscv::trace::rv32_trace_lookup_addr_group_for_table_shape; use neo_memory::witness::{LutInstance, LutTableSpec, MemInstance, StepInstanceBundle}; use crate::PiCcsError; @@ -104,7 +104,7 @@ impl RouteATimeClaimPlan { for (inst_idx, lut_inst) in lut_insts.iter().enumerate() { let lanes = lut_inst.lanes.max(1); let ell_addr = lut_inst.d * lut_inst.ell; - let addr_group = rv32_trace_lookup_addr_group_for_table_id(lut_inst.table_id); + let addr_group = rv32_trace_lookup_addr_group_for_table_shape(lut_inst.table_id, ell_addr); let is_packed = matches!( lut_inst.table_spec, Some(LutTableSpec::RiscvOpcodePacked { .. } | LutTableSpec::RiscvOpcodeEventTablePacked { .. }) diff --git a/crates/neo-fold/src/memory_sidecar/cpu_bus.rs b/crates/neo-fold/src/memory_sidecar/cpu_bus.rs index fcbaa0bc..2b13d2c5 100644 --- a/crates/neo-fold/src/memory_sidecar/cpu_bus.rs +++ b/crates/neo-fold/src/memory_sidecar/cpu_bus.rs @@ -8,7 +8,7 @@ use neo_memory::cpu::{ }; use neo_memory::riscv::lookups::{PROG_ID, REG_ID}; use neo_memory::riscv::trace::{ - rv32_is_decode_lookup_table_id, rv32_is_width_lookup_table_id, rv32_trace_lookup_addr_group_for_table_id, + rv32_is_decode_lookup_table_id, rv32_is_width_lookup_table_id, rv32_trace_lookup_addr_group_for_table_shape, rv32_trace_lookup_selector_group_for_table_id, }; use neo_memory::sparse_time::SparseIdxVec; @@ -127,7 +127,8 @@ fn infer_bus_layout_for_steps>( .collect(); let base_shout_addr_groups: Vec> = (0..steps[0].lut_insts_len()) .map(|i| { - rv32_trace_lookup_addr_group_for_table_id(steps[0].lut_inst(i).table_id).map(|v| v as u64) + let inst = steps[0].lut_inst(i); + rv32_trace_lookup_addr_group_for_table_shape(inst.table_id, inst.d * inst.ell).map(|v| v as u64) }) .collect(); let base_shout_selector_groups: Vec> = (0..steps[0].lut_insts_len()) @@ -161,7 +162,8 @@ fn infer_bus_layout_for_steps>( .collect(); let cur_shout_addr_groups: Vec> = (0..step.lut_insts_len()) .map(|j| { - rv32_trace_lookup_addr_group_for_table_id(step.lut_inst(j).table_id).map(|v| v as u64) + let inst = step.lut_inst(j); + rv32_trace_lookup_addr_group_for_table_shape(inst.table_id, inst.d * inst.ell).map(|v| v as u64) }) .collect(); let cur_shout_selector_groups: Vec> = (0..step.lut_insts_len()) diff --git a/crates/neo-fold/src/memory_sidecar/memory.rs b/crates/neo-fold/src/memory_sidecar/memory.rs index 1bb8236e..652da0df 100644 --- a/crates/neo-fold/src/memory_sidecar/memory.rs +++ b/crates/neo-fold/src/memory_sidecar/memory.rs @@ -22,7 +22,7 @@ use neo_memory::riscv::lookups::{PROG_ID, RAM_ID, REG_ID}; use neo_memory::riscv::shout_oracle::RiscvAddressLookupOracleSparse; use neo_memory::riscv::trace::{ rv32_decode_lookup_backed_cols, rv32_decode_lookup_backed_row_from_instr_word, rv32_decode_lookup_table_id_for_col, - rv32_is_decode_lookup_table_id, rv32_is_width_lookup_table_id, rv32_trace_lookup_addr_group_for_table_id, + rv32_is_decode_lookup_table_id, rv32_is_width_lookup_table_id, rv32_trace_lookup_addr_group_for_table_shape, rv32_trace_lookup_selector_group_for_table_id, rv32_width_lookup_backed_cols, rv32_width_lookup_table_id_for_col, Rv32DecodeSidecarLayout, Rv32TraceLayout, Rv32WidthSidecarLayout, @@ -5511,16 +5511,80 @@ fn has_trace_lookup_families_witness(step: &StepWitnessBundle) -> boo }) } +#[inline] +fn has_rv32_trace_twist_mem_ids_instance(step: &StepInstanceBundle) -> bool { + if step.mem_insts.is_empty() { + return false; + } + let mut ids = std::collections::BTreeSet::::new(); + for inst in step.mem_insts.iter() { + ids.insert(inst.mem_id); + } + let required = std::collections::BTreeSet::from([PROG_ID.0, REG_ID.0]); + let allowed = std::collections::BTreeSet::from([PROG_ID.0, REG_ID.0, RAM_ID.0]); + required.is_subset(&ids) && ids.is_subset(&allowed) +} + +#[inline] +fn has_rv32_trace_twist_mem_ids_witness(step: &StepWitnessBundle) -> bool { + if step.mem_instances.is_empty() { + return false; + } + let mut ids = std::collections::BTreeSet::::new(); + for (inst, _) in step.mem_instances.iter() { + ids.insert(inst.mem_id); + } + let required = std::collections::BTreeSet::from([PROG_ID.0, REG_ID.0]); + let allowed = std::collections::BTreeSet::from([PROG_ID.0, REG_ID.0, RAM_ID.0]); + required.is_subset(&ids) && ids.is_subset(&allowed) +} + +#[inline] +fn has_rv32_trace_shout_specs_instance(step: &StepInstanceBundle) -> bool { + step.lut_insts.iter().any(|inst| { + matches!( + inst.table_spec, + Some( + LutTableSpec::RiscvOpcode { xlen: 32, .. } + | LutTableSpec::RiscvOpcodePacked { xlen: 32, .. } + | LutTableSpec::RiscvOpcodeEventTablePacked { xlen: 32, .. } + ) + ) + }) +} + +#[inline] +fn has_rv32_trace_shout_specs_witness(step: &StepWitnessBundle) -> bool { + step.lut_instances.iter().any(|(inst, _)| { + matches!( + inst.table_spec, + Some( + LutTableSpec::RiscvOpcode { xlen: 32, .. } + | LutTableSpec::RiscvOpcodePacked { xlen: 32, .. } + | LutTableSpec::RiscvOpcodeEventTablePacked { xlen: 32, .. } + ) + ) + }) +} + #[inline] pub(crate) fn wb_wp_required_for_step_instance(step: &StepInstanceBundle) -> bool { // Stage gating is keyed by lookup-family presence instead of fixed `m_in`/`mem_id` // assumptions so adapter-side routing can evolve without hardcoding RV32 shapes here. + // + // No-shared RV32 trace mode also needs CPU trace openings for: + // - packed/event-table Shout linkage (spec-tagged), and + // - Twist(PROG/REG[/RAM]) linkage (mem-id keyed). has_trace_lookup_families_instance(step) + || (step.mcs_inst.m_in == 5 + && (has_rv32_trace_shout_specs_instance(step) || has_rv32_trace_twist_mem_ids_instance(step))) } #[inline] pub(crate) fn wb_wp_required_for_step_witness(step: &StepWitnessBundle) -> bool { has_trace_lookup_families_witness(step) + || (step.mcs.0.m_in == 5 + && (has_rv32_trace_shout_specs_witness(step) || has_rv32_trace_twist_mem_ids_witness(step))) } pub(crate) fn build_bus_layout_for_step_witness( @@ -5536,7 +5600,7 @@ pub(crate) fn build_bus_layout_for_step_witness( ell_addr: inst.d * inst.ell, lanes: inst.lanes.max(1), n_vals: 1usize, - addr_group: rv32_trace_lookup_addr_group_for_table_id(inst.table_id).map(|v| v as u64), + addr_group: rv32_trace_lookup_addr_group_for_table_shape(inst.table_id, inst.d * inst.ell).map(|v| v as u64), selector_group: rv32_trace_lookup_selector_group_for_table_id(inst.table_id).map(|v| v as u64), }) .collect(); diff --git a/crates/neo-fold/tests/suites/perf/single_addi_metrics_nightstream.rs b/crates/neo-fold/tests/suites/perf/single_addi_metrics_nightstream.rs index 7292ea5b..fc8e4840 100644 --- a/crates/neo-fold/tests/suites/perf/single_addi_metrics_nightstream.rs +++ b/crates/neo-fold/tests/suites/perf/single_addi_metrics_nightstream.rs @@ -176,17 +176,17 @@ fn opening_surface_from_shard_proof(proof: &ShardProof) -> OpeningSurfaceBuckets } fn opening_surface_from_rv32_b1_run(run: &neo_fold::riscv_shard::Rv32B1Run) -> OpeningSurfaceBuckets { - let mut buckets = opening_surface_from_shard_proof(&run.proof().main); - buckets.sidecars += sum_y_scalars(&run.proof().decode_plumbing.me_out); - buckets.sidecars += sum_y_scalars(&run.proof().semantics.me_out); - if let Some(rv32m) = &run.proof().rv32m { + let mut buckets = opening_surface_from_shard_proof(run.proof()); + buckets.sidecars += sum_y_scalars(&run.proof_bundle().decode_plumbing.me_out); + buckets.sidecars += sum_y_scalars(&run.proof_bundle().semantics.me_out); + if let Some(rv32m) = &run.proof_bundle().rv32m { for chunk in rv32m { buckets.sidecars += sum_y_scalars(&chunk.me_out); buckets.pcs_open += chunk.me_out.len(); } } - buckets.pcs_open += run.proof().decode_plumbing.me_out.len(); - buckets.pcs_open += run.proof().semantics.me_out.len(); + buckets.pcs_open += run.proof_bundle().decode_plumbing.me_out.len(); + buckets.pcs_open += run.proof_bundle().semantics.me_out.len(); buckets } @@ -502,7 +502,7 @@ fn report_track_a_w0_w1_snapshot() { let mut other_claims = Vec::new(); for i in 0..bt.labels.len() { - let label = std::str::from_utf8(bt.labels[i]).unwrap_or(""); + let label = std::str::from_utf8(&bt.labels[i]).unwrap_or(""); let deg = bt.degree_bounds[i]; let entry = (label.to_string(), deg); if label.starts_with("ccs/") { diff --git a/crates/neo-fold/tests/suites/redteam/riscv_verifier_gaps.rs b/crates/neo-fold/tests/suites/redteam/riscv_verifier_gaps.rs index e9823b35..cdce8588 100644 --- a/crates/neo-fold/tests/suites/redteam/riscv_verifier_gaps.rs +++ b/crates/neo-fold/tests/suites/redteam/riscv_verifier_gaps.rs @@ -154,7 +154,7 @@ fn swap_decode_plumbing_for_trivial_ccs(run: &Rv32B1Run, bundle: &mut Rv32B1Proo fn redteam_output_claim_path_should_not_accept_without_sidecar_enforcement() { let run = prove_output_run(); - let mut bad_bundle = run.proof().clone(); + let mut bad_bundle = run.proof_bundle().clone(); bad_bundle.semantics.me_out.clear(); assert!( run.verify_proof_bundle(&bad_bundle).is_err(), @@ -172,7 +172,7 @@ fn redteam_output_claim_path_should_not_accept_without_sidecar_enforcement() { fn redteam_output_claim_variants_should_not_accept_without_sidecar_enforcement() { let run = prove_output_run(); - let mut bad_bundle = run.proof().clone(); + let mut bad_bundle = run.proof_bundle().clone(); bad_bundle.semantics.me_out.clear(); assert!( run.verify_proof_bundle(&bad_bundle).is_err(), @@ -198,7 +198,7 @@ fn redteam_verifier_should_reject_prover_selected_decode_ccs() { let mut run = prove_basic_run(); run.verify().expect("baseline verify"); - let mut bad_bundle = run.proof().clone(); + let mut bad_bundle = run.proof_bundle().clone(); swap_decode_plumbing_for_trivial_ccs(&run, &mut bad_bundle); assert!( @@ -212,7 +212,7 @@ fn redteam_legacy_main_only_verifier_should_not_accept_without_sidecars() { let mut run = prove_basic_run(); run.verify().expect("baseline verify"); - let mut bad_bundle = run.proof().clone(); + let mut bad_bundle = run.proof_bundle().clone(); bad_bundle.semantics.me_out.clear(); assert!( run.verify_proof_bundle(&bad_bundle).is_err(), diff --git a/crates/neo-fold/tests/suites/redteam_riscv/riscv_bus_binding_redteam.rs b/crates/neo-fold/tests/suites/redteam_riscv/riscv_bus_binding_redteam.rs index dd2f2df3..5213675f 100644 --- a/crates/neo-fold/tests/suites/redteam_riscv/riscv_bus_binding_redteam.rs +++ b/crates/neo-fold/tests/suites/redteam_riscv/riscv_bus_binding_redteam.rs @@ -13,6 +13,9 @@ fn prove_run(program: Vec, max_steps: usize) -> Rv32B1Run { .chunk_size(1) .max_steps(max_steps) .ram_bytes(0x200) + // Keep this fixture explicit: these tests rely on XOR lookups in tiny programs, + // and we don't want them coupled to shout auto-inference details. + .shout_ops([RiscvOpcode::Add, RiscvOpcode::Xor]) .prove() .expect("prove"); run.verify().expect("baseline verify"); @@ -66,14 +69,14 @@ fn rv32_b1_cpu_vs_bus_twist_rv_mismatch_must_fail() { #[test] fn rv32_b1_cpu_vs_bus_shout_val_mismatch_must_fail() { - // Program: XORI x1, x0, 1; HALT (forces a Shout XOR lookup). + // Program: XOR x1, x0, x0; HALT (forces a Shout XOR lookup). let run = prove_run( vec![ - RiscvInstruction::IAlu { + RiscvInstruction::RAlu { op: RiscvOpcode::Xor, rd: 1, rs1: 0, - imm: 1, + rs2: 0, }, RiscvInstruction::Halt, ], diff --git a/crates/neo-fold/tests/suites/redteam_riscv/riscv_main_proof_redteam.rs b/crates/neo-fold/tests/suites/redteam_riscv/riscv_main_proof_redteam.rs index 8ffa9842..6d89f869 100644 --- a/crates/neo-fold/tests/suites/redteam_riscv/riscv_main_proof_redteam.rs +++ b/crates/neo-fold/tests/suites/redteam_riscv/riscv_main_proof_redteam.rs @@ -56,7 +56,7 @@ fn rv32_b1_main_proof_truncated_steps_must_fail() { let sess_ok = verifier_only_session_for_steps(&run, steps_ok); assert_eq!( sess_ok - .verify_collected(run.ccs(), &run.proof().main) + .verify_collected(run.ccs(), run.proof()) .expect("main proof verify"), true ); @@ -64,7 +64,7 @@ fn rv32_b1_main_proof_truncated_steps_must_fail() { // Truncate steps (verifier-side) and reuse the original proof. let steps_bad: Vec = run.steps_witness().iter().cloned().take(1).collect(); let sess_bad = verifier_only_session_for_steps(&run, steps_bad); - let res = sess_bad.verify_collected(run.ccs(), &run.proof().main); + let res = sess_bad.verify_collected(run.ccs(), run.proof()); assert!(matches!(res, Err(_) | Ok(false)), "truncated steps must not verify"); } @@ -86,7 +86,7 @@ fn rv32_b1_main_proof_tamper_prog_init_must_fail() { steps_bad[0].mem_instances[prog_idx].0.init = MemInit::Zero; let sess_bad = verifier_only_session_for_steps(&run, steps_bad); - let res = sess_bad.verify_collected(run.ccs(), &run.proof().main); + let res = sess_bad.verify_collected(run.ccs(), run.proof()); assert!( matches!(res, Err(_) | Ok(false)), "tampering PROG Twist init in public input must fail verification" @@ -113,7 +113,7 @@ fn rv32_b1_main_proof_tamper_reg_init_must_fail() { steps_bad[0].mem_instances[reg_idx].0.init = MemInit::Zero; let sess_bad = verifier_only_session_for_steps(&run, steps_bad); - let res = sess_bad.verify_collected(run.ccs(), &run.proof().main); + let res = sess_bad.verify_collected(run.ccs(), run.proof()); assert!( matches!(res, Err(_) | Ok(false)), "tampering REG Twist init in public input must fail verification" @@ -136,7 +136,7 @@ fn rv32_b1_main_proof_step_reordering_must_fail() { steps_bad.swap(0, 1); let sess_bad = verifier_only_session_for_steps(&run, steps_bad); - let res = sess_bad.verify_collected(run.ccs(), &run.proof().main); + let res = sess_bad.verify_collected(run.ccs(), run.proof()); assert!( matches!(res, Err(_) | Ok(false)), "reordering shard steps must not verify" @@ -166,7 +166,7 @@ fn rv32_b1_main_proof_splicing_across_runs_must_fail() { // Attempt to verify run A's main proof against run B's public step bundles. let steps_bad: Vec = run_b.steps_witness().to_vec(); let sess_bad = verifier_only_session_for_steps(&run_a, steps_bad); - let res = sess_bad.verify_collected(run_a.ccs(), &run_a.proof().main); + let res = sess_bad.verify_collected(run_a.ccs(), run_a.proof()); assert!( matches!(res, Err(_) | Ok(false)), "splicing main proof across runs must not verify" diff --git a/crates/neo-fold/tests/suites/redteam_riscv/riscv_twist_shout_redteam.rs b/crates/neo-fold/tests/suites/redteam_riscv/riscv_twist_shout_redteam.rs index 4a9ed794..b84a3bc2 100644 --- a/crates/neo-fold/tests/suites/redteam_riscv/riscv_twist_shout_redteam.rs +++ b/crates/neo-fold/tests/suites/redteam_riscv/riscv_twist_shout_redteam.rs @@ -43,6 +43,7 @@ fn rv32_b1_twist_instances_reordered_must_fail() { .chunk_size(1) .max_steps(2) .ram_bytes(0x200) + .shout_ops([RiscvOpcode::Add]) .prove() .expect("prove"); run.verify().expect("baseline verify"); @@ -54,7 +55,7 @@ fn rv32_b1_twist_instances_reordered_must_fail() { } let sess_bad = verifier_only_session_for_steps(&run, steps_bad); - let res = sess_bad.verify_collected(run.ccs(), &run.proof().main); + let res = sess_bad.verify_collected(run.ccs(), run.proof()); assert!( matches!(res, Err(_) | Ok(false)), "reordering Twist instances must not verify" @@ -78,6 +79,7 @@ fn rv32_b1_shout_table_spec_tamper_must_fail() { .chunk_size(1) .max_steps(2) .ram_bytes(0x200) + .shout_ops([RiscvOpcode::Add]) .prove() .expect("prove"); run.verify().expect("baseline verify"); @@ -97,7 +99,7 @@ fn rv32_b1_shout_table_spec_tamper_must_fail() { } let sess_bad = verifier_only_session_for_steps(&run, steps_bad); - let res = sess_bad.verify_collected(run.ccs(), &run.proof().main); + let res = sess_bad.verify_collected(run.ccs(), run.proof()); assert!( matches!(res, Err(_) | Ok(false)), "tampering Shout table_spec must not verify" @@ -106,13 +108,13 @@ fn rv32_b1_shout_table_spec_tamper_must_fail() { #[test] fn rv32_b1_shout_instances_reordered_must_fail() { - // Ensure we have at least two Shout tables by including XORI (plus implicit ADD table). + // Ensure we have at least two Shout tables by including XOR (plus implicit ADD table). let program = vec![ - RiscvInstruction::IAlu { + RiscvInstruction::RAlu { op: RiscvOpcode::Xor, rd: 1, rs1: 0, - imm: 1, + rs2: 0, }, RiscvInstruction::Halt, ]; @@ -130,13 +132,13 @@ fn rv32_b1_shout_instances_reordered_must_fail() { for step in &mut steps_bad { assert!( step.lut_instances.len() >= 2, - "expected at least 2 Shout instances for XORI program" + "expected at least 2 Shout instances for XOR program" ); step.lut_instances.swap(0, 1); } let sess_bad = verifier_only_session_for_steps(&run, steps_bad); - let res = sess_bad.verify_collected(run.ccs(), &run.proof().main); + let res = sess_bad.verify_collected(run.ccs(), run.proof()); assert!( matches!(res, Err(_) | Ok(false)), "reordering Shout instances must not verify" @@ -175,7 +177,7 @@ fn rv32_b1_ram_init_statement_tamper_must_fail() { steps_bad[0].mem_instances[ram_idx].0.init = MemInit::Zero; let sess_bad = verifier_only_session_for_steps(&run, steps_bad); - let res = sess_bad.verify_collected(run.ccs(), &run.proof().main); + let res = sess_bad.verify_collected(run.ccs(), run.proof()); assert!( matches!(res, Err(_) | Ok(false)), "tampering RAM Twist init in public input must fail verification" diff --git a/crates/neo-fold/tests/suites/rv32m/rv32m_sidecar_sparse_steps.rs b/crates/neo-fold/tests/suites/rv32m/rv32m_sidecar_sparse_steps.rs index 112235b2..f07e1137 100644 --- a/crates/neo-fold/tests/suites/rv32m/rv32m_sidecar_sparse_steps.rs +++ b/crates/neo-fold/tests/suites/rv32m/rv32m_sidecar_sparse_steps.rs @@ -23,7 +23,7 @@ fn rv32m_sidecar_is_skipped_for_non_m_programs() { run.verify().expect("verify"); assert!( - run.proof().rv32m.is_none(), + run.proof_bundle().rv32m.is_none(), "expected no RV32M sidecar proof for a non-M program" ); } @@ -51,7 +51,7 @@ fn rv32m_sidecar_is_sparse_over_time() { run.verify().expect("verify"); let rv32m = run - .proof() + .proof_bundle() .rv32m .as_ref() .expect("rv32m sidecar proof present"); @@ -108,7 +108,7 @@ fn rv32m_sidecar_selects_only_m_lanes_within_chunks() { run.verify().expect("verify"); let rv32m = run - .proof() + .proof_bundle() .rv32m .as_ref() .expect("rv32m sidecar proof present"); diff --git a/crates/neo-memory/src/cpu/constraints.rs b/crates/neo-memory/src/cpu/constraints.rs index 8d069484..6d260a8e 100644 --- a/crates/neo-memory/src/cpu/constraints.rs +++ b/crates/neo-memory/src/cpu/constraints.rs @@ -44,7 +44,7 @@ use crate::cpu::bus_layout::{ build_bus_layout_for_instances_with_shout_shapes_and_twist_lanes, BusLayout, ShoutCols, ShoutInstanceShape, TwistCols, }; -use crate::riscv::trace::{rv32_trace_lookup_addr_group_for_table_id, rv32_trace_lookup_selector_group_for_table_id}; +use crate::riscv::trace::{rv32_trace_lookup_addr_group_for_table_shape, rv32_trace_lookup_selector_group_for_table_id}; use crate::witness::{LutInstance, MemInstance}; /// CPU column layout for binding to the bus. @@ -1054,7 +1054,7 @@ pub fn extend_ccs_with_shared_cpu_bus_constraints_optional_shout< ell_addr: inst.d * inst.ell, lanes: inst.lanes.max(1), n_vals: 1usize, - addr_group: rv32_trace_lookup_addr_group_for_table_id(inst.table_id).map(|v| v as u64), + addr_group: rv32_trace_lookup_addr_group_for_table_shape(inst.table_id, inst.d * inst.ell).map(|v| v as u64), selector_group: rv32_trace_lookup_selector_group_for_table_id(inst.table_id).map(|v| v as u64), }), mem_insts diff --git a/crates/neo-memory/src/cpu/r1cs_adapter.rs b/crates/neo-memory/src/cpu/r1cs_adapter.rs index 3675dc4c..d8ff86c4 100644 --- a/crates/neo-memory/src/cpu/r1cs_adapter.rs +++ b/crates/neo-memory/src/cpu/r1cs_adapter.rs @@ -17,7 +17,7 @@ use crate::mem_init::MemInit; use crate::plain::LutTable; use crate::plain::PlainMemLayout; use crate::riscv::trace::{ - rv32_trace_lookup_addr_group_for_table_id, rv32_trace_lookup_selector_group_for_table_id, + rv32_trace_lookup_addr_group_for_table_shape, rv32_trace_lookup_selector_group_for_table_id, }; use crate::witness::{LutInstance, LutTableSpec, MemInstance}; use neo_ajtai::{decomp_b, DecompStyle}; @@ -197,7 +197,7 @@ where ell_addr, lanes, n_vals: 1usize, - addr_group: rv32_trace_lookup_addr_group_for_table_id(*table_id).map(|v| v as u64), + addr_group: rv32_trace_lookup_addr_group_for_table_shape(*table_id, ell_addr).map(|v| v as u64), selector_group: rv32_trace_lookup_selector_group_for_table_id(*table_id).map(|v| v as u64), }); } diff --git a/crates/neo-memory/src/riscv/ccs/bus_bindings.rs b/crates/neo-memory/src/riscv/ccs/bus_bindings.rs index 86b3ba94..9e07fa15 100644 --- a/crates/neo-memory/src/riscv/ccs/bus_bindings.rs +++ b/crates/neo-memory/src/riscv/ccs/bus_bindings.rs @@ -9,7 +9,8 @@ use crate::plain::PlainMemLayout; use crate::riscv::lookups::{PROG_ID, RAM_ID, REG_ID}; use crate::riscv::trace::{ rv32_decode_lookup_table_id_for_col, rv32_is_decode_lookup_table_id, rv32_is_width_lookup_table_id, - rv32_trace_lookup_addr_group_for_table_id, rv32_trace_lookup_selector_group_for_table_id, Rv32DecodeSidecarLayout, + rv32_trace_lookup_addr_group_for_table_shape, rv32_trace_lookup_selector_group_for_table_id, + Rv32DecodeSidecarLayout, }; use super::config::{derive_mem_ids_and_ell_addrs, derive_shout_ids_and_ell_addrs}; @@ -198,8 +199,8 @@ fn validate_trace_shout_table_id(table_id: u32) -> Result<(), String> { } #[inline] -fn trace_lookup_addr_group_for_table_id(table_id: u32) -> Option { - rv32_trace_lookup_addr_group_for_table_id(table_id) +fn trace_lookup_addr_group_for_table_shape(table_id: u32, ell_addr: usize) -> Option { + rv32_trace_lookup_addr_group_for_table_shape(table_id, ell_addr) } #[inline] @@ -230,7 +231,7 @@ fn derive_trace_shout_shapes( table_id, ell_addr: 2 * RV32_XLEN, n_vals: 1usize, - addr_group: trace_lookup_addr_group_for_table_id(table_id), + addr_group: trace_lookup_addr_group_for_table_shape(table_id, 2 * RV32_XLEN), selector_group: trace_lookup_selector_group_for_table_id(table_id), }, ); @@ -262,7 +263,7 @@ fn derive_trace_shout_shapes( spec.table_id, prev.n_vals, spec.n_vals )); } - let inferred_group = trace_lookup_addr_group_for_table_id(spec.table_id); + let inferred_group = trace_lookup_addr_group_for_table_shape(spec.table_id, spec.ell_addr); if prev.addr_group != inferred_group { return Err(format!( "RV32 trace shared bus: conflicting addr_group for table_id={} (base/spec mismatch: {:?} vs {:?})", @@ -283,7 +284,7 @@ fn derive_trace_shout_shapes( table_id: spec.table_id, ell_addr: spec.ell_addr, n_vals: spec.n_vals, - addr_group: trace_lookup_addr_group_for_table_id(spec.table_id), + addr_group: trace_lookup_addr_group_for_table_shape(spec.table_id, spec.ell_addr), selector_group: trace_lookup_selector_group_for_table_id(spec.table_id), }, ); diff --git a/crates/neo-memory/src/riscv/ccs/layout.rs b/crates/neo-memory/src/riscv/ccs/layout.rs index 67304c35..695eb3a5 100644 --- a/crates/neo-memory/src/riscv/ccs/layout.rs +++ b/crates/neo-memory/src/riscv/ccs/layout.rs @@ -1,8 +1,9 @@ use std::collections::HashMap; -use crate::cpu::bus_layout::BusLayout; +use crate::cpu::bus_layout::{build_bus_layout_for_instances_with_shout_shapes_and_twist_lanes, BusLayout, ShoutInstanceShape}; use crate::plain::PlainMemLayout; use crate::riscv::lookups::{PROG_ID, RAM_ID, REG_ID}; +use crate::riscv::trace::{rv32_trace_lookup_addr_group_for_table_shape, rv32_trace_lookup_selector_group_for_table_id}; use super::config::{derive_mem_ids_and_ell_addrs, derive_shout_ids_and_ell_addrs}; @@ -1127,11 +1128,21 @@ pub(super) fn build_layout_with_m( (*ell_addr, lanes) }) .collect(); - let bus = crate::cpu::bus_layout::build_bus_layout_for_instances_with_twist_lanes( + let shout_shapes = table_ids + .iter() + .zip(shout_ell_addrs.iter()) + .map(|(table_id, ell_addr)| ShoutInstanceShape { + ell_addr: *ell_addr, + lanes: 1usize, + n_vals: 1usize, + addr_group: rv32_trace_lookup_addr_group_for_table_shape(*table_id, *ell_addr).map(|v| v as u64), + selector_group: rv32_trace_lookup_selector_group_for_table_id(*table_id).map(|v| v as u64), + }); + let bus = build_bus_layout_for_instances_with_shout_shapes_and_twist_lanes( m, m_in, chunk_size, - shout_ell_addrs, + shout_shapes, twist_ell_addrs_and_lanes, )?; if cpu_cols_used > bus.bus_base { diff --git a/crates/neo-memory/src/riscv/ccs/witness.rs b/crates/neo-memory/src/riscv/ccs/witness.rs index 3c5890ab..fe521b02 100644 --- a/crates/neo-memory/src/riscv/ccs/witness.rs +++ b/crates/neo-memory/src/riscv/ccs/witness.rs @@ -1223,11 +1223,11 @@ fn rv32_b1_chunk_to_witness_internal( let add_b0 = z[layout.bus.bus_cell(add_lane.addr_bits.start + 1, j)]; z[layout.add_a0b0(j)] = add_a0 * add_b0; } else if let Some(ev) = shout_ev { - if ev.shout_id.0 == ADD_TABLE_ID { - let a0 = if (ev.key & 1) == 1 { F::ONE } else { F::ZERO }; - let b0 = if ((ev.key >> 1) & 1) == 1 { F::ONE } else { F::ZERO }; - z[layout.add_a0b0(j)] = a0 * b0; - } + // In shared-bus mode ADD addr bits may be grouped with other Shout tables, so this + // helper must follow the active Shout key even when the opcode is not ADD. + let a0 = if (ev.key & 1) == 1 { F::ONE } else { F::ZERO }; + let b0 = if ((ev.key >> 1) & 1) == 1 { F::ONE } else { F::ZERO }; + z[layout.add_a0b0(j)] = a0 * b0; } let rs1_i32 = rs1_u32 as i32; diff --git a/crates/neo-memory/src/riscv/trace/mod.rs b/crates/neo-memory/src/riscv/trace/mod.rs index 45c00e55..e9c5c866 100644 --- a/crates/neo-memory/src/riscv/trace/mod.rs +++ b/crates/neo-memory/src/riscv/trace/mod.rs @@ -43,6 +43,20 @@ pub fn rv32_trace_lookup_addr_group_for_table_id(table_id: u32) -> Option { } } +/// Shape-aware address-group hint for shared-bus Shout lanes. +/// +/// This guards against accidental grouping when callers use low numeric `table_id`s for +/// non-RV32 opcode tables (common in generic tests/fixtures). RV32 opcode tables (id 0..=19) +/// are grouped only when their key shape matches the canonical interleaved width. +#[inline] +pub fn rv32_trace_lookup_addr_group_for_table_shape(table_id: u32, ell_addr: usize) -> Option { + let group = rv32_trace_lookup_addr_group_for_table_id(table_id)?; + if table_id <= 19 && ell_addr != 64 { + return None; + } + Some(group) +} + #[inline] pub fn rv32_trace_lookup_selector_group_for_table_id(table_id: u32) -> Option { if rv32_is_decode_lookup_table_id(table_id) { diff --git a/crates/neo-memory/tests/riscv_ccs_tests.rs b/crates/neo-memory/tests/riscv_ccs_tests.rs index 0d37e81f..39d769ef 100644 --- a/crates/neo-memory/tests/riscv_ccs_tests.rs +++ b/crates/neo-memory/tests/riscv_ccs_tests.rs @@ -14,7 +14,7 @@ use neo_memory::riscv::ccs::{ }; use neo_memory::riscv::lookups::{ decode_instruction, encode_program, BranchCondition, RiscvCpu, RiscvInstruction, RiscvMemOp, RiscvMemory, - RiscvOpcode, RiscvShoutTables, PROG_ID, RAM_ID, REG_ID, + RiscvOpcode, RiscvShoutTables, POSEIDON2_ECALL_NUM, POSEIDON2_READ_ECALL_NUM, PROG_ID, RAM_ID, REG_ID, }; use neo_memory::riscv::rom_init::prog_init_words; use neo_memory::witness::LutTableSpec; @@ -431,7 +431,7 @@ fn rv32_b1_ccs_happy_path_poseidon2_ecall() { let (k_prog, d_prog) = pow2_ceil_k(program_bytes.len()); let (k_ram, d_ram) = pow2_ceil_k(0x40); - let mem_layouts = HashMap::from([ + let mem_layouts = with_reg_layout(HashMap::from([ ( 0u32, PlainMemLayout { @@ -450,7 +450,7 @@ fn rv32_b1_ccs_happy_path_poseidon2_ecall() { lanes: 1, }, ), - ]); + ])); let initial_mem = prog_init_words(PROG_ID, 0, &program_bytes); @@ -469,6 +469,7 @@ fn rv32_b1_ccs_happy_path_poseidon2_ecall() { &table_specs, rv32_b1_chunk_to_witness(layout.clone()), ) + .expect("R1csCpu::new") .with_shared_cpu_bus( rv32_b1_shared_cpu_bus_config(&layout, &shout_table_ids, mem_layouts, initial_mem).expect("cfg"), 1, From 92a228fbb51c08fca1ee3390ab32bba294b40ae9 Mon Sep 17 00:00:00 2001 From: Guillermo Valin Date: Wed, 18 Feb 2026 15:42:15 -0300 Subject: [PATCH 10/13] Add Poseidon2 ECALL performance tests. --- crates/neo-fold/tests/suites/perf/mod.rs | 1 + .../suites/perf/riscv_poseidon2_ecall_perf.rs | 289 ++++++++++++++++++ 2 files changed, 290 insertions(+) create mode 100644 crates/neo-fold/tests/suites/perf/riscv_poseidon2_ecall_perf.rs diff --git a/crates/neo-fold/tests/suites/perf/mod.rs b/crates/neo-fold/tests/suites/perf/mod.rs index 1bed6535..5d12d097 100644 --- a/crates/neo-fold/tests/suites/perf/mod.rs +++ b/crates/neo-fold/tests/suites/perf/mod.rs @@ -1,5 +1,6 @@ mod memory_adversarial_tests; mod prefix_scaling; mod riscv_b1_ab_perf; +mod riscv_poseidon2_ecall_perf; mod riscv_trace_wiring_output_binding_perf; mod single_addi_metrics_nightstream; diff --git a/crates/neo-fold/tests/suites/perf/riscv_poseidon2_ecall_perf.rs b/crates/neo-fold/tests/suites/perf/riscv_poseidon2_ecall_perf.rs new file mode 100644 index 00000000..e0561840 --- /dev/null +++ b/crates/neo-fold/tests/suites/perf/riscv_poseidon2_ecall_perf.rs @@ -0,0 +1,289 @@ +#![allow(non_snake_case)] + +use std::time::Duration; + +use neo_fold::riscv_shard::Rv32B1; +use neo_memory::riscv::lookups::{ + encode_program, RiscvInstruction, RiscvOpcode, POSEIDON2_ECALL_NUM, POSEIDON2_READ_ECALL_NUM, +}; + +fn load_u32_imm(rd: u8, value: u32) -> Vec { + let upper = ((value as i64 + 0x800) >> 12) as i32; + let lower = (value as i32) - (upper << 12); + vec![ + RiscvInstruction::Lui { rd, imm: upper }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd, + rs1: rd, + imm: lower, + }, + ] +} + +fn poseidon2_ecall_program(n_hashes: usize) -> Vec { + let mut program = Vec::new(); + + for _ in 0..n_hashes { + // a1 = 0 (n_elements = 0 for empty-input Poseidon2 hash). + program.push(RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 11, + rs1: 0, + imm: 0, + }); + + // a0 = POSEIDON2_ECALL_NUM -> compute ECALL. + program.extend(load_u32_imm(10, POSEIDON2_ECALL_NUM)); + program.push(RiscvInstruction::Halt); + + // Read all 8 digest words via read ECALLs. + for _ in 0..8 { + program.extend(load_u32_imm(10, POSEIDON2_READ_ECALL_NUM)); + program.push(RiscvInstruction::Halt); + } + } + + // Clear a0 -> final halt. + program.push(RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 10, + rs1: 0, + imm: 0, + }); + program.push(RiscvInstruction::Halt); + + program +} + +fn run_once( + program_bytes: &[u8], + max_steps: usize, +) -> Result { + Rv32B1::from_rom(0, program_bytes) + .chunk_size(max_steps) + .max_steps(max_steps) + .prove() +} + +#[derive(Clone, Copy, Debug)] +struct Stats { + min: Duration, + median: Duration, + mean: Duration, + max: Duration, +} + +fn summarize(samples: &[Duration]) -> Stats { + assert!(!samples.is_empty()); + let mut v = samples.to_vec(); + v.sort_unstable(); + let min = v[0]; + let max = v[v.len() - 1]; + let median = v[v.len() / 2]; + let mean_secs = v.iter().map(Duration::as_secs_f64).sum::() / v.len() as f64; + let mean = Duration::from_secs_f64(mean_secs); + Stats { + min, + median, + mean, + max, + } +} + +fn fmt_duration(d: Duration) -> String { + if d.as_secs_f64() < 1.0 { + format!("{:.3}ms", d.as_secs_f64() * 1000.0) + } else { + format!("{:.3}s", d.as_secs_f64()) + } +} + +fn env_usize(name: &str, default: usize) -> usize { + match std::env::var(name) { + Ok(v) => v.parse::().unwrap_or(default), + Err(_) => default, + } +} + +#[test] +#[ignore = "perf-style test: run with `P2_HASHES=1 cargo test -p neo-fold --release --test perf -- --ignored --nocapture rv32_b1_poseidon2_ecall_perf`"] +fn rv32_b1_poseidon2_ecall_perf() { + let n_hashes = env_usize("P2_HASHES", 1); + let warmups = env_usize("P2_WARMUPS", 1); + let samples = env_usize("P2_SAMPLES", 5); + assert!(n_hashes > 0, "P2_HASHES must be > 0"); + assert!(samples > 0, "P2_SAMPLES must be > 0"); + + let program = poseidon2_ecall_program(n_hashes); + let program_bytes = encode_program(&program); + let max_steps = program.len() + 64; + + for _ in 0..warmups { + let mut run = run_once(&program_bytes, max_steps).expect("warmup prove"); + run.verify().expect("warmup verify"); + } + + let mut prove_times = Vec::with_capacity(samples); + let mut verify_times = Vec::with_capacity(samples); + let mut end_to_end_times = Vec::with_capacity(samples); + let mut ccs_n = 0; + let mut ccs_m = 0; + let mut fold_count = 0; + let mut trace_len = 0; + + for _ in 0..samples { + let total_start = std::time::Instant::now(); + let mut run = run_once(&program_bytes, max_steps).expect("prove"); + run.verify().expect("verify"); + end_to_end_times.push(total_start.elapsed()); + prove_times.push(run.prove_duration()); + verify_times.push(run.verify_duration().unwrap_or(Duration::ZERO)); + + ccs_n = run.ccs_num_constraints(); + ccs_m = run.ccs_num_variables(); + fold_count = run.fold_count(); + trace_len = run.riscv_trace_len().unwrap_or(0); + } + + let prove = summarize(&prove_times); + let verify = summarize(&verify_times); + let e2e = summarize(&end_to_end_times); + + let sep = "=".repeat(96); + let thin = "-".repeat(96); + + println!(); + println!("{sep}"); + println!("POSEIDON2 ECALL BENCHMARK (RV32 B1)"); + println!("{sep}"); + println!( + "config: hashes={} instructions={} warmups={} samples={}", + n_hashes, + program.len(), + warmups, + samples + ); + println!( + "CCS: n={} (pow2={}) m={} (pow2={}) folds={} trace_len={}", + ccs_n, + ccs_n.next_power_of_two(), + ccs_m, + ccs_m.next_power_of_two(), + fold_count, + trace_len + ); + println!("{thin}"); + println!( + "{:>12} {:>10} {:>10} {:>10} {:>10}", + "phase", "min", "median", "mean", "max" + ); + println!("{thin}"); + println!( + "{:>12} {:>10} {:>10} {:>10} {:>10}", + "prove", + fmt_duration(prove.min), + fmt_duration(prove.median), + fmt_duration(prove.mean), + fmt_duration(prove.max), + ); + println!( + "{:>12} {:>10} {:>10} {:>10} {:>10}", + "verify", + fmt_duration(verify.min), + fmt_duration(verify.median), + fmt_duration(verify.mean), + fmt_duration(verify.max), + ); + println!( + "{:>12} {:>10} {:>10} {:>10} {:>10}", + "end-to-end", + fmt_duration(e2e.min), + fmt_duration(e2e.median), + fmt_duration(e2e.mean), + fmt_duration(e2e.max), + ); + println!("{thin}"); + if n_hashes > 0 { + let per_hash_prove = Duration::from_secs_f64(prove.median.as_secs_f64() / n_hashes as f64); + let per_hash_e2e = Duration::from_secs_f64(e2e.median.as_secs_f64() / n_hashes as f64); + println!( + "per-hash (median): prove={} end-to-end={}", + fmt_duration(per_hash_prove), + fmt_duration(per_hash_e2e), + ); + } + println!("{sep}"); + println!(); +} + +#[test] +#[ignore = "perf-style test: run with `cargo test -p neo-fold --release --test perf -- --ignored --nocapture rv32_b1_poseidon2_ecall_scaling`"] +fn rv32_b1_poseidon2_ecall_scaling() { + let samples = env_usize("P2_SAMPLES", 3); + let warmups = env_usize("P2_WARMUPS", 1); + let hash_counts = [1, 2, 4, 8]; + + let sep = "=".repeat(110); + let thin = "-".repeat(110); + + println!(); + println!("{sep}"); + println!("POSEIDON2 ECALL SCALING (RV32 B1) — warmups={} samples={}", warmups, samples); + println!("{sep}"); + println!( + "{:>8} {:>8} {:>8} {:>8} {:>10} {:>10} {:>10} {:>10} {:>12}", + "hashes", "instrs", "ccs_n", "ccs_m", "prove_med", "verify_med", "e2e_med", "per_hash", "throughput" + ); + println!("{thin}"); + + for &n_hashes in &hash_counts { + let program = poseidon2_ecall_program(n_hashes); + let program_bytes = encode_program(&program); + let max_steps = program.len() + 64; + + for _ in 0..warmups { + let mut run = run_once(&program_bytes, max_steps).expect("warmup prove"); + run.verify().expect("warmup verify"); + } + + let mut prove_times = Vec::with_capacity(samples); + let mut verify_times = Vec::with_capacity(samples); + let mut e2e_times = Vec::with_capacity(samples); + let mut ccs_n = 0; + let mut ccs_m = 0; + + for _ in 0..samples { + let total_start = std::time::Instant::now(); + let mut run = run_once(&program_bytes, max_steps).expect("prove"); + run.verify().expect("verify"); + e2e_times.push(total_start.elapsed()); + prove_times.push(run.prove_duration()); + verify_times.push(run.verify_duration().unwrap_or(Duration::ZERO)); + ccs_n = run.ccs_num_constraints(); + ccs_m = run.ccs_num_variables(); + } + + let prove_med = summarize(&prove_times).median; + let verify_med = summarize(&verify_times).median; + let e2e_med = summarize(&e2e_times).median; + let per_hash = Duration::from_secs_f64(e2e_med.as_secs_f64() / n_hashes as f64); + let hashes_per_sec = n_hashes as f64 / e2e_med.as_secs_f64(); + + println!( + "{:>8} {:>8} {:>8} {:>8} {:>10} {:>10} {:>10} {:>10} {:>10.1} h/s", + n_hashes, + program.len(), + ccs_n, + ccs_m, + fmt_duration(prove_med), + fmt_duration(verify_med), + fmt_duration(e2e_med), + fmt_duration(per_hash), + hashes_per_sec, + ); + } + + println!("{sep}"); + println!(); +} From b832e2ed9b4902c168df03c24095ffb43a4c8847 Mon Sep 17 00:00:00 2001 From: Guillermo Valin Date: Wed, 18 Feb 2026 15:42:26 -0300 Subject: [PATCH 11/13] Enhance RV32 trace wiring with new verification context and additional fields - Added `ram_d` and `width_lookup_addr_d` fields to `Rv32TraceWiringRun` for improved trace handling. - Introduced `Rv32TraceWiringVerifier` struct to facilitate proof verification without re-executing the program. - Implemented methods for accessing RAM address width and width-lookup address bits during proving runs. - Updated `build_verifier` method to utilize new sizing values for verification context construction. --- crates/neo-fold/src/riscv_trace_shard.rs | 353 ++++++++++++++++++++++- 1 file changed, 352 insertions(+), 1 deletion(-) diff --git a/crates/neo-fold/src/riscv_trace_shard.rs b/crates/neo-fold/src/riscv_trace_shard.rs index 7c4f42fb..b5a34953 100644 --- a/crates/neo-fold/src/riscv_trace_shard.rs +++ b/crates/neo-fold/src/riscv_trace_shard.rs @@ -19,7 +19,7 @@ use crate::pi_ccs::FoldingMode; use crate::session::FoldingSession; use crate::shard::{ShardProof, StepLinkingConfig}; use crate::PiCcsError; -use neo_ajtai::AjtaiSModule; +use neo_ajtai::{AjtaiSModule, Commitment as Cmt}; use neo_ccs::relations::{McsInstance, McsWitness}; use neo_ccs::traits::SModuleHomomorphism; use neo_ccs::CcsStructure; @@ -1429,6 +1429,8 @@ impl Rv32TraceWiring { used_mem_ids, used_shout_table_ids, output_binding_cfg, + ram_d, + width_lookup_addr_d, prove_duration, prove_phase_durations, verify_duration: None, @@ -1446,6 +1448,8 @@ pub struct Rv32TraceWiringRun { used_mem_ids: Vec, used_shout_table_ids: Vec, output_binding_cfg: Option, + ram_d: usize, + width_lookup_addr_d: usize, prove_duration: Duration, prove_phase_durations: Rv32TraceProvePhaseDurations, verify_duration: Option, @@ -1539,4 +1543,351 @@ impl Rv32TraceWiringRun { pub fn steps_public(&self) -> Vec> { self.session.steps_public() } + + /// Rows per trace step (= layout.t), needed by [`Rv32TraceWiring::build_verifier`]. + pub fn step_rows(&self) -> usize { + self.layout.t + } + + /// RAM address width (bits) used during this proving run. + pub fn ram_d(&self) -> usize { + self.ram_d + } + + /// Width-lookup address bits used during this proving run (0 when width lookup is absent). + pub fn width_lookup_addr_d(&self) -> usize { + self.width_lookup_addr_d + } +} + +/// Verification context for RV32 trace-wiring proofs. +/// +/// Created by [`Rv32TraceWiring::build_verifier`]. Contains the CCS and session +/// state required to verify proofs without re-running the guest. +pub struct Rv32TraceWiringVerifier { + session: FoldingSession, + ccs: CcsStructure, + _layout: Rv32TraceCcsLayout, + mem_layouts: HashMap, + statement_initial_mem: HashMap<(u32, u64), F>, + output_binding_cfg: Option, +} + +impl Rv32TraceWiringVerifier { + /// Verify a `ShardProof` against externally supplied public step instances. + pub fn verify( + &self, + proof: &ShardProof, + steps_public: &[neo_memory::witness::StepInstanceBundle], + ) -> Result { + crate::riscv_shard::rv32_b1_enforce_chunk0_mem_init_matches_statement( + &self.mem_layouts, + &self.statement_initial_mem, + steps_public, + )?; + + if let Some(ob_cfg) = &self.output_binding_cfg { + self.session + .verify_with_external_steps_and_output_binding(&self.ccs, steps_public, proof, ob_cfg) + } else { + self.session + .verify_with_external_steps(&self.ccs, steps_public, proof) + } + } +} + +impl Rv32TraceWiring { + /// Build only the verification context (CCS + session) without executing the + /// program or proving. + /// + /// The caller must supply execution-dependent sizing values that were recorded + /// during the proving run: + /// - `step_rows`: rows per trace step (= [`Rv32TraceWiringRun::step_rows`]). + /// - `ram_d`: RAM address width in bits (= [`Rv32TraceWiringRun::ram_d`]). + /// - `width_lookup_addr_d`: width-lookup address bits + /// (= [`Rv32TraceWiringRun::width_lookup_addr_d`]; pass 0 when width lookup + /// is absent). + pub fn build_verifier( + self, + step_rows: usize, + ram_d: usize, + width_lookup_addr_d: usize, + ) -> Result { + if self.xlen != 32 { + return Err(PiCcsError::InvalidInput(format!( + "RV32 trace wiring requires xlen == 32 (got {})", + self.xlen + ))); + } + if self.program_bytes.is_empty() { + return Err(PiCcsError::InvalidInput("program_bytes must be non-empty".into())); + } + if self.program_bytes.len() % 4 != 0 { + return Err(PiCcsError::InvalidInput( + "program_bytes must be 4-byte aligned (RVC is not supported)".into(), + )); + } + if step_rows == 0 { + return Err(PiCcsError::InvalidInput("step_rows must be non-zero".into())); + } + + let program = decode_program(&self.program_bytes) + .map_err(|e| PiCcsError::InvalidInput(format!("decode_program failed: {e}")))?; + + let (prog_layout, prog_init_words) = + prog_rom_layout_and_init_words::(PROG_ID, /*base_addr=*/ 0, &self.program_bytes) + .map_err(|e| PiCcsError::InvalidInput(format!("prog_rom_layout_and_init_words failed: {e}")))?; + + let ram_init_map = self.ram_init.clone(); + let reg_init_map = self.reg_init.clone(); + let output_claims = self.output_claims.clone(); + let output_target = self.output_target; + + let wants_ram_output = matches!(output_target, OutputTarget::Ram) && !output_claims.is_empty(); + let include_ram_sidecar = + program_requires_ram_sidecar(&program) || !ram_init_map.is_empty() || wants_ram_output; + + let ram_k = 1usize + .checked_shl(ram_d as u32) + .ok_or_else(|| PiCcsError::InvalidInput(format!("RAM address width too large: d={ram_d}")))?; + + let mut mem_layouts: HashMap = HashMap::from([ + ( + REG_ID.0, + PlainMemLayout { + k: 32, + d: 5, + n_side: 2, + lanes: 2, + }, + ), + (PROG_ID.0, prog_layout.clone()), + ]); + if include_ram_sidecar { + mem_layouts.insert( + RAM_ID.0, + PlainMemLayout { + k: ram_k, + d: ram_d, + n_side: 2, + lanes: 1, + }, + ); + } + + let inferred_shout_ops = infer_required_trace_shout_opcodes(&program); + let shout_ops = match &self.shout_ops { + Some(override_ops) => { + let missing: HashSet = inferred_shout_ops + .difference(override_ops) + .copied() + .collect(); + if !missing.is_empty() { + let mut missing_names: Vec = missing.into_iter().map(|op| format!("{op:?}")).collect(); + missing_names.sort_unstable(); + return Err(PiCcsError::InvalidInput(format!( + "trace shout_ops override must be a superset of required opcodes; missing [{}]", + missing_names.join(", ") + ))); + } + override_ops.clone() + } + None => inferred_shout_ops, + }; + + let mut table_specs = rv32_trace_table_specs(&shout_ops); + let mut base_shout_table_ids: Vec = table_specs.keys().copied().collect(); + base_shout_table_ids.sort_unstable(); + for (&table_id, spec) in &self.extra_lut_table_specs { + if table_specs.contains_key(&table_id) { + return Err(PiCcsError::InvalidInput(format!( + "extra_lut_table_spec collides with existing table_id={table_id}" + ))); + } + table_specs.insert(table_id, spec.clone()); + } + + let decode_layout = Rv32DecodeSidecarLayout::new(); + let decode_lookup_bus_specs: Vec = { + let decode_lookup_cols = rv32_decode_lookup_backed_cols(&decode_layout); + decode_lookup_cols + .iter() + .copied() + .map(|col_id| TraceShoutBusSpec { + table_id: rv32_decode_lookup_table_id_for_col(col_id), + ell_addr: prog_layout.d, + n_vals: 1usize, + }) + .collect() + }; + + let include_width_lookup = program_requires_width_lookup(&program); + let width_layout = Rv32WidthSidecarLayout::new(); + let width_lookup_bus_specs: Vec = if include_width_lookup && width_lookup_addr_d > 0 { + let width_lookup_cols = rv32_width_lookup_backed_cols(&width_layout); + width_lookup_cols + .iter() + .copied() + .map(|col_id| TraceShoutBusSpec { + table_id: rv32_width_lookup_table_id_for_col(col_id), + ell_addr: width_lookup_addr_d, + n_vals: 1usize, + }) + .collect() + } else { + Vec::new() + }; + + let mut all_extra_shout_specs = self.extra_shout_bus_specs.clone(); + all_extra_shout_specs.extend(decode_lookup_bus_specs); + all_extra_shout_specs.extend(width_lookup_bus_specs); + + let mut layout = Rv32TraceCcsLayout::new(step_rows) + .map_err(|e| PiCcsError::InvalidInput(format!("Rv32TraceCcsLayout::new failed: {e}")))?; + + let ccs_reserved_rows; + { + let (bus_region_len, reserved_rows) = rv32_trace_shared_bus_requirements_with_specs( + &layout, + &base_shout_table_ids, + &all_extra_shout_specs, + &mem_layouts, + ) + .map_err(|e| { + PiCcsError::InvalidInput(format!("rv32_trace_shared_bus_requirements_with_specs failed: {e}")) + })?; + layout.m = layout + .m + .checked_add(bus_region_len) + .ok_or_else(|| PiCcsError::InvalidInput("trace layout m overflow after bus tail reservation".into()))?; + ccs_reserved_rows = reserved_rows; + } + + let ccs_base = if ccs_reserved_rows == 0 { + build_rv32_trace_wiring_ccs(&layout) + .map_err(|e| PiCcsError::InvalidInput(format!("build_rv32_trace_wiring_ccs failed: {e}")))? + } else { + build_rv32_trace_wiring_ccs_with_reserved_rows(&layout, ccs_reserved_rows) + .map_err(|e| PiCcsError::InvalidInput(format!("build_rv32_trace_wiring_ccs_with_reserved_rows failed: {e}")))? + }; + + let session = FoldingSession::::new_ajtai(self.mode.clone(), &ccs_base)?; + + let decode_lookup_tables = build_rv32_decode_lookup_tables(&prog_layout, &prog_init_words); + let mut lut_tables = decode_lookup_tables; + if include_width_lookup && width_lookup_addr_d > 0 { + let width_k = 1usize.checked_shl(width_lookup_addr_d as u32).ok_or_else(|| { + PiCcsError::InvalidInput(format!("width lookup addr width too large: d={width_lookup_addr_d}")) + })?; + let width_lookup_cols = rv32_width_lookup_backed_cols(&width_layout); + for &col_id in width_lookup_cols.iter() { + let table_id = rv32_width_lookup_table_id_for_col(col_id); + lut_tables.insert( + table_id, + LutTable { + table_id, + k: width_k, + d: width_lookup_addr_d, + n_side: 2, + content: vec![F::ZERO; width_k], + }, + ); + } + } + + let mut cpu = R1csCpu::new( + ccs_base, + session.params().clone(), + session.committer().clone(), + layout.m_in, + &lut_tables, + &table_specs, + rv32_trace_chunk_to_witness(layout.clone()), + ) + .map_err(|e| PiCcsError::InvalidInput(format!("R1csCpu::new failed: {e}")))?; + + let mut prog_init_pairs: Vec<(u64, F)> = prog_init_words + .into_iter() + .filter_map(|((mem_id, addr), value)| (mem_id == PROG_ID.0 && value != F::ZERO).then_some((addr, value))) + .collect(); + prog_init_pairs.sort_by_key(|(addr, _)| *addr); + let mut initial_mem: HashMap<(u32, u64), F> = HashMap::new(); + for &(addr, value) in &prog_init_pairs { + if value != F::ZERO { + initial_mem.insert((PROG_ID.0, addr), value); + } + } + for (®, &value) in ®_init_map { + let v = F::from_u64(value as u32 as u64); + if v != F::ZERO { + initial_mem.insert((REG_ID.0, reg), v); + } + } + for (&addr, &value) in &ram_init_map { + let v = F::from_u64(value as u32 as u64); + if v != F::ZERO { + initial_mem.insert((RAM_ID.0, addr), v); + } + } + + cpu = cpu + .with_shared_cpu_bus( + rv32_trace_shared_cpu_bus_config_with_specs( + &layout, + &base_shout_table_ids, + &all_extra_shout_specs, + mem_layouts.clone(), + initial_mem.clone(), + ) + .map_err(|e| { + PiCcsError::InvalidInput(format!("rv32_trace_shared_cpu_bus_config_with_specs failed: {e}")) + })?, + layout.t, + ) + .map_err(|e| PiCcsError::InvalidInput(format!("shared bus inject failed: {e}")))?; + + let ccs = cpu.ccs.clone(); + + let mut session = FoldingSession::::new_ajtai(self.mode, &ccs)?; + session.set_step_linking(StepLinkingConfig::new(vec![(layout.pc_final, layout.pc0)])); + + let output_binding_cfg = if output_claims.is_empty() { + None + } else { + let out_mem_id = match output_target { + OutputTarget::Ram => RAM_ID.0, + OutputTarget::Reg => REG_ID.0, + }; + let out_layout = mem_layouts.get(&out_mem_id).ok_or_else(|| { + PiCcsError::InvalidInput(format!( + "output binding: missing PlainMemLayout for mem_id={out_mem_id}" + )) + })?; + let expected_k = 1usize + .checked_shl(out_layout.d as u32) + .ok_or_else(|| PiCcsError::InvalidInput("output binding: 2^d overflow".into()))?; + if out_layout.k != expected_k { + return Err(PiCcsError::InvalidInput(format!( + "output binding: mem_id={out_mem_id} has k={}, but expected 2^d={} (d={})", + out_layout.k, expected_k, out_layout.d + ))); + } + let mut mem_ids: Vec = mem_layouts.keys().copied().collect(); + mem_ids.sort_unstable(); + let mem_idx = mem_ids + .iter() + .position(|&id| id == out_mem_id) + .ok_or_else(|| PiCcsError::InvalidInput("output binding: mem_id not in mem_layouts".into()))?; + Some(OutputBindingConfig::new(out_layout.d, output_claims).with_mem_idx(mem_idx)) + }; + + Ok(Rv32TraceWiringVerifier { + session, + ccs, + _layout: layout, + mem_layouts, + statement_initial_mem: initial_mem, + output_binding_cfg, + }) + } } From c7d9b3128c8d6fd627727e01bed2a12a44318476 Mon Sep 17 00:00:00 2001 From: Guillermo Valin Date: Wed, 18 Feb 2026 20:03:56 -0300 Subject: [PATCH 12/13] Enhance RV32 trace with pc_carry support and related adjustments - Added `pc_carry` field to `Rv32TraceLayout` and updated relevant functions to incorporate it. - Modified control residual calculations to account for `pc_carry` in branch and jump operations. - Updated trace width assertions in tests to reflect the addition of the new field. - Adjusted various functions to ensure compatibility with the new `pc_carry` logic. --- .../memory/route_a_claim_builders.rs | 46 ++++++++++++------- .../memory/route_a_terminal_checks.rs | 2 + .../memory/transcript_and_common.rs | 31 +++++++++---- crates/neo-memory/src/riscv/trace/air.rs | 1 + crates/neo-memory/src/riscv/trace/layout.rs | 5 +- crates/neo-memory/src/riscv/trace/witness.rs | 39 +++++++++++++++- .../tests/riscv_trace_wiring_ccs.rs | 6 +-- 7 files changed, 100 insertions(+), 30 deletions(-) diff --git a/crates/neo-fold/src/memory_sidecar/memory/route_a_claim_builders.rs b/crates/neo-fold/src/memory_sidecar/memory/route_a_claim_builders.rs index e57012f1..8f4aa62c 100644 --- a/crates/neo-fold/src/memory_sidecar/memory/route_a_claim_builders.rs +++ b/crates/neo-fold/src/memory_sidecar/memory/route_a_claim_builders.rs @@ -422,6 +422,7 @@ pub(crate) fn build_route_a_control_time_claims( trace.rd_val, trace.shout_val, trace.jalr_drop_bit, + trace.pc_carry, ]; let decode_col_ids = vec![ decode.op_lui, @@ -575,29 +576,42 @@ pub(crate) fn build_route_a_control_time_claims( ); let control_sparse = vec![ - main_col(trace.active)?, - main_col(trace.pc_before)?, - main_col(trace.pc_after)?, - main_col(trace.rs1_val)?, - main_col(trace.jalr_drop_bit)?, - main_col(trace.shout_val)?, - decode_col(decode.funct3_bit[0])?, - decode_col(decode.op_jal)?, - decode_col(decode.op_jalr)?, - decode_col(decode.op_branch)?, - decode_col(decode.imm_i)?, - decode_col(decode.imm_b)?, - decode_col(decode.imm_j)?, + main_col(trace.active)?, // 0 + main_col(trace.pc_before)?, // 1 + main_col(trace.pc_after)?, // 2 + main_col(trace.rs1_val)?, // 3 + main_col(trace.jalr_drop_bit)?, // 4 + main_col(trace.pc_carry)?, // 5 + main_col(trace.shout_val)?, // 6 + decode_col(decode.funct3_bit[0])?,// 7 + decode_col(decode.op_jal)?, // 8 + decode_col(decode.op_jalr)?, // 9 + decode_col(decode.op_branch)?, // 10 + decode_col(decode.imm_i)?, // 11 + decode_col(decode.imm_b)?, // 12 + decode_col(decode.imm_j)?, // 13 ]; - let control_weights = control_next_pc_control_weight_vector(r_cycle, 5); + let control_weights = control_next_pc_control_weight_vector(r_cycle, 7); let control_oracle = FormulaOracleSparseTime::new( control_sparse, 5, r_cycle, Box::new(move |vals: &[K]| { let residuals = control_next_pc_control_residuals( - vals[0], vals[1], vals[2], vals[3], vals[4], vals[10], vals[11], vals[12], vals[7], vals[8], vals[9], - vals[5], vals[6], + vals[0], // active + vals[1], // pc_before + vals[2], // pc_after + vals[3], // rs1_val + vals[4], // jalr_drop_bit + vals[5], // pc_carry + vals[11], // imm_i + vals[12], // imm_b + vals[13], // imm_j + vals[8], // op_jal + vals[9], // op_jalr + vals[10], // op_branch + vals[6], // shout_val + vals[7], // funct3_bit0 ); let mut weighted = K::ZERO; for (r, w) in residuals.iter().zip(control_weights.iter()) { diff --git a/crates/neo-fold/src/memory_sidecar/memory/route_a_terminal_checks.rs b/crates/neo-fold/src/memory_sidecar/memory/route_a_terminal_checks.rs index b7051636..d1fb3b9b 100644 --- a/crates/neo-fold/src/memory_sidecar/memory/route_a_terminal_checks.rs +++ b/crates/neo-fold/src/memory_sidecar/memory/route_a_terminal_checks.rs @@ -663,6 +663,7 @@ pub(crate) fn verify_route_a_control_terminals( let rs1_val = wp_open_col(trace.rs1_val)?; let rd_val = wp_open_col(trace.rd_val)?; let jalr_drop_bit = wp_open_col(trace.jalr_drop_bit)?; + let pc_carry = wp_open_col(trace.pc_carry)?; let shout_val = wp_open_col(trace.shout_val)?; let funct3_bits = [ decode_open_col(decode.funct3_bit[0])?, @@ -756,6 +757,7 @@ pub(crate) fn verify_route_a_control_terminals( pc_after, rs1_val, jalr_drop_bit, + pc_carry, imm_i, imm_b, imm_j, diff --git a/crates/neo-fold/src/memory_sidecar/memory/transcript_and_common.rs b/crates/neo-fold/src/memory_sidecar/memory/transcript_and_common.rs index 1c1cef55..67662635 100644 --- a/crates/neo-fold/src/memory_sidecar/memory/transcript_and_common.rs +++ b/crates/neo-fold/src/memory_sidecar/memory/transcript_and_common.rs @@ -824,7 +824,7 @@ pub(crate) fn rv32_trace_wb_columns(layout: &Rv32TraceLayout) -> Vec { vec![layout.active, layout.halted, layout.shout_has_lookup] } -pub(crate) const W2_FIELDS_RESIDUAL_COUNT: usize = 70; +pub(crate) const W2_FIELDS_RESIDUAL_COUNT: usize = 76; pub(crate) const W2_IMM_RESIDUAL_COUNT: usize = 4; #[inline] @@ -929,12 +929,14 @@ pub(crate) fn w2_alu_branch_lookup_residuals( rs2_decode: K, imm_i: K, imm_s: K, -) -> [K; 42] { +) -> [K; 48] { let op_lui = opcode_flags[0]; let op_auipc = opcode_flags[1]; let op_jal = opcode_flags[2]; let op_jalr = opcode_flags[3]; let op_branch = opcode_flags[4]; + let op_load = opcode_flags[5]; + let op_store = opcode_flags[6]; let op_alu_imm = opcode_flags[7]; let op_alu_reg = opcode_flags[8]; let op_misc_mem = opcode_flags[9]; @@ -966,7 +968,7 @@ pub(crate) fn w2_alu_branch_lookup_residuals( op_alu_reg * (shout_has_lookup - K::ONE), op_branch * (shout_has_lookup - K::ONE), (K::ONE - shout_has_lookup) * shout_table_id, - (op_alu_imm + op_alu_reg + op_branch) * (shout_lhs - rs1_val), + (op_alu_imm + op_alu_reg + op_branch + op_load + op_store) * (shout_lhs - rs1_val), alu_imm_shift_rhs_delta - shift_selector * (rs2_decode - imm_i), op_alu_imm * (shout_rhs - imm_i - alu_imm_shift_rhs_delta), op_alu_reg * (shout_rhs - rs2_val), @@ -1002,8 +1004,14 @@ pub(crate) fn w2_alu_branch_lookup_residuals( non_mem_ops * ram_has_read, non_mem_ops * ram_has_write, non_mem_ops * ram_addr, - opcode_flags[5] * (ram_addr - rs1_val - imm_i), - opcode_flags[6] * (ram_addr - rs1_val - imm_s), + op_load * (ram_addr - shout_val), + op_store * (ram_addr - shout_val), + op_load * (shout_has_lookup - K::ONE), + op_store * (shout_has_lookup - K::ONE), + op_load * (shout_rhs - imm_i), + op_store * (shout_rhs - imm_s), + op_load * (shout_table_id - K::from(F::from_u64(3))), + op_store * (shout_table_id - K::from(F::from_u64(3))), ] } @@ -1256,6 +1264,7 @@ pub(crate) fn control_next_pc_control_residuals( pc_after: K, rs1_val: K, jalr_drop_bit: K, + pc_carry: K, imm_i: K, imm_b: K, imm_j: K, @@ -1264,15 +1273,18 @@ pub(crate) fn control_next_pc_control_residuals( op_branch: K, shout_val: K, funct3_bit0: K, -) -> [K; 5] { +) -> [K; 7] { let four = K::from(F::from_u64(4)); + let two32 = K::from(F::from_u64(1u64 << 32)); let taken = control_branch_taken_from_bits(shout_val, funct3_bit0); [ - op_jal * (pc_after - pc_before - imm_j), - op_jalr * (pc_after - rs1_val - imm_i + jalr_drop_bit), - op_branch * (pc_after - pc_before - four - taken * (imm_b - four)), + op_jal * (pc_after + pc_carry * two32 - pc_before - imm_j), + op_jalr * (pc_after + pc_carry * two32 - rs1_val - imm_i + jalr_drop_bit), + op_branch * (pc_after + pc_carry * two32 - pc_before - four - taken * (imm_b - four)), op_jalr * jalr_drop_bit * (jalr_drop_bit - K::ONE), (active - op_jalr) * jalr_drop_bit, + pc_carry * (pc_carry - K::ONE), + (active - op_jal - op_jalr - op_branch) * pc_carry, ] } @@ -1328,6 +1340,7 @@ pub(crate) fn rv32_trace_wp_columns(layout: &Rv32TraceLayout) -> Vec { layout.shout_lhs, layout.shout_rhs, layout.jalr_drop_bit, + layout.pc_carry, ] } diff --git a/crates/neo-memory/src/riscv/trace/air.rs b/crates/neo-memory/src/riscv/trace/air.rs index 373954fe..3db95148 100644 --- a/crates/neo-memory/src/riscv/trace/air.rs +++ b/crates/neo-memory/src/riscv/trace/air.rs @@ -91,6 +91,7 @@ impl Rv32TraceAir { ("shout_lhs", l.shout_lhs), ("shout_rhs", l.shout_rhs), ("jalr_drop_bit", l.jalr_drop_bit), + ("pc_carry", l.pc_carry), ] { let e = Self::gated_zero(inv_active, col(c, i)); if !Self::is_zero(e) { diff --git a/crates/neo-memory/src/riscv/trace/layout.rs b/crates/neo-memory/src/riscv/trace/layout.rs index c3f52ca4..bbaf4fa1 100644 --- a/crates/neo-memory/src/riscv/trace/layout.rs +++ b/crates/neo-memory/src/riscv/trace/layout.rs @@ -30,6 +30,7 @@ pub struct Rv32TraceLayout { pub shout_lhs: usize, pub shout_rhs: usize, pub jalr_drop_bit: usize, + pub pc_carry: usize, } impl Rv32TraceLayout { @@ -65,8 +66,9 @@ impl Rv32TraceLayout { let shout_lhs = take(); let shout_rhs = take(); let jalr_drop_bit = take(); + let pc_carry = take(); - debug_assert_eq!(next, 21, "RV32 trace width drift after decode-helper offload"); + debug_assert_eq!(next, 22, "RV32 trace width drift after decode-helper offload"); Self { cols: next, @@ -91,6 +93,7 @@ impl Rv32TraceLayout { shout_lhs, shout_rhs, jalr_drop_bit, + pc_carry, } } } diff --git a/crates/neo-memory/src/riscv/trace/witness.rs b/crates/neo-memory/src/riscv/trace/witness.rs index 4d3ac949..dd9ec750 100644 --- a/crates/neo-memory/src/riscv/trace/witness.rs +++ b/crates/neo-memory/src/riscv/trace/witness.rs @@ -18,6 +18,26 @@ fn imm_i_from_word(instr_word: u32) -> u32 { sign_extend_to_u32((instr_word >> 20) & 0x0fff, 12) } +#[inline] +fn imm_j_from_word(instr_word: u32) -> u32 { + let imm20 = (instr_word >> 31) & 1; + let imm10_1 = (instr_word >> 21) & 0x3FF; + let imm11 = (instr_word >> 20) & 1; + let imm19_12 = (instr_word >> 12) & 0xFF; + let raw = (imm20 << 20) | (imm19_12 << 12) | (imm11 << 11) | (imm10_1 << 1); + sign_extend_to_u32(raw, 21) +} + +#[inline] +fn imm_b_from_word(instr_word: u32) -> u32 { + let imm12 = (instr_word >> 31) & 1; + let imm10_5 = (instr_word >> 25) & 0x3F; + let imm4_1 = (instr_word >> 8) & 0xF; + let imm11 = (instr_word >> 7) & 1; + let raw = (imm12 << 12) | (imm11 << 11) | (imm10_5 << 5) | (imm4_1 << 1); + sign_extend_to_u32(raw, 13) +} + #[derive(Clone, Debug)] pub struct Rv32TraceWitness { pub t: usize, @@ -63,11 +83,28 @@ impl Rv32TraceWitness { // this address is don't-care for bus semantics. wit.cols[layout.rd_addr][i] = F::from_u64(cols.rd[i] as u64); wit.cols[layout.rd_val][i] = F::from_u64(cols.rd_val[i]); - if cols.opcode[i] == 0x67 { + let opcode = cols.opcode[i]; + if opcode == 0x67 { + // JALR: pc = (rs1 + imm_i) & ~1 let rs1 = cols.rs1_val[i] as u32; let imm_i = imm_i_from_word(cols.instr_word[i]); let drop = rs1.wrapping_add(imm_i) & 1; wit.cols[layout.jalr_drop_bit][i] = F::from_u64(drop as u64); + let sum = (cols.rs1_val[i]) + (imm_i as u64); + wit.cols[layout.pc_carry][i] = F::from_u64(sum >> 32); + } else if opcode == 0x6F { + // JAL: pc = pc + imm_j + let imm_j = imm_j_from_word(cols.instr_word[i]); + let sum = (cols.pc_before[i]) + (imm_j as u64); + wit.cols[layout.pc_carry][i] = F::from_u64(sum >> 32); + } else if opcode == 0x63 { + // BRANCH: pc = taken ? pc + imm_b : pc + 4 + let taken = cols.pc_after[i] != cols.pc_before[i].wrapping_add(4); + if taken { + let imm_b = imm_b_from_word(cols.instr_word[i]); + let sum = (cols.pc_before[i]) + (imm_b as u64); + wit.cols[layout.pc_carry][i] = F::from_u64(sum >> 32); + } } } diff --git a/crates/neo-memory/tests/riscv_trace_wiring_ccs.rs b/crates/neo-memory/tests/riscv_trace_wiring_ccs.rs index 02b19274..c2e98079 100644 --- a/crates/neo-memory/tests/riscv_trace_wiring_ccs.rs +++ b/crates/neo-memory/tests/riscv_trace_wiring_ccs.rs @@ -15,12 +15,12 @@ fn rv32_trace_layout_removes_fixed_shout_table_selector_lanes() { let layout = Rv32TraceCcsLayout::new(/*t=*/ 4).expect("trace CCS layout"); assert_eq!( - layout.trace.cols, 21, - "trace width regression: expected 21 columns after shout_lhs/jalr_drop_bit hardening" + layout.trace.cols, 22, + "trace width regression: expected 22 columns after pc_carry addition" ); assert_eq!( layout.trace.cols, - layout.trace.jalr_drop_bit + 1, + layout.trace.pc_carry + 1, "trace layout should remain densely packed" ); } From f63c84e587e525bcdae7551c02f464959a4d3458 Mon Sep 17 00:00:00 2001 From: Guillermo Valin Date: Wed, 18 Feb 2026 21:07:44 -0300 Subject: [PATCH 13/13] Implement Goldilocks field ECALLs in RISC-V CPU - Added support for Goldilocks field operations (addition, subtraction, multiplication, and reading results) via ECALLs in the RiscvCpu struct. - Introduced new ECALL identifiers for Goldilocks operations in the mod.rs file. - Updated the Goldilocks arithmetic implementation to utilize ECALLs on RISC-V targets, while maintaining software fallbacks for other architectures. - Enhanced the handling of poseidon2 ECALLs to accommodate the new Goldilocks operations. --- .../memory/transcript_and_common.rs | 2 +- crates/neo-memory/src/riscv/lookups/cpu.rs | 84 +++++++++++- crates/neo-memory/src/riscv/lookups/mod.rs | 25 ++++ crates/neo-vm-trace/src/lib.rs | 15 +++ crates/nightstream-sdk/src/goldilocks.rs | 123 ++++++++++++++---- 5 files changed, 216 insertions(+), 33 deletions(-) diff --git a/crates/neo-fold/src/memory_sidecar/memory/transcript_and_common.rs b/crates/neo-fold/src/memory_sidecar/memory/transcript_and_common.rs index 67662635..1d5c087e 100644 --- a/crates/neo-fold/src/memory_sidecar/memory/transcript_and_common.rs +++ b/crates/neo-fold/src/memory_sidecar/memory/transcript_and_common.rs @@ -998,7 +998,7 @@ pub(crate) fn w2_alu_branch_lookup_residuals( opcode_flags[6] * rd_has_write, op_misc_mem * rd_has_write, op_system * rd_has_write, - active * (halted - op_system), + active * halted * (K::ONE - op_system), opcode_flags[5] * (ram_has_read - K::ONE), opcode_flags[6] * (ram_has_write - K::ONE), non_mem_ops * ram_has_read, diff --git a/crates/neo-memory/src/riscv/lookups/cpu.rs b/crates/neo-memory/src/riscv/lookups/cpu.rs index 94632327..3752483b 100644 --- a/crates/neo-memory/src/riscv/lookups/cpu.rs +++ b/crates/neo-memory/src/riscv/lookups/cpu.rs @@ -7,7 +7,10 @@ use super::decode::decode_instruction; use super::encode::encode_instruction; use super::isa::{BranchCondition, RiscvInstruction, RiscvMemOp, RiscvOpcode}; use super::tables::RiscvShoutTables; -use super::{POSEIDON2_ECALL_NUM, POSEIDON2_READ_ECALL_NUM}; +use super::{ + GL_ADD_ECALL_NUM, GL_MUL_ECALL_NUM, GL_READ_ECALL_NUM, GL_SUB_ECALL_NUM, + POSEIDON2_ECALL_NUM, POSEIDON2_READ_ECALL_NUM, +}; /// A RISC-V CPU that can be traced using Neo's VmCpu trait. /// @@ -31,6 +34,10 @@ pub struct RiscvCpu { poseidon2_pending: Option<[u32; 8]>, /// Index into `poseidon2_pending` for the next read ECALL. poseidon2_read_idx: usize, + /// Pending Goldilocks field operation result (2 × u32: lo, hi). + gl_pending: Option<[u32; 2]>, + /// Index into `gl_pending` for the next GL read ECALL. + gl_read_idx: usize, } impl RiscvCpu { @@ -46,6 +53,8 @@ impl RiscvCpu { program_base: 0, poseidon2_pending: None, poseidon2_read_idx: 0, + gl_pending: None, + gl_read_idx: 0, } } @@ -144,16 +153,70 @@ impl RiscvCpu { /// /// Called 8 times by the guest to retrieve the full 4-element digest. /// Each call returns the next word and advances the internal read index. - fn handle_poseidon2_read_ecall(&mut self) { + /// Uses `store_untraced` so the next instruction sees the updated a0 + /// without generating a CCS-visible register write event. + fn handle_poseidon2_read_ecall>(&mut self, twist: &mut T) { let words = self .poseidon2_pending .expect("poseidon2 read ECALL called without a pending digest"); let idx = self.poseidon2_read_idx; assert!(idx < 8, "poseidon2 read ECALL: all 8 words already consumed"); - let word = words[idx] as u64; - self.set_reg(10, word); // a0 + let word = self.mask_value(words[idx] as u64); + self.regs[10] = word; + twist.store_untraced(super::REG_ID, 10, word); self.poseidon2_read_idx = idx + 1; } + + /// Goldilocks field multiply ECALL handler. + fn handle_gl_mul_ecall(&mut self) { + let (a, b) = self.read_gl_operands(); + let result = Goldilocks::from_u64(a) * Goldilocks::from_u64(b); + self.store_gl_result(result.as_canonical_u64()); + } + + /// Goldilocks field add ECALL handler. + fn handle_gl_add_ecall(&mut self) { + let (a, b) = self.read_gl_operands(); + let result = Goldilocks::from_u64(a) + Goldilocks::from_u64(b); + self.store_gl_result(result.as_canonical_u64()); + } + + /// Goldilocks field subtract ECALL handler. + fn handle_gl_sub_ecall(&mut self) { + let (a, b) = self.read_gl_operands(); + let result = Goldilocks::from_u64(a) - Goldilocks::from_u64(b); + self.store_gl_result(result.as_canonical_u64()); + } + + /// Read two 64-bit Goldilocks operands from registers a1..a4. + fn read_gl_operands(&self) -> (u64, u64) { + let a_lo = self.get_reg(11) as u32; // a1 + let a_hi = self.get_reg(12) as u32; // a2 + let b_lo = self.get_reg(13) as u32; // a3 + let b_hi = self.get_reg(14) as u32; // a4 + let a = (a_lo as u64) | ((a_hi as u64) << 32); + let b = (b_lo as u64) | ((b_hi as u64) << 32); + (a, b) + } + + /// Store a 64-bit GL result as two u32 words in CPU internal state. + fn store_gl_result(&mut self, val: u64) { + self.gl_pending = Some([val as u32, (val >> 32) as u32]); + self.gl_read_idx = 0; + } + + /// GL read ECALL handler: return next u32 word of the pending result in a0. + fn handle_gl_read_ecall>(&mut self, twist: &mut T) { + let words = self + .gl_pending + .expect("GL read ECALL called without a pending result"); + let idx = self.gl_read_idx; + assert!(idx < 2, "GL read ECALL: both words already consumed"); + let word = self.mask_value(words[idx] as u64); + self.regs[10] = word; + twist.store_untraced(super::REG_ID, 10, word); + self.gl_read_idx = idx + 1; + } } impl neo_vm_trace::VmCpu for RiscvCpu { @@ -653,9 +716,18 @@ impl neo_vm_trace::VmCpu for RiscvCpu { if call_id == POSEIDON2_ECALL_NUM { self.handle_poseidon2_ecall(twist); } else if call_id == POSEIDON2_READ_ECALL_NUM { - self.handle_poseidon2_read_ecall(); + self.handle_poseidon2_read_ecall(twist); + } else if call_id == GL_MUL_ECALL_NUM { + self.handle_gl_mul_ecall(); + } else if call_id == GL_ADD_ECALL_NUM { + self.handle_gl_add_ecall(); + } else if call_id == GL_SUB_ECALL_NUM { + self.handle_gl_sub_ecall(); + } else if call_id == GL_READ_ECALL_NUM { + self.handle_gl_read_ecall(twist); + } else { + self.handle_ecall(); } - self.handle_ecall(); } RiscvInstruction::Ebreak => { diff --git a/crates/neo-memory/src/riscv/lookups/mod.rs b/crates/neo-memory/src/riscv/lookups/mod.rs index 6d599496..e0d5d71f 100644 --- a/crates/neo-memory/src/riscv/lookups/mod.rs +++ b/crates/neo-memory/src/riscv/lookups/mod.rs @@ -122,6 +122,31 @@ pub const POSEIDON2_ECALL_NUM: u32 = 0x504F53; /// to retrieve the full digest. pub const POSEIDON2_READ_ECALL_NUM: u32 = 0x80504F53; +/// Goldilocks field multiply ECALL identifier ("GLM"). +/// +/// ABI: a0 = GL_MUL_ECALL_NUM, a1 = a_lo, a2 = a_hi, a3 = b_lo, a4 = b_hi. +/// Computes (a * b) mod p and stores the 64-bit result in CPU state. +/// Retrieve via GL_READ_ECALL_NUM (2 calls for lo/hi words). +pub const GL_MUL_ECALL_NUM: u32 = 0x474C4D; + +/// Goldilocks field add ECALL identifier ("GLA"). +/// +/// ABI: a0 = GL_ADD_ECALL_NUM, a1 = a_lo, a2 = a_hi, a3 = b_lo, a4 = b_hi. +/// Computes (a + b) mod p and stores the 64-bit result in CPU state. +pub const GL_ADD_ECALL_NUM: u32 = 0x474C41; + +/// Goldilocks field subtract ECALL identifier ("GLS"). +/// +/// ABI: a0 = GL_SUB_ECALL_NUM, a1 = a_lo, a2 = a_hi, a3 = b_lo, a4 = b_hi. +/// Computes (a - b) mod p and stores the 64-bit result in CPU state. +pub const GL_SUB_ECALL_NUM: u32 = 0x474C53; + +/// Goldilocks field operation read ECALL identifier (bit 31 set on "GLR"). +/// +/// ABI: a0 = GL_READ_ECALL_NUM. Returns the next u32 word of the +/// pending field operation result in register a0. Call 2 times (lo/hi). +pub const GL_READ_ECALL_NUM: u32 = 0x80474C52; + pub use alu::{compute_op, lookup_entry}; pub use bits::{interleave_bits, uninterleave_bits}; pub use cpu::RiscvCpu; diff --git a/crates/neo-vm-trace/src/lib.rs b/crates/neo-vm-trace/src/lib.rs index e7750d77..b1161d73 100644 --- a/crates/neo-vm-trace/src/lib.rs +++ b/crates/neo-vm-trace/src/lib.rs @@ -350,6 +350,17 @@ pub trait Twist { fn load_untraced(&mut self, twist_id: TwistId, addr: Addr) -> Word { self.load(twist_id, addr) } + + /// Store a value to memory without recording a Twist event. + /// + /// Used by ECALL precompiles that write results back to the register file + /// without generating per-step Twist bus entries. The CCS treats the ECALL + /// row as having no `rd` write; subsequent instructions see the updated + /// value through normal Twist reads. + #[inline] + fn store_untraced(&mut self, twist_id: TwistId, addr: Addr, value: Word) { + self.store(twist_id, addr, value); + } } /// A tracing wrapper around any `Twist` implementation. @@ -445,6 +456,10 @@ where fn load_untraced(&mut self, twist_id: TwistId, addr: Addr) -> Word { self.inner.load(twist_id, addr) } + + fn store_untraced(&mut self, twist_id: TwistId, addr: Addr, value: Word) { + self.inner.store(twist_id, addr, value); + } } // ============================================================================ diff --git a/crates/nightstream-sdk/src/goldilocks.rs b/crates/nightstream-sdk/src/goldilocks.rs index a25193c6..61f6d2c7 100644 --- a/crates/nightstream-sdk/src/goldilocks.rs +++ b/crates/nightstream-sdk/src/goldilocks.rs @@ -1,8 +1,10 @@ -//! Software Goldilocks field arithmetic for RISC-V guests. +//! Goldilocks field arithmetic for Nightstream guests. //! //! Field: p = 2^64 - 2^32 + 1 = 0xFFFF_FFFF_0000_0001 //! -//! All values are canonical: in [0, p). +//! On RISC-V targets, `gl_mul`, `gl_add`, and `gl_sub` use ECALL precompiles +//! (3 ECALLs each: 1 compute + 2 reads for the 64-bit result). On other targets +//! the software implementations are used. #![allow(dead_code)] @@ -12,21 +14,75 @@ pub const GL_P: u64 = 0xFFFF_FFFF_0000_0001; /// A Goldilocks field element digest: 4 elements (32 bytes). pub type GlDigest = [u64; 4]; -/// Reduce a u128 value mod p using the identity 2^64 ≡ 2^32 - 1 (mod p). -/// Avoids 128-bit division. +// --------------------------------------------------------------------------- +// ECALL-based implementations (RISC-V only) +// --------------------------------------------------------------------------- + +#[cfg(any(target_arch = "riscv32", target_arch = "riscv64"))] +const GL_MUL_ECALL_NUM: u32 = 0x474C4D; +#[cfg(any(target_arch = "riscv32", target_arch = "riscv64"))] +const GL_ADD_ECALL_NUM: u32 = 0x474C41; +#[cfg(any(target_arch = "riscv32", target_arch = "riscv64"))] +const GL_SUB_ECALL_NUM: u32 = 0x474C53; +#[cfg(any(target_arch = "riscv32", target_arch = "riscv64"))] +const GL_READ_ECALL_NUM: u32 = 0x80474C52; + +#[cfg(any(target_arch = "riscv32", target_arch = "riscv64"))] +#[inline] +fn gl_ecall_compute(ecall_id: u32, a: u64, b: u64) { + let a_lo = a as u32; + let a_hi = (a >> 32) as u32; + let b_lo = b as u32; + let b_hi = (b >> 32) as u32; + unsafe { + core::arch::asm!( + "ecall", + in("a0") ecall_id, + in("a1") a_lo, + in("a2") a_hi, + in("a3") b_lo, + in("a4") b_hi, + options(nostack), + ); + } +} + +#[cfg(any(target_arch = "riscv32", target_arch = "riscv64"))] +#[inline] +fn gl_ecall_read() -> u32 { + let result: u32; + unsafe { + core::arch::asm!( + "ecall", + inout("a0") GL_READ_ECALL_NUM => result, + options(nostack), + ); + } + result +} + +#[cfg(any(target_arch = "riscv32", target_arch = "riscv64"))] +#[inline] +fn gl_ecall_op(ecall_id: u32, a: u64, b: u64) -> u64 { + gl_ecall_compute(ecall_id, a, b); + let lo = gl_ecall_read() as u64; + let hi = gl_ecall_read() as u64; + lo | (hi << 32) +} + +// --------------------------------------------------------------------------- +// Software fallbacks (non-RISC-V) +// --------------------------------------------------------------------------- + +#[cfg(not(any(target_arch = "riscv32", target_arch = "riscv64")))] #[inline] fn reduce128(x: u128) -> u64 { let lo = x as u64; let hi = (x >> 64) as u64; - - // x ≡ lo + hi * (2^32 - 1) (mod p) let hi_times_correction = (hi as u128) * 0xFFFF_FFFFu128; let sum = lo as u128 + hi_times_correction; - let lo2 = sum as u64; let hi2 = (sum >> 64) as u64; - - // hi2 is 0 or 1; do one more pass let (mut result, overflow) = lo2.overflowing_add(hi2.wrapping_mul(0xFFFF_FFFF)); if overflow || result >= GL_P { result = result.wrapping_sub(GL_P); @@ -43,26 +99,36 @@ pub const GL_ONE: u64 = 1; /// Field addition: (a + b) mod p. #[inline] pub fn gl_add(a: u64, b: u64) -> u64 { - let sum = a as u128 + b as u128; - let s = sum as u64; - let carry = (sum >> 64) as u64; - // carry is 0 or 1 - let (mut r, overflow) = s.overflowing_add(carry.wrapping_mul(0xFFFF_FFFF)); - if overflow || r >= GL_P { - r = r.wrapping_sub(GL_P); + #[cfg(any(target_arch = "riscv32", target_arch = "riscv64"))] + { gl_ecall_op(GL_ADD_ECALL_NUM, a, b) } + + #[cfg(not(any(target_arch = "riscv32", target_arch = "riscv64")))] + { + let sum = a as u128 + b as u128; + let s = sum as u64; + let carry = (sum >> 64) as u64; + let (mut r, overflow) = s.overflowing_add(carry.wrapping_mul(0xFFFF_FFFF)); + if overflow || r >= GL_P { + r = r.wrapping_sub(GL_P); + } + r } - r } /// Field subtraction: (a - b) mod p. #[inline] pub fn gl_sub(a: u64, b: u64) -> u64 { - if a >= b { - let diff = a - b; - if diff >= GL_P { diff - GL_P } else { diff } - } else { - // a < b, so result = p - (b - a) - GL_P - (b - a) + #[cfg(any(target_arch = "riscv32", target_arch = "riscv64"))] + { gl_ecall_op(GL_SUB_ECALL_NUM, a, b) } + + #[cfg(not(any(target_arch = "riscv32", target_arch = "riscv64")))] + { + if a >= b { + let diff = a - b; + if diff >= GL_P { diff - GL_P } else { diff } + } else { + GL_P - (b - a) + } } } @@ -75,8 +141,14 @@ pub fn gl_neg(a: u64) -> u64 { /// Field multiplication: (a * b) mod p. #[inline] pub fn gl_mul(a: u64, b: u64) -> u64 { - let prod = (a as u128) * (b as u128); - reduce128(prod) + #[cfg(any(target_arch = "riscv32", target_arch = "riscv64"))] + { gl_ecall_op(GL_MUL_ECALL_NUM, a, b) } + + #[cfg(not(any(target_arch = "riscv32", target_arch = "riscv64")))] + { + let prod = (a as u128) * (b as u128); + reduce128(prod) + } } /// Field squaring: (a * a) mod p. @@ -103,7 +175,6 @@ pub fn gl_pow(mut base: u64, mut exp: u64) -> u64 { /// Panics (halts) if a == 0. pub fn gl_inv(a: u64) -> u64 { debug_assert!(a != 0, "inverse of zero"); - // p - 2 = 0xFFFF_FFFE_FFFF_FFFF gl_pow(a, GL_P - 2) }