diff --git a/crates/neo-ajtai/src/types.rs b/crates/neo-ajtai/src/types.rs index 9f0b91e3..07527f3e 100644 --- a/crates/neo-ajtai/src/types.rs +++ b/crates/neo-ajtai/src/types.rs @@ -1,6 +1,7 @@ use p3_field::PrimeCharacteristicRing; use p3_goldilocks::Goldilocks as Fq; -use serde::{Deserialize, Serialize}; +use serde::de::Error as _; +use serde::{Deserialize, Deserializer, Serialize}; /// Public parameters for Ajtai: M ∈ R_q^{κ×m}, stored row-major. #[derive(Clone, Debug, Serialize, Deserialize)] @@ -13,7 +14,7 @@ pub struct PP { } /// Commitment c ∈ F_q^{d×κ}, stored as column-major flat matrix (κ columns, each length d). -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize)] pub struct Commitment { pub d: usize, pub kappa: usize, @@ -22,6 +23,26 @@ pub struct Commitment { } impl Commitment { + #[inline] + fn validate_shape(d: usize, kappa: usize, data_len: usize) -> Result<(), String> { + let expected_d = neo_math::ring::D; + if d != expected_d { + return Err(format!("invalid Commitment.d: expected {expected_d}, got {d}")); + } + + let expected_len = d + .checked_mul(kappa) + .ok_or_else(|| format!("invalid Commitment shape: d*kappa overflow (d={d}, kappa={kappa})"))?; + if data_len != expected_len { + return Err(format!( + "invalid Commitment shape: data.len()={} but d*kappa={expected_len}", + data_len + )); + } + + Ok(()) + } + pub fn zeros(d: usize, kappa: usize) -> Self { Self { d, @@ -48,3 +69,26 @@ impl Commitment { } } } + +impl<'de> Deserialize<'de> for Commitment { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + #[derive(Deserialize)] + struct CommitmentWire { + d: usize, + kappa: usize, + data: Vec, + } + + let wire = CommitmentWire::deserialize(deserializer)?; + Commitment::validate_shape(wire.d, wire.kappa, wire.data.len()).map_err(D::Error::custom)?; + + Ok(Self { + d: wire.d, + kappa: wire.kappa, + data: wire.data, + }) + } +} diff --git a/crates/neo-ajtai/tests/commitment_deserialize_invariants.rs b/crates/neo-ajtai/tests/commitment_deserialize_invariants.rs new file mode 100644 index 00000000..9b743ec2 --- /dev/null +++ b/crates/neo-ajtai/tests/commitment_deserialize_invariants.rs @@ -0,0 +1,41 @@ +use neo_ajtai::Commitment; +use neo_math::ring::D; +use serde_json::Value; + +fn valid_commitment_json(kappa: usize) -> Value { + let c = Commitment::zeros(D, kappa); + serde_json::to_value(&c).expect("serialize valid commitment") +} + +#[test] +fn commitment_deserialize_rejects_wrong_d() { + let mut value = valid_commitment_json(2); + value["d"] = serde_json::json!(D - 1); + + let err = serde_json::from_value::(value).expect_err("wrong d must be rejected"); + let msg = err.to_string(); + assert!(msg.contains("invalid Commitment.d"), "unexpected error message: {msg}"); +} + +#[test] +fn commitment_deserialize_rejects_data_len_mismatch() { + let mut value = valid_commitment_json(2); + let data = value + .get_mut("data") + .and_then(Value::as_array_mut) + .expect("commitment.data array"); + data.pop().expect("non-empty data"); + + let err = serde_json::from_value::(value).expect_err("shape mismatch must be rejected"); + let msg = err.to_string(); + assert!(msg.contains("data.len()"), "unexpected error message: {msg}"); +} + +#[test] +fn commitment_deserialize_accepts_valid_shape() { + let value = valid_commitment_json(3); + let c: Commitment = serde_json::from_value(value).expect("valid shape should deserialize"); + assert_eq!(c.d, D); + assert_eq!(c.kappa, 3); + assert_eq!(c.data.len(), D * 3); +} diff --git a/crates/neo-fold/src/memory_sidecar/memory/route_a_claim_builders.rs b/crates/neo-fold/src/memory_sidecar/memory/route_a_claim_builders.rs index e57012f1..8f4aa62c 100644 --- a/crates/neo-fold/src/memory_sidecar/memory/route_a_claim_builders.rs +++ b/crates/neo-fold/src/memory_sidecar/memory/route_a_claim_builders.rs @@ -422,6 +422,7 @@ pub(crate) fn build_route_a_control_time_claims( trace.rd_val, trace.shout_val, trace.jalr_drop_bit, + trace.pc_carry, ]; let decode_col_ids = vec![ decode.op_lui, @@ -575,29 +576,42 @@ pub(crate) fn build_route_a_control_time_claims( ); let control_sparse = vec![ - main_col(trace.active)?, - main_col(trace.pc_before)?, - main_col(trace.pc_after)?, - main_col(trace.rs1_val)?, - main_col(trace.jalr_drop_bit)?, - main_col(trace.shout_val)?, - decode_col(decode.funct3_bit[0])?, - decode_col(decode.op_jal)?, - decode_col(decode.op_jalr)?, - decode_col(decode.op_branch)?, - decode_col(decode.imm_i)?, - decode_col(decode.imm_b)?, - decode_col(decode.imm_j)?, + main_col(trace.active)?, // 0 + main_col(trace.pc_before)?, // 1 + main_col(trace.pc_after)?, // 2 + main_col(trace.rs1_val)?, // 3 + main_col(trace.jalr_drop_bit)?, // 4 + main_col(trace.pc_carry)?, // 5 + main_col(trace.shout_val)?, // 6 + decode_col(decode.funct3_bit[0])?,// 7 + decode_col(decode.op_jal)?, // 8 + decode_col(decode.op_jalr)?, // 9 + decode_col(decode.op_branch)?, // 10 + decode_col(decode.imm_i)?, // 11 + decode_col(decode.imm_b)?, // 12 + decode_col(decode.imm_j)?, // 13 ]; - let control_weights = control_next_pc_control_weight_vector(r_cycle, 5); + let control_weights = control_next_pc_control_weight_vector(r_cycle, 7); let control_oracle = FormulaOracleSparseTime::new( control_sparse, 5, r_cycle, Box::new(move |vals: &[K]| { let residuals = control_next_pc_control_residuals( - vals[0], vals[1], vals[2], vals[3], vals[4], vals[10], vals[11], vals[12], vals[7], vals[8], vals[9], - vals[5], vals[6], + vals[0], // active + vals[1], // pc_before + vals[2], // pc_after + vals[3], // rs1_val + vals[4], // jalr_drop_bit + vals[5], // pc_carry + vals[11], // imm_i + vals[12], // imm_b + vals[13], // imm_j + vals[8], // op_jal + vals[9], // op_jalr + vals[10], // op_branch + vals[6], // shout_val + vals[7], // funct3_bit0 ); let mut weighted = K::ZERO; for (r, w) in residuals.iter().zip(control_weights.iter()) { diff --git a/crates/neo-fold/src/memory_sidecar/memory/route_a_terminal_checks.rs b/crates/neo-fold/src/memory_sidecar/memory/route_a_terminal_checks.rs index b7051636..d1fb3b9b 100644 --- a/crates/neo-fold/src/memory_sidecar/memory/route_a_terminal_checks.rs +++ b/crates/neo-fold/src/memory_sidecar/memory/route_a_terminal_checks.rs @@ -663,6 +663,7 @@ pub(crate) fn verify_route_a_control_terminals( let rs1_val = wp_open_col(trace.rs1_val)?; let rd_val = wp_open_col(trace.rd_val)?; let jalr_drop_bit = wp_open_col(trace.jalr_drop_bit)?; + let pc_carry = wp_open_col(trace.pc_carry)?; let shout_val = wp_open_col(trace.shout_val)?; let funct3_bits = [ decode_open_col(decode.funct3_bit[0])?, @@ -756,6 +757,7 @@ pub(crate) fn verify_route_a_control_terminals( pc_after, rs1_val, jalr_drop_bit, + pc_carry, imm_i, imm_b, imm_j, diff --git a/crates/neo-fold/src/memory_sidecar/memory/transcript_and_common.rs b/crates/neo-fold/src/memory_sidecar/memory/transcript_and_common.rs index 1c1cef55..1d5c087e 100644 --- a/crates/neo-fold/src/memory_sidecar/memory/transcript_and_common.rs +++ b/crates/neo-fold/src/memory_sidecar/memory/transcript_and_common.rs @@ -824,7 +824,7 @@ pub(crate) fn rv32_trace_wb_columns(layout: &Rv32TraceLayout) -> Vec { vec![layout.active, layout.halted, layout.shout_has_lookup] } -pub(crate) const W2_FIELDS_RESIDUAL_COUNT: usize = 70; +pub(crate) const W2_FIELDS_RESIDUAL_COUNT: usize = 76; pub(crate) const W2_IMM_RESIDUAL_COUNT: usize = 4; #[inline] @@ -929,12 +929,14 @@ pub(crate) fn w2_alu_branch_lookup_residuals( rs2_decode: K, imm_i: K, imm_s: K, -) -> [K; 42] { +) -> [K; 48] { let op_lui = opcode_flags[0]; let op_auipc = opcode_flags[1]; let op_jal = opcode_flags[2]; let op_jalr = opcode_flags[3]; let op_branch = opcode_flags[4]; + let op_load = opcode_flags[5]; + let op_store = opcode_flags[6]; let op_alu_imm = opcode_flags[7]; let op_alu_reg = opcode_flags[8]; let op_misc_mem = opcode_flags[9]; @@ -966,7 +968,7 @@ pub(crate) fn w2_alu_branch_lookup_residuals( op_alu_reg * (shout_has_lookup - K::ONE), op_branch * (shout_has_lookup - K::ONE), (K::ONE - shout_has_lookup) * shout_table_id, - (op_alu_imm + op_alu_reg + op_branch) * (shout_lhs - rs1_val), + (op_alu_imm + op_alu_reg + op_branch + op_load + op_store) * (shout_lhs - rs1_val), alu_imm_shift_rhs_delta - shift_selector * (rs2_decode - imm_i), op_alu_imm * (shout_rhs - imm_i - alu_imm_shift_rhs_delta), op_alu_reg * (shout_rhs - rs2_val), @@ -996,14 +998,20 @@ pub(crate) fn w2_alu_branch_lookup_residuals( opcode_flags[6] * rd_has_write, op_misc_mem * rd_has_write, op_system * rd_has_write, - active * (halted - op_system), + active * halted * (K::ONE - op_system), opcode_flags[5] * (ram_has_read - K::ONE), opcode_flags[6] * (ram_has_write - K::ONE), non_mem_ops * ram_has_read, non_mem_ops * ram_has_write, non_mem_ops * ram_addr, - opcode_flags[5] * (ram_addr - rs1_val - imm_i), - opcode_flags[6] * (ram_addr - rs1_val - imm_s), + op_load * (ram_addr - shout_val), + op_store * (ram_addr - shout_val), + op_load * (shout_has_lookup - K::ONE), + op_store * (shout_has_lookup - K::ONE), + op_load * (shout_rhs - imm_i), + op_store * (shout_rhs - imm_s), + op_load * (shout_table_id - K::from(F::from_u64(3))), + op_store * (shout_table_id - K::from(F::from_u64(3))), ] } @@ -1256,6 +1264,7 @@ pub(crate) fn control_next_pc_control_residuals( pc_after: K, rs1_val: K, jalr_drop_bit: K, + pc_carry: K, imm_i: K, imm_b: K, imm_j: K, @@ -1264,15 +1273,18 @@ pub(crate) fn control_next_pc_control_residuals( op_branch: K, shout_val: K, funct3_bit0: K, -) -> [K; 5] { +) -> [K; 7] { let four = K::from(F::from_u64(4)); + let two32 = K::from(F::from_u64(1u64 << 32)); let taken = control_branch_taken_from_bits(shout_val, funct3_bit0); [ - op_jal * (pc_after - pc_before - imm_j), - op_jalr * (pc_after - rs1_val - imm_i + jalr_drop_bit), - op_branch * (pc_after - pc_before - four - taken * (imm_b - four)), + op_jal * (pc_after + pc_carry * two32 - pc_before - imm_j), + op_jalr * (pc_after + pc_carry * two32 - rs1_val - imm_i + jalr_drop_bit), + op_branch * (pc_after + pc_carry * two32 - pc_before - four - taken * (imm_b - four)), op_jalr * jalr_drop_bit * (jalr_drop_bit - K::ONE), (active - op_jalr) * jalr_drop_bit, + pc_carry * (pc_carry - K::ONE), + (active - op_jal - op_jalr - op_branch) * pc_carry, ] } @@ -1328,6 +1340,7 @@ pub(crate) fn rv32_trace_wp_columns(layout: &Rv32TraceLayout) -> Vec { layout.shout_lhs, layout.shout_rhs, layout.jalr_drop_bit, + layout.pc_carry, ] } diff --git a/crates/neo-fold/src/memory_sidecar/route_a_time.rs b/crates/neo-fold/src/memory_sidecar/route_a_time.rs index c5c450a1..fcf7a85d 100644 --- a/crates/neo-fold/src/memory_sidecar/route_a_time.rs +++ b/crates/neo-fold/src/memory_sidecar/route_a_time.rs @@ -349,7 +349,7 @@ pub fn prove_route_a_batched_time( let proof = BatchedTimeProof { claimed_sums: claimed_sums.clone(), degree_bounds: degree_bounds.clone(), - labels: labels.clone(), + labels: labels.iter().map(|label| label.to_vec()).collect(), round_polys: per_claim_results .iter() .map(|r| r.round_polys.clone()) @@ -449,7 +449,7 @@ pub fn verify_route_a_batched_time( ))); } for (i, (got, exp)) in proof.labels.iter().zip(expected_labels.iter()).enumerate() { - if (*got as &[u8]) != *exp { + if got.as_slice() != *exp { return Err(PiCcsError::ProtocolError(format!( "step {}: batched_time label mismatch at claim {}", step_idx, i diff --git a/crates/neo-fold/src/riscv_trace_shard.rs b/crates/neo-fold/src/riscv_trace_shard.rs index 75ec868f..66324a2b 100644 --- a/crates/neo-fold/src/riscv_trace_shard.rs +++ b/crates/neo-fold/src/riscv_trace_shard.rs @@ -19,7 +19,7 @@ use crate::pi_ccs::FoldingMode; use crate::session::FoldingSession; use crate::shard::{ShardProof, StepLinkingConfig}; use crate::PiCcsError; -use neo_ajtai::AjtaiSModule; +use neo_ajtai::{AjtaiSModule, Commitment as Cmt}; use neo_ccs::relations::{McsInstance, McsWitness}; use neo_ccs::traits::SModuleHomomorphism; use neo_ccs::CcsStructure; @@ -43,7 +43,7 @@ use neo_memory::riscv::trace::{ rv32_width_sidecar_witness_from_exec_table, Rv32DecodeSidecarLayout, Rv32WidthSidecarLayout, TwistLaneOverTime, }; use neo_memory::witness::{LutInstance, LutWitness, MemInstance, MemWitness, StepWitnessBundle}; -use neo_memory::{LutTableSpec, MemInit, R1csCpu}; +use neo_memory::{mem_init_from_initial_mem, LutTableSpec, MemInit, R1csCpu}; use neo_params::NeoParams; use neo_vm_trace::{ShoutEvent, ShoutId, StepTrace, Twist as _, TwistOpKind, VmTrace}; use p3_field::PrimeCharacteristicRing; @@ -1429,6 +1429,8 @@ impl Rv32TraceWiring { used_mem_ids, used_shout_table_ids, output_binding_cfg, + ram_d, + width_lookup_addr_d, prove_duration, prove_phase_durations, verify_duration: None, @@ -1446,6 +1448,8 @@ pub struct Rv32TraceWiringRun { used_mem_ids: Vec, used_shout_table_ids: Vec, output_binding_cfg: Option, + ram_d: usize, + width_lookup_addr_d: usize, prove_duration: Duration, prove_phase_durations: Rv32TraceProvePhaseDurations, verify_duration: Option, @@ -1539,4 +1543,384 @@ impl Rv32TraceWiringRun { pub fn steps_public(&self) -> Vec> { self.session.steps_public() } + + /// Rows per trace step (= layout.t), needed by [`Rv32TraceWiring::build_verifier`]. + pub fn step_rows(&self) -> usize { + self.layout.t + } + + /// RAM address width (bits) used during this proving run. + pub fn ram_d(&self) -> usize { + self.ram_d + } + + /// Width-lookup address bits used during this proving run (0 when width lookup is absent). + pub fn width_lookup_addr_d(&self) -> usize { + self.width_lookup_addr_d + } +} + +/// Enforce that the public statement initial memory matches chunk 0's `MemInstance.init`. +fn enforce_chunk0_mem_init_matches_statement( + mem_layouts: &HashMap, + statement_initial_mem: &HashMap<(u32, u64), F>, + steps: &[neo_memory::witness::StepInstanceBundle], +) -> Result<(), PiCcsError> { + let chunk0 = steps + .first() + .ok_or_else(|| PiCcsError::InvalidInput("no steps provided".into()))?; + + let mut mem_ids: Vec = mem_layouts.keys().copied().collect(); + mem_ids.sort_unstable(); + + if chunk0.mem_insts.len() != mem_ids.len() { + return Err(PiCcsError::InvalidInput(format!( + "mem instance count mismatch: chunk0 has {}, but mem_layouts has {}", + chunk0.mem_insts.len(), + mem_ids.len() + ))); + } + + for (idx, mem_id) in mem_ids.into_iter().enumerate() { + let layout = mem_layouts + .get(&mem_id) + .ok_or_else(|| PiCcsError::InvalidInput(format!("missing PlainMemLayout for mem_id={mem_id}")))?; + let expected = mem_init_from_initial_mem(mem_id, layout.k, statement_initial_mem)?; + let got = &chunk0.mem_insts[idx].init; + if *got != expected { + return Err(PiCcsError::InvalidInput(format!( + "chunk0 MemInstance.init mismatch for mem_id={mem_id}" + ))); + } + } + + Ok(()) +} + +/// Verification context for RV32 trace-wiring proofs. +/// +/// Created by [`Rv32TraceWiring::build_verifier`]. Contains the CCS and session +/// state required to verify proofs without re-running the guest. +pub struct Rv32TraceWiringVerifier { + session: FoldingSession, + ccs: CcsStructure, + _layout: Rv32TraceCcsLayout, + mem_layouts: HashMap, + statement_initial_mem: HashMap<(u32, u64), F>, + output_binding_cfg: Option, +} + +impl Rv32TraceWiringVerifier { + /// Verify a `ShardProof` against externally supplied public step instances. + pub fn verify( + &self, + proof: &ShardProof, + steps_public: &[neo_memory::witness::StepInstanceBundle], + ) -> Result { + enforce_chunk0_mem_init_matches_statement(&self.mem_layouts, &self.statement_initial_mem, steps_public)?; + + if let Some(ob_cfg) = &self.output_binding_cfg { + self.session + .verify_with_external_steps_and_output_binding(&self.ccs, steps_public, proof, ob_cfg) + } else { + self.session + .verify_with_external_steps(&self.ccs, steps_public, proof) + } + } +} + +impl Rv32TraceWiring { + /// Build only the verification context (CCS + session) without executing the + /// program or proving. + /// + /// The caller must supply execution-dependent sizing values that were recorded + /// during the proving run: + /// - `step_rows`: rows per trace step (= [`Rv32TraceWiringRun::step_rows`]). + /// - `ram_d`: RAM address width in bits (= [`Rv32TraceWiringRun::ram_d`]). + /// - `width_lookup_addr_d`: width-lookup address bits + /// (= [`Rv32TraceWiringRun::width_lookup_addr_d`]; pass 0 when width lookup + /// is absent). + pub fn build_verifier( + self, + step_rows: usize, + ram_d: usize, + width_lookup_addr_d: usize, + ) -> Result { + if self.xlen != 32 { + return Err(PiCcsError::InvalidInput(format!( + "RV32 trace wiring requires xlen == 32 (got {})", + self.xlen + ))); + } + if self.program_bytes.is_empty() { + return Err(PiCcsError::InvalidInput("program_bytes must be non-empty".into())); + } + if self.program_bytes.len() % 4 != 0 { + return Err(PiCcsError::InvalidInput( + "program_bytes must be 4-byte aligned (RVC is not supported)".into(), + )); + } + if step_rows == 0 { + return Err(PiCcsError::InvalidInput("step_rows must be non-zero".into())); + } + + let program = decode_program(&self.program_bytes) + .map_err(|e| PiCcsError::InvalidInput(format!("decode_program failed: {e}")))?; + + let (prog_layout, prog_init_words) = + prog_rom_layout_and_init_words::(PROG_ID, /*base_addr=*/ 0, &self.program_bytes) + .map_err(|e| PiCcsError::InvalidInput(format!("prog_rom_layout_and_init_words failed: {e}")))?; + + let ram_init_map = self.ram_init.clone(); + let reg_init_map = self.reg_init.clone(); + let output_claims = self.output_claims.clone(); + let output_target = self.output_target; + + let wants_ram_output = matches!(output_target, OutputTarget::Ram) && !output_claims.is_empty(); + let include_ram_sidecar = + program_requires_ram_sidecar(&program) || !ram_init_map.is_empty() || wants_ram_output; + + let ram_k = 1usize + .checked_shl(ram_d as u32) + .ok_or_else(|| PiCcsError::InvalidInput(format!("RAM address width too large: d={ram_d}")))?; + + let mut mem_layouts: HashMap = HashMap::from([ + ( + REG_ID.0, + PlainMemLayout { + k: 32, + d: 5, + n_side: 2, + lanes: 2, + }, + ), + (PROG_ID.0, prog_layout.clone()), + ]); + if include_ram_sidecar { + mem_layouts.insert( + RAM_ID.0, + PlainMemLayout { + k: ram_k, + d: ram_d, + n_side: 2, + lanes: 1, + }, + ); + } + + let inferred_shout_ops = infer_required_trace_shout_opcodes(&program); + let shout_ops = match &self.shout_ops { + Some(override_ops) => { + let missing: HashSet = inferred_shout_ops + .difference(override_ops) + .copied() + .collect(); + if !missing.is_empty() { + let mut missing_names: Vec = missing.into_iter().map(|op| format!("{op:?}")).collect(); + missing_names.sort_unstable(); + return Err(PiCcsError::InvalidInput(format!( + "trace shout_ops override must be a superset of required opcodes; missing [{}]", + missing_names.join(", ") + ))); + } + override_ops.clone() + } + None => inferred_shout_ops, + }; + + let mut table_specs = rv32_trace_table_specs(&shout_ops); + let mut base_shout_table_ids: Vec = table_specs.keys().copied().collect(); + base_shout_table_ids.sort_unstable(); + for (&table_id, spec) in &self.extra_lut_table_specs { + if table_specs.contains_key(&table_id) { + return Err(PiCcsError::InvalidInput(format!( + "extra_lut_table_spec collides with existing table_id={table_id}" + ))); + } + table_specs.insert(table_id, spec.clone()); + } + + let decode_layout = Rv32DecodeSidecarLayout::new(); + let decode_lookup_bus_specs: Vec = { + let decode_lookup_cols = rv32_decode_lookup_backed_cols(&decode_layout); + decode_lookup_cols + .iter() + .copied() + .map(|col_id| TraceShoutBusSpec { + table_id: rv32_decode_lookup_table_id_for_col(col_id), + ell_addr: prog_layout.d, + n_vals: 1usize, + }) + .collect() + }; + + let include_width_lookup = program_requires_width_lookup(&program); + let width_layout = Rv32WidthSidecarLayout::new(); + let width_lookup_bus_specs: Vec = if include_width_lookup && width_lookup_addr_d > 0 { + let width_lookup_cols = rv32_width_lookup_backed_cols(&width_layout); + width_lookup_cols + .iter() + .copied() + .map(|col_id| TraceShoutBusSpec { + table_id: rv32_width_lookup_table_id_for_col(col_id), + ell_addr: width_lookup_addr_d, + n_vals: 1usize, + }) + .collect() + } else { + Vec::new() + }; + + let mut all_extra_shout_specs = self.extra_shout_bus_specs.clone(); + all_extra_shout_specs.extend(decode_lookup_bus_specs); + all_extra_shout_specs.extend(width_lookup_bus_specs); + + let mut layout = Rv32TraceCcsLayout::new(step_rows) + .map_err(|e| PiCcsError::InvalidInput(format!("Rv32TraceCcsLayout::new failed: {e}")))?; + + let ccs_reserved_rows; + { + let (bus_region_len, reserved_rows) = rv32_trace_shared_bus_requirements_with_specs( + &layout, + &base_shout_table_ids, + &all_extra_shout_specs, + &mem_layouts, + ) + .map_err(|e| { + PiCcsError::InvalidInput(format!("rv32_trace_shared_bus_requirements_with_specs failed: {e}")) + })?; + layout.m = layout + .m + .checked_add(bus_region_len) + .ok_or_else(|| PiCcsError::InvalidInput("trace layout m overflow after bus tail reservation".into()))?; + ccs_reserved_rows = reserved_rows; + } + + let ccs_base = if ccs_reserved_rows == 0 { + build_rv32_trace_wiring_ccs(&layout) + .map_err(|e| PiCcsError::InvalidInput(format!("build_rv32_trace_wiring_ccs failed: {e}")))? + } else { + build_rv32_trace_wiring_ccs_with_reserved_rows(&layout, ccs_reserved_rows) + .map_err(|e| PiCcsError::InvalidInput(format!("build_rv32_trace_wiring_ccs_with_reserved_rows failed: {e}")))? + }; + + let session = FoldingSession::::new_ajtai(self.mode.clone(), &ccs_base)?; + + let decode_lookup_tables = build_rv32_decode_lookup_tables(&prog_layout, &prog_init_words); + let mut lut_tables = decode_lookup_tables; + if include_width_lookup && width_lookup_addr_d > 0 { + let width_k = 1usize.checked_shl(width_lookup_addr_d as u32).ok_or_else(|| { + PiCcsError::InvalidInput(format!("width lookup addr width too large: d={width_lookup_addr_d}")) + })?; + let width_lookup_cols = rv32_width_lookup_backed_cols(&width_layout); + for &col_id in width_lookup_cols.iter() { + let table_id = rv32_width_lookup_table_id_for_col(col_id); + lut_tables.insert( + table_id, + LutTable { + table_id, + k: width_k, + d: width_lookup_addr_d, + n_side: 2, + content: vec![F::ZERO; width_k], + }, + ); + } + } + + let mut cpu = R1csCpu::new( + ccs_base, + session.params().clone(), + session.committer().clone(), + layout.m_in, + &lut_tables, + &table_specs, + rv32_trace_chunk_to_witness(layout.clone()), + ) + .map_err(|e| PiCcsError::InvalidInput(format!("R1csCpu::new failed: {e}")))?; + + let mut prog_init_pairs: Vec<(u64, F)> = prog_init_words + .into_iter() + .filter_map(|((mem_id, addr), value)| (mem_id == PROG_ID.0 && value != F::ZERO).then_some((addr, value))) + .collect(); + prog_init_pairs.sort_by_key(|(addr, _)| *addr); + let mut initial_mem: HashMap<(u32, u64), F> = HashMap::new(); + for &(addr, value) in &prog_init_pairs { + if value != F::ZERO { + initial_mem.insert((PROG_ID.0, addr), value); + } + } + for (®, &value) in ®_init_map { + let v = F::from_u64(value as u32 as u64); + if v != F::ZERO { + initial_mem.insert((REG_ID.0, reg), v); + } + } + for (&addr, &value) in &ram_init_map { + let v = F::from_u64(value as u32 as u64); + if v != F::ZERO { + initial_mem.insert((RAM_ID.0, addr), v); + } + } + + cpu = cpu + .with_shared_cpu_bus( + rv32_trace_shared_cpu_bus_config_with_specs( + &layout, + &base_shout_table_ids, + &all_extra_shout_specs, + mem_layouts.clone(), + initial_mem.clone(), + ) + .map_err(|e| { + PiCcsError::InvalidInput(format!("rv32_trace_shared_cpu_bus_config_with_specs failed: {e}")) + })?, + layout.t, + ) + .map_err(|e| PiCcsError::InvalidInput(format!("shared bus inject failed: {e}")))?; + + let ccs = cpu.ccs.clone(); + + let mut session = FoldingSession::::new_ajtai(self.mode, &ccs)?; + session.set_step_linking(StepLinkingConfig::new(vec![(layout.pc_final, layout.pc0)])); + + let output_binding_cfg = if output_claims.is_empty() { + None + } else { + let out_mem_id = match output_target { + OutputTarget::Ram => RAM_ID.0, + OutputTarget::Reg => REG_ID.0, + }; + let out_layout = mem_layouts.get(&out_mem_id).ok_or_else(|| { + PiCcsError::InvalidInput(format!( + "output binding: missing PlainMemLayout for mem_id={out_mem_id}" + )) + })?; + let expected_k = 1usize + .checked_shl(out_layout.d as u32) + .ok_or_else(|| PiCcsError::InvalidInput("output binding: 2^d overflow".into()))?; + if out_layout.k != expected_k { + return Err(PiCcsError::InvalidInput(format!( + "output binding: mem_id={out_mem_id} has k={}, but expected 2^d={} (d={})", + out_layout.k, expected_k, out_layout.d + ))); + } + let mut mem_ids: Vec = mem_layouts.keys().copied().collect(); + mem_ids.sort_unstable(); + let mem_idx = mem_ids + .iter() + .position(|&id| id == out_mem_id) + .ok_or_else(|| PiCcsError::InvalidInput("output binding: mem_id not in mem_layouts".into()))?; + Some(OutputBindingConfig::new(out_layout.d, output_claims).with_mem_idx(mem_idx)) + }; + + Ok(Rv32TraceWiringVerifier { + session, + ccs, + _layout: layout, + mem_layouts, + statement_initial_mem: initial_mem, + output_binding_cfg, + }) + } } diff --git a/crates/neo-fold/src/session.rs b/crates/neo-fold/src/session.rs index 9f7af36a..71b1d9f5 100644 --- a/crates/neo-fold/src/session.rs +++ b/crates/neo-fold/src/session.rs @@ -596,7 +596,7 @@ where mode: FoldingMode, params: NeoParams, l: L, - commit_m: Option, + pub(crate) commit_m: Option, mixers: CommitMixers], &[Cmt]) -> Cmt, fn(&[Cmt], u32) -> Cmt>, // Cached CCS preprocessing for proving (best-effort reuse). @@ -1253,6 +1253,30 @@ where Ok(()) } + /// Preload prover-side CCS preprocessing with a precomputed matrix digest, skipping the + /// expensive `digest_ccs_matrices` call (~1.5s for large circuits). + pub fn preload_ccs_sparse_cache_with_digest( + &mut self, + s: &CcsStructure, + ccs_sparse_cache: Arc>, + ccs_mat_digest: Vec, + ) -> Result<(), PiCcsError> { + let src_ptr = (s as *const CcsStructure) as usize; + if ccs_sparse_cache.len() != s.t() { + return Err(PiCcsError::InvalidInput(format!( + "SparseCache matrix count mismatch: cache has {}, CCS has {}", + ccs_sparse_cache.len(), + s.t() + ))); + } + let ctx = ShardProverContext { + ccs_mat_digest, + ccs_sparse_cache: Some(ccs_sparse_cache), + }; + self.prover_ctx = Some(SessionCcsCache { src_ptr, ctx }); + Ok(()) + } + /// Preload verifier-side CCS preprocessing (sparse-cache + matrix-digest). /// /// This does **not** affect proving. It exists so benchmarks can model a verifier that has @@ -1266,6 +1290,29 @@ where Ok(()) } + /// Preload verifier-side CCS preprocessing with a precomputed matrix digest. + pub fn preload_verifier_ccs_sparse_cache_with_digest( + &mut self, + s: &CcsStructure, + ccs_sparse_cache: Arc>, + ccs_mat_digest: Vec, + ) -> Result<(), PiCcsError> { + let src_ptr = (s as *const CcsStructure) as usize; + if ccs_sparse_cache.len() != s.t() { + return Err(PiCcsError::InvalidInput(format!( + "SparseCache matrix count mismatch: cache has {}, CCS has {}", + ccs_sparse_cache.len(), + s.t() + ))); + } + let ctx = ShardProverContext { + ccs_mat_digest, + ccs_sparse_cache: Some(ccs_sparse_cache), + }; + self.verifier_ctx = Some(SessionCcsCache { src_ptr, ctx }); + Ok(()) + } + /// Fold and prove: run folding over all collected steps and return a `FoldRun`. /// This is where the actual cryptographic work happens (Π_CCS → RLC → DEC for each step). /// This method manages the transcript internally for ease of use. @@ -1746,6 +1793,327 @@ where Ok(true) } + /// Verify a proof using externally supplied step instance bundles. + /// + /// Unlike [`verify`] and [`verify_collected`] which use the session's internal + /// collected steps (populated by execution), this method takes the step bundles + /// directly. This enables **verify-only** flows where the verifier never + /// executes the program -- the steps come from the proof package. + pub fn verify_with_external_steps( + &self, + s: &CcsStructure, + steps_public: &[StepInstanceBundle], + run: &FoldRun, + ) -> Result { + let mut tr = Poseidon2Transcript::new(b"neo.fold/session"); + self.verify_with_external_steps_transcript(&mut tr, s, steps_public, run) + } + + /// Like [`verify_with_external_steps`] but with output binding. + pub fn verify_with_external_steps_and_output_binding( + &self, + s: &CcsStructure, + steps_public: &[StepInstanceBundle], + run: &FoldRun, + ob_cfg: &crate::output_binding::OutputBindingConfig, + ) -> Result { + let mut tr = Poseidon2Transcript::new(b"neo.fold/session"); + self.verify_with_external_steps_and_output_binding_transcript(&mut tr, s, steps_public, run, ob_cfg) + } + + /// Internal: verify with external steps + transcript. + fn verify_with_external_steps_transcript( + &self, + tr: &mut Poseidon2Transcript, + s: &CcsStructure, + steps_public: &[StepInstanceBundle], + run: &FoldRun, + ) -> Result { + let src_ptr = (s as *const CcsStructure) as usize; + let verifier_cache = self + .verifier_ctx + .as_ref() + .filter(|cache| cache.src_ptr == src_ptr); + let verifier_ctx = verifier_cache.map(|cache| &cache.ctx); + + let m_in_steps = steps_public + .first() + .map(|inst| inst.mcs_inst.m_in) + .unwrap_or(0); + if !steps_public + .iter() + .all(|inst| inst.mcs_inst.m_in == m_in_steps) + { + return Err(PiCcsError::InvalidInput("all steps must share the same m_in".into())); + } + let s_prepared = self.prepared_ccs_for_accumulator(s)?; + + let seed_me: &[MeInstance] = match &self.acc0 { + Some(acc) => { + acc.check(&self.params, s_prepared)?; + let acc_m_in = acc.me.first().map(|m| m.m_in).unwrap_or(m_in_steps); + if acc_m_in != m_in_steps { + return Err(PiCcsError::InvalidInput( + "initial Accumulator.m_in must match steps' m_in".into(), + )); + } + &acc.me + } + None => &[], + }; + + let step_linking = self + .step_linking + .as_ref() + .filter(|cfg| !cfg.prev_next_equalities.is_empty()); + + let outputs = if steps_public.len() > 1 { + match step_linking { + Some(cfg) => match verifier_ctx { + Some(ctx) => shard::fold_shard_verify_with_step_linking_with_context( + self.mode.clone(), + tr, + &self.params, + s, + steps_public, + seed_me, + run, + self.mixers, + cfg, + ctx, + )?, + None => shard::fold_shard_verify_with_step_linking( + self.mode.clone(), + tr, + &self.params, + s, + steps_public, + seed_me, + run, + self.mixers, + cfg, + )?, + }, + None if self.allow_unlinked_steps => match verifier_ctx { + Some(ctx) => shard::fold_shard_verify_with_context( + self.mode.clone(), + tr, + &self.params, + s, + steps_public, + seed_me, + run, + self.mixers, + ctx, + )?, + None => shard::fold_shard_verify( + self.mode.clone(), + tr, + &self.params, + s, + steps_public, + seed_me, + run, + self.mixers, + )?, + }, + None => { + let mut msg = + "multi-step verification requires step linking; call FoldingSession::set_step_linking(...)" + .to_string(); + if let Some(diag) = &self.auto_step_linking_error { + msg.push_str(&format!(" (auto step-linking from StepSpec failed: {diag})")); + } + return Err(PiCcsError::InvalidInput(msg)); + } + } + } else { + match verifier_ctx { + Some(ctx) => shard::fold_shard_verify_with_context( + self.mode.clone(), + tr, + &self.params, + s, + steps_public, + seed_me, + run, + self.mixers, + ctx, + )?, + None => shard::fold_shard_verify( + self.mode.clone(), + tr, + &self.params, + s, + steps_public, + seed_me, + run, + self.mixers, + )?, + } + }; + + // Detect twist/shout from the provided steps (self.steps is empty for verify-only). + let has_twist_or_shout = steps_public.iter().any(|s| !s.mem_insts.is_empty()) + || steps_public.iter().any(|s| !s.lut_insts.is_empty()); + if !has_twist_or_shout && !outputs.obligations.val.is_empty() { + return Err(PiCcsError::ProtocolError( + "CCS-only session verification produced unexpected val-lane obligations".into(), + )); + } + Ok(true) + } + + /// Internal: verify with external steps + output binding + transcript. + fn verify_with_external_steps_and_output_binding_transcript( + &self, + tr: &mut Poseidon2Transcript, + s: &CcsStructure, + steps_public: &[StepInstanceBundle], + run: &FoldRun, + ob_cfg: &crate::output_binding::OutputBindingConfig, + ) -> Result { + let src_ptr = (s as *const CcsStructure) as usize; + let verifier_cache = self + .verifier_ctx + .as_ref() + .filter(|cache| cache.src_ptr == src_ptr); + let verifier_ctx = verifier_cache.map(|cache| &cache.ctx); + + let m_in_steps = steps_public + .first() + .map(|inst| inst.mcs_inst.m_in) + .unwrap_or(0); + if !steps_public + .iter() + .all(|inst| inst.mcs_inst.m_in == m_in_steps) + { + return Err(PiCcsError::InvalidInput("all steps must share the same m_in".into())); + } + let s_prepared = self.prepared_ccs_for_accumulator(s)?; + + let seed_me: &[MeInstance] = match &self.acc0 { + Some(acc) => { + acc.check(&self.params, s_prepared)?; + let acc_m_in = acc.me.first().map(|m| m.m_in).unwrap_or(m_in_steps); + if acc_m_in != m_in_steps { + return Err(PiCcsError::InvalidInput( + "initial Accumulator.m_in must match steps' m_in".into(), + )); + } + &acc.me + } + None => &[], + }; + + let step_linking = self + .step_linking + .as_ref() + .filter(|cfg| !cfg.prev_next_equalities.is_empty()); + + let outputs = if steps_public.len() > 1 { + match step_linking { + Some(cfg) => match verifier_ctx { + Some(ctx) => shard::fold_shard_verify_with_output_binding_and_step_linking_with_context( + self.mode.clone(), + tr, + &self.params, + s, + steps_public, + seed_me, + run, + self.mixers, + ob_cfg, + cfg, + ctx, + )?, + None => shard::fold_shard_verify_with_output_binding_and_step_linking( + self.mode.clone(), + tr, + &self.params, + s, + steps_public, + seed_me, + run, + self.mixers, + ob_cfg, + cfg, + )?, + }, + None if self.allow_unlinked_steps => match verifier_ctx { + Some(ctx) => shard::fold_shard_verify_with_output_binding_with_context( + self.mode.clone(), + tr, + &self.params, + s, + steps_public, + seed_me, + run, + self.mixers, + ob_cfg, + ctx, + )?, + None => shard::fold_shard_verify_with_output_binding( + self.mode.clone(), + tr, + &self.params, + s, + steps_public, + seed_me, + run, + self.mixers, + ob_cfg, + )?, + }, + None => { + let mut msg = + "multi-step verification with output binding requires step linking" + .to_string(); + if let Some(diag) = &self.auto_step_linking_error { + msg.push_str(&format!(" (auto step-linking from StepSpec failed: {diag})")); + } + return Err(PiCcsError::InvalidInput(msg)); + } + } + } else { + match verifier_ctx { + Some(ctx) => shard::fold_shard_verify_with_output_binding_with_context( + self.mode.clone(), + tr, + &self.params, + s, + steps_public, + seed_me, + run, + self.mixers, + ob_cfg, + ctx, + )?, + None => shard::fold_shard_verify_with_output_binding( + self.mode.clone(), + tr, + &self.params, + s, + steps_public, + seed_me, + run, + self.mixers, + ob_cfg, + )?, + } + }; + + let has_twist_or_shout = !steps_public.is_empty() + && (steps_public.iter().any(|s| !s.lut_insts.is_empty()) + || steps_public.iter().any(|s| !s.mem_insts.is_empty())); + if !has_twist_or_shout && !outputs.obligations.val.is_empty() { + return Err(PiCcsError::ProtocolError( + "CCS-only session verification produced unexpected val-lane obligations".into(), + )); + } + Ok(true) + } + /// Verify with output binding, managing the transcript internally. pub fn verify_with_output_binding_simple( &self, diff --git a/crates/neo-fold/src/shard/verifier_and_api.rs b/crates/neo-fold/src/shard/verifier_and_api.rs index e2e37eee..261a8f9a 100644 --- a/crates/neo-fold/src/shard/verifier_and_api.rs +++ b/crates/neo-fold/src/shard/verifier_and_api.rs @@ -872,7 +872,9 @@ where .len() .checked_sub(1) .ok_or_else(|| PiCcsError::ProtocolError("missing inc_total claim".into()))?; - if step_proof.batched_time.labels.get(inc_idx).copied() != Some(crate::output_binding::OB_INC_TOTAL_LABEL) { + if step_proof.batched_time.labels.get(inc_idx).map(|label| label.as_slice()) + != Some(crate::output_binding::OB_INC_TOTAL_LABEL) + { return Err(PiCcsError::ProtocolError("output binding claim not last".into())); } diff --git a/crates/neo-fold/src/shard_proof_types.rs b/crates/neo-fold/src/shard_proof_types.rs index 261b7cb8..789b6c5a 100644 --- a/crates/neo-fold/src/shard_proof_types.rs +++ b/crates/neo-fold/src/shard_proof_types.rs @@ -19,7 +19,7 @@ pub type ShoutProofK = neo_memory::shout::ShoutProof; /// Within each group, when a Shout lane is provably inactive for a step (no lookups), we can /// skip its address-domain sumcheck entirely. We still bind all `claimed_sums` to the transcript, /// but we include sumcheck rounds only for the active subset. -#[derive(Clone, Debug)] +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] pub struct ShoutAddrPreProof { /// Claimed sums per Shout lane. /// @@ -32,7 +32,7 @@ pub struct ShoutAddrPreProof { pub groups: Vec>, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] pub struct ShoutAddrPreGroupProof { /// Address-bit width (sumcheck round count) for this group. pub ell_addr: u32, @@ -58,7 +58,7 @@ impl Default for ShoutAddrPreProof { } /// One fold step’s artifacts (Π_CCS → Π_RLC → Π_DEC). -#[derive(Clone, Debug)] +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] pub struct FoldStep { /// Π_CCS outputs (k ME(b,L) instances) pub ccs_out: Vec>, @@ -125,13 +125,17 @@ pub struct ShardFoldWitnesses { pub val_lane_wits: Vec>, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] pub enum MemOrLutProof { Twist(TwistProofK), Shout(ShoutProofK), } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] +#[serde(bound( + serialize = "C: serde::Serialize, FF: serde::Serialize, KK: serde::Serialize", + deserialize = "C: serde::de::DeserializeOwned, FF: serde::de::DeserializeOwned, KK: serde::de::DeserializeOwned + Default" +))] pub struct MemSidecarProof { /// ME claims evaluated at `r_val` (Twist val-eval terminal point). /// @@ -150,20 +154,20 @@ pub struct MemSidecarProof { /// /// This batches CCS (row/time rounds) with Twist/Shout time-domain oracles so all /// protocols share the same transcript-derived `r` (enabling Π_RLC folding). -#[derive(Clone, Debug)] +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] pub struct BatchedTimeProof { /// Claimed sums per participating oracle (in the same order as `round_polys`). pub claimed_sums: Vec, /// Degree bounds per participating oracle. pub degree_bounds: Vec, /// Domain-separation labels per participating oracle. - pub labels: Vec<&'static [u8]>, + pub labels: Vec>, /// Per-claim sum-check messages: `round_polys[claim][round] = coeffs`. pub round_polys: Vec>>, } /// Proof data for a standalone Π_RLC → Π_DEC lane (no Π_CCS). -#[derive(Clone, Debug)] +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] pub struct RlcDecProof { /// RLC mixing matrices ρ_i ∈ S ⊆ F^{D×D} pub rlc_rhos: Vec>, @@ -173,7 +177,7 @@ pub struct RlcDecProof { pub dec_children: Vec>, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] pub struct StepProof { pub fold: FoldStep, pub mem: MemSidecarProof, @@ -188,7 +192,7 @@ pub struct StepProof { pub wp_fold: Vec, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] pub struct ShardProof { pub steps: Vec, /// Optional output binding proof (proves final memory matches claimed outputs). diff --git a/crates/neo-fold/tests/riscv_build_verifier_statement_memory.rs b/crates/neo-fold/tests/riscv_build_verifier_statement_memory.rs new file mode 100644 index 00000000..23efdb43 --- /dev/null +++ b/crates/neo-fold/tests/riscv_build_verifier_statement_memory.rs @@ -0,0 +1,147 @@ +use neo_fold::riscv_trace_shard::Rv32TraceWiring; +use neo_fold::PiCcsError; +use neo_math::F; +use neo_memory::riscv::lookups::{encode_program, RiscvInstruction, RiscvOpcode}; +use p3_field::PrimeCharacteristicRing; + +fn program_bytes_with_seed(seed: i32) -> Vec { + let program = vec![ + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 1, + rs1: 0, + imm: seed, + }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 2, + rs1: 1, + imm: 1, + }, + RiscvInstruction::Halt, + ]; + encode_program(&program) +} + +fn assert_m_in_mismatch_rejected(err: PiCcsError) { + match err { + PiCcsError::InvalidInput(msg) => { + assert!( + msg.contains("all steps must share the same m_in"), + "unexpected invalid-input error: {msg}" + ); + } + other => panic!("unexpected error kind: {other:?}"), + } +} + +#[test] +fn rv32_trace_build_verifier_binds_statement_memory_to_chunk0_mem_init() { + let program_a = program_bytes_with_seed(7); + let program_b = program_bytes_with_seed(9); + + let mut run_a = Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_a) + .chunk_rows(1) + .min_trace_len(1) + .prove() + .expect("prove statement A"); + run_a.verify().expect("self-verify statement A"); + + let proof_a = run_a.proof().clone(); + let steps_public_a = run_a.steps_public(); + let step_rows = run_a.step_rows(); + let ram_d = run_a.ram_d(); + let width_lookup_addr_d = run_a.width_lookup_addr_d(); + + let verifier_a = Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_a) + .chunk_rows(1) + .min_trace_len(1) + .build_verifier(step_rows, ram_d, width_lookup_addr_d) + .expect("build verifier for statement A"); + let ok = verifier_a + .verify(&proof_a, &steps_public_a) + .expect("matching statement should verify"); + assert!(ok, "matching statement should verify"); + + let verifier_b = Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program_b) + .chunk_rows(1) + .min_trace_len(1) + .build_verifier(step_rows, ram_d, width_lookup_addr_d) + .expect("build verifier for statement B"); + let err = verifier_b + .verify(&proof_a, &steps_public_a) + .expect_err("different statement memory must be rejected"); + + match err { + PiCcsError::InvalidInput(msg) => { + assert!( + msg.contains("chunk0 MemInstance.init mismatch"), + "unexpected invalid-input error: {msg}" + ); + } + other => panic!("unexpected error kind: {other:?}"), + } +} + +#[test] +fn rv32_trace_build_verifier_rejects_external_steps_with_non_uniform_m_in() { + let program = program_bytes_with_seed(11); + + let mut run = Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program) + .chunk_rows(1) + .min_trace_len(1) + .prove() + .expect("prove statement"); + run.verify().expect("self-verify statement"); + + let proof = run.proof().clone(); + let mut steps_public = run.steps_public(); + assert!( + steps_public.len() > 1, + "test needs at least two steps to create m_in mismatch" + ); + steps_public[1].mcs_inst.m_in += 1; + + let verifier = Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program) + .chunk_rows(1) + .min_trace_len(1) + .build_verifier(run.step_rows(), run.ram_d(), run.width_lookup_addr_d()) + .expect("build verifier"); + let err = verifier + .verify(&proof, &steps_public) + .expect_err("non-uniform m_in must be rejected"); + assert_m_in_mismatch_rejected(err); +} + +#[test] +fn rv32_trace_build_verifier_output_binding_rejects_external_steps_with_non_uniform_m_in() { + let program = program_bytes_with_seed(13); + + let mut run = Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program) + .chunk_rows(1) + .min_trace_len(1) + .output(/*output_addr=*/ 0, /*expected_output=*/ F::ZERO) + .prove() + .expect("prove statement with output binding"); + run.verify() + .expect("self-verify statement with output binding"); + + let proof = run.proof().clone(); + let mut steps_public = run.steps_public(); + assert!( + steps_public.len() > 1, + "test needs at least two steps to create m_in mismatch" + ); + steps_public[1].mcs_inst.m_in += 1; + + let verifier = Rv32TraceWiring::from_rom(/*program_base=*/ 0, &program) + .chunk_rows(1) + .min_trace_len(1) + .output(/*output_addr=*/ 0, /*expected_output=*/ F::ZERO) + .build_verifier(run.step_rows(), run.ram_d(), run.width_lookup_addr_d()) + .expect("build verifier with output binding"); + let err = verifier + .verify(&proof, &steps_public) + .expect_err("non-uniform m_in must be rejected"); + assert_m_in_mismatch_rejected(err); +} diff --git a/crates/neo-fold/tests/riscv_poseidon2_ecall_prove_verify.rs b/crates/neo-fold/tests/riscv_poseidon2_ecall_prove_verify.rs new file mode 100644 index 00000000..1f325bc9 --- /dev/null +++ b/crates/neo-fold/tests/riscv_poseidon2_ecall_prove_verify.rs @@ -0,0 +1,93 @@ +use neo_fold::riscv_trace_shard::Rv32TraceWiring; +use neo_memory::riscv::lookups::{ + encode_program, RiscvInstruction, RiscvOpcode, POSEIDON2_ECALL_NUM, POSEIDON2_READ_ECALL_NUM, +}; + +fn load_u32_imm(rd: u8, value: u32) -> Vec { + let upper = ((value as i64 + 0x800) >> 12) as i32; + let lower = (value as i32) - (upper << 12); + vec![ + RiscvInstruction::Lui { rd, imm: upper }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd, + rs1: rd, + imm: lower, + }, + ] +} + +fn poseidon2_ecall_program() -> Vec { + let mut program = Vec::new(); + + // a1 = 0 (n_elements = 0 for empty-input Poseidon2 hash). + program.push(RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 11, + rs1: 0, + imm: 0, + }); + + // a0 = POSEIDON2_ECALL_NUM -> compute ECALL. + program.extend(load_u32_imm(10, POSEIDON2_ECALL_NUM)); + program.push(RiscvInstruction::Halt); + + // Read all 8 digest words via read ECALLs. + for _ in 0..8 { + program.extend(load_u32_imm(10, POSEIDON2_READ_ECALL_NUM)); + program.push(RiscvInstruction::Halt); + } + + // Clear a0 -> final halt. + program.push(RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 10, + rs1: 0, + imm: 0, + }); + program.push(RiscvInstruction::Halt); + + program +} + +#[test] +fn rv32_trace_prove_verify_poseidon2_ecall_chunk1() { + let program = poseidon2_ecall_program(); + let program_bytes = encode_program(&program); + + let mut run = Rv32TraceWiring::from_rom(0, &program_bytes) + .chunk_rows(1) + .max_steps(program.len() + 64) + .prove() + .expect("prove should succeed"); + + run.verify().expect("verify should succeed for Poseidon2 ECALL with chunk_size=1"); +} + +#[test] +fn rv32_trace_prove_verify_poseidon2_ecall_chunk4() { + let program = poseidon2_ecall_program(); + let program_bytes = encode_program(&program); + + let mut run = Rv32TraceWiring::from_rom(0, &program_bytes) + .chunk_rows(4) + .max_steps(program.len() + 64) + .prove() + .expect("prove should succeed"); + + run.verify().expect("verify should succeed for Poseidon2 ECALL with chunk_size=4"); +} + +#[test] +fn rv32_trace_prove_verify_poseidon2_ecall_chunk32() { + let program = poseidon2_ecall_program(); + let program_bytes = encode_program(&program); + + let mut run = Rv32TraceWiring::from_rom(0, &program_bytes) + .chunk_rows(32) + .max_steps(program.len() + 64) + .prove() + .expect("prove should succeed"); + + run.verify().expect("verify should succeed for Poseidon2 ECALL with chunk_size=32"); +} diff --git a/crates/neo-fold/tests/suites/perf/riscv_poseidon2_ecall_perf.rs b/crates/neo-fold/tests/suites/perf/riscv_poseidon2_ecall_perf.rs new file mode 100644 index 00000000..ff2f4fe7 --- /dev/null +++ b/crates/neo-fold/tests/suites/perf/riscv_poseidon2_ecall_perf.rs @@ -0,0 +1,292 @@ +#![allow(non_snake_case)] + +use std::time::Duration; + +use neo_fold::riscv_trace_shard::{Rv32TraceWiring, Rv32TraceWiringRun}; +use neo_memory::riscv::lookups::{ + encode_program, RiscvInstruction, RiscvOpcode, POSEIDON2_ECALL_NUM, POSEIDON2_READ_ECALL_NUM, +}; + +fn load_u32_imm(rd: u8, value: u32) -> Vec { + let upper = ((value as i64 + 0x800) >> 12) as i32; + let lower = (value as i32) - (upper << 12); + vec![ + RiscvInstruction::Lui { rd, imm: upper }, + RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd, + rs1: rd, + imm: lower, + }, + ] +} + +fn poseidon2_ecall_program(n_hashes: usize) -> Vec { + let mut program = Vec::new(); + + for _ in 0..n_hashes { + // a1 = 0 (n_elements = 0 for empty-input Poseidon2 hash). + program.push(RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 11, + rs1: 0, + imm: 0, + }); + + // a0 = POSEIDON2_ECALL_NUM -> compute ECALL. + program.extend(load_u32_imm(10, POSEIDON2_ECALL_NUM)); + program.push(RiscvInstruction::Halt); + + // Read all 8 digest words via read ECALLs. + for _ in 0..8 { + program.extend(load_u32_imm(10, POSEIDON2_READ_ECALL_NUM)); + program.push(RiscvInstruction::Halt); + } + } + + // Clear a0 -> final halt. + program.push(RiscvInstruction::IAlu { + op: RiscvOpcode::Add, + rd: 10, + rs1: 0, + imm: 0, + }); + program.push(RiscvInstruction::Halt); + + program +} + +fn run_once( + program_bytes: &[u8], + max_steps: usize, +) -> Result { + Rv32TraceWiring::from_rom(0, program_bytes) + .chunk_rows(max_steps) + .max_steps(max_steps) + .prove() +} + +#[derive(Clone, Copy, Debug)] +struct Stats { + min: Duration, + median: Duration, + mean: Duration, + max: Duration, +} + +fn summarize(samples: &[Duration]) -> Stats { + assert!(!samples.is_empty()); + let mut v = samples.to_vec(); + v.sort_unstable(); + let min = v[0]; + let max = v[v.len() - 1]; + let median = v[v.len() / 2]; + let mean_secs = v.iter().map(Duration::as_secs_f64).sum::() / v.len() as f64; + let mean = Duration::from_secs_f64(mean_secs); + Stats { + min, + median, + mean, + max, + } +} + +fn fmt_duration(d: Duration) -> String { + if d.as_secs_f64() < 1.0 { + format!("{:.3}ms", d.as_secs_f64() * 1000.0) + } else { + format!("{:.3}s", d.as_secs_f64()) + } +} + +fn env_usize(name: &str, default: usize) -> usize { + match std::env::var(name) { + Ok(v) => v.parse::().unwrap_or(default), + Err(_) => default, + } +} + +#[test] +#[ignore = "perf-style test: run with `P2_HASHES=1 cargo test -p neo-fold --release --test perf -- --ignored --nocapture rv32_trace_poseidon2_ecall_perf`"] +fn rv32_trace_poseidon2_ecall_perf() { + let n_hashes = env_usize("P2_HASHES", 1); + let warmups = env_usize("P2_WARMUPS", 1); + let samples = env_usize("P2_SAMPLES", 5); + assert!(n_hashes > 0, "P2_HASHES must be > 0"); + assert!(samples > 0, "P2_SAMPLES must be > 0"); + + let program = poseidon2_ecall_program(n_hashes); + let program_bytes = encode_program(&program); + let max_steps = program.len() + 64; + + for _ in 0..warmups { + let mut run = run_once(&program_bytes, max_steps).expect("warmup prove"); + run.verify().expect("warmup verify"); + } + + let mut prove_times = Vec::with_capacity(samples); + let mut verify_times = Vec::with_capacity(samples); + let mut end_to_end_times = Vec::with_capacity(samples); + let mut ccs_n = 0; + let mut ccs_m = 0; + let mut fold_count = 0; + let mut trace_len = 0; + + for _ in 0..samples { + let total_start = std::time::Instant::now(); + let mut run = run_once(&program_bytes, max_steps).expect("prove"); + run.verify().expect("verify"); + end_to_end_times.push(total_start.elapsed()); + prove_times.push(run.prove_duration()); + verify_times.push(run.verify_duration().unwrap_or(Duration::ZERO)); + + ccs_n = run.ccs_num_constraints(); + ccs_m = run.ccs_num_variables(); + fold_count = run.fold_count(); + trace_len = run.trace_len(); + } + + let prove = summarize(&prove_times); + let verify = summarize(&verify_times); + let e2e = summarize(&end_to_end_times); + + let sep = "=".repeat(96); + let thin = "-".repeat(96); + + println!(); + println!("{sep}"); + println!("POSEIDON2 ECALL BENCHMARK (RV32 TRACE)"); + println!("{sep}"); + println!( + "config: hashes={} instructions={} warmups={} samples={}", + n_hashes, + program.len(), + warmups, + samples + ); + println!( + "CCS: n={} (pow2={}) m={} (pow2={}) folds={} trace_len={}", + ccs_n, + ccs_n.next_power_of_two(), + ccs_m, + ccs_m.next_power_of_two(), + fold_count, + trace_len + ); + println!("{thin}"); + println!( + "{:>12} {:>10} {:>10} {:>10} {:>10}", + "phase", "min", "median", "mean", "max" + ); + println!("{thin}"); + println!( + "{:>12} {:>10} {:>10} {:>10} {:>10}", + "prove", + fmt_duration(prove.min), + fmt_duration(prove.median), + fmt_duration(prove.mean), + fmt_duration(prove.max), + ); + println!( + "{:>12} {:>10} {:>10} {:>10} {:>10}", + "verify", + fmt_duration(verify.min), + fmt_duration(verify.median), + fmt_duration(verify.mean), + fmt_duration(verify.max), + ); + println!( + "{:>12} {:>10} {:>10} {:>10} {:>10}", + "end-to-end", + fmt_duration(e2e.min), + fmt_duration(e2e.median), + fmt_duration(e2e.mean), + fmt_duration(e2e.max), + ); + println!("{thin}"); + if n_hashes > 0 { + let per_hash_prove = Duration::from_secs_f64(prove.median.as_secs_f64() / n_hashes as f64); + let per_hash_e2e = Duration::from_secs_f64(e2e.median.as_secs_f64() / n_hashes as f64); + println!( + "per-hash (median): prove={} end-to-end={}", + fmt_duration(per_hash_prove), + fmt_duration(per_hash_e2e), + ); + } + println!("{sep}"); + println!(); +} + +#[test] +#[ignore = "perf-style test: run with `cargo test -p neo-fold --release --test perf -- --ignored --nocapture rv32_trace_poseidon2_ecall_scaling`"] +fn rv32_trace_poseidon2_ecall_scaling() { + let samples = env_usize("P2_SAMPLES", 3); + let warmups = env_usize("P2_WARMUPS", 1); + let hash_counts = [1, 2, 4, 8]; + + let sep = "=".repeat(110); + let thin = "-".repeat(110); + + println!(); + println!("{sep}"); + println!( + "POSEIDON2 ECALL SCALING (RV32 TRACE) — warmups={} samples={}", + warmups, samples + ); + println!("{sep}"); + println!( + "{:>8} {:>8} {:>8} {:>8} {:>10} {:>10} {:>10} {:>10} {:>12}", + "hashes", "instrs", "ccs_n", "ccs_m", "prove_med", "verify_med", "e2e_med", "per_hash", "throughput" + ); + println!("{thin}"); + + for &n_hashes in &hash_counts { + let program = poseidon2_ecall_program(n_hashes); + let program_bytes = encode_program(&program); + let max_steps = program.len() + 64; + + for _ in 0..warmups { + let mut run = run_once(&program_bytes, max_steps).expect("warmup prove"); + run.verify().expect("warmup verify"); + } + + let mut prove_times = Vec::with_capacity(samples); + let mut verify_times = Vec::with_capacity(samples); + let mut e2e_times = Vec::with_capacity(samples); + let mut ccs_n = 0; + let mut ccs_m = 0; + + for _ in 0..samples { + let total_start = std::time::Instant::now(); + let mut run = run_once(&program_bytes, max_steps).expect("prove"); + run.verify().expect("verify"); + e2e_times.push(total_start.elapsed()); + prove_times.push(run.prove_duration()); + verify_times.push(run.verify_duration().unwrap_or(Duration::ZERO)); + ccs_n = run.ccs_num_constraints(); + ccs_m = run.ccs_num_variables(); + } + + let prove_med = summarize(&prove_times).median; + let verify_med = summarize(&verify_times).median; + let e2e_med = summarize(&e2e_times).median; + let per_hash = Duration::from_secs_f64(e2e_med.as_secs_f64() / n_hashes as f64); + let hashes_per_sec = n_hashes as f64 / e2e_med.as_secs_f64(); + + println!( + "{:>8} {:>8} {:>8} {:>8} {:>10} {:>10} {:>10} {:>10} {:>10.1} h/s", + n_hashes, + program.len(), + ccs_n, + ccs_m, + fmt_duration(prove_med), + fmt_duration(verify_med), + fmt_duration(e2e_med), + fmt_duration(per_hash), + hashes_per_sec, + ); + } + + println!("{sep}"); + println!(); +} diff --git a/crates/neo-fold/tests/suites/perf/single_addi_metrics_nightstream.rs b/crates/neo-fold/tests/suites/perf/single_addi_metrics_nightstream.rs index 4be0923e..ccbd01df 100644 --- a/crates/neo-fold/tests/suites/perf/single_addi_metrics_nightstream.rs +++ b/crates/neo-fold/tests/suites/perf/single_addi_metrics_nightstream.rs @@ -521,7 +521,7 @@ fn report_track_a_w0_w1_snapshot() { let mut other_claims = Vec::new(); for i in 0..bt.labels.len() { - let label = std::str::from_utf8(bt.labels[i]).unwrap_or(""); + let label = std::str::from_utf8(&bt.labels[i]).unwrap_or(""); let deg = bt.degree_bounds[i]; let entry = (label.to_string(), deg); if label.starts_with("ccs/") { diff --git a/crates/neo-fold/tests/suites/trace_twist/twist_shout_soundness.rs b/crates/neo-fold/tests/suites/trace_twist/twist_shout_soundness.rs index e02be54b..cfaa31db 100644 --- a/crates/neo-fold/tests/suites/trace_twist/twist_shout_soundness.rs +++ b/crates/neo-fold/tests/suites/trace_twist/twist_shout_soundness.rs @@ -338,7 +338,7 @@ fn tamper_batched_time_label_fails() { ) .expect("prove should succeed"); - proof.steps[0].batched_time.labels[0] = b"bad/label".as_slice(); + proof.steps[0].batched_time.labels[0] = b"bad/label".to_vec(); let mut tr_verify = Poseidon2Transcript::new(b"soundness/tamper-batched-time-label"); let steps_public = [StepInstanceBundle::from(&step_bundle)]; diff --git a/crates/neo-memory/src/builder.rs b/crates/neo-memory/src/builder.rs index 7ebcc9e1..a93a3570 100644 --- a/crates/neo-memory/src/builder.rs +++ b/crates/neo-memory/src/builder.rs @@ -69,6 +69,10 @@ pub struct ShardWitnessAux { pub mem_ids: Vec, /// Final sparse memory states at the end of the shard: mem_id -> (addr -> value), with zero cells omitted. pub final_mem_states: HashMap>, + /// Wall-clock time for VM trace collection (`trace_program`). + pub vm_trace_duration: std::time::Duration, + /// Wall-clock time for CPU witness building (`build_ccs_chunks`), includes decomp + commit. + pub cpu_witness_duration: std::time::Duration, } fn ell_from_pow2_n_side(n_side: usize) -> Result { @@ -238,9 +242,11 @@ where table_ids.dedup(); // 3) CPU arithmetization chunks. + let t_witness_start = std::time::Instant::now(); let mcss = cpu_arith .build_ccs_chunks(trace, chunk_size) .map_err(|e| ShardBuildError::CcsError(e.to_string()))?; + let cpu_witness_duration = t_witness_start.elapsed(); if mcss.len() != chunks_len { return Err(ShardBuildError::CcsError(format!( "cpu arithmetization returned {} chunks, expected {} (steps={}, chunk_size={})", @@ -427,6 +433,8 @@ where chunk_size, mem_ids, final_mem_states: mem_states, + vm_trace_duration: std::time::Duration::ZERO, + cpu_witness_duration, }; Ok((step_bundles, aux)) } diff --git a/crates/neo-memory/src/riscv/ccs/bus_bindings.rs b/crates/neo-memory/src/riscv/ccs/bus_bindings.rs index 28f60381..f86f22dc 100644 --- a/crates/neo-memory/src/riscv/ccs/bus_bindings.rs +++ b/crates/neo-memory/src/riscv/ccs/bus_bindings.rs @@ -13,7 +13,8 @@ use crate::plain::PlainMemLayout; use crate::riscv::lookups::{PROG_ID, RAM_ID, REG_ID}; use crate::riscv::trace::{ rv32_decode_lookup_table_id_for_col, rv32_is_decode_lookup_table_id, rv32_is_width_lookup_table_id, - rv32_trace_lookup_addr_group_for_table_id, rv32_trace_lookup_selector_group_for_table_id, Rv32DecodeSidecarLayout, + rv32_trace_lookup_addr_group_for_table_shape, rv32_trace_lookup_selector_group_for_table_id, + Rv32DecodeSidecarLayout, }; use super::constants::{ @@ -71,8 +72,8 @@ fn validate_trace_shout_table_id(table_id: u32) -> Result<(), String> { } #[inline] -fn trace_lookup_addr_group_for_table_id(table_id: u32) -> Option { - rv32_trace_lookup_addr_group_for_table_id(table_id) +fn trace_lookup_addr_group_for_table_shape(table_id: u32, ell_addr: usize) -> Option { + rv32_trace_lookup_addr_group_for_table_shape(table_id, ell_addr) } #[inline] @@ -103,7 +104,7 @@ fn derive_trace_shout_shapes( table_id, ell_addr: 2 * RV32_XLEN, n_vals: 1usize, - addr_group: trace_lookup_addr_group_for_table_id(table_id), + addr_group: trace_lookup_addr_group_for_table_shape(table_id, 2 * RV32_XLEN), selector_group: trace_lookup_selector_group_for_table_id(table_id), }, ); @@ -135,7 +136,7 @@ fn derive_trace_shout_shapes( spec.table_id, prev.n_vals, spec.n_vals )); } - let inferred_group = trace_lookup_addr_group_for_table_id(spec.table_id); + let inferred_group = trace_lookup_addr_group_for_table_shape(spec.table_id, spec.ell_addr); if prev.addr_group != inferred_group { return Err(format!( "RV32 trace shared bus: conflicting addr_group for table_id={} (base/spec mismatch: {:?} vs {:?})", @@ -156,7 +157,7 @@ fn derive_trace_shout_shapes( table_id: spec.table_id, ell_addr: spec.ell_addr, n_vals: spec.n_vals, - addr_group: trace_lookup_addr_group_for_table_id(spec.table_id), + addr_group: trace_lookup_addr_group_for_table_shape(spec.table_id, spec.ell_addr), selector_group: trace_lookup_selector_group_for_table_id(spec.table_id), }, ); diff --git a/crates/neo-memory/src/riscv/lookups/cpu.rs b/crates/neo-memory/src/riscv/lookups/cpu.rs index 6c315d7f..3752483b 100644 --- a/crates/neo-memory/src/riscv/lookups/cpu.rs +++ b/crates/neo-memory/src/riscv/lookups/cpu.rs @@ -1,10 +1,16 @@ use neo_vm_trace::{Shout, Twist}; +use p3_field::{PrimeCharacteristicRing, PrimeField64}; +use p3_goldilocks::Goldilocks; use super::bits::interleave_bits; use super::decode::decode_instruction; use super::encode::encode_instruction; use super::isa::{BranchCondition, RiscvInstruction, RiscvMemOp, RiscvOpcode}; use super::tables::RiscvShoutTables; +use super::{ + GL_ADD_ECALL_NUM, GL_MUL_ECALL_NUM, GL_READ_ECALL_NUM, GL_SUB_ECALL_NUM, + POSEIDON2_ECALL_NUM, POSEIDON2_READ_ECALL_NUM, +}; /// A RISC-V CPU that can be traced using Neo's VmCpu trait. /// @@ -24,6 +30,14 @@ pub struct RiscvCpu { program: Vec, /// Base address of the program. program_base: u64, + /// Pending Poseidon2 digest words (8 × u32, set by the compute ECALL). + poseidon2_pending: Option<[u32; 8]>, + /// Index into `poseidon2_pending` for the next read ECALL. + poseidon2_read_idx: usize, + /// Pending Goldilocks field operation result (2 × u32: lo, hi). + gl_pending: Option<[u32; 2]>, + /// Index into `gl_pending` for the next GL read ECALL. + gl_read_idx: usize, } impl RiscvCpu { @@ -37,6 +51,10 @@ impl RiscvCpu { halted: false, program: Vec::new(), program_base: 0, + poseidon2_pending: None, + poseidon2_read_idx: 0, + gl_pending: None, + gl_read_idx: 0, } } @@ -95,6 +113,110 @@ impl RiscvCpu { twist.store_lane(super::REG_ID, reg as u64, masked, /*lane=*/ 0); self.regs[reg as usize] = masked; } + + /// Execute the Poseidon2 hash compute precompile. + /// + /// Reads `n_elements` Goldilocks field elements from RAM at `input_addr` + /// using `load_untraced` (no Twist bus events), computes the Poseidon2- + /// Goldilocks sponge hash, and stores the 4-element digest as 8 u32 words + /// in internal CPU state (`poseidon2_pending`). The guest retrieves output + /// words one at a time via `POSEIDON2_READ_ECALL_NUM`. + fn handle_poseidon2_ecall>(&mut self, twist: &mut T) { + use neo_ccs::crypto::poseidon2_goldilocks::poseidon2_hash; + + let ram = super::RAM_ID; + let n_elements = self.get_reg(11) as u32; // a1 + let input_addr = self.get_reg(12); // a2 + + let mut inputs = Vec::with_capacity(n_elements as usize); + for i in 0..n_elements { + let addr = input_addr + (i as u64) * 8; + let lo = twist.load_untraced(ram, addr) & 0xFFFF_FFFF; + let hi = twist.load_untraced(ram, addr + 4) & 0xFFFF_FFFF; + let val = lo | (hi << 32); + inputs.push(Goldilocks::from_u64(val)); + } + + let digest = poseidon2_hash(&inputs); + + let mut words = [0u32; 8]; + for (i, &elem) in digest.iter().enumerate() { + let val = elem.as_canonical_u64(); + words[i * 2] = val as u32; + words[i * 2 + 1] = (val >> 32) as u32; + } + self.poseidon2_pending = Some(words); + self.poseidon2_read_idx = 0; + } + + /// Read one u32 word of the pending Poseidon2 digest into register a0. + /// + /// Called 8 times by the guest to retrieve the full 4-element digest. + /// Each call returns the next word and advances the internal read index. + /// Uses `store_untraced` so the next instruction sees the updated a0 + /// without generating a CCS-visible register write event. + fn handle_poseidon2_read_ecall>(&mut self, twist: &mut T) { + let words = self + .poseidon2_pending + .expect("poseidon2 read ECALL called without a pending digest"); + let idx = self.poseidon2_read_idx; + assert!(idx < 8, "poseidon2 read ECALL: all 8 words already consumed"); + let word = self.mask_value(words[idx] as u64); + self.regs[10] = word; + twist.store_untraced(super::REG_ID, 10, word); + self.poseidon2_read_idx = idx + 1; + } + + /// Goldilocks field multiply ECALL handler. + fn handle_gl_mul_ecall(&mut self) { + let (a, b) = self.read_gl_operands(); + let result = Goldilocks::from_u64(a) * Goldilocks::from_u64(b); + self.store_gl_result(result.as_canonical_u64()); + } + + /// Goldilocks field add ECALL handler. + fn handle_gl_add_ecall(&mut self) { + let (a, b) = self.read_gl_operands(); + let result = Goldilocks::from_u64(a) + Goldilocks::from_u64(b); + self.store_gl_result(result.as_canonical_u64()); + } + + /// Goldilocks field subtract ECALL handler. + fn handle_gl_sub_ecall(&mut self) { + let (a, b) = self.read_gl_operands(); + let result = Goldilocks::from_u64(a) - Goldilocks::from_u64(b); + self.store_gl_result(result.as_canonical_u64()); + } + + /// Read two 64-bit Goldilocks operands from registers a1..a4. + fn read_gl_operands(&self) -> (u64, u64) { + let a_lo = self.get_reg(11) as u32; // a1 + let a_hi = self.get_reg(12) as u32; // a2 + let b_lo = self.get_reg(13) as u32; // a3 + let b_hi = self.get_reg(14) as u32; // a4 + let a = (a_lo as u64) | ((a_hi as u64) << 32); + let b = (b_lo as u64) | ((b_hi as u64) << 32); + (a, b) + } + + /// Store a 64-bit GL result as two u32 words in CPU internal state. + fn store_gl_result(&mut self, val: u64) { + self.gl_pending = Some([val as u32, (val >> 32) as u32]); + self.gl_read_idx = 0; + } + + /// GL read ECALL handler: return next u32 word of the pending result in a0. + fn handle_gl_read_ecall>(&mut self, twist: &mut T) { + let words = self + .gl_pending + .expect("GL read ECALL called without a pending result"); + let idx = self.gl_read_idx; + assert!(idx < 2, "GL read ECALL: both words already consumed"); + let word = self.mask_value(words[idx] as u64); + self.regs[10] = word; + twist.store_untraced(super::REG_ID, 10, word); + self.gl_read_idx = idx + 1; + } } impl neo_vm_trace::VmCpu for RiscvCpu { @@ -590,8 +712,22 @@ impl neo_vm_trace::VmCpu for RiscvCpu { // === System Instructions === RiscvInstruction::Ecall => { - // ECALL - environment call (syscall). - self.handle_ecall(); + let call_id = self.get_reg(10) as u32; + if call_id == POSEIDON2_ECALL_NUM { + self.handle_poseidon2_ecall(twist); + } else if call_id == POSEIDON2_READ_ECALL_NUM { + self.handle_poseidon2_read_ecall(twist); + } else if call_id == GL_MUL_ECALL_NUM { + self.handle_gl_mul_ecall(); + } else if call_id == GL_ADD_ECALL_NUM { + self.handle_gl_add_ecall(); + } else if call_id == GL_SUB_ECALL_NUM { + self.handle_gl_sub_ecall(); + } else if call_id == GL_READ_ECALL_NUM { + self.handle_gl_read_ecall(twist); + } else { + self.handle_ecall(); + } } RiscvInstruction::Ebreak => { diff --git a/crates/neo-memory/src/riscv/lookups/mod.rs b/crates/neo-memory/src/riscv/lookups/mod.rs index 65866a8b..e0d5d71f 100644 --- a/crates/neo-memory/src/riscv/lookups/mod.rs +++ b/crates/neo-memory/src/riscv/lookups/mod.rs @@ -106,6 +106,47 @@ pub const PROG_ID: TwistId = TwistId(1); /// This is used by the RV32 trace-wiring circuit in "regfile-as-Twist" mode. pub const REG_ID: TwistId = TwistId(2); +/// Poseidon2-Goldilocks hash compute ECALL identifier. +/// +/// ABI: a0 = POSEIDON2_ECALL_NUM, a1 = input element count, +/// a2 = input RAM address (elements as 2×u32 LE). +/// The host reads inputs via untraced loads, computes the hash, and stores the +/// 4-element digest in CPU-internal state. Use POSEIDON2_READ_ECALL_NUM to +/// retrieve output words one at a time via register a0. +pub const POSEIDON2_ECALL_NUM: u32 = 0x504F53; + +/// Poseidon2-Goldilocks digest read ECALL identifier (bit 31 set). +/// +/// ABI: a0 = POSEIDON2_READ_ECALL_NUM. Returns the next u32 word of the +/// pending Poseidon2 digest in register a0. Call 8 times (4 elements × 2 words) +/// to retrieve the full digest. +pub const POSEIDON2_READ_ECALL_NUM: u32 = 0x80504F53; + +/// Goldilocks field multiply ECALL identifier ("GLM"). +/// +/// ABI: a0 = GL_MUL_ECALL_NUM, a1 = a_lo, a2 = a_hi, a3 = b_lo, a4 = b_hi. +/// Computes (a * b) mod p and stores the 64-bit result in CPU state. +/// Retrieve via GL_READ_ECALL_NUM (2 calls for lo/hi words). +pub const GL_MUL_ECALL_NUM: u32 = 0x474C4D; + +/// Goldilocks field add ECALL identifier ("GLA"). +/// +/// ABI: a0 = GL_ADD_ECALL_NUM, a1 = a_lo, a2 = a_hi, a3 = b_lo, a4 = b_hi. +/// Computes (a + b) mod p and stores the 64-bit result in CPU state. +pub const GL_ADD_ECALL_NUM: u32 = 0x474C41; + +/// Goldilocks field subtract ECALL identifier ("GLS"). +/// +/// ABI: a0 = GL_SUB_ECALL_NUM, a1 = a_lo, a2 = a_hi, a3 = b_lo, a4 = b_hi. +/// Computes (a - b) mod p and stores the 64-bit result in CPU state. +pub const GL_SUB_ECALL_NUM: u32 = 0x474C53; + +/// Goldilocks field operation read ECALL identifier (bit 31 set on "GLR"). +/// +/// ABI: a0 = GL_READ_ECALL_NUM. Returns the next u32 word of the +/// pending field operation result in register a0. Call 2 times (lo/hi). +pub const GL_READ_ECALL_NUM: u32 = 0x80474C52; + pub use alu::{compute_op, lookup_entry}; pub use bits::{interleave_bits, uninterleave_bits}; pub use cpu::RiscvCpu; diff --git a/crates/neo-memory/src/riscv/trace/air.rs b/crates/neo-memory/src/riscv/trace/air.rs index 373954fe..3db95148 100644 --- a/crates/neo-memory/src/riscv/trace/air.rs +++ b/crates/neo-memory/src/riscv/trace/air.rs @@ -91,6 +91,7 @@ impl Rv32TraceAir { ("shout_lhs", l.shout_lhs), ("shout_rhs", l.shout_rhs), ("jalr_drop_bit", l.jalr_drop_bit), + ("pc_carry", l.pc_carry), ] { let e = Self::gated_zero(inv_active, col(c, i)); if !Self::is_zero(e) { diff --git a/crates/neo-memory/src/riscv/trace/layout.rs b/crates/neo-memory/src/riscv/trace/layout.rs index c3f52ca4..bbaf4fa1 100644 --- a/crates/neo-memory/src/riscv/trace/layout.rs +++ b/crates/neo-memory/src/riscv/trace/layout.rs @@ -30,6 +30,7 @@ pub struct Rv32TraceLayout { pub shout_lhs: usize, pub shout_rhs: usize, pub jalr_drop_bit: usize, + pub pc_carry: usize, } impl Rv32TraceLayout { @@ -65,8 +66,9 @@ impl Rv32TraceLayout { let shout_lhs = take(); let shout_rhs = take(); let jalr_drop_bit = take(); + let pc_carry = take(); - debug_assert_eq!(next, 21, "RV32 trace width drift after decode-helper offload"); + debug_assert_eq!(next, 22, "RV32 trace width drift after decode-helper offload"); Self { cols: next, @@ -91,6 +93,7 @@ impl Rv32TraceLayout { shout_lhs, shout_rhs, jalr_drop_bit, + pc_carry, } } } diff --git a/crates/neo-memory/src/riscv/trace/mod.rs b/crates/neo-memory/src/riscv/trace/mod.rs index 45c00e55..e9c5c866 100644 --- a/crates/neo-memory/src/riscv/trace/mod.rs +++ b/crates/neo-memory/src/riscv/trace/mod.rs @@ -43,6 +43,20 @@ pub fn rv32_trace_lookup_addr_group_for_table_id(table_id: u32) -> Option { } } +/// Shape-aware address-group hint for shared-bus Shout lanes. +/// +/// This guards against accidental grouping when callers use low numeric `table_id`s for +/// non-RV32 opcode tables (common in generic tests/fixtures). RV32 opcode tables (id 0..=19) +/// are grouped only when their key shape matches the canonical interleaved width. +#[inline] +pub fn rv32_trace_lookup_addr_group_for_table_shape(table_id: u32, ell_addr: usize) -> Option { + let group = rv32_trace_lookup_addr_group_for_table_id(table_id)?; + if table_id <= 19 && ell_addr != 64 { + return None; + } + Some(group) +} + #[inline] pub fn rv32_trace_lookup_selector_group_for_table_id(table_id: u32) -> Option { if rv32_is_decode_lookup_table_id(table_id) { diff --git a/crates/neo-memory/src/riscv/trace/witness.rs b/crates/neo-memory/src/riscv/trace/witness.rs index 4d3ac949..dd9ec750 100644 --- a/crates/neo-memory/src/riscv/trace/witness.rs +++ b/crates/neo-memory/src/riscv/trace/witness.rs @@ -18,6 +18,26 @@ fn imm_i_from_word(instr_word: u32) -> u32 { sign_extend_to_u32((instr_word >> 20) & 0x0fff, 12) } +#[inline] +fn imm_j_from_word(instr_word: u32) -> u32 { + let imm20 = (instr_word >> 31) & 1; + let imm10_1 = (instr_word >> 21) & 0x3FF; + let imm11 = (instr_word >> 20) & 1; + let imm19_12 = (instr_word >> 12) & 0xFF; + let raw = (imm20 << 20) | (imm19_12 << 12) | (imm11 << 11) | (imm10_1 << 1); + sign_extend_to_u32(raw, 21) +} + +#[inline] +fn imm_b_from_word(instr_word: u32) -> u32 { + let imm12 = (instr_word >> 31) & 1; + let imm10_5 = (instr_word >> 25) & 0x3F; + let imm4_1 = (instr_word >> 8) & 0xF; + let imm11 = (instr_word >> 7) & 1; + let raw = (imm12 << 12) | (imm11 << 11) | (imm10_5 << 5) | (imm4_1 << 1); + sign_extend_to_u32(raw, 13) +} + #[derive(Clone, Debug)] pub struct Rv32TraceWitness { pub t: usize, @@ -63,11 +83,28 @@ impl Rv32TraceWitness { // this address is don't-care for bus semantics. wit.cols[layout.rd_addr][i] = F::from_u64(cols.rd[i] as u64); wit.cols[layout.rd_val][i] = F::from_u64(cols.rd_val[i]); - if cols.opcode[i] == 0x67 { + let opcode = cols.opcode[i]; + if opcode == 0x67 { + // JALR: pc = (rs1 + imm_i) & ~1 let rs1 = cols.rs1_val[i] as u32; let imm_i = imm_i_from_word(cols.instr_word[i]); let drop = rs1.wrapping_add(imm_i) & 1; wit.cols[layout.jalr_drop_bit][i] = F::from_u64(drop as u64); + let sum = (cols.rs1_val[i]) + (imm_i as u64); + wit.cols[layout.pc_carry][i] = F::from_u64(sum >> 32); + } else if opcode == 0x6F { + // JAL: pc = pc + imm_j + let imm_j = imm_j_from_word(cols.instr_word[i]); + let sum = (cols.pc_before[i]) + (imm_j as u64); + wit.cols[layout.pc_carry][i] = F::from_u64(sum >> 32); + } else if opcode == 0x63 { + // BRANCH: pc = taken ? pc + imm_b : pc + 4 + let taken = cols.pc_after[i] != cols.pc_before[i].wrapping_add(4); + if taken { + let imm_b = imm_b_from_word(cols.instr_word[i]); + let sum = (cols.pc_before[i]) + (imm_b as u64); + wit.cols[layout.pc_carry][i] = F::from_u64(sum >> 32); + } } } diff --git a/crates/neo-memory/tests/riscv_trace_wiring_ccs.rs b/crates/neo-memory/tests/riscv_trace_wiring_ccs.rs index 02b19274..c2e98079 100644 --- a/crates/neo-memory/tests/riscv_trace_wiring_ccs.rs +++ b/crates/neo-memory/tests/riscv_trace_wiring_ccs.rs @@ -15,12 +15,12 @@ fn rv32_trace_layout_removes_fixed_shout_table_selector_lanes() { let layout = Rv32TraceCcsLayout::new(/*t=*/ 4).expect("trace CCS layout"); assert_eq!( - layout.trace.cols, 21, - "trace width regression: expected 21 columns after shout_lhs/jalr_drop_bit hardening" + layout.trace.cols, 22, + "trace width regression: expected 22 columns after pc_carry addition" ); assert_eq!( layout.trace.cols, - layout.trace.jalr_drop_bit + 1, + layout.trace.pc_carry + 1, "trace layout should remain densely packed" ); } diff --git a/crates/neo-reductions/Cargo.toml b/crates/neo-reductions/Cargo.toml index 4f2636e7..2c7691a0 100644 --- a/crates/neo-reductions/Cargo.toml +++ b/crates/neo-reductions/Cargo.toml @@ -21,7 +21,7 @@ fs-guard = ["neo-transcript/fs-guard"] # Enable verbose logging for debugging (no-op in this crate; used for cfg gating) neo-logs = [] # Enable constraint diagnostic system -prove-diagnostics = ["serde", "serde_json", "flate2", "hex", "chrono"] +prove-diagnostics = ["serde_json", "flate2", "hex", "chrono"] # Enable CBOR format for diagnostics (smaller but less readable) prove-diagnostics-cbor = ["prove-diagnostics", "serde_cbor"] wasm-threads = ["neo-ajtai/wasm-threads", "neo-ccs/wasm-threads"] @@ -46,8 +46,8 @@ thiserror = { workspace = true } bincode = { workspace = true } blake3 = "1.5" -# Diagnostic system dependencies (feature-gated) -serde = { workspace = true, optional = true } +# Serde for proof serialization (always available; also used by diagnostics) +serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true, optional = true } flate2 = { version = "1.0", optional = true } hex = { version = "0.4", optional = true } diff --git a/crates/neo-reductions/src/engines/optimized_engine/common.rs b/crates/neo-reductions/src/engines/optimized_engine/common.rs index ff925bc7..94b52d9e 100644 --- a/crates/neo-reductions/src/engines/optimized_engine/common.rs +++ b/crates/neo-reductions/src/engines/optimized_engine/common.rs @@ -17,7 +17,7 @@ use p3_field::{Field, PrimeCharacteristicRing}; use rayon::prelude::*; /// Challenges sampled in Step 1 of the protocol -#[derive(Debug, Clone)] +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] pub struct Challenges { /// α ∈ K^{log d} - for Ajtai dimension pub alpha: Vec, diff --git a/crates/neo-reductions/src/engines/optimized_engine/mod.rs b/crates/neo-reductions/src/engines/optimized_engine/mod.rs index f4276945..92f75c70 100644 --- a/crates/neo-reductions/src/engines/optimized_engine/mod.rs +++ b/crates/neo-reductions/src/engines/optimized_engine/mod.rs @@ -20,7 +20,7 @@ pub mod verify; pub use common::Challenges; /// Proof format variant for Π_CCS. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)] pub enum PiCcsProofVariant { /// Split-NC proof with two sumchecks: FE-only + NC-only. SplitNcV1, @@ -56,7 +56,7 @@ pub use common::{ }; /// Proof structure for the Π_CCS protocol -#[derive(Debug, Clone)] +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] pub struct PiCcsProof { /// Proof format variant. pub variant: PiCcsProofVariant, diff --git a/crates/neo-vm-trace/src/lib.rs b/crates/neo-vm-trace/src/lib.rs index f75f963d..b1161d73 100644 --- a/crates/neo-vm-trace/src/lib.rs +++ b/crates/neo-vm-trace/src/lib.rs @@ -340,6 +340,27 @@ pub trait Twist { self.store_lane(twist_id, addr, value, lane); } } + + /// Load a value from memory without recording a Twist event. + /// + /// Used by ECALL precompiles that need to read guest RAM without generating + /// per-step Twist bus entries (the ECALL computation is trusted host code). + /// The default implementation delegates to `load`. + #[inline] + fn load_untraced(&mut self, twist_id: TwistId, addr: Addr) -> Word { + self.load(twist_id, addr) + } + + /// Store a value to memory without recording a Twist event. + /// + /// Used by ECALL precompiles that write results back to the register file + /// without generating per-step Twist bus entries. The CCS treats the ECALL + /// row as having no `rd` write; subsequent instructions see the updated + /// value through normal Twist reads. + #[inline] + fn store_untraced(&mut self, twist_id: TwistId, addr: Addr, value: Word) { + self.store(twist_id, addr, value); + } } /// A tracing wrapper around any `Twist` implementation. @@ -431,6 +452,14 @@ where lane: Some(lane), }); } + + fn load_untraced(&mut self, twist_id: TwistId, addr: Addr) -> Word { + self.inner.load(twist_id, addr) + } + + fn store_untraced(&mut self, twist_id: TwistId, addr: Addr, value: Word) { + self.inner.store(twist_id, addr, value); + } } // ============================================================================ diff --git a/crates/nightstream-sdk/src/goldilocks.rs b/crates/nightstream-sdk/src/goldilocks.rs new file mode 100644 index 00000000..61f6d2c7 --- /dev/null +++ b/crates/nightstream-sdk/src/goldilocks.rs @@ -0,0 +1,187 @@ +//! Goldilocks field arithmetic for Nightstream guests. +//! +//! Field: p = 2^64 - 2^32 + 1 = 0xFFFF_FFFF_0000_0001 +//! +//! On RISC-V targets, `gl_mul`, `gl_add`, and `gl_sub` use ECALL precompiles +//! (3 ECALLs each: 1 compute + 2 reads for the 64-bit result). On other targets +//! the software implementations are used. + +#![allow(dead_code)] + +/// Goldilocks prime: p = 2^64 - 2^32 + 1. +pub const GL_P: u64 = 0xFFFF_FFFF_0000_0001; + +/// A Goldilocks field element digest: 4 elements (32 bytes). +pub type GlDigest = [u64; 4]; + +// --------------------------------------------------------------------------- +// ECALL-based implementations (RISC-V only) +// --------------------------------------------------------------------------- + +#[cfg(any(target_arch = "riscv32", target_arch = "riscv64"))] +const GL_MUL_ECALL_NUM: u32 = 0x474C4D; +#[cfg(any(target_arch = "riscv32", target_arch = "riscv64"))] +const GL_ADD_ECALL_NUM: u32 = 0x474C41; +#[cfg(any(target_arch = "riscv32", target_arch = "riscv64"))] +const GL_SUB_ECALL_NUM: u32 = 0x474C53; +#[cfg(any(target_arch = "riscv32", target_arch = "riscv64"))] +const GL_READ_ECALL_NUM: u32 = 0x80474C52; + +#[cfg(any(target_arch = "riscv32", target_arch = "riscv64"))] +#[inline] +fn gl_ecall_compute(ecall_id: u32, a: u64, b: u64) { + let a_lo = a as u32; + let a_hi = (a >> 32) as u32; + let b_lo = b as u32; + let b_hi = (b >> 32) as u32; + unsafe { + core::arch::asm!( + "ecall", + in("a0") ecall_id, + in("a1") a_lo, + in("a2") a_hi, + in("a3") b_lo, + in("a4") b_hi, + options(nostack), + ); + } +} + +#[cfg(any(target_arch = "riscv32", target_arch = "riscv64"))] +#[inline] +fn gl_ecall_read() -> u32 { + let result: u32; + unsafe { + core::arch::asm!( + "ecall", + inout("a0") GL_READ_ECALL_NUM => result, + options(nostack), + ); + } + result +} + +#[cfg(any(target_arch = "riscv32", target_arch = "riscv64"))] +#[inline] +fn gl_ecall_op(ecall_id: u32, a: u64, b: u64) -> u64 { + gl_ecall_compute(ecall_id, a, b); + let lo = gl_ecall_read() as u64; + let hi = gl_ecall_read() as u64; + lo | (hi << 32) +} + +// --------------------------------------------------------------------------- +// Software fallbacks (non-RISC-V) +// --------------------------------------------------------------------------- + +#[cfg(not(any(target_arch = "riscv32", target_arch = "riscv64")))] +#[inline] +fn reduce128(x: u128) -> u64 { + let lo = x as u64; + let hi = (x >> 64) as u64; + let hi_times_correction = (hi as u128) * 0xFFFF_FFFFu128; + let sum = lo as u128 + hi_times_correction; + let lo2 = sum as u64; + let hi2 = (sum >> 64) as u64; + let (mut result, overflow) = lo2.overflowing_add(hi2.wrapping_mul(0xFFFF_FFFF)); + if overflow || result >= GL_P { + result = result.wrapping_sub(GL_P); + } + result +} + +/// Additive identity. +pub const GL_ZERO: u64 = 0; + +/// Multiplicative identity. +pub const GL_ONE: u64 = 1; + +/// Field addition: (a + b) mod p. +#[inline] +pub fn gl_add(a: u64, b: u64) -> u64 { + #[cfg(any(target_arch = "riscv32", target_arch = "riscv64"))] + { gl_ecall_op(GL_ADD_ECALL_NUM, a, b) } + + #[cfg(not(any(target_arch = "riscv32", target_arch = "riscv64")))] + { + let sum = a as u128 + b as u128; + let s = sum as u64; + let carry = (sum >> 64) as u64; + let (mut r, overflow) = s.overflowing_add(carry.wrapping_mul(0xFFFF_FFFF)); + if overflow || r >= GL_P { + r = r.wrapping_sub(GL_P); + } + r + } +} + +/// Field subtraction: (a - b) mod p. +#[inline] +pub fn gl_sub(a: u64, b: u64) -> u64 { + #[cfg(any(target_arch = "riscv32", target_arch = "riscv64"))] + { gl_ecall_op(GL_SUB_ECALL_NUM, a, b) } + + #[cfg(not(any(target_arch = "riscv32", target_arch = "riscv64")))] + { + if a >= b { + let diff = a - b; + if diff >= GL_P { diff - GL_P } else { diff } + } else { + GL_P - (b - a) + } + } +} + +/// Field negation: (-a) mod p. +#[inline] +pub fn gl_neg(a: u64) -> u64 { + if a == 0 { 0 } else { GL_P - a } +} + +/// Field multiplication: (a * b) mod p. +#[inline] +pub fn gl_mul(a: u64, b: u64) -> u64 { + #[cfg(any(target_arch = "riscv32", target_arch = "riscv64"))] + { gl_ecall_op(GL_MUL_ECALL_NUM, a, b) } + + #[cfg(not(any(target_arch = "riscv32", target_arch = "riscv64")))] + { + let prod = (a as u128) * (b as u128); + reduce128(prod) + } +} + +/// Field squaring: (a * a) mod p. +#[inline] +pub fn gl_sqr(a: u64) -> u64 { + gl_mul(a, a) +} + +/// Field exponentiation: a^exp mod p via square-and-multiply. +pub fn gl_pow(mut base: u64, mut exp: u64) -> u64 { + let mut result = GL_ONE; + while exp > 0 { + if exp & 1 == 1 { + result = gl_mul(result, base); + } + base = gl_sqr(base); + exp >>= 1; + } + result +} + +/// Field inverse: a^(p-2) mod p (Fermat's little theorem). +/// +/// Panics (halts) if a == 0. +pub fn gl_inv(a: u64) -> u64 { + debug_assert!(a != 0, "inverse of zero"); + gl_pow(a, GL_P - 2) +} + +/// Check equality of two digests. +pub fn digest_eq(a: &GlDigest, b: &GlDigest) -> bool { + a[0] == b[0] && a[1] == b[1] && a[2] == b[2] && a[3] == b[3] +} + +/// Zero digest. +pub const ZERO_DIGEST: GlDigest = [0; 4]; diff --git a/crates/nightstream-sdk/src/lib.rs b/crates/nightstream-sdk/src/lib.rs index f0a043c8..4cd5b989 100644 --- a/crates/nightstream-sdk/src/lib.rs +++ b/crates/nightstream-sdk/src/lib.rs @@ -3,6 +3,8 @@ pub use nightstream_sdk_macros::{entry, provable, NeoAbi}; pub mod abi; +pub mod goldilocks; +pub mod poseidon2; #[cfg(any(target_arch = "riscv32", target_arch = "riscv64"))] mod riscv; diff --git a/crates/nightstream-sdk/src/poseidon2.rs b/crates/nightstream-sdk/src/poseidon2.rs new file mode 100644 index 00000000..1cf9b90d --- /dev/null +++ b/crates/nightstream-sdk/src/poseidon2.rs @@ -0,0 +1,111 @@ +//! Poseidon2-Goldilocks hash via ECALL precompile. +//! +//! Provides the guest-side API for computing Poseidon2 hashes over Goldilocks +//! field elements. On RISC-V targets, this issues ECALLs to the host: +//! 1. A "compute" ECALL that reads inputs (via untraced loads) and computes the +//! Poseidon2 hash, storing the digest in host-side CPU state. +//! 2. Eight "read" ECALLs that each return one u32 word of the digest in +//! register a0. +//! +//! On non-RISC-V targets, a stub panics. + +use crate::goldilocks::GlDigest; + +/// Poseidon2 compute ECALL number (must match neo-memory's POSEIDON2_ECALL_NUM). +#[cfg(any(target_arch = "riscv32", target_arch = "riscv64"))] +const POSEIDON2_ECALL_NUM: u32 = 0x504F53; + +/// Poseidon2 read ECALL number (must match neo-memory's POSEIDON2_READ_ECALL_NUM). +#[cfg(any(target_arch = "riscv32", target_arch = "riscv64"))] +const POSEIDON2_READ_ECALL_NUM: u32 = 0x80504F53; + +/// Digest length in Goldilocks elements (matches neo-params DIGEST_LEN). +const DIGEST_LEN: usize = 4; + +/// Scratch buffer size: supports hashing up to 64 Goldilocks elements per call. +/// Each element occupies 2 u32 words (8 bytes). +const MAX_INPUT_ELEMENTS: usize = 64; + +/// Hash an arbitrary-length slice of Goldilocks field elements. +/// +/// Packs elements to a stack-allocated scratch buffer, issues the Poseidon2 +/// compute ECALL, then retrieves the 4-element digest via 8 read ECALLs +/// (each returning one u32 word in register a0). +/// +/// # Panics +/// +/// Panics if `input.len() > MAX_INPUT_ELEMENTS` (64). +pub fn poseidon2_hash(input: &[u64]) -> GlDigest { + assert!( + input.len() <= MAX_INPUT_ELEMENTS, + "poseidon2_hash: too many input elements" + ); + + // Scratch buffer for input (as u32 words). + let mut input_buf: [u32; MAX_INPUT_ELEMENTS * 2] = [0; MAX_INPUT_ELEMENTS * 2]; + + // Pack Goldilocks elements as pairs of u32 words (little-endian). + for (i, &elem) in input.iter().enumerate() { + input_buf[i * 2] = elem as u32; + input_buf[i * 2 + 1] = (elem >> 32) as u32; + } + + // Compute ECALL: host reads inputs via untraced loads and stores digest in CPU state. + poseidon2_ecall_compute(input.len() as u32, input_buf.as_ptr() as u32); + + // Read 8 u32 words of the digest via register a0. + let d: [u32; 8] = core::array::from_fn(|_| poseidon2_ecall_read()); + + // Unpack into 4 Goldilocks elements. + let mut digest = [0u64; DIGEST_LEN]; + for i in 0..DIGEST_LEN { + digest[i] = (d[i * 2] as u64) | ((d[i * 2 + 1] as u64) << 32); + } + digest +} + +/// Issue the Poseidon2 compute ECALL. +/// +/// Registers: a0 = ECALL ID, a1 = element count, a2 = input addr. +/// No output via registers; digest is stored in host CPU state. +#[cfg(any(target_arch = "riscv32", target_arch = "riscv64"))] +fn poseidon2_ecall_compute(n_elements: u32, input_addr: u32) { + unsafe { + core::arch::asm!( + "ecall", + in("a0") POSEIDON2_ECALL_NUM, + in("a1") n_elements, + in("a2") input_addr, + options(nostack), + ); + } +} + +/// Issue the Poseidon2 read ECALL and return the next digest word. +/// +/// The ECALL number is passed in a0, and the host returns the next u32 word +/// of the pending digest in a0. +#[cfg(any(target_arch = "riscv32", target_arch = "riscv64"))] +fn poseidon2_ecall_read() -> u32 { + let result: u32; + unsafe { + core::arch::asm!( + "ecall", + inout("a0") POSEIDON2_READ_ECALL_NUM => result, + options(nostack), + ); + } + result +} + +/// Stub for non-RISC-V targets (used in native tests via the reference impl). +#[cfg(not(any(target_arch = "riscv32", target_arch = "riscv64")))] +fn poseidon2_ecall_compute(_n_elements: u32, _input_addr: u32) { + unimplemented!("poseidon2_ecall_compute is only available on RISC-V targets") +} + +/// Stub for non-RISC-V targets. +#[cfg(not(any(target_arch = "riscv32", target_arch = "riscv64")))] +fn poseidon2_ecall_read() -> u32 { + unimplemented!("poseidon2_ecall_read is only available on RISC-V targets") +}